introduced http header to connection

This commit is contained in:
Jay
2026-04-30 14:09:41 -04:00
parent 6332e4438e
commit ab641e8313
6 changed files with 64 additions and 19 deletions
+12
View File
@@ -2,6 +2,7 @@ package transport
import (
"log/slog"
"net/http"
"time"
)
@@ -9,6 +10,7 @@ type CloseHandler func(code int, text string) error
type ConnectionConfig struct {
CloseHandler CloseHandler
RequestHeader http.Header
WriteTimeout time.Duration
PingInterval time.Duration
IncomingBufferSize int
@@ -39,8 +41,11 @@ func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error)
}
func GetDefaultConnectionConfig() *ConnectionConfig {
header := http.Header{}
header.Set("User-Agent", "honeybee/0.1.0")
return &ConnectionConfig{
CloseHandler: nil,
RequestHeader: header,
WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second,
IncomingBufferSize: 100,
@@ -160,6 +165,13 @@ func WithCloseHandler(handler CloseHandler) ConnectionOption {
}
}
func WithRequestHeader(header http.Header) ConnectionOption {
return func(c *ConnectionConfig) error {
c.RequestHeader = header.Clone()
return nil
}
}
// When WriteTimeout is set to zero, read timeouts are disabled.
func WithWriteTimeout(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
+5 -12
View File
@@ -3,6 +3,7 @@ package transport
import (
"github.com/stretchr/testify/assert"
"log/slog"
"net/http"
"testing"
"time"
)
@@ -10,20 +11,9 @@ import (
// Connection Config Tests
func TestNewConnectionConfig(t *testing.T) {
conf, err := NewConnectionConfig()
_, err := NewConnectionConfig()
assert.NoError(t, err)
assert.Equal(t, conf, &ConnectionConfig{
CloseHandler: nil,
WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second,
IncomingBufferSize: 100,
ErrorsBufferSize: 10,
LoggingEnabled: true,
LogLevel: nil,
Retry: GetDefaultRetryConfig(),
})
// errors propagate
_, err = NewConnectionConfig(WithRetryMaxRetries(-1))
assert.Error(t, err)
@@ -35,10 +25,13 @@ func TestNewConnectionConfig(t *testing.T) {
// Default Tests
func TestDefaultConnectionConfig(t *testing.T) {
header := http.Header{}
header.Set("User-Agent", "honeybee/0.1.0")
conf := GetDefaultConnectionConfig()
assert.Equal(t, conf, &ConnectionConfig{
CloseHandler: nil,
RequestHeader: header,
WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second,
IncomingBufferSize: 100,
+1 -1
View File
@@ -169,7 +169,7 @@ func (c *Connection) Connect(ctx context.Context) error {
retryMgr := NewRetryManager(c.config.Retry)
socket, _, err := AcquireSocket(
ctx, retryMgr, c.dialer, c.url.String(), c.logger)
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
if err != nil {
c.state = StateDisconnected
+20
View File
@@ -404,6 +404,26 @@ func TestConnect(t *testing.T) {
conn.Close()
})
t.Run("passes headers when configured", func(t *testing.T) {
header := http.Header{"X-Custom": []string{"val"}}
conf, _ := NewConnectionConfig(WithRequestHeader(header))
conn, _ := NewConnection("ws://test", conf, nil)
dialCalled := false
conn.dialer = &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
assert.Equal(t, "val", h.Get("X-Custom"))
dialCalled = true
return honeybeetest.NewMockSocket(), nil, nil
},
}
err := conn.Connect(context.Background())
assert.NoError(t, err)
assert.True(t, dialCalled)
})
}
func TestConnectContextCancellation(t *testing.T) {
+2 -1
View File
@@ -45,6 +45,7 @@ func AcquireSocket(
retryMgr *RetryManager,
dialer types.Dialer,
url string,
header http.Header,
logger *slog.Logger,
) (types.Socket, *http.Response, error) {
select {
@@ -68,7 +69,7 @@ func AcquireSocket(
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
}
socket, resp, err := dialer.DialContext(ctx, url, nil)
socket, resp, err := dialer.DialContext(ctx, url, header)
if err == nil {
if logger != nil {
logger.Debug("dial successful", "attempt", retryMgr.RetryCount()+1)
+24 -5
View File
@@ -85,7 +85,7 @@ func TestAcquireSocket(t *testing.T) {
})
socket, _, err := AcquireSocket(
context.Background(), retryMgr, mockDialer, "ws://test", nil)
context.Background(), retryMgr, mockDialer, "ws://test", nil, nil)
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
if tc.wantErr {
@@ -141,7 +141,7 @@ func TestAcquireSocketGuards(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
socket, resp, err := AcquireSocket(
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil)
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil, nil)
assert.Error(t, err)
assert.ErrorContains(t, err, tc.wantErr)
@@ -168,7 +168,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
cancel()
retryMgr := NewRetryManager(GetDefaultRetryConfig())
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
assert.ErrorIs(t, err, context.Canceled)
assert.False(t, dialCalled.Load())
@@ -195,7 +195,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
done := make(chan error, 1)
go func() {
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
done <- err
}()
@@ -233,7 +233,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
retryMgr := NewRetryManager(GetDefaultRetryConfig())
done := make(chan error, 1)
go func() {
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
done <- err
}()
@@ -250,3 +250,22 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
})
}
func TestAcquireSocketPassesHeaders(t *testing.T) {
header := http.Header{"User-Agent": []string{"test-agent"}}
dialCalled := false
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
assert.Equal(t, "test-agent", h.Get("User-Agent"))
dialCalled = true
return honeybeetest.NewMockSocket(), nil, nil
},
}
retryMgr := NewRetryManager(&RetryConfig{MaxRetries: 0})
_, _, err := AcquireSocket(context.Background(), retryMgr, mockDialer, "ws://test", header, nil)
assert.NoError(t, err)
assert.True(t, dialCalled)
}