introduced http header to connection
This commit is contained in:
@@ -2,6 +2,7 @@ package transport
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -9,6 +10,7 @@ type CloseHandler func(code int, text string) error
|
||||
|
||||
type ConnectionConfig struct {
|
||||
CloseHandler CloseHandler
|
||||
RequestHeader http.Header
|
||||
WriteTimeout time.Duration
|
||||
PingInterval time.Duration
|
||||
IncomingBufferSize int
|
||||
@@ -39,8 +41,11 @@ func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error)
|
||||
}
|
||||
|
||||
func GetDefaultConnectionConfig() *ConnectionConfig {
|
||||
header := http.Header{}
|
||||
header.Set("User-Agent", "honeybee/0.1.0")
|
||||
return &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
RequestHeader: header,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
PingInterval: 20 * time.Second,
|
||||
IncomingBufferSize: 100,
|
||||
@@ -160,6 +165,13 @@ func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithRequestHeader(header http.Header) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.RequestHeader = header.Clone()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// When WriteTimeout is set to zero, read timeouts are disabled.
|
||||
func WithWriteTimeout(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
|
||||
@@ -3,6 +3,7 @@ package transport
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -10,20 +11,9 @@ import (
|
||||
// Connection Config Tests
|
||||
|
||||
func TestNewConnectionConfig(t *testing.T) {
|
||||
conf, err := NewConnectionConfig()
|
||||
_, err := NewConnectionConfig()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, conf, &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
PingInterval: 20 * time.Second,
|
||||
IncomingBufferSize: 100,
|
||||
ErrorsBufferSize: 10,
|
||||
LoggingEnabled: true,
|
||||
LogLevel: nil,
|
||||
Retry: GetDefaultRetryConfig(),
|
||||
})
|
||||
|
||||
// errors propagate
|
||||
_, err = NewConnectionConfig(WithRetryMaxRetries(-1))
|
||||
assert.Error(t, err)
|
||||
@@ -35,10 +25,13 @@ func TestNewConnectionConfig(t *testing.T) {
|
||||
// Default Tests
|
||||
|
||||
func TestDefaultConnectionConfig(t *testing.T) {
|
||||
header := http.Header{}
|
||||
header.Set("User-Agent", "honeybee/0.1.0")
|
||||
conf := GetDefaultConnectionConfig()
|
||||
|
||||
assert.Equal(t, conf, &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
RequestHeader: header,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
PingInterval: 20 * time.Second,
|
||||
IncomingBufferSize: 100,
|
||||
|
||||
@@ -169,7 +169,7 @@ func (c *Connection) Connect(ctx context.Context) error {
|
||||
|
||||
retryMgr := NewRetryManager(c.config.Retry)
|
||||
socket, _, err := AcquireSocket(
|
||||
ctx, retryMgr, c.dialer, c.url.String(), c.logger)
|
||||
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
|
||||
|
||||
if err != nil {
|
||||
c.state = StateDisconnected
|
||||
|
||||
@@ -404,6 +404,26 @@ func TestConnect(t *testing.T) {
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
|
||||
t.Run("passes headers when configured", func(t *testing.T) {
|
||||
header := http.Header{"X-Custom": []string{"val"}}
|
||||
conf, _ := NewConnectionConfig(WithRequestHeader(header))
|
||||
conn, _ := NewConnection("ws://test", conf, nil)
|
||||
|
||||
dialCalled := false
|
||||
conn.dialer = &honeybeetest.MockDialer{
|
||||
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
|
||||
assert.Equal(t, "val", h.Get("X-Custom"))
|
||||
dialCalled = true
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
err := conn.Connect(context.Background())
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, dialCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectContextCancellation(t *testing.T) {
|
||||
|
||||
+2
-1
@@ -45,6 +45,7 @@ func AcquireSocket(
|
||||
retryMgr *RetryManager,
|
||||
dialer types.Dialer,
|
||||
url string,
|
||||
header http.Header,
|
||||
logger *slog.Logger,
|
||||
) (types.Socket, *http.Response, error) {
|
||||
select {
|
||||
@@ -68,7 +69,7 @@ func AcquireSocket(
|
||||
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
|
||||
}
|
||||
|
||||
socket, resp, err := dialer.DialContext(ctx, url, nil)
|
||||
socket, resp, err := dialer.DialContext(ctx, url, header)
|
||||
if err == nil {
|
||||
if logger != nil {
|
||||
logger.Debug("dial successful", "attempt", retryMgr.RetryCount()+1)
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestAcquireSocket(t *testing.T) {
|
||||
})
|
||||
|
||||
socket, _, err := AcquireSocket(
|
||||
context.Background(), retryMgr, mockDialer, "ws://test", nil)
|
||||
context.Background(), retryMgr, mockDialer, "ws://test", nil, nil)
|
||||
|
||||
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
||||
if tc.wantErr {
|
||||
@@ -141,7 +141,7 @@ func TestAcquireSocketGuards(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
socket, resp, err := AcquireSocket(
|
||||
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil)
|
||||
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil, nil)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, tc.wantErr)
|
||||
@@ -168,7 +168,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
||||
cancel()
|
||||
|
||||
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
|
||||
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.False(t, dialCalled.Load())
|
||||
@@ -195,7 +195,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
@@ -233,7 +233,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
||||
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
@@ -250,3 +250,22 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestAcquireSocketPassesHeaders(t *testing.T) {
|
||||
header := http.Header{"User-Agent": []string{"test-agent"}}
|
||||
dialCalled := false
|
||||
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
|
||||
assert.Equal(t, "test-agent", h.Get("User-Agent"))
|
||||
dialCalled = true
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
retryMgr := NewRetryManager(&RetryConfig{MaxRetries: 0})
|
||||
_, _, err := AcquireSocket(context.Background(), retryMgr, mockDialer, "ws://test", header, nil)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, dialCalled)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user