10 Commits

19 changed files with 449 additions and 311 deletions
+16 -6
View File
@@ -17,7 +17,7 @@ Logging can be controlled independently at the pool, worker, and connection leve
| Connection | `ErrorsBufferSize` | 10 | — | Must be positive | | Connection | `ErrorsBufferSize` | 10 | — | Must be positive |
| Connection | `LoggingEnabled` | true | `false` | | | Connection | `LoggingEnabled` | true | `false` | |
| Connection | `LogLevel` | nil | — | nil defers to handler's own filter | | Connection | `LogLevel` | nil | — | nil defers to handler's own filter |
| Retry | enabled | yes | `WithoutRetry()` | Governs `Connect()` only | | Retry | enabled | yes | `WithRetryDisabled()` | Governs `Connect()` only |
| Retry | `MaxRetries` | 0 | — | 0 means infinite | | Retry | `MaxRetries` | 0 | — | 0 means infinite |
| Retry | `InitialDelay` | 1s | — | Must be positive | | Retry | `InitialDelay` | 1s | — | Must be positive |
| Retry | `MaxDelay` | 60s | — | Must be ≥ InitialDelay | | Retry | `MaxDelay` | 60s | — | Must be ≥ InitialDelay |
@@ -61,11 +61,16 @@ Sets the capacity of the channel that buffers inbound messages between the reade
**`WithErrorsBufferSize(int)`** **`WithErrorsBufferSize(int)`**
Sets the capacity of the channel that carries connection-level errors to the consumer. Must be at least 1. Sets the capacity of the channel that carries connection-level errors to the consumer. Must be at least 1.
### Dialer
**`WithConnectionDialer(types.Dialer)`**
Overrides the dialer used to establish the WebSocket connection. When not set, the connection uses the default dialer. Useful in tests or when routing connections through a custom transport.
### Retry ### Retry
The retry policy governs the `Connect()` call only. It does not affect worker reconnection, which is controlled by `ReconnectDelay` on the worker config. The retry policy governs the `Connect()` call only. It does not affect worker reconnection, which is controlled by `ReconnectDelay` on the worker config.
**`WithoutRetry()`** **`WithRetryDisabled()`**
Disables retry entirely. `Connect()` returns on the first dial failure. Disables retry entirely. `Connect()` returns on the first dial failure.
**`WithRetryMaxRetries(int)`** **`WithRetryMaxRetries(int)`**
@@ -124,13 +129,18 @@ Enables or disables worker-level logging.
**`honeybee.WithWorkerLogLevel(slog.Level)`** **`honeybee.WithWorkerLogLevel(slog.Level)`**
Overrides the minimum log level for worker-scoped records only. Overrides the minimum log level for worker-scoped records only.
### Per-connection
**`honeybee.WithDialer(types.Dialer)`**
Overrides the dialer for a single `Connect` call. Passed as a variadic option: `pool.Connect(id, honeybee.WithDialer(d))`. When provided, it takes precedence over the dialer resolved from `ConnectionConfig`. Existing callers that pass no options are unaffected.
### Wiring ### Wiring
**`honeybee.WithConnectionConfig(*transport.ConnectionConfig)`** **`honeybee.WithConnectionConfig(transport.ConnectionConfig)`**
Supplies a connection config used when dialing each peer. Supplies a connection config used when dialing each peer. Accepted by value; the pool stores its own copy.
**`honeybee.WithWorkerConfig(*honeybee.WorkerConfig)`** **`honeybee.WithWorkerConfig(honeybee.WorkerConfig)`**
Supplies a worker config applied to every worker the pool creates. Supplies a worker config applied to every worker the pool creates. Accepted by value; the pool stores its own copy.
**`honeybee.WithWorkerFactory(honeybee.WorkerFactory)`** **`honeybee.WithWorkerFactory(honeybee.WorkerFactory)`**
Replaces the default worker constructor. See [EXTEND.md](EXTEND.md) for the factory contract. Replaces the default worker constructor. See [EXTEND.md](EXTEND.md) for the factory contract.
+1 -2
View File
@@ -35,8 +35,7 @@ type PoolPlugin struct {
Inbox chan<- honeybee.InboxMessage Inbox chan<- honeybee.InboxMessage
Events chan<- honeybee.PoolEvent Events chan<- honeybee.PoolEvent
InboxCounter *atomic.Uint64 InboxCounter *atomic.Uint64
Dialer honeybee.Dialer ConnectionConfig transport.ConnectionConfig
ConnectionConfig *transport.ConnectionConfig
Handler slog.Handler Handler slog.Handler
} }
``` ```
+1 -1
View File
@@ -240,7 +240,7 @@ Connections send periodic WebSocket ping frames and listen for the corresponding
Pong-derived heartbeats reset the keepalive timer alongside data messages and sends. A peer that sends no data but responds to pings will not be disconnected and reconnected by the keepalive mechanism. Pong-derived heartbeats reset the keepalive timer alongside data messages and sends. A peer that sends no data but responds to pings will not be disconnected and reconnected by the keepalive mechanism.
The ping interval is configured via `transport.WithPingInterval` on the `transport.ConnectionConfig`. Import `git.wisehodl.dev/jay/go-honeybee/transport` to construct a `ConnectionConfig`, then pass it to the pool via `honeybee.WithConnectionConfig`, or supply it directly to `NewConnection` and `NewConnectionFromSocket`. The default is 20 seconds. Set to zero to disable pings entirely, in which case only data messages and outbound sends generate heartbeats. The ping interval is configured via `transport.WithPingInterval` on the `transport.ConnectionConfig`. Import `git.wisehodl.dev/jay/go-honeybee/transport` to construct a `ConnectionConfig`, then pass it to the pool by value via `honeybee.WithConnectionConfig`, or supply it directly to `NewConnection` and `NewConnectionFromSocket`. The default is 20 seconds. Set to zero to disable pings entirely, in which case only data messages and outbound sends generate heartbeats.
## Statistics ## Statistics
+10 -14
View File
@@ -14,9 +14,9 @@ import (
type PoolConfig struct { type PoolConfig struct {
InboxBufferSize int InboxBufferSize int
EventsBufferSize int EventsBufferSize int
ConnectionConfig *transport.ConnectionConfig ConnectionConfig transport.ConnectionConfig
WorkerFactory WorkerFactory WorkerFactory WorkerFactory
WorkerConfig *WorkerConfig WorkerConfig WorkerConfig
} }
type PoolOption func(*PoolConfig) error type PoolOption func(*PoolConfig) error
@@ -38,9 +38,9 @@ func GetDefaultPoolConfig() *PoolConfig {
return &PoolConfig{ return &PoolConfig{
InboxBufferSize: 256, InboxBufferSize: 256,
EventsBufferSize: 10, EventsBufferSize: 10,
ConnectionConfig: nil, ConnectionConfig: *transport.GetDefaultConnectionConfig(),
WorkerFactory: nil, WorkerFactory: nil,
WorkerConfig: nil, WorkerConfig: *GetDefaultWorkerConfig(),
} }
} }
@@ -58,19 +58,15 @@ func applyPoolOptions(config *PoolConfig, options ...PoolOption) error {
func ValidatePoolConfig(config *PoolConfig) error { func ValidatePoolConfig(config *PoolConfig) error {
var err error var err error
if config.ConnectionConfig != nil { err = transport.ValidateConnectionConfig(&config.ConnectionConfig)
err = transport.ValidateConnectionConfig(config.ConnectionConfig)
if err != nil { if err != nil {
return err return err
} }
}
if config.WorkerConfig != nil { err = ValidateWorkerConfig(&config.WorkerConfig)
err = ValidateWorkerConfig(config.WorkerConfig)
if err != nil { if err != nil {
return err return err
} }
}
return nil return nil
} }
@@ -104,9 +100,9 @@ func WithEventsBufferSize(value int) PoolOption {
} }
} }
func WithConnectionConfig(cc *transport.ConnectionConfig) PoolOption { func WithConnectionConfig(cc transport.ConnectionConfig) PoolOption {
return func(c *PoolConfig) error { return func(c *PoolConfig) error {
err := transport.ValidateConnectionConfig(cc) err := transport.ValidateConnectionConfig(&cc)
if err != nil { if err != nil {
return err return err
} }
@@ -115,9 +111,9 @@ func WithConnectionConfig(cc *transport.ConnectionConfig) PoolOption {
} }
} }
func WithWorkerConfig(wc *WorkerConfig) PoolOption { func WithWorkerConfig(wc WorkerConfig) PoolOption {
return func(c *PoolConfig) error { return func(c *PoolConfig) error {
err := ValidateWorkerConfig(wc) err := ValidateWorkerConfig(&wc)
if err != nil { if err != nil {
return err return err
} }
+27 -13
View File
@@ -14,8 +14,8 @@ func TestNewPoolConfig(t *testing.T) {
assert.Equal(t, conf, &PoolConfig{ assert.Equal(t, conf, &PoolConfig{
InboxBufferSize: 256, InboxBufferSize: 256,
EventsBufferSize: 10, EventsBufferSize: 10,
ConnectionConfig: nil, ConnectionConfig: *transport.GetDefaultConnectionConfig(),
WorkerConfig: nil, WorkerConfig: *GetDefaultWorkerConfig(),
WorkerFactory: nil, WorkerFactory: nil,
}) })
} }
@@ -26,8 +26,8 @@ func TestDefaultPoolConfig(t *testing.T) {
assert.Equal(t, conf, &PoolConfig{ assert.Equal(t, conf, &PoolConfig{
InboxBufferSize: 256, InboxBufferSize: 256,
EventsBufferSize: 10, EventsBufferSize: 10,
ConnectionConfig: nil, ConnectionConfig: *transport.GetDefaultConnectionConfig(),
WorkerConfig: nil, WorkerConfig: *GetDefaultWorkerConfig(),
WorkerFactory: nil, WorkerFactory: nil,
}) })
} }
@@ -36,7 +36,9 @@ func TestApplyPoolOptions(t *testing.T) {
conf := &PoolConfig{} conf := &PoolConfig{}
err := applyPoolOptions( err := applyPoolOptions(
conf, conf,
WithConnectionConfig(&transport.ConnectionConfig{}), WithConnectionConfig(transport.ConnectionConfig{
Retry: transport.RetryConfig{Disabled: true},
}),
) )
assert.NoError(t, err) assert.NoError(t, err)
@@ -57,15 +59,21 @@ func TestWithBufferSizes(t *testing.T) {
func TestWithConnectionConfig(t *testing.T) { func TestWithConnectionConfig(t *testing.T) {
conf := &PoolConfig{} conf := &PoolConfig{}
opt := WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: 1 * time.Second}) opt := WithConnectionConfig(transport.ConnectionConfig{
WriteTimeout: 1 * time.Second,
Retry: transport.RetryConfig{Disabled: true},
})
err := applyPoolOptions(conf, opt) err := applyPoolOptions(conf, opt)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, conf.ConnectionConfig)
assert.Equal(t, 1*time.Second, conf.ConnectionConfig.WriteTimeout) assert.Equal(t, 1*time.Second, conf.ConnectionConfig.WriteTimeout)
// invalid config is rejected // invalid config is rejected
conf = &PoolConfig{} conf = &PoolConfig{}
opt = WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: -1 * time.Second}) opt = WithConnectionConfig(
transport.ConnectionConfig{
WriteTimeout: -1 * time.Second,
Retry: transport.RetryConfig{Disabled: true},
})
err = applyPoolOptions(conf, opt) err = applyPoolOptions(conf, opt)
assert.Error(t, err) assert.Error(t, err)
} }
@@ -78,8 +86,12 @@ func TestValidatePoolConfig(t *testing.T) {
wantErrText string wantErrText string
}{ }{
{ {
name: "valid empty", name: "valid empty (retry disabled)",
conf: *&PoolConfig{}, conf: PoolConfig{
ConnectionConfig: transport.ConnectionConfig{
Retry: transport.RetryConfig{Disabled: true},
},
},
}, },
{ {
name: "valid defaults", name: "valid defaults",
@@ -88,14 +100,16 @@ func TestValidatePoolConfig(t *testing.T) {
{ {
name: "valid complete", name: "valid complete",
conf: PoolConfig{ conf: PoolConfig{
ConnectionConfig: &transport.ConnectionConfig{}, ConnectionConfig: transport.ConnectionConfig{
Retry: transport.RetryConfig{Disabled: true},
},
}, },
}, },
{ {
name: "invalid connection config", name: "invalid connection config",
conf: PoolConfig{ conf: PoolConfig{
ConnectionConfig: &transport.ConnectionConfig{ ConnectionConfig: transport.ConnectionConfig{
Retry: &transport.RetryConfig{ Retry: transport.RetryConfig{
InitialDelay: 10 * time.Second, InitialDelay: 10 * time.Second,
MaxDelay: 1 * time.Second, MaxDelay: 1 * time.Second,
}, },
+45 -19
View File
@@ -56,8 +56,7 @@ type PoolPlugin struct {
Inbox chan<- types.InboxMessage Inbox chan<- types.InboxMessage
Events chan<- PoolEvent Events chan<- PoolEvent
InboxCounter *atomic.Uint64 InboxCounter *atomic.Uint64
Dialer types.Dialer ConnectionConfig transport.ConnectionConfig
ConnectionConfig *transport.ConnectionConfig
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@@ -101,7 +100,8 @@ func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
if config.WorkerFactory == nil { if config.WorkerFactory == nil {
config.WorkerFactory = func( config.WorkerFactory = func(
ctx context.Context, id string, handler slog.Handler) (Worker, error) { ctx context.Context, id string, handler slog.Handler) (Worker, error) {
return NewWorker(ctx, id, config.WorkerConfig, handler) wc := config.WorkerConfig
return NewWorker(ctx, id, &wc, handler)
} }
} }
@@ -117,12 +117,19 @@ func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
logger = slog.New(handler).With(slog.Any("component", c)) logger = slog.New(handler).With(slog.Any("component", c))
} }
var dialer types.Dialer
if config.ConnectionConfig.Dialer != nil {
dialer = config.ConnectionConfig.Dialer
} else {
dialer = transport.NewDialer()
}
return &Pool{ return &Pool{
peers: make(map[string]*Peer), peers: make(map[string]*Peer),
inbox: make(chan types.InboxMessage, config.InboxBufferSize), inbox: make(chan types.InboxMessage, config.InboxBufferSize),
events: make(chan PoolEvent, config.EventsBufferSize), events: make(chan PoolEvent, config.EventsBufferSize),
dialer: transport.NewDialer(), dialer: dialer,
config: config, config: config,
handler: handler, handler: handler,
logger: logger, logger: logger,
@@ -194,16 +201,9 @@ func (p *Pool) PeerStats(id string) (PeerStats, error) {
}, nil }, nil
} }
func (p *Pool) SetDialer(d types.Dialer) {
if d == nil {
panic("dialer cannot be nil")
}
p.dialer = d
}
func (p *Pool) Close() { func (p *Pool) Close() {
if p.logger != nil { if p.logger != nil {
p.logger.Debug("closing") p.logger.Info("closing")
} }
p.mu.Lock() p.mu.Lock()
@@ -231,9 +231,24 @@ func (p *Pool) Close() {
}() }()
} }
func (p *Pool) Connect(id string) error { // ConnectOption configures a single Connect call.
type ConnectOption func(*connectOptions)
type connectOptions struct {
dialer types.Dialer
}
// WithDialer returns a ConnectOption that overrides the pool dialer for this
// connection only.
func WithDialer(d types.Dialer) ConnectOption {
return func(o *connectOptions) {
o.dialer = d
}
}
func (p *Pool) Connect(id string, opts ...ConnectOption) error {
if p.logger != nil { if p.logger != nil {
p.logger.Debug("connecting to peer", "peer", id) p.logger.Info("connecting", "peer", id)
} }
id, err := transport.NormalizeURL(id) id, err := transport.NormalizeURL(id)
@@ -258,12 +273,23 @@ func (p *Pool) Connect(id string) error {
return err return err
} }
o := &connectOptions{}
for _, opt := range opts {
opt(o)
}
effectiveDialer := p.dialer
if o.dialer != nil {
effectiveDialer = o.dialer
}
cc := p.config.ConnectionConfig.Clone()
cc.Dialer = effectiveDialer
pool := PoolPlugin{ pool := PoolPlugin{
Inbox: p.inbox, Inbox: p.inbox,
Events: p.events, Events: p.events,
InboxCounter: p.inboxCounter, InboxCounter: p.inboxCounter,
Dialer: p.dialer, ConnectionConfig: cc,
ConnectionConfig: p.config.ConnectionConfig,
} }
p.wg.Go(func() { p.wg.Go(func() {
@@ -273,7 +299,7 @@ func (p *Pool) Connect(id string) error {
p.peers[id] = &Peer{id: id, worker: worker} p.peers[id] = &Peer{id: id, worker: worker}
if p.logger != nil { if p.logger != nil {
p.logger.Info("registered peer", "peer", id) p.logger.Debug("registered peer", "peer", id)
} }
return nil return nil
@@ -281,7 +307,7 @@ func (p *Pool) Connect(id string) error {
func (p *Pool) Remove(id string) error { func (p *Pool) Remove(id string) error {
if p.logger != nil { if p.logger != nil {
p.logger.Debug("disconnecting from peer", "peer", id) p.logger.Info("disconnecting", "peer", id)
} }
id, err := transport.NormalizeURL(id) id, err := transport.NormalizeURL(id)
@@ -305,7 +331,7 @@ func (p *Pool) Remove(id string) error {
peer.worker.Stop() peer.worker.Stop()
if p.logger != nil { if p.logger != nil {
p.logger.Info("disconnected from peer", "peer", id) p.logger.Debug("disconnected from peer", "peer", id)
} }
return nil return nil
+63 -5
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/transport"
"git.wisehodl.dev/jay/go-honeybee/types" "git.wisehodl.dev/jay/go-honeybee/types"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -15,14 +16,20 @@ import (
func setupPool(t *testing.T) (*Pool, *honeybeetest.MockDialer) { func setupPool(t *testing.T) (*Pool, *honeybeetest.MockDialer) {
t.Helper() t.Helper()
pool, err := NewPool(context.Background(), nil, nil)
assert.NoError(t, err)
dialer := &honeybeetest.MockDialer{ dialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return honeybeetest.NewMockSocket(), nil, nil return honeybeetest.NewMockSocket(), nil, nil
}, },
} }
pool.dialer = dialer cc := *transport.GetDefaultConnectionConfig()
cc.Dialer = dialer
pool, err := NewPool(context.Background(), &PoolConfig{
InboxBufferSize: 256,
EventsBufferSize: 10,
ConnectionConfig: cc,
WorkerConfig: *GetDefaultWorkerConfig(),
}, nil)
assert.NoError(t, err)
return pool, dialer return pool, dialer
} }
@@ -83,6 +90,51 @@ func TestPoolConnect(t *testing.T) {
}) })
} }
func TestPoolConnectWithDialer(t *testing.T) {
t.Run("per-call dialer is used instead of pool dialer", func(t *testing.T) {
perCallUsed := false
perCallDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
perCallUsed = true
return honeybeetest.NewMockSocket(), nil, nil
},
}
// pool dialer should NOT be called
poolDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
t.Error("pool dialer should not be called when per-call dialer is provided")
return nil, nil, fmt.Errorf("unexpected call")
},
}
cc := *transport.GetDefaultConnectionConfig()
cc.Dialer = poolDialer
pool, err := NewPool(context.Background(), &PoolConfig{
InboxBufferSize: 256,
EventsBufferSize: 10,
ConnectionConfig: cc,
WorkerConfig: *GetDefaultWorkerConfig(),
}, nil)
assert.NoError(t, err)
err = pool.Connect("wss://test", WithDialer(perCallDialer))
assert.NoError(t, err)
honeybeetest.Eventually(t, func() bool {
select {
case e := <-pool.events:
return e.ID == "wss://test" && e.Kind == EventConnected
default:
return false
}
}, "expected connected event")
assert.True(t, perCallUsed, "per-call dialer was not used")
pool.Close()
})
}
func TestPoolClose(t *testing.T) { func TestPoolClose(t *testing.T) {
t.Run("channels close after pool close", func(t *testing.T) { t.Run("channels close after pool close", func(t *testing.T) {
pool, _ := NewPool(context.Background(), nil, nil) pool, _ := NewPool(context.Background(), nil, nil)
@@ -152,9 +204,15 @@ func TestPoolSend(t *testing.T) {
}, },
} }
pool, err := NewPool(context.Background(), nil, nil) cc := *transport.GetDefaultConnectionConfig()
cc.Dialer = mockDialer
pool, err := NewPool(context.Background(), &PoolConfig{
InboxBufferSize: 256,
EventsBufferSize: 10,
ConnectionConfig: cc,
WorkerConfig: *GetDefaultWorkerConfig(),
}, nil)
assert.NoError(t, err) assert.NoError(t, err)
pool.dialer = mockDialer
err = pool.Connect("wss://test") err = pool.Connect("wss://test")
assert.NoError(t, err) assert.NoError(t, err)
+23 -26
View File
@@ -1,6 +1,7 @@
package transport package transport
import ( import (
"git.wisehodl.dev/jay/go-honeybee/types"
"net/http" "net/http"
"time" "time"
) )
@@ -20,10 +21,12 @@ type ConnectionConfig struct {
PingInterval time.Duration PingInterval time.Duration
IncomingBufferSize int IncomingBufferSize int
ErrorsBufferSize int ErrorsBufferSize int
Retry *RetryConfig Retry RetryConfig
Dialer types.Dialer
} }
type RetryConfig struct { type RetryConfig struct {
Disabled bool
MaxRetries int MaxRetries int
InitialDelay time.Duration InitialDelay time.Duration
MaxDelay time.Duration MaxDelay time.Duration
@@ -55,19 +58,22 @@ func GetDefaultConnectionConfig() *ConnectionConfig {
PingInterval: 20 * time.Second, PingInterval: 20 * time.Second,
IncomingBufferSize: 100, IncomingBufferSize: 100,
ErrorsBufferSize: 10, ErrorsBufferSize: 10,
Retry: GetDefaultRetryConfig(), Retry: RetryConfig{
}
}
func GetDefaultRetryConfig() *RetryConfig {
return &RetryConfig{
MaxRetries: 0, // Infinite retries MaxRetries: 0, // Infinite retries
InitialDelay: 1 * time.Second, InitialDelay: 1 * time.Second,
MaxDelay: 60 * time.Second, MaxDelay: 60 * time.Second,
JitterFactor: 0.2, JitterFactor: 0.2,
},
} }
} }
func (c ConnectionConfig) Clone() ConnectionConfig {
if c.RequestHeader != nil {
c.RequestHeader = c.RequestHeader.Clone()
}
return c
}
func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error { func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error {
for _, option := range options { for _, option := range options {
if err := option(config); err != nil { if err := option(config); err != nil {
@@ -85,7 +91,7 @@ func ValidateConnectionConfig(config *ConnectionConfig) error {
return err return err
} }
if config.Retry != nil { if !config.Retry.Disabled {
err = validateMaxRetries(config.Retry.MaxRetries) err = validateMaxRetries(config.Retry.MaxRetries)
if err != nil { if err != nil {
return err return err
@@ -223,19 +229,22 @@ func WithErrorsBufferSize(value int) ConnectionOption {
} }
} }
func WithoutRetry() ConnectionOption { func WithConnectionDialer(d types.Dialer) ConnectionOption {
return func(c *ConnectionConfig) error { return func(c *ConnectionConfig) error {
c.Retry = nil c.Dialer = d
return nil
}
}
func WithRetryDisabled() ConnectionOption {
return func(c *ConnectionConfig) error {
c.Retry.Disabled = true
return nil return nil
} }
} }
func WithRetryMaxRetries(value int) ConnectionOption { func WithRetryMaxRetries(value int) ConnectionOption {
return func(c *ConnectionConfig) error { return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateMaxRetries(value) err := validateMaxRetries(value)
if err != nil { if err != nil {
return err return err
@@ -248,10 +257,6 @@ func WithRetryMaxRetries(value int) ConnectionOption {
func WithRetryInitialDelay(value time.Duration) ConnectionOption { func WithRetryInitialDelay(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error { return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateInitialDelay(value) err := validateInitialDelay(value)
if err != nil { if err != nil {
return err return err
@@ -264,10 +269,6 @@ func WithRetryInitialDelay(value time.Duration) ConnectionOption {
func WithRetryMaxDelay(value time.Duration) ConnectionOption { func WithRetryMaxDelay(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error { return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateMaxDelay(value) err := validateMaxDelay(value)
if err != nil { if err != nil {
return err return err
@@ -280,10 +281,6 @@ func WithRetryMaxDelay(value time.Duration) ConnectionOption {
func WithRetryJitterFactor(value float64) ConnectionOption { func WithRetryJitterFactor(value float64) ConnectionOption {
return func(c *ConnectionConfig) error { return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateJitterFactor(value) err := validateJitterFactor(value)
if err != nil { if err != nil {
return err return err
+35 -13
View File
@@ -1,6 +1,7 @@
package transport package transport
import ( import (
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"net/http" "net/http"
"testing" "testing"
@@ -35,18 +36,12 @@ func TestDefaultConnectionConfig(t *testing.T) {
PingInterval: 20 * time.Second, PingInterval: 20 * time.Second,
IncomingBufferSize: 100, IncomingBufferSize: 100,
ErrorsBufferSize: 10, ErrorsBufferSize: 10,
Retry: GetDefaultRetryConfig(), Retry: RetryConfig{
})
}
func TestDefaultRetryConnectionConfig(t *testing.T) {
conf := GetDefaultRetryConfig()
assert.Equal(t, conf, &RetryConfig{
MaxRetries: 0, MaxRetries: 0,
InitialDelay: 1 * time.Second, InitialDelay: 1 * time.Second,
MaxDelay: 60 * time.Second, MaxDelay: 60 * time.Second,
JitterFactor: 0.2, JitterFactor: 0.2,
},
}) })
} }
@@ -114,10 +109,10 @@ func TestWithWriteTimeout(t *testing.T) {
func TestWithRetry(t *testing.T) { func TestWithRetry(t *testing.T) {
t.Run("without retry", func(t *testing.T) { t.Run("without retry", func(t *testing.T) {
conf := GetDefaultConnectionConfig() conf := GetDefaultConnectionConfig()
opt := WithoutRetry() opt := WithRetryDisabled()
err := applyConnectionOptions(conf, opt) err := applyConnectionOptions(conf, opt)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, conf.Retry) assert.True(t, conf.Retry.Disabled)
}) })
t.Run("with attempts", func(t *testing.T) { t.Run("with attempts", func(t *testing.T) {
@@ -209,7 +204,7 @@ func TestValidateConnectionConfig(t *testing.T) {
}{ }{
{ {
name: "valid empty", name: "valid empty",
conf: *&ConnectionConfig{}, conf: ConnectionConfig{Retry: RetryConfig{Disabled: true}},
}, },
{ {
name: "valid defaults", name: "valid defaults",
@@ -220,7 +215,7 @@ func TestValidateConnectionConfig(t *testing.T) {
conf: ConnectionConfig{ conf: ConnectionConfig{
CloseHandler: (func(code int, text string) error { return nil }), CloseHandler: (func(code int, text string) error { return nil }),
WriteTimeout: time.Duration(30), WriteTimeout: time.Duration(30),
Retry: &RetryConfig{ Retry: RetryConfig{
MaxRetries: 0, MaxRetries: 0,
InitialDelay: 2 * time.Second, InitialDelay: 2 * time.Second,
MaxDelay: 10 * time.Second, MaxDelay: 10 * time.Second,
@@ -231,7 +226,7 @@ func TestValidateConnectionConfig(t *testing.T) {
{ {
name: "invalid - initial delay > max delay", name: "invalid - initial delay > max delay",
conf: ConnectionConfig{ conf: ConnectionConfig{
Retry: &RetryConfig{ Retry: RetryConfig{
InitialDelay: 10 * time.Second, InitialDelay: 10 * time.Second,
MaxDelay: 1 * time.Second, MaxDelay: 1 * time.Second,
}, },
@@ -259,3 +254,30 @@ func TestValidateConnectionConfig(t *testing.T) {
}) })
} }
} }
func TestConnectionConfigClone(t *testing.T) {
header := http.Header{}
header.Set("X-Test", "val")
orig := ConnectionConfig{
RequestHeader: header,
WriteTimeout: 5 * time.Second,
Retry: RetryConfig{Disabled: true},
}
cloned := orig.Clone()
// values match
assert.Equal(t, orig.WriteTimeout, cloned.WriteTimeout)
assert.Equal(t, "val", cloned.RequestHeader.Get("X-Test"))
// header is a distinct copy
cloned.RequestHeader.Set("X-Test", "mutated")
assert.Equal(t, "val", orig.RequestHeader.Get("X-Test"))
}
func TestWithConnectionDialer(t *testing.T) {
mock := &honeybeetest.MockDialer{}
conf, err := NewConnectionConfig(WithConnectionDialer(mock))
assert.NoError(t, err)
assert.Equal(t, mock, conf.Dialer)
}
+23 -18
View File
@@ -65,7 +65,7 @@ type Connection struct {
url *url.URL url *url.URL
dialer types.Dialer dialer types.Dialer
socket types.Socket socket types.Socket
config *ConnectionConfig config ConnectionConfig
logger *slog.Logger logger *slog.Logger
incoming chan []byte incoming chan []byte
@@ -107,14 +107,20 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
ctx = component.MustExtend(ctx, "connection") ctx = component.MustExtend(ctx, "connection")
} }
// Clone config to ensure full ownership of all fields.
cc := config.Clone()
if cc.Dialer == nil {
cc.Dialer = NewDialer()
}
conn := &Connection{ conn := &Connection{
url: url, url: url,
dialer: NewDialer(), dialer: cc.Dialer,
socket: nil, socket: nil,
config: config, config: cc,
incoming: make(chan []byte, config.IncomingBufferSize), incoming: make(chan []byte, cc.IncomingBufferSize),
heartbeat: make(chan struct{}, 1), heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize), errors: make(chan error, cc.ErrorsBufferSize),
incomingCount: &atomic.Uint64{}, incomingCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{},
heartbeatCount: &atomic.Uint64{}, heartbeatCount: &atomic.Uint64{},
@@ -151,14 +157,17 @@ func NewConnectionFromSocket(
ctx = component.MustExtend(ctx, "connection") ctx = component.MustExtend(ctx, "connection")
} }
// Clone config to ensure full ownership of all fields.
cc := config.Clone()
conn := &Connection{ conn := &Connection{
url: nil, url: nil,
dialer: nil, dialer: nil,
socket: socket, socket: socket,
config: config, config: cc,
incoming: make(chan []byte, config.IncomingBufferSize), incoming: make(chan []byte, cc.IncomingBufferSize),
heartbeat: make(chan struct{}, 1), heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize), errors: make(chan error, cc.ErrorsBufferSize),
incomingCount: &atomic.Uint64{}, incomingCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{},
heartbeatCount: &atomic.Uint64{}, heartbeatCount: &atomic.Uint64{},
@@ -219,7 +228,7 @@ func (c *Connection) Connect(ctx context.Context) error {
// socket acquisition failed // socket acquisition failed
c.state = StateDisconnected c.state = StateDisconnected
if c.logger != nil { if c.logger != nil {
c.logger.Error("connection failed", "error", err) c.logger.Warn("connection failed", "error", err)
} }
return NewConnectionError(err) return NewConnectionError(err)
} }
@@ -244,7 +253,7 @@ func (c *Connection) Connect(ctx context.Context) error {
c.state = StateConnected c.state = StateConnected
if c.logger != nil { if c.logger != nil {
c.logger.Info("connected") c.logger.Debug("connected")
} }
return nil return nil
@@ -311,10 +320,6 @@ func (c *Connection) Stats() ConnectionStats {
} }
} }
func (c *Connection) SetDialer(d types.Dialer) {
c.dialer = d
}
// ---------------------------/ // ---------------------------/
// Reader loop // Reader loop
// -------------------------/ // -------------------------/
@@ -357,7 +362,7 @@ func (c *Connection) classifyCloseError(err error) error {
switch closeErr.Code { switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway: case websocket.CloseNormalClosure, websocket.CloseGoingAway:
if c.logger != nil { if c.logger != nil {
c.logger.Info("connection closed by peer", c.logger.Debug("connection closed by peer",
"code", closeErr.Code, "code", closeErr.Code,
"text", closeErr.Text, "text", closeErr.Text,
) )
@@ -366,7 +371,7 @@ func (c *Connection) classifyCloseError(err error) error {
default: default:
if c.logger != nil { if c.logger != nil {
c.logger.Error("unexpected close", c.logger.Warn("unexpected close",
"code", closeErr.Code, "code", closeErr.Code,
"text", closeErr.Text, "text", closeErr.Text,
) )
@@ -492,7 +497,7 @@ func (c *Connection) shutdownInner() {
}) })
if c.logger != nil { if c.logger != nil {
c.logger.Info("closing") c.logger.Debug("closing")
} }
if c.socket != nil { if c.socket != nil {
@@ -518,7 +523,7 @@ func (c *Connection) shutdownCleanup() {
close(c.errors) close(c.errors)
if c.logger != nil { if c.logger != nil {
c.logger.Info("closed") c.logger.Debug("closed")
} }
}) })
} }
+3 -3
View File
@@ -102,7 +102,7 @@ func TestConnectionSend(t *testing.T) {
}) })
t.Run("write timeout disabled when zero", func(t *testing.T) { t.Run("write timeout disabled when zero", func(t *testing.T) {
config := &ConnectionConfig{WriteTimeout: 0} config := &ConnectionConfig{WriteTimeout: 0, Retry: RetryConfig{Disabled: true}}
outgoingData := make(chan honeybeetest.MockOutgoingData, 10) outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
@@ -148,7 +148,7 @@ func TestConnectionSend(t *testing.T) {
}) })
t.Run("write timeout sets deadline when positive", func(t *testing.T) { t.Run("write timeout sets deadline when positive", func(t *testing.T) {
config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond} config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond, Retry: RetryConfig{Disabled: true}}
outgoingData := make(chan honeybeetest.MockOutgoingData, 10) outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
@@ -194,7 +194,7 @@ func TestConnectionSend(t *testing.T) {
}) })
t.Run("send fails on deadline error", func(t *testing.T) { t.Run("send fails on deadline error", func(t *testing.T) {
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond} config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond, Retry: RetryConfig{Disabled: true}}
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
+73 -63
View File
@@ -69,7 +69,7 @@ func TestNewConnection(t *testing.T) {
{ {
name: "valid url, valid config", name: "valid url, valid config",
url: "wss://relay.example.com:8080/path", url: "wss://relay.example.com:8080/path",
config: &ConnectionConfig{WriteTimeout: 30 * time.Second}, config: &ConnectionConfig{WriteTimeout: 30 * time.Second, Retry: RetryConfig{Disabled: true}},
}, },
{ {
name: "invalid url", name: "invalid url",
@@ -82,7 +82,7 @@ func TestNewConnection(t *testing.T) {
name: "invalid config", name: "invalid config",
url: "ws://example.com", url: "ws://example.com",
config: &ConnectionConfig{ config: &ConnectionConfig{
Retry: &RetryConfig{ Retry: RetryConfig{
InitialDelay: 10 * time.Second, InitialDelay: 10 * time.Second,
MaxDelay: 1 * time.Second, MaxDelay: 1 * time.Second,
}, },
@@ -121,9 +121,13 @@ func TestNewConnection(t *testing.T) {
// Verify default config is used if nil is passed // Verify default config is used if nil is passed
if tc.config == nil { if tc.config == nil {
assert.Equal(t, GetDefaultConnectionConfig(), conn.config) expected := *GetDefaultConnectionConfig()
expected.Dialer = conn.config.Dialer // dialer resolved at construction
assert.Equal(t, expected, conn.config)
} else { } else {
assert.Equal(t, tc.config, conn.config) expected := *tc.config
expected.Dialer = conn.config.Dialer
assert.Equal(t, expected, conn.config)
} }
}) })
} }
@@ -152,13 +156,13 @@ func TestNewConnectionFromSocket(t *testing.T) {
{ {
name: "valid socket with valid config", name: "valid socket with valid config",
socket: honeybeetest.NewMockSocket(), socket: honeybeetest.NewMockSocket(),
config: &ConnectionConfig{WriteTimeout: 30 * time.Second}, config: &ConnectionConfig{WriteTimeout: 30 * time.Second, Retry: RetryConfig{Disabled: true}},
}, },
{ {
name: "invalid config", name: "invalid config",
socket: honeybeetest.NewMockSocket(), socket: honeybeetest.NewMockSocket(),
config: &ConnectionConfig{ config: &ConnectionConfig{
Retry: &RetryConfig{ Retry: RetryConfig{
InitialDelay: 10 * time.Second, InitialDelay: 10 * time.Second,
MaxDelay: 1 * time.Second, MaxDelay: 1 * time.Second,
}, },
@@ -173,6 +177,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
CloseHandler: func(code int, text string) error { CloseHandler: func(code int, text string) error {
return nil return nil
}, },
Retry: RetryConfig{Disabled: true},
}, },
}, },
} }
@@ -219,11 +224,19 @@ func TestNewConnectionFromSocket(t *testing.T) {
assert.Equal(t, StateConnected, conn.state) assert.Equal(t, StateConnected, conn.state)
assert.False(t, conn.closed) assert.False(t, conn.closed)
// Verify default config is used if nil is passed // Verify default config is used if nil is passed.
// CloseHandler is a func; exclude it from the struct comparison
// (identity is verified separately via closeHandlerSet).
gotCfg := conn.config
gotCfg.CloseHandler = nil
if tc.config == nil { if tc.config == nil {
assert.Equal(t, GetDefaultConnectionConfig(), conn.config) expected := *GetDefaultConnectionConfig()
expected.CloseHandler = nil
assert.Equal(t, expected, gotCfg)
} else { } else {
assert.Equal(t, tc.config, conn.config) expected := *tc.config
expected.CloseHandler = nil
assert.Equal(t, expected, gotCfg)
} }
// Verify close handler was set if provided // Verify close handler was set if provided
@@ -260,9 +273,6 @@ func TestConnect(t *testing.T) {
}) })
t.Run("connect succeeds and starts goroutines", func(t *testing.T) { t.Run("connect succeeds and starts goroutines", func(t *testing.T) {
conn, err := NewConnection(context.Background(), "ws://test", nil, nil)
assert.NoError(t, err)
outgoingData := make(chan honeybeetest.MockOutgoingData, 10) outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
@@ -276,7 +286,9 @@ func TestConnect(t *testing.T) {
return mockSocket, nil, nil return mockSocket, nil, nil
}, },
} }
conn.dialer = mockDialer conn, err := NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, nil)
assert.NoError(t, err)
err = conn.Connect(context.Background()) err = conn.Connect(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
@@ -298,17 +310,6 @@ func TestConnect(t *testing.T) {
}) })
t.Run("connect retries on dial failure", func(t *testing.T) { t.Run("connect retries on dial failure", func(t *testing.T) {
config := &ConnectionConfig{
Retry: &RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
attemptCount := 0 attemptCount := 0
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
@@ -319,7 +320,17 @@ func TestConnect(t *testing.T) {
return honeybeetest.NewMockSocket(), nil, nil return honeybeetest.NewMockSocket(), nil, nil
}, },
} }
conn.dialer = mockDialer config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
err = conn.Connect(context.Background()) err = conn.Connect(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
@@ -330,23 +341,22 @@ func TestConnect(t *testing.T) {
}) })
t.Run("connect fails after max retries", func(t *testing.T) { t.Run("connect fails after max retries", func(t *testing.T) {
config := &ConnectionConfig{
Retry: &RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, fmt.Errorf("dial failed") return nil, nil, fmt.Errorf("dial failed")
}, },
} }
conn.dialer = mockDialer config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
err = conn.Connect(context.Background()) err = conn.Connect(context.Background())
assert.Error(t, err) assert.Error(t, err)
@@ -355,18 +365,20 @@ func TestConnect(t *testing.T) {
}) })
t.Run("state transitions during connect", func(t *testing.T) { t.Run("state transitions during connect", func(t *testing.T) {
conn, err := NewConnection(context.Background(), "ws://test", nil, nil)
assert.NoError(t, err)
assert.Equal(t, StateDisconnected, conn.State())
stateDuringDial := StateDisconnected stateDuringDial := StateDisconnected
// conn captured after construction; closure safe because dialer runs during Connect
var conn *Connection
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
stateDuringDial = conn.state stateDuringDial = conn.state
return honeybeetest.NewMockSocket(), nil, nil return honeybeetest.NewMockSocket(), nil, nil
}, },
} }
conn.dialer = mockDialer var err error
conn, err = NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, nil)
assert.NoError(t, err)
assert.Equal(t, StateDisconnected, conn.State())
conn.Connect(context.Background()) conn.Connect(context.Background())
@@ -378,25 +390,24 @@ func TestConnect(t *testing.T) {
t.Run("close handler configured when provided", func(t *testing.T) { t.Run("close handler configured when provided", func(t *testing.T) {
handlerSet := false handlerSet := false
config := &ConnectionConfig{
CloseHandler: func(code int, text string) error {
return nil
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) { mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
handlerSet = true handlerSet = true
} }
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil return mockSocket, nil, nil
}, },
} }
conn.dialer = mockDialer config := &ConnectionConfig{
CloseHandler: func(code int, text string) error {
return nil
},
Retry: RetryConfig{Disabled: true},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
conn.Connect(context.Background()) conn.Connect(context.Background())
@@ -407,17 +418,16 @@ func TestConnect(t *testing.T) {
t.Run("passes headers when configured", func(t *testing.T) { t.Run("passes headers when configured", func(t *testing.T) {
header := http.Header{"X-Custom": []string{"val"}} header := http.Header{"X-Custom": []string{"val"}}
conf, _ := NewConnectionConfig(WithRequestHeader(header))
conn, _ := NewConnection(context.Background(), "ws://test", conf, nil)
dialCalled := false dialCalled := false
conn.dialer = &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
assert.Equal(t, "val", h.Get("X-Custom")) assert.Equal(t, "val", h.Get("X-Custom"))
dialCalled = true dialCalled = true
return honeybeetest.NewMockSocket(), nil, nil return honeybeetest.NewMockSocket(), nil, nil
}, },
} }
conf, _ := NewConnectionConfig(WithRequestHeader(header), WithConnectionDialer(mockDialer))
conn, _ := NewConnection(context.Background(), "ws://test", conf, nil)
err := conn.Connect(context.Background()) err := conn.Connect(context.Background())
@@ -429,25 +439,25 @@ func TestConnect(t *testing.T) {
func TestConnectContextCancellation(t *testing.T) { func TestConnectContextCancellation(t *testing.T) {
t.Run("context cancelled during connect returns before retries exhaust", func(t *testing.T) { t.Run("context cancelled during connect returns before retries exhaust", func(t *testing.T) {
config := &ConnectionConfig{ config := &ConnectionConfig{
Retry: &RetryConfig{ Retry: RetryConfig{
MaxRetries: 100, MaxRetries: 100,
InitialDelay: 500 * time.Millisecond, InitialDelay: 500 * time.Millisecond,
MaxDelay: 1 * time.Second, MaxDelay: 1 * time.Second,
JitterFactor: 0.0, JitterFactor: 0.0,
}, },
} }
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
dialCount := atomic.Int32{} dialCount := atomic.Int32{}
ctx, cancel := context.WithCancel(context.Background()) mockDialer := &honeybeetest.MockDialer{
conn.dialer = &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
dialCount.Add(1) dialCount.Add(1)
return nil, nil, fmt.Errorf("dial failed") return nil, nil, fmt.Errorf("dial failed")
}, },
} }
config.Dialer = mockDialer
conn, err := NewConnection(context.Background(), "ws://test", config, nil)
assert.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
+43 -47
View File
@@ -28,16 +28,15 @@ func TestConnectLogging(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler() mockHandler := honeybeetest.NewMockSlogHandler()
conn, err := NewConnection(context.Background(), "ws://test", nil, mockHandler)
assert.NoError(t, err)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil return mockSocket, nil, nil
}, },
} }
conn.dialer = mockDialer conn, err := NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, mockHandler)
assert.NoError(t, err)
err = conn.Connect(context.Background()) err = conn.Connect(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
@@ -49,7 +48,7 @@ func TestConnectLogging(t *testing.T) {
log(slog.LevelDebug, "connecting", map[string]any{}), log(slog.LevelDebug, "connecting", map[string]any{}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}),
log(slog.LevelDebug, "dial successful", map[string]any{"attempt": 1}), log(slog.LevelDebug, "dial successful", map[string]any{"attempt": 1}),
log(slog.LevelInfo, "connected", map[string]any{}), log(slog.LevelDebug, "connected", map[string]any{}),
} }
honeybeetest.AssertLogSequence(t, records, expected) honeybeetest.AssertLogSequence(t, records, expected)
@@ -58,25 +57,24 @@ func TestConnectLogging(t *testing.T) {
t.Run("max retries failure", func(t *testing.T) { t.Run("max retries failure", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler() mockHandler := honeybeetest.NewMockSlogHandler()
config := &ConnectionConfig{
Retry: &RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler)
assert.NoError(t, err)
dialErr := fmt.Errorf("dial error") dialErr := fmt.Errorf("dial error")
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, dialErr return nil, nil, dialErr
}, },
} }
conn.dialer = mockDialer config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler)
assert.NoError(t, err)
err = conn.Connect(context.Background()) err = conn.Connect(context.Background())
assert.Error(t, err) assert.Error(t, err)
@@ -90,8 +88,8 @@ func TestConnectLogging(t *testing.T) {
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}),
log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}),
log(slog.LevelError, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}), log(slog.LevelDebug, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}),
log(slog.LevelError, "connection failed", map[string]any{"error": dialErr}), log(slog.LevelWarn, "connection failed", map[string]any{"error": dialErr}),
} }
honeybeetest.AssertLogSequence(t, records, expected) honeybeetest.AssertLogSequence(t, records, expected)
@@ -100,18 +98,6 @@ func TestConnectLogging(t *testing.T) {
t.Run("success after retry", func(t *testing.T) { t.Run("success after retry", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler() mockHandler := honeybeetest.NewMockSlogHandler()
config := &ConnectionConfig{
Retry: &RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler)
assert.NoError(t, err)
attemptCount := 0 attemptCount := 0
dialErr := fmt.Errorf("dial error") dialErr := fmt.Errorf("dial error")
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
@@ -123,7 +109,18 @@ func TestConnectLogging(t *testing.T) {
return honeybeetest.NewMockSocket(), nil, nil return honeybeetest.NewMockSocket(), nil, nil
}, },
} }
conn.dialer = mockDialer config := &ConnectionConfig{
Retry: RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
Dialer: mockDialer,
}
conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler)
assert.NoError(t, err)
err = conn.Connect(context.Background()) err = conn.Connect(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
@@ -139,7 +136,7 @@ func TestConnectLogging(t *testing.T) {
log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}),
log(slog.LevelDebug, "dial successful", map[string]any{"attempt": 3}), log(slog.LevelDebug, "dial successful", map[string]any{"attempt": 3}),
log(slog.LevelInfo, "connected", map[string]any{}), log(slog.LevelDebug, "connected", map[string]any{}),
} }
honeybeetest.AssertLogSequence(t, records, expected) honeybeetest.AssertLogSequence(t, records, expected)
@@ -158,14 +155,14 @@ func TestCloseLogging(t *testing.T) {
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
return honeybeetest.FindLogRecord( return honeybeetest.FindLogRecord(
mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil mockHandler.GetRecords(), slog.LevelDebug, "closed") != nil
}, "expected log") }, "expected log")
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []honeybeetest.ExpectedLog{ expected := []honeybeetest.ExpectedLog{
log(slog.LevelInfo, "closing", map[string]any{}), log(slog.LevelDebug, "closing", map[string]any{}),
log(slog.LevelInfo, "closed", map[string]any{}), log(slog.LevelDebug, "closed", map[string]any{}),
} }
honeybeetest.AssertLogSequence(t, records, expected) honeybeetest.AssertLogSequence(t, records, expected)
@@ -193,7 +190,7 @@ func TestCloseLogging(t *testing.T) {
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []honeybeetest.ExpectedLog{ expected := []honeybeetest.ExpectedLog{
log(slog.LevelInfo, "closing", map[string]any{}), log(slog.LevelDebug, "closing", map[string]any{}),
log(slog.LevelError, "socket close failed", map[string]any{"error": closeErr}), log(slog.LevelError, "socket close failed", map[string]any{"error": closeErr}),
} }
@@ -219,10 +216,10 @@ func TestReaderLogging(t *testing.T) {
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
return honeybeetest.FindLogRecord( return honeybeetest.FindLogRecord(
mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") != nil mockHandler.GetRecords(), slog.LevelDebug, "connection closed by peer") != nil
}, "expected log") }, "expected log")
record := honeybeetest.FindLogRecord(mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") record := honeybeetest.FindLogRecord(mockHandler.GetRecords(), slog.LevelDebug, "connection closed by peer")
assert.NotNil(t, record) assert.NotNil(t, record)
honeybeetest.AssertAttributePresent(t, *record, "code", websocket.CloseNormalClosure) honeybeetest.AssertAttributePresent(t, *record, "code", websocket.CloseNormalClosure)
honeybeetest.AssertAttributePresent(t, *record, "text", "goodbye") honeybeetest.AssertAttributePresent(t, *record, "text", "goodbye")
@@ -246,10 +243,10 @@ func TestReaderLogging(t *testing.T) {
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
return honeybeetest.FindLogRecord( return honeybeetest.FindLogRecord(
mockHandler.GetRecords(), slog.LevelError, "unexpected close") != nil mockHandler.GetRecords(), slog.LevelWarn, "unexpected close") != nil
}, "expected log") }, "expected log")
record := honeybeetest.FindLogRecord(mockHandler.GetRecords(), slog.LevelError, "unexpected close") record := honeybeetest.FindLogRecord(mockHandler.GetRecords(), slog.LevelWarn, "unexpected close")
assert.NotNil(t, record) assert.NotNil(t, record)
honeybeetest.AssertAttributePresent(t, *record, "code", websocket.CloseProtocolError) honeybeetest.AssertAttributePresent(t, *record, "code", websocket.CloseProtocolError)
honeybeetest.AssertAttributePresent(t, *record, "text", "bad protocol") honeybeetest.AssertAttributePresent(t, *record, "text", "bad protocol")
@@ -279,7 +276,7 @@ func TestWriterLogging(t *testing.T) {
t.Run("write deadline error", func(t *testing.T) { t.Run("write deadline error", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler() mockHandler := honeybeetest.NewMockSlogHandler()
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond} config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond, Retry: RetryConfig{Disabled: true}}
deadlineErr := fmt.Errorf("deadline error") deadlineErr := fmt.Errorf("deadline error")
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
@@ -341,16 +338,15 @@ func TestLoggingDisabled(t *testing.T) {
t.Run("nil logger produces no logs", func(t *testing.T) { t.Run("nil logger produces no logs", func(t *testing.T) {
mockHandler := honeybeetest.NewMockSlogHandler() mockHandler := honeybeetest.NewMockSlogHandler()
conn, err := NewConnection(context.Background(), "ws://test", nil, nil)
assert.NoError(t, err)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil return mockSocket, nil, nil
}, },
} }
conn.dialer = mockDialer conn, err := NewConnection(context.Background(), "ws://test",
&ConnectionConfig{Retry: RetryConfig{Disabled: true}, Dialer: mockDialer}, nil)
assert.NoError(t, err)
err = conn.Connect(context.Background()) err = conn.Connect(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
+6 -6
View File
@@ -7,16 +7,16 @@ import (
) )
type RetryManager struct { type RetryManager struct {
config *RetryConfig config RetryConfig
retryCount int retryCount int
saturation int saturation int
} }
func NewRetryManager(config *RetryConfig) *RetryManager { func NewRetryManager(config RetryConfig) *RetryManager {
// saturationCount: retry count at which base delay meets or exceeds MaxDelay. // saturationCount: retry count at which base delay meets or exceeds MaxDelay.
// Conservative by two to preserve jitter variance near the boundary. // Conservative by two to preserve jitter variance near the boundary.
saturation := 0 saturation := 0
if config != nil && if !config.Disabled &&
config.InitialDelay > 0 && config.InitialDelay > 0 &&
config.InitialDelay <= config.MaxDelay { config.InitialDelay <= config.MaxDelay {
ratio := float64(config.MaxDelay) / float64(config.InitialDelay) ratio := float64(config.MaxDelay) / float64(config.InitialDelay)
@@ -31,7 +31,7 @@ func NewRetryManager(config *RetryConfig) *RetryManager {
} }
func (r *RetryManager) ShouldRetry() bool { func (r *RetryManager) ShouldRetry() bool {
if r.config == nil { if r.config.Disabled {
return false return false
} }
@@ -43,7 +43,7 @@ func (r *RetryManager) ShouldRetry() bool {
} }
func (r *RetryManager) CalculateDelay() time.Duration { func (r *RetryManager) CalculateDelay() time.Duration {
if r.config == nil { if r.config.Disabled {
return time.Second return time.Second
} }
@@ -54,7 +54,7 @@ func (r *RetryManager) CalculateDelay() time.Duration {
// if saturation is reached, calculated backoff will always be higher than // if saturation is reached, calculated backoff will always be higher than
// the maximum delay // the maximum delay
if r.config != nil && r.retryCount >= r.saturation { if r.retryCount >= r.saturation {
return r.config.MaxDelay return r.config.MaxDelay
} }
+13 -13
View File
@@ -7,7 +7,7 @@ import (
) )
func TestNewRetryManager(t *testing.T) { func TestNewRetryManager(t *testing.T) {
config := &RetryConfig{ config := RetryConfig{
MaxRetries: 0, MaxRetries: 0,
} }
@@ -16,14 +16,14 @@ func TestNewRetryManager(t *testing.T) {
assert.Equal(t, config, mgr.config) assert.Equal(t, config, mgr.config)
assert.Equal(t, 0, mgr.retryCount) assert.Equal(t, 0, mgr.retryCount)
// Should accept nil config // Should accept a disabled config
mgr = NewRetryManager(nil) mgr = NewRetryManager(RetryConfig{Disabled: true})
assert.Nil(t, mgr.config) assert.True(t, mgr.config.Disabled)
assert.Equal(t, 0, mgr.retryCount) assert.Equal(t, 0, mgr.retryCount)
} }
func TestRecordRetry(t *testing.T) { func TestRecordRetry(t *testing.T) {
mgr := NewRetryManager(nil) mgr := NewRetryManager(RetryConfig{Disabled: true})
assert.Equal(t, mgr.retryCount, 0) assert.Equal(t, mgr.retryCount, 0)
mgr.RecordRetry() mgr.RecordRetry()
@@ -34,13 +34,13 @@ func TestRecordRetry(t *testing.T) {
} }
func TestShouldRetry(t *testing.T) { func TestShouldRetry(t *testing.T) {
// never retry if config is nil // never retry if config is disabled
mgr := NewRetryManager(nil) mgr := NewRetryManager(RetryConfig{Disabled: true})
assert.False(t, mgr.ShouldRetry()) assert.False(t, mgr.ShouldRetry())
// always retry if max attempt count is zero // always retry if max attempt count is zero
mgr = &RetryManager{ mgr = &RetryManager{
config: &RetryConfig{ config: RetryConfig{
MaxRetries: 0, MaxRetries: 0,
}, },
retryCount: 1000, retryCount: 1000,
@@ -49,7 +49,7 @@ func TestShouldRetry(t *testing.T) {
// retry if below max attempt count // retry if below max attempt count
mgr = &RetryManager{ mgr = &RetryManager{
config: &RetryConfig{ config: RetryConfig{
MaxRetries: 10, MaxRetries: 10,
}, },
retryCount: 5, retryCount: 5,
@@ -58,7 +58,7 @@ func TestShouldRetry(t *testing.T) {
// do not retry if above max attempt count // do not retry if above max attempt count
mgr = &RetryManager{ mgr = &RetryManager{
config: &RetryConfig{ config: RetryConfig{
MaxRetries: 10, MaxRetries: 10,
}, },
retryCount: 11, retryCount: 11,
@@ -68,12 +68,12 @@ func TestShouldRetry(t *testing.T) {
func TestCalculateDelayDisabled(t *testing.T) { func TestCalculateDelayDisabled(t *testing.T) {
// default delay if retry is disabled // default delay if retry is disabled
mgr := NewRetryManager(nil) mgr := NewRetryManager(RetryConfig{Disabled: true})
assert.Equal(t, time.Second, mgr.CalculateDelay()) assert.Equal(t, time.Second, mgr.CalculateDelay())
} }
func TestCalculateDelayWithoutJitter(t *testing.T) { func TestCalculateDelayWithoutJitter(t *testing.T) {
mgr := NewRetryManager(&RetryConfig{ mgr := NewRetryManager(RetryConfig{
MaxRetries: 0, MaxRetries: 0,
InitialDelay: 1 * time.Second, InitialDelay: 1 * time.Second,
MaxDelay: 5 * time.Second, MaxDelay: 5 * time.Second,
@@ -105,7 +105,7 @@ func TestCalculateDelayWithoutJitter(t *testing.T) {
} }
func TestCalculateDelayWithJitter(t *testing.T) { func TestCalculateDelayWithJitter(t *testing.T) {
mgr := NewRetryManager(&RetryConfig{ mgr := NewRetryManager(RetryConfig{
MaxRetries: 0, MaxRetries: 0,
InitialDelay: 1 * time.Second, InitialDelay: 1 * time.Second,
MaxDelay: 5 * time.Second, MaxDelay: 5 * time.Second,
+1 -1
View File
@@ -82,7 +82,7 @@ func AcquireSocket(
if !retryMgr.ShouldRetry() { if !retryMgr.ShouldRetry() {
// retry policy expired // retry policy expired
if logger != nil { if logger != nil {
logger.Error("dial failed, max retries reached", logger.Debug("dial failed, max retries reached",
"error", err, "error", err,
"attempt", retryMgr.RetryCount()+1) "attempt", retryMgr.RetryCount()+1)
} }
+7 -6
View File
@@ -77,7 +77,7 @@ func TestAcquireSocket(t *testing.T) {
}, },
} }
retryMgr := NewRetryManager(&RetryConfig{ retryMgr := NewRetryManager(RetryConfig{
MaxRetries: tc.maxRetries, MaxRetries: tc.maxRetries,
InitialDelay: 1 * time.Millisecond, InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond, MaxDelay: 5 * time.Millisecond,
@@ -106,7 +106,8 @@ func TestAcquireSocketGuards(t *testing.T) {
return honeybeetest.NewMockSocket(), nil, nil return honeybeetest.NewMockSocket(), nil, nil
}, },
} }
validRetryMgr := NewRetryManager(GetDefaultRetryConfig()) validRetryConfig := GetDefaultConnectionConfig().Retry
validRetryMgr := NewRetryManager(validRetryConfig)
cases := []struct { cases := []struct {
name string name string
@@ -167,7 +168,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
// cancel before acquiring socket // cancel before acquiring socket
cancel() cancel()
retryMgr := NewRetryManager(GetDefaultRetryConfig()) retryMgr := NewRetryManager(GetDefaultConnectionConfig().Retry)
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil) _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
assert.ErrorIs(t, err, context.Canceled) assert.ErrorIs(t, err, context.Canceled)
@@ -186,7 +187,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
retryMgr := NewRetryManager(&RetryConfig{ retryMgr := NewRetryManager(RetryConfig{
MaxRetries: 10, MaxRetries: 10,
InitialDelay: 1 * time.Second, InitialDelay: 1 * time.Second,
MaxDelay: 1 * time.Second, MaxDelay: 1 * time.Second,
@@ -230,7 +231,7 @@ func TestAcquireSocketContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
retryMgr := NewRetryManager(GetDefaultRetryConfig()) retryMgr := NewRetryManager(GetDefaultConnectionConfig().Retry)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil) _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil, nil)
@@ -263,7 +264,7 @@ func TestAcquireSocketPassesHeaders(t *testing.T) {
}, },
} }
retryMgr := NewRetryManager(&RetryConfig{MaxRetries: 0}) retryMgr := NewRetryManager(RetryConfig{MaxRetries: 0, InitialDelay: 1 * time.Millisecond, MaxDelay: 5 * time.Millisecond})
_, _, err := AcquireSocket(context.Background(), retryMgr, mockDialer, "ws://test", header, nil) _, _, err := AcquireSocket(context.Background(), retryMgr, mockDialer, "ws://test", header, nil)
assert.NoError(t, err) assert.NoError(t, err)
+24 -21
View File
@@ -2,6 +2,7 @@ package honeybee
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -94,7 +95,6 @@ func NewWorker(
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
config: config, config: config,
handler: handler,
processedCount: &atomic.Uint64{}, processedCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{},
@@ -103,7 +103,8 @@ func NewWorker(
if handler != nil { if handler != nil {
comp := component.FromContext(ctx) comp := component.FromContext(ctx)
w.logger = slog.New(handler).With(slog.Any("component", comp), slog.String("peer_id", id)) w.handler = handler.WithAttrs([]slog.Attr{slog.String("peer", id)})
w.logger = slog.New(w.handler).With(slog.Any("component", comp))
} }
return w, nil return w, nil
@@ -124,13 +125,13 @@ func (w *DefaultWorker) Start(pool PoolPlugin) {
}) })
if w.logger != nil { if w.logger != nil {
w.logger.Info("started") w.logger.Debug("started")
} }
wg.Wait() wg.Wait()
if w.logger != nil { if w.logger != nil {
w.logger.Info("stopped") w.logger.Debug("stopped")
} }
} }
@@ -162,7 +163,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
case conn = <-newConn: case conn = <-newConn:
if w.logger != nil { if w.logger != nil {
w.logger.Debug("session: connected") w.logger.Info("connected")
} }
break preConn break preConn
@@ -171,7 +172,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
case <-inactive(): case <-inactive():
if w.logger != nil { if w.logger != nil {
w.logger.Info("keepalive: no activity observed") w.logger.Warn("keepalive: no activity observed")
} }
timer.Reset(w.config.KeepaliveTimeout) timer.Reset(w.config.KeepaliveTimeout)
spawnDialer() spawnDialer()
@@ -183,7 +184,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
pool.Events <- PoolEvent{ID: w.id, Kind: EventConnected, At: time.Now()} pool.Events <- PoolEvent{ID: w.id, Kind: EventConnected, At: time.Now()}
if w.logger != nil { if w.logger != nil {
w.logger.Info("session: started") w.logger.Debug("session: started")
} }
// run session loop // run session loop
@@ -195,8 +196,14 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
case data, ok := <-conn.Incoming(): case data, ok := <-conn.Incoming():
if !ok { if !ok {
var reason error
select {
case reason = <-conn.Errors():
default:
reason = fmt.Errorf("unknown")
}
if w.logger != nil { if w.logger != nil {
w.logger.Debug("reader: disconnected") w.logger.Info("websocket: closed", "reason", reason)
} }
break conn_loop break conn_loop
} }
@@ -210,9 +217,6 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
heartbeat() heartbeat()
case <-conn.Heartbeat(): case <-conn.Heartbeat():
if w.logger != nil {
w.logger.Debug("ping-pong heartbeat")
}
heartbeat() heartbeat()
case <-w.sendHeartbeat: case <-w.sendHeartbeat:
@@ -220,7 +224,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
case <-inactive(): case <-inactive():
if w.logger != nil { if w.logger != nil {
w.logger.Info("keepalive: no activity observed") w.logger.Warn("keepalive: no activity observed")
} }
timer.Reset(w.config.KeepaliveTimeout) timer.Reset(w.config.KeepaliveTimeout)
break conn_loop break conn_loop
@@ -231,7 +235,10 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
conn.Close() conn.Close()
if w.logger != nil { if w.logger != nil {
w.logger.Info("session: ended") w.logger.Info("disconnected")
}
if w.logger != nil {
w.logger.Debug("session: ended")
} }
// tear down connection // tear down connection
@@ -301,16 +308,13 @@ func (w *DefaultWorker) spawnDialer(
dialCtx, dialCancel := context.WithCancel(ctx) dialCtx, dialCancel := context.WithCancel(ctx)
if w.logger != nil { if w.logger != nil {
w.logger.Debug("session: requesting connection") w.logger.Debug("session: dialing")
} }
go func() { go func() {
conn, err := connect(w.id, dialCtx, pool, w.handler) conn, err := connect(w.id, dialCtx, pool, w.handler)
if err != nil { if err != nil {
if w.logger != nil {
w.logger.Warn("dialer: dial failed", "error", err)
}
return return
} }
@@ -330,12 +334,11 @@ func connect(
pool PoolPlugin, pool PoolPlugin,
handler slog.Handler, handler slog.Handler,
) (*transport.Connection, error) { ) (*transport.Connection, error) {
conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler) cc := pool.ConnectionConfig
conn, err := transport.NewConnection(ctx, id, &cc, handler)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn.SetDialer(pool.Dialer)
return conn, conn.Connect(ctx) return conn, conn.Connect(ctx)
} }
@@ -345,7 +348,7 @@ func connect(
func (w *DefaultWorker) Stop() { func (w *DefaultWorker) Stop() {
if w.logger != nil { if w.logger != nil {
w.logger.Debug("shutting down") w.logger.Info("shutting down")
} }
w.cancel() w.cancel()
} }
+20 -19
View File
@@ -28,6 +28,7 @@ func makeWorkerContext(t *testing.T) (
Inbox: inbox, Inbox: inbox,
Events: events, Events: events,
InboxCounter: &atomic.Uint64{}, InboxCounter: &atomic.Uint64{},
ConnectionConfig: *transport.GetDefaultConnectionConfig(),
} }
return return
} }
@@ -65,7 +66,7 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, ctx, cancel) w := makeWorker(t, ctx, cancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { wg.Go(func() {
@@ -89,13 +90,13 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, ctx, cancel) w := makeWorker(t, ctx, cancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
pool.Dialer = &honeybeetest.MockDialer{ cc, _ := transport.NewConnectionConfig(transport.WithRetryDisabled())
pool.ConnectionConfig = *cc
pool.ConnectionConfig.Dialer = &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, errors.New("connection refused") return nil, nil, errors.New("connection refused")
}, },
} }
cc, _ := transport.NewConnectionConfig(transport.WithoutRetry())
pool.ConnectionConfig = cc
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { w.Start(pool) }) wg.Go(func() { w.Start(pool) })
@@ -143,15 +144,15 @@ func TestWorkerSession(t *testing.T) {
_, _, pool := makeWorkerContext(t) _, _, pool := makeWorkerContext(t)
var dialCount atomic.Uint64 var dialCount atomic.Uint64
pool.Dialer = &honeybeetest.MockDialer{ cc, _ := transport.NewConnectionConfig(transport.WithRetryDisabled())
pool.ConnectionConfig = *cc
pool.ConnectionConfig.Dialer = &honeybeetest.MockDialer{
DialContextFunc: func(dialCtx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(dialCtx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
dialCount.Add(1) dialCount.Add(1)
<-dialCtx.Done() <-dialCtx.Done()
return nil, nil, dialCtx.Err() return nil, nil, dialCtx.Err()
}, },
} }
cc, _ := transport.NewConnectionConfig(transport.WithoutRetry())
pool.ConnectionConfig = cc
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { w.Start(pool) }) wg.Go(func() { w.Start(pool) })
@@ -169,14 +170,14 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, ctx, cancel) w := makeWorker(t, ctx, cancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
pool.Dialer = &honeybeetest.MockDialer{ cc, _ := transport.NewConnectionConfig(transport.WithRetryDisabled())
pool.ConnectionConfig = *cc
pool.ConnectionConfig.Dialer = &honeybeetest.MockDialer{
DialContextFunc: func(dialCtx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { DialContextFunc: func(dialCtx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
<-dialCtx.Done() <-dialCtx.Done()
return nil, nil, dialCtx.Err() return nil, nil, dialCtx.Err()
}, },
} }
cc, _ := transport.NewConnectionConfig(transport.WithoutRetry())
pool.ConnectionConfig = cc
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { w.Start(pool) }) wg.Go(func() { w.Start(pool) })
@@ -204,7 +205,7 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, ctx, cancel) w := makeWorker(t, ctx, cancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
_, mockSocket, _, outgoingData := setupTestConnection(t) _, mockSocket, _, outgoingData := setupTestConnection(t)
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { wg.Go(func() {
@@ -255,7 +256,7 @@ func TestWorkerSession(t *testing.T) {
} }
} }
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { wg.Go(func() {
@@ -311,7 +312,7 @@ func TestWorkerSession(t *testing.T) {
} }
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
_, mockSocket, incomingData, _ := setupTestConnection(t) _, mockSocket, incomingData, _ := setupTestConnection(t)
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { w.Start(pool) }) wg.Go(func() { w.Start(pool) })
@@ -377,7 +378,7 @@ func TestWorkerSession(t *testing.T) {
var pongHandler func(string) error var pongHandler func(string) error
mockSocket, incomingData, _ := honeybeetest.SetupTestSocket(t) mockSocket, incomingData, _ := honeybeetest.SetupTestSocket(t)
mockSocket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h } mockSocket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h }
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { w.Start(pool) }) wg.Go(func() { w.Start(pool) })
@@ -439,7 +440,7 @@ func TestWorkerSession(t *testing.T) {
} }
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
_, mockSocket, _, _ := setupTestConnection(t) _, mockSocket, _, _ := setupTestConnection(t)
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { w.Start(pool) }) wg.Go(func() { w.Start(pool) })
@@ -481,7 +482,7 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, ctx, cancel) w := makeWorker(t, ctx, cancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
_, mockSocket, incomingData, _ := setupTestConnection(t) _, mockSocket, incomingData, _ := setupTestConnection(t)
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { wg.Go(func() {
@@ -525,7 +526,7 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, ctx, cancel) w := makeWorker(t, ctx, cancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
_, mockSocket, incomingData, _ := setupTestConnection(t) _, mockSocket, incomingData, _ := setupTestConnection(t)
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { w.Start(pool) }) wg.Go(func() { w.Start(pool) })
@@ -561,7 +562,7 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, ctx, cancel) w := makeWorker(t, ctx, cancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { wg.Go(func() {
@@ -607,7 +608,7 @@ func TestWorkerSession(t *testing.T) {
w := makeWorker(t, workerCtx, workerCancel) w := makeWorker(t, workerCtx, workerCancel)
_, events, pool := makeWorkerContext(t) _, events, pool := makeWorkerContext(t)
mockSocket := honeybeetest.NewMockSocket() mockSocket := honeybeetest.NewMockSocket()
pool.Dialer = mockDialer(mockSocket) pool.ConnectionConfig.Dialer = mockDialer(mockSocket)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Go(func() { wg.Go(func() {