introduced http header to connection

This commit is contained in:
Jay
2026-04-30 14:09:41 -04:00
parent 6332e4438e
commit ab641e8313
6 changed files with 64 additions and 19 deletions
+12
View File
@@ -2,6 +2,7 @@ package transport
import ( import (
"log/slog" "log/slog"
"net/http"
"time" "time"
) )
@@ -9,6 +10,7 @@ type CloseHandler func(code int, text string) error
type ConnectionConfig struct { type ConnectionConfig struct {
CloseHandler CloseHandler CloseHandler CloseHandler
RequestHeader http.Header
WriteTimeout time.Duration WriteTimeout time.Duration
PingInterval time.Duration PingInterval time.Duration
IncomingBufferSize int IncomingBufferSize int
@@ -39,8 +41,11 @@ func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error)
} }
func GetDefaultConnectionConfig() *ConnectionConfig { func GetDefaultConnectionConfig() *ConnectionConfig {
header := http.Header{}
header.Set("User-Agent", "honeybee/0.1.0")
return &ConnectionConfig{ return &ConnectionConfig{
CloseHandler: nil, CloseHandler: nil,
RequestHeader: header,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second, PingInterval: 20 * time.Second,
IncomingBufferSize: 100, IncomingBufferSize: 100,
@@ -160,6 +165,13 @@ func WithCloseHandler(handler CloseHandler) ConnectionOption {
} }
} }
func WithRequestHeader(header http.Header) ConnectionOption {
return func(c *ConnectionConfig) error {
c.RequestHeader = header.Clone()
return nil
}
}
// When WriteTimeout is set to zero, read timeouts are disabled. // When WriteTimeout is set to zero, read timeouts are disabled.
func WithWriteTimeout(value time.Duration) ConnectionOption { func WithWriteTimeout(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error { return func(c *ConnectionConfig) error {
+5 -12
View File
@@ -3,6 +3,7 @@ package transport
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"log/slog" "log/slog"
"net/http"
"testing" "testing"
"time" "time"
) )
@@ -10,20 +11,9 @@ import (
// Connection Config Tests // Connection Config Tests
func TestNewConnectionConfig(t *testing.T) { func TestNewConnectionConfig(t *testing.T) {
conf, err := NewConnectionConfig() _, err := NewConnectionConfig()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, conf, &ConnectionConfig{
CloseHandler: nil,
WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second,
IncomingBufferSize: 100,
ErrorsBufferSize: 10,
LoggingEnabled: true,
LogLevel: nil,
Retry: GetDefaultRetryConfig(),
})
// errors propagate // errors propagate
_, err = NewConnectionConfig(WithRetryMaxRetries(-1)) _, err = NewConnectionConfig(WithRetryMaxRetries(-1))
assert.Error(t, err) assert.Error(t, err)
@@ -35,10 +25,13 @@ func TestNewConnectionConfig(t *testing.T) {
// Default Tests // Default Tests
func TestDefaultConnectionConfig(t *testing.T) { func TestDefaultConnectionConfig(t *testing.T) {
header := http.Header{}
header.Set("User-Agent", "honeybee/0.1.0")
conf := GetDefaultConnectionConfig() conf := GetDefaultConnectionConfig()
assert.Equal(t, conf, &ConnectionConfig{ assert.Equal(t, conf, &ConnectionConfig{
CloseHandler: nil, CloseHandler: nil,
RequestHeader: header,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second, PingInterval: 20 * time.Second,
IncomingBufferSize: 100, IncomingBufferSize: 100,
+1 -1
View File
@@ -169,7 +169,7 @@ func (c *Connection) Connect(ctx context.Context) error {
retryMgr := NewRetryManager(c.config.Retry) retryMgr := NewRetryManager(c.config.Retry)
socket, _, err := AcquireSocket( socket, _, err := AcquireSocket(
ctx, retryMgr, c.dialer, c.url.String(), c.logger) ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
if err != nil { if err != nil {
c.state = StateDisconnected c.state = StateDisconnected
+20
View File
@@ -404,6 +404,26 @@ func TestConnect(t *testing.T) {
conn.Close() conn.Close()
}) })
t.Run("passes headers when configured", func(t *testing.T) {
header := http.Header{"X-Custom": []string{"val"}}
conf, _ := NewConnectionConfig(WithRequestHeader(header))
conn, _ := NewConnection("ws://test", conf, nil)
dialCalled := false
conn.dialer = &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
assert.Equal(t, "val", h.Get("X-Custom"))
dialCalled = true
return honeybeetest.NewMockSocket(), nil, nil
},
}
err := conn.Connect(context.Background())
assert.NoError(t, err)
assert.True(t, dialCalled)
})
} }
func TestConnectContextCancellation(t *testing.T) { func TestConnectContextCancellation(t *testing.T) {
+2 -1
View File
@@ -45,6 +45,7 @@ func AcquireSocket(
retryMgr *RetryManager, retryMgr *RetryManager,
dialer types.Dialer, dialer types.Dialer,
url string, url string,
header http.Header,
logger *slog.Logger, logger *slog.Logger,
) (types.Socket, *http.Response, error) { ) (types.Socket, *http.Response, error) {
select { select {
@@ -68,7 +69,7 @@ func AcquireSocket(
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1) logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
} }
socket, resp, err := dialer.DialContext(ctx, url, nil) socket, resp, err := dialer.DialContext(ctx, url, header)
if err == nil { if err == nil {
if logger != nil { if logger != nil {
logger.Debug("dial successful", "attempt", retryMgr.RetryCount()+1) logger.Debug("dial successful", "attempt", retryMgr.RetryCount()+1)
+24 -5
View File
@@ -85,7 +85,7 @@ func TestAcquireSocket(t *testing.T) {
}) })
socket, _, err := AcquireSocket( socket, _, err := AcquireSocket(
context.Background(), retryMgr, mockDialer, "ws://test", nil) context.Background(), retryMgr, mockDialer, "ws://test", nil, nil)
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount()) assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
if tc.wantErr { if tc.wantErr {
@@ -141,7 +141,7 @@ func TestAcquireSocketGuards(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
socket, resp, err := AcquireSocket( socket, resp, err := AcquireSocket(
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil) context.Background(), tc.retryMgr, tc.dialer, tc.url, nil, nil)
assert.Error(t, err) assert.Error(t, err)
assert.ErrorContains(t, err, tc.wantErr) assert.ErrorContains(t, err, tc.wantErr)
@@ -168,7 +168,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
cancel() cancel()
retryMgr := NewRetryManager(GetDefaultRetryConfig()) retryMgr := NewRetryManager(GetDefaultRetryConfig())
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
assert.ErrorIs(t, err, context.Canceled) assert.ErrorIs(t, err, context.Canceled)
assert.False(t, dialCalled.Load()) assert.False(t, dialCalled.Load())
@@ -195,7 +195,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
done <- err done <- err
}() }()
@@ -233,7 +233,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
retryMgr := NewRetryManager(GetDefaultRetryConfig()) retryMgr := NewRetryManager(GetDefaultRetryConfig())
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
done <- err done <- err
}() }()
@@ -250,3 +250,22 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
}) })
} }
func TestAcquireSocketPassesHeaders(t *testing.T) {
header := http.Header{"User-Agent": []string{"test-agent"}}
dialCalled := false
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
assert.Equal(t, "test-agent", h.Get("User-Agent"))
dialCalled = true
return honeybeetest.NewMockSocket(), nil, nil
},
}
retryMgr := NewRetryManager(&RetryConfig{MaxRetries: 0})
_, _, err := AcquireSocket(context.Background(), retryMgr, mockDialer, "ws://test", header, nil)
assert.NoError(t, err)
assert.True(t, dialCalled)
}