introduced http header to connection
This commit is contained in:
@@ -2,6 +2,7 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -9,6 +10,7 @@ type CloseHandler func(code int, text string) error
|
|||||||
|
|
||||||
type ConnectionConfig struct {
|
type ConnectionConfig struct {
|
||||||
CloseHandler CloseHandler
|
CloseHandler CloseHandler
|
||||||
|
RequestHeader http.Header
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
PingInterval time.Duration
|
PingInterval time.Duration
|
||||||
IncomingBufferSize int
|
IncomingBufferSize int
|
||||||
@@ -39,8 +41,11 @@ func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetDefaultConnectionConfig() *ConnectionConfig {
|
func GetDefaultConnectionConfig() *ConnectionConfig {
|
||||||
|
header := http.Header{}
|
||||||
|
header.Set("User-Agent", "honeybee/0.1.0")
|
||||||
return &ConnectionConfig{
|
return &ConnectionConfig{
|
||||||
CloseHandler: nil,
|
CloseHandler: nil,
|
||||||
|
RequestHeader: header,
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
PingInterval: 20 * time.Second,
|
PingInterval: 20 * time.Second,
|
||||||
IncomingBufferSize: 100,
|
IncomingBufferSize: 100,
|
||||||
@@ -160,6 +165,13 @@ func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithRequestHeader(header http.Header) ConnectionOption {
|
||||||
|
return func(c *ConnectionConfig) error {
|
||||||
|
c.RequestHeader = header.Clone()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// When WriteTimeout is set to zero, read timeouts are disabled.
|
// When WriteTimeout is set to zero, read timeouts are disabled.
|
||||||
func WithWriteTimeout(value time.Duration) ConnectionOption {
|
func WithWriteTimeout(value time.Duration) ConnectionOption {
|
||||||
return func(c *ConnectionConfig) error {
|
return func(c *ConnectionConfig) error {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package transport
|
|||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -10,20 +11,9 @@ import (
|
|||||||
// Connection Config Tests
|
// Connection Config Tests
|
||||||
|
|
||||||
func TestNewConnectionConfig(t *testing.T) {
|
func TestNewConnectionConfig(t *testing.T) {
|
||||||
conf, err := NewConnectionConfig()
|
_, err := NewConnectionConfig()
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, conf, &ConnectionConfig{
|
|
||||||
CloseHandler: nil,
|
|
||||||
WriteTimeout: 30 * time.Second,
|
|
||||||
PingInterval: 20 * time.Second,
|
|
||||||
IncomingBufferSize: 100,
|
|
||||||
ErrorsBufferSize: 10,
|
|
||||||
LoggingEnabled: true,
|
|
||||||
LogLevel: nil,
|
|
||||||
Retry: GetDefaultRetryConfig(),
|
|
||||||
})
|
|
||||||
|
|
||||||
// errors propagate
|
// errors propagate
|
||||||
_, err = NewConnectionConfig(WithRetryMaxRetries(-1))
|
_, err = NewConnectionConfig(WithRetryMaxRetries(-1))
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -35,10 +25,13 @@ func TestNewConnectionConfig(t *testing.T) {
|
|||||||
// Default Tests
|
// Default Tests
|
||||||
|
|
||||||
func TestDefaultConnectionConfig(t *testing.T) {
|
func TestDefaultConnectionConfig(t *testing.T) {
|
||||||
|
header := http.Header{}
|
||||||
|
header.Set("User-Agent", "honeybee/0.1.0")
|
||||||
conf := GetDefaultConnectionConfig()
|
conf := GetDefaultConnectionConfig()
|
||||||
|
|
||||||
assert.Equal(t, conf, &ConnectionConfig{
|
assert.Equal(t, conf, &ConnectionConfig{
|
||||||
CloseHandler: nil,
|
CloseHandler: nil,
|
||||||
|
RequestHeader: header,
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
PingInterval: 20 * time.Second,
|
PingInterval: 20 * time.Second,
|
||||||
IncomingBufferSize: 100,
|
IncomingBufferSize: 100,
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ func (c *Connection) Connect(ctx context.Context) error {
|
|||||||
|
|
||||||
retryMgr := NewRetryManager(c.config.Retry)
|
retryMgr := NewRetryManager(c.config.Retry)
|
||||||
socket, _, err := AcquireSocket(
|
socket, _, err := AcquireSocket(
|
||||||
ctx, retryMgr, c.dialer, c.url.String(), c.logger)
|
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.state = StateDisconnected
|
c.state = StateDisconnected
|
||||||
|
|||||||
@@ -404,6 +404,26 @@ func TestConnect(t *testing.T) {
|
|||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("passes headers when configured", func(t *testing.T) {
|
||||||
|
header := http.Header{"X-Custom": []string{"val"}}
|
||||||
|
conf, _ := NewConnectionConfig(WithRequestHeader(header))
|
||||||
|
conn, _ := NewConnection("ws://test", conf, nil)
|
||||||
|
|
||||||
|
dialCalled := false
|
||||||
|
conn.dialer = &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
assert.Equal(t, "val", h.Get("X-Custom"))
|
||||||
|
dialCalled = true
|
||||||
|
return honeybeetest.NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := conn.Connect(context.Background())
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, dialCalled)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectContextCancellation(t *testing.T) {
|
func TestConnectContextCancellation(t *testing.T) {
|
||||||
|
|||||||
+2
-1
@@ -45,6 +45,7 @@ func AcquireSocket(
|
|||||||
retryMgr *RetryManager,
|
retryMgr *RetryManager,
|
||||||
dialer types.Dialer,
|
dialer types.Dialer,
|
||||||
url string,
|
url string,
|
||||||
|
header http.Header,
|
||||||
logger *slog.Logger,
|
logger *slog.Logger,
|
||||||
) (types.Socket, *http.Response, error) {
|
) (types.Socket, *http.Response, error) {
|
||||||
select {
|
select {
|
||||||
@@ -68,7 +69,7 @@ func AcquireSocket(
|
|||||||
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
|
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
socket, resp, err := dialer.DialContext(ctx, url, nil)
|
socket, resp, err := dialer.DialContext(ctx, url, header)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
logger.Debug("dial successful", "attempt", retryMgr.RetryCount()+1)
|
logger.Debug("dial successful", "attempt", retryMgr.RetryCount()+1)
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func TestAcquireSocket(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
socket, _, err := AcquireSocket(
|
socket, _, err := AcquireSocket(
|
||||||
context.Background(), retryMgr, mockDialer, "ws://test", nil)
|
context.Background(), retryMgr, mockDialer, "ws://test", nil, nil)
|
||||||
|
|
||||||
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
@@ -141,7 +141,7 @@ func TestAcquireSocketGuards(t *testing.T) {
|
|||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
socket, resp, err := AcquireSocket(
|
socket, resp, err := AcquireSocket(
|
||||||
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil)
|
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil, nil)
|
||||||
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, err, tc.wantErr)
|
assert.ErrorContains(t, err, tc.wantErr)
|
||||||
@@ -168,7 +168,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
|||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
|
||||||
|
|
||||||
assert.ErrorIs(t, err, context.Canceled)
|
assert.ErrorIs(t, err, context.Canceled)
|
||||||
assert.False(t, dialCalled.Load())
|
assert.False(t, dialCalled.Load())
|
||||||
@@ -195,7 +195,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
|||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
|
||||||
done <- err
|
done <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -233,7 +233,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
|||||||
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
|
||||||
done <- err
|
done <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -250,3 +250,22 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
|
|||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAcquireSocketPassesHeaders(t *testing.T) {
|
||||||
|
header := http.Header{"User-Agent": []string{"test-agent"}}
|
||||||
|
dialCalled := false
|
||||||
|
|
||||||
|
mockDialer := &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
assert.Equal(t, "test-agent", h.Get("User-Agent"))
|
||||||
|
dialCalled = true
|
||||||
|
return honeybeetest.NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
retryMgr := NewRetryManager(&RetryConfig{MaxRetries: 0})
|
||||||
|
_, _, err := AcquireSocket(context.Background(), retryMgr, mockDialer, "ws://test", header, nil)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, dialCalled)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user