transport: copy-on-intake in NewConnection/NewPool; add ConnectionConfig.Clone; remove SetDialer; dialer via config

This commit is contained in:
Jay
2026-05-26 14:46:10 -04:00
parent 695389798e
commit d4da16f82a
11 changed files with 180 additions and 150 deletions
+7
View File
@@ -67,6 +67,13 @@ func GetDefaultConnectionConfig() *ConnectionConfig {
}
}
func (c ConnectionConfig) Clone() ConnectionConfig {
if c.RequestHeader != nil {
c.RequestHeader = c.RequestHeader.Clone()
}
return c
}
func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error {
for _, option := range options {
if err := option(config); err != nil {
+20
View File
@@ -255,6 +255,26 @@ func TestValidateConnectionConfig(t *testing.T) {
}
}
func TestConnectionConfigClone(t *testing.T) {
header := http.Header{}
header.Set("X-Test", "val")
orig := ConnectionConfig{
RequestHeader: header,
WriteTimeout: 5 * time.Second,
Retry: RetryConfig{Disabled: true},
}
cloned := orig.Clone()
// values match
assert.Equal(t, orig.WriteTimeout, cloned.WriteTimeout)
assert.Equal(t, "val", cloned.RequestHeader.Get("X-Test"))
// header is a distinct copy
cloned.RequestHeader.Set("X-Test", "mutated")
assert.Equal(t, "val", orig.RequestHeader.Get("X-Test"))
}
func TestWithConnectionDialer(t *testing.T) {
mock := &honeybeetest.MockDialer{}
conf, err := NewConnectionConfig(WithConnectionDialer(mock))
+17 -12
View File
@@ -65,7 +65,7 @@ type Connection struct {
url *url.URL
dialer types.Dialer
socket types.Socket
config *ConnectionConfig
config ConnectionConfig
logger *slog.Logger
incoming chan []byte
@@ -107,14 +107,20 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
ctx = component.MustExtend(ctx, "connection")
}
// Clone config to ensure full ownership of all fields.
cc := config.Clone()
if cc.Dialer == nil {
cc.Dialer = NewDialer()
}
conn := &Connection{
url: url,
dialer: NewDialer(),
dialer: cc.Dialer,
socket: nil,
config: config,
incoming: make(chan []byte, config.IncomingBufferSize),
config: cc,
incoming: make(chan []byte, cc.IncomingBufferSize),
heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize),
errors: make(chan error, cc.ErrorsBufferSize),
incomingCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
heartbeatCount: &atomic.Uint64{},
@@ -151,14 +157,17 @@ func NewConnectionFromSocket(
ctx = component.MustExtend(ctx, "connection")
}
// Clone config to ensure full ownership of all fields.
cc := config.Clone()
conn := &Connection{
url: nil,
dialer: nil,
socket: socket,
config: config,
incoming: make(chan []byte, config.IncomingBufferSize),
config: cc,
incoming: make(chan []byte, cc.IncomingBufferSize),
heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize),
errors: make(chan error, cc.ErrorsBufferSize),
incomingCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
heartbeatCount: &atomic.Uint64{},
@@ -311,10 +320,6 @@ func (c *Connection) Stats() ConnectionStats {
}
}
func (c *Connection) SetDialer(d types.Dialer) {
c.dialer = d
}
// ---------------------------/
// Reader loop
// -------------------------/
+62 -54
View File
@@ -121,9 +121,13 @@ func TestNewConnection(t *testing.T) {
// Verify default config is used if nil is passed
if tc.config == nil {
assert.Equal(t, GetDefaultConnectionConfig(), conn.config)
expected := *GetDefaultConnectionConfig()
expected.Dialer = conn.config.Dialer // dialer resolved at construction
assert.Equal(t, expected, conn.config)
} else {
assert.Equal(t, tc.config, conn.config)
expected := *tc.config
expected.Dialer = conn.config.Dialer
assert.Equal(t, expected, conn.config)
}
})
}
@@ -220,11 +224,19 @@ func TestNewConnectionFromSocket(t *testing.T) {
assert.Equal(t, StateConnected, conn.state)
assert.False(t, conn.closed)
// Verify default config is used if nil is passed
// Verify default config is used if nil is passed.
// CloseHandler is a func; exclude it from the struct comparison
// (identity is verified separately via closeHandlerSet).
gotCfg := conn.config
gotCfg.CloseHandler = nil
if tc.config == nil {
assert.Equal(t, GetDefaultConnectionConfig(), conn.config)
expected := *GetDefaultConnectionConfig()
expected.CloseHandler = nil
assert.Equal(t, expected, gotCfg)
} else {
assert.Equal(t, tc.config, conn.config)
expected := *tc.config
expected.CloseHandler = nil
assert.Equal(t, expected, gotCfg)
}
// Verify close handler was set if provided
@@ -261,9 +273,6 @@ func TestConnect(t *testing.T) {
})
t.Run("connect succeeds and starts goroutines", func(t *testing.T) {
conn, err := NewConnection(context.Background(), "ws://test", nil, nil)
assert.NoError(t, err)
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket := honeybeetest.NewMockSocket()
@@ -277,7 +286,9 @@ func TestConnect(t *testing.T) {
return mockSocket, nil, nil
},
}
conn.dialer = mockDialer
conn, err := NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, nil)
assert.NoError(t, err)
err = conn.Connect(context.Background())
assert.NoError(t, err)
@@ -299,17 +310,6 @@ func TestConnect(t *testing.T) {
})
t.Run("connect retries on dial failure", func(t *testing.T) {
config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
attemptCount := 0
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
@@ -320,7 +320,17 @@ func TestConnect(t *testing.T) {
return honeybeetest.NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
err = conn.Connect(context.Background())
assert.NoError(t, err)
@@ -331,6 +341,11 @@ func TestConnect(t *testing.T) {
})
t.Run("connect fails after max retries", func(t *testing.T) {
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, fmt.Errorf("dial failed")
},
}
config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 2,
@@ -338,17 +353,11 @@ func TestConnect(t *testing.T) {
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, fmt.Errorf("dial failed")
},
}
conn.dialer = mockDialer
err = conn.Connect(context.Background())
assert.Error(t, err)
assert.ErrorContains(t, err, "dial failed")
@@ -356,18 +365,20 @@ func TestConnect(t *testing.T) {
})
t.Run("state transitions during connect", func(t *testing.T) {
conn, err := NewConnection(context.Background(), "ws://test", nil, nil)
assert.NoError(t, err)
assert.Equal(t, StateDisconnected, conn.State())
stateDuringDial := StateDisconnected
// conn captured after construction; closure safe because dialer runs during Connect
var conn *Connection
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
stateDuringDial = conn.state
return honeybeetest.NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
var err error
conn, err = NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, nil)
assert.NoError(t, err)
assert.Equal(t, StateDisconnected, conn.State())
conn.Connect(context.Background())
@@ -379,26 +390,24 @@ func TestConnect(t *testing.T) {
t.Run("close handler configured when provided", func(t *testing.T) {
handlerSet := false
config := &ConnectionConfig{
CloseHandler: func(code int, text string) error {
return nil
},
Retry: RetryConfig{Disabled: true},
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
mockSocket := honeybeetest.NewMockSocket()
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
handlerSet = true
}
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
conn.dialer = mockDialer
config := &ConnectionConfig{
CloseHandler: func(code int, text string) error {
return nil
},
Retry: RetryConfig{Disabled: true},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
conn.Connect(context.Background())
@@ -409,17 +418,16 @@ func TestConnect(t *testing.T) {
t.Run("passes headers when configured", func(t *testing.T) {
header := http.Header{"X-Custom": []string{"val"}}
conf, _ := NewConnectionConfig(WithRequestHeader(header))
conn, _ := NewConnection(context.Background(), "ws://test", conf, nil)
dialCalled := false
conn.dialer = &honeybeetest.MockDialer{
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
assert.Equal(t, "val", h.Get("X-Custom"))
dialCalled = true
return honeybeetest.NewMockSocket(), nil, nil
},
}
conf, _ := NewConnectionConfig(WithRequestHeader(header), WithConnectionDialer(mockDialer))
conn, _ := NewConnection(context.Background(), "ws://test", conf, nil)
err := conn.Connect(context.Background())
@@ -438,18 +446,18 @@ func TestConnectContextCancellation(t *testing.T) {
JitterFactor: 0.0,
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
dialCount := atomic.Int32{}
ctx, cancel := context.WithCancel(context.Background())
conn.dialer = &honeybeetest.MockDialer{
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
dialCount.Add(1)
return nil, nil, fmt.Errorf("dial failed")
},
}
config.Dialer = mockDialer
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
+25 -29
View File
@@ -28,16 +28,15 @@ func TestConnectLogging(t *testing.T) {
t.Run("success", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler()
conn, err := NewConnection(context.Background(), "ws://test", nil, mockHandler)
assert.NoError(t, err)
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
conn.dialer = mockDialer
conn, err := NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, mockHandler)
assert.NoError(t, err)
err = conn.Connect(context.Background())
assert.NoError(t, err)
@@ -58,6 +57,12 @@ func TestConnectLogging(t *testing.T) {
t.Run("max retries failure", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler()
dialErr := fmt.Errorf("dial error")
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, dialErr
},
}
config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 2,
@@ -65,19 +70,12 @@ func TestConnectLogging(t *testing.T) {
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler)
assert.NoError(t, err)
dialErr := fmt.Errorf("dial error")
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, dialErr
},
}
conn.dialer = mockDialer
err = conn.Connect(context.Background())
assert.Error(t, err)
@@ -100,18 +98,6 @@ func TestConnectLogging(t *testing.T) {
t.Run("success after retry", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler()
config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler)
assert.NoError(t, err)
attemptCount := 0
dialErr := fmt.Errorf("dial error")
mockDialer := &honeybeetest.MockDialer{
@@ -123,7 +109,18 @@ func TestConnectLogging(t *testing.T) {
return honeybeetest.NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler)
assert.NoError(t, err)
err = conn.Connect(context.Background())
assert.NoError(t, err)
@@ -341,16 +338,15 @@ func TestLoggingDisabled(t *testing.T) {
t.Run("nil logger produces no logs", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler()
conn, err := NewConnection(context.Background(), "ws://test", nil, nil)
assert.NoError(t, err)
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
conn.dialer = mockDialer
conn, err := NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, nil)
assert.NoError(t, err)
err = conn.Connect(context.Background())
assert.NoError(t, err)