diff --git a/transport/config.go b/transport/config.go index 95cfc93..3720543 100644 --- a/transport/config.go +++ b/transport/config.go @@ -2,6 +2,7 @@ package transport import ( "log/slog" + "net/http" "time" ) @@ -9,6 +10,7 @@ type CloseHandler func(code int, text string) error type ConnectionConfig struct { CloseHandler CloseHandler + RequestHeader http.Header WriteTimeout time.Duration PingInterval time.Duration IncomingBufferSize int @@ -39,8 +41,11 @@ func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) } func GetDefaultConnectionConfig() *ConnectionConfig { + header := http.Header{} + header.Set("User-Agent", "honeybee/0.1.0") return &ConnectionConfig{ CloseHandler: nil, + RequestHeader: header, WriteTimeout: 30 * time.Second, PingInterval: 20 * time.Second, IncomingBufferSize: 100, @@ -160,6 +165,13 @@ func WithCloseHandler(handler CloseHandler) ConnectionOption { } } +func WithRequestHeader(header http.Header) ConnectionOption { + return func(c *ConnectionConfig) error { + c.RequestHeader = header.Clone() + return nil + } +} + // When WriteTimeout is set to zero, read timeouts are disabled. func WithWriteTimeout(value time.Duration) ConnectionOption { return func(c *ConnectionConfig) error { diff --git a/transport/config_test.go b/transport/config_test.go index e9455e0..0b79fde 100644 --- a/transport/config_test.go +++ b/transport/config_test.go @@ -3,6 +3,7 @@ package transport import ( "github.com/stretchr/testify/assert" "log/slog" + "net/http" "testing" "time" ) @@ -10,20 +11,9 @@ import ( // Connection Config Tests func TestNewConnectionConfig(t *testing.T) { - conf, err := NewConnectionConfig() + _, err := NewConnectionConfig() assert.NoError(t, err) - assert.Equal(t, conf, &ConnectionConfig{ - CloseHandler: nil, - WriteTimeout: 30 * time.Second, - PingInterval: 20 * time.Second, - IncomingBufferSize: 100, - ErrorsBufferSize: 10, - LoggingEnabled: true, - LogLevel: nil, - Retry: GetDefaultRetryConfig(), - }) - // errors propagate _, err = NewConnectionConfig(WithRetryMaxRetries(-1)) assert.Error(t, err) @@ -35,10 +25,13 @@ func TestNewConnectionConfig(t *testing.T) { // Default Tests func TestDefaultConnectionConfig(t *testing.T) { + header := http.Header{} + header.Set("User-Agent", "honeybee/0.1.0") conf := GetDefaultConnectionConfig() assert.Equal(t, conf, &ConnectionConfig{ CloseHandler: nil, + RequestHeader: header, WriteTimeout: 30 * time.Second, PingInterval: 20 * time.Second, IncomingBufferSize: 100, diff --git a/transport/connection.go b/transport/connection.go index 1308675..9999758 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -169,7 +169,7 @@ func (c *Connection) Connect(ctx context.Context) error { retryMgr := NewRetryManager(c.config.Retry) socket, _, err := AcquireSocket( - ctx, retryMgr, c.dialer, c.url.String(), c.logger) + ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger) if err != nil { c.state = StateDisconnected diff --git a/transport/connection_test.go b/transport/connection_test.go index 399bdac..5b70110 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -404,6 +404,26 @@ func TestConnect(t *testing.T) { conn.Close() }) + + t.Run("passes headers when configured", func(t *testing.T) { + header := http.Header{"X-Custom": []string{"val"}} + conf, _ := NewConnectionConfig(WithRequestHeader(header)) + conn, _ := NewConnection("ws://test", conf, nil) + + dialCalled := false + conn.dialer = &honeybeetest.MockDialer{ + DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) { + assert.Equal(t, "val", h.Get("X-Custom")) + dialCalled = true + return honeybeetest.NewMockSocket(), nil, nil + }, + } + + err := conn.Connect(context.Background()) + + assert.NoError(t, err) + assert.True(t, dialCalled) + }) } func TestConnectContextCancellation(t *testing.T) { diff --git a/transport/socket.go b/transport/socket.go index 8df0868..ffa3a13 100644 --- a/transport/socket.go +++ b/transport/socket.go @@ -45,6 +45,7 @@ func AcquireSocket( retryMgr *RetryManager, dialer types.Dialer, url string, + header http.Header, logger *slog.Logger, ) (types.Socket, *http.Response, error) { select { @@ -68,7 +69,7 @@ func AcquireSocket( logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1) } - socket, resp, err := dialer.DialContext(ctx, url, nil) + socket, resp, err := dialer.DialContext(ctx, url, header) if err == nil { if logger != nil { logger.Debug("dial successful", "attempt", retryMgr.RetryCount()+1) diff --git a/transport/socket_test.go b/transport/socket_test.go index d2130a6..5aa4558 100644 --- a/transport/socket_test.go +++ b/transport/socket_test.go @@ -85,7 +85,7 @@ func TestAcquireSocket(t *testing.T) { }) socket, _, err := AcquireSocket( - context.Background(), retryMgr, mockDialer, "ws://test", nil) + context.Background(), retryMgr, mockDialer, "ws://test", nil, nil) assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount()) if tc.wantErr { @@ -141,7 +141,7 @@ func TestAcquireSocketGuards(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { socket, resp, err := AcquireSocket( - context.Background(), tc.retryMgr, tc.dialer, tc.url, nil) + context.Background(), tc.retryMgr, tc.dialer, tc.url, nil, nil) assert.Error(t, err) assert.ErrorContains(t, err, tc.wantErr) @@ -168,7 +168,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) { cancel() retryMgr := NewRetryManager(GetDefaultRetryConfig()) - _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) + _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil) assert.ErrorIs(t, err, context.Canceled) assert.False(t, dialCalled.Load()) @@ -195,7 +195,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) { done := make(chan error, 1) go func() { - _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) + _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil) done <- err }() @@ -233,7 +233,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) { retryMgr := NewRetryManager(GetDefaultRetryConfig()) done := make(chan error, 1) go func() { - _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) + _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil) done <- err }() @@ -250,3 +250,22 @@ func TestAcquireSocketContextCancellation(t *testing.T) { }) } + +func TestAcquireSocketPassesHeaders(t *testing.T) { + header := http.Header{"User-Agent": []string{"test-agent"}} + dialCalled := false + + mockDialer := &honeybeetest.MockDialer{ + DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) { + assert.Equal(t, "test-agent", h.Get("User-Agent")) + dialCalled = true + return honeybeetest.NewMockSocket(), nil, nil + }, + } + + retryMgr := NewRetryManager(&RetryConfig{MaxRetries: 0}) + _, _, err := AcquireSocket(context.Background(), retryMgr, mockDialer, "ws://test", header, nil) + + assert.NoError(t, err) + assert.True(t, dialCalled) +}