diff --git a/config.go b/config.go index 6dcb0e8..0cc2436 100644 --- a/config.go +++ b/config.go @@ -12,7 +12,7 @@ import ( type WorkerFactory func( ctx context.Context, id string, - logger *slog.Logger, + handler slog.Handler, ) (Worker, error) // Pool Config @@ -20,8 +20,6 @@ type WorkerFactory func( type PoolConfig struct { InboxBufferSize int EventsBufferSize int - LoggingEnabled bool - LogLevel *slog.Level ConnectionConfig *transport.ConnectionConfig WorkerFactory WorkerFactory WorkerConfig *WorkerConfig @@ -44,8 +42,6 @@ func GetDefaultPoolConfig() *PoolConfig { return &PoolConfig{ InboxBufferSize: 256, EventsBufferSize: 10, - LoggingEnabled: true, - LogLevel: nil, ConnectionConfig: nil, WorkerFactory: nil, WorkerConfig: nil, @@ -108,21 +104,6 @@ func WithEventsBufferSize(value int) PoolOption { } } -func WithPoolLoggingEnabled(value bool) PoolOption { - return func(c *PoolConfig) error { - c.LoggingEnabled = value - return nil - } -} - -func WithPoolLogLevel(level slog.Level) PoolOption { - return func(c *PoolConfig) error { - l := level - c.LogLevel = &l - return nil - } -} - func WithConnectionConfig(cc *transport.ConnectionConfig) PoolOption { return func(c *PoolConfig) error { err := transport.ValidateConnectionConfig(cc) @@ -157,8 +138,6 @@ func WithWorkerFactory(wf WorkerFactory) PoolOption { type WorkerConfig struct { KeepaliveTimeout time.Duration ReconnectDelay time.Duration - LoggingEnabled bool - LogLevel *slog.Level } type WorkerOption func(*WorkerConfig) error @@ -178,8 +157,6 @@ func GetDefaultWorkerConfig() *WorkerConfig { return &WorkerConfig{ KeepaliveTimeout: 60 * time.Second, ReconnectDelay: 2 * time.Second, - LoggingEnabled: true, - LogLevel: nil, } } @@ -237,18 +214,3 @@ func WithReconnectDelay(value time.Duration) WorkerOption { return nil } } - -func WithWorkerLoggingEnabled(value bool) WorkerOption { - return func(c *WorkerConfig) error { - c.LoggingEnabled = value - return nil - } -} - -func WithWorkerLogLevel(level slog.Level) WorkerOption { - return func(c *WorkerConfig) error { - l := level - c.LogLevel = &l - return nil - } -} diff --git a/config_pool_test.go b/config_pool_test.go index 25b4ba7..8c44c83 100644 --- a/config_pool_test.go +++ b/config_pool_test.go @@ -14,8 +14,6 @@ func TestNewPoolConfig(t *testing.T) { assert.Equal(t, conf, &PoolConfig{ InboxBufferSize: 256, EventsBufferSize: 10, - LoggingEnabled: true, - LogLevel: nil, ConnectionConfig: nil, WorkerConfig: nil, WorkerFactory: nil, @@ -28,8 +26,6 @@ func TestDefaultPoolConfig(t *testing.T) { assert.Equal(t, conf, &PoolConfig{ InboxBufferSize: 256, EventsBufferSize: 10, - LoggingEnabled: true, - LogLevel: nil, ConnectionConfig: nil, WorkerConfig: nil, WorkerFactory: nil, diff --git a/errors.go b/errors.go index 48f9326..55e841d 100644 --- a/errors.go +++ b/errors.go @@ -10,10 +10,9 @@ var ( InvalidBufferSize = errors.New("buffer size must be greater than zero") // Pool errors - ErrInvalidPoolID = errors.New("pool id cannot be empty") - ErrPoolClosed = errors.New("pool is closed") - ErrPeerNotFound = errors.New("peer not found") - ErrPeerExists = errors.New("peer already exists") + ErrPoolClosed = errors.New("pool is closed") + ErrPeerNotFound = errors.New("peer not found") + ErrPeerExists = errors.New("peer already exists") // Worker errors ErrConnectionUnavailable = errors.New("connection unavailable") diff --git a/go.mod b/go.mod index 48a05a7..d834c0e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module git.wisehodl.dev/jay/go-honeybee -go 1.23.5 +go 1.25.0 require ( github.com/gorilla/websocket v1.5.3 @@ -8,6 +8,7 @@ require ( ) require ( + git.wisehodl.dev/jay/go-mana-component v0.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 4b33f39..4e60a5b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +git.wisehodl.dev/jay/go-mana-component v0.1.0 h1:wWYN5MzC9Hq3tEt4z7FjrwNuQz3rZY3RWAmgmNE8EZE= +git.wisehodl.dev/jay/go-mana-component v0.1.0/go.mod h1:r2ZaTjKzwV5JJfC5boikxtjAKusPrzlJU/7qul0EUqA= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= diff --git a/helper_test.go b/helper_test.go index cab276b..5774cab 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,6 +1,7 @@ package honeybee import ( + "context" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/transport" "github.com/stretchr/testify/assert" @@ -18,7 +19,7 @@ func setupTestConnection(t *testing.T) ( socket, incoming, outgoing = honeybeetest.SetupTestSocket(t) var err error - conn, err = transport.NewConnectionFromSocket(socket, nil, nil) + conn, err = transport.NewConnectionFromSocket(context.Background(), socket, nil, nil) assert.NoError(t, err) return } diff --git a/honeybeetest/mocks.go b/honeybeetest/mocks.go index 3dc54c7..ffb28e9 100644 --- a/honeybeetest/mocks.go +++ b/honeybeetest/mocks.go @@ -98,7 +98,7 @@ func (m *MockSocket) SetPongHandler(h func(s string) error) { type MockSlogHandler struct { records *[]slog.Record attrs []slog.Attr - mu sync.RWMutex + mu *sync.RWMutex } func NewMockSlogHandler() *MockSlogHandler { @@ -106,6 +106,7 @@ func NewMockSlogHandler() *MockSlogHandler { return &MockSlogHandler{ records: &records, attrs: make([]slog.Attr, 0), + mu: &sync.RWMutex{}, } } @@ -126,6 +127,7 @@ func (m *MockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler { defer m.mu.RUnlock() return &MockSlogHandler{ records: m.records, // shared records slice + mu: m.mu, // shared mutex attrs: append(m.attrs, attrs...), } } diff --git a/logging/logging.go b/logging/logging.go deleted file mode 100644 index af63a47..0000000 --- a/logging/logging.go +++ /dev/null @@ -1,111 +0,0 @@ -package logging - -import ( - "context" - "log/slog" -) - -// Constants - -const KEY_MODULE = "module" -const KEY_COMPONENT = "component" -const KEY_POOL_ID = "pool_id" -const KEY_PEER_ID = "peer_id" - -const MODULE_NAME = "honeybee" - -const COMPONENT_OUTBOUND_POOL = "outbound_pool" -const COMPONENT_OUTBOUND_WORKER = "outbound_worker" - -const COMPONENT_INBOUND_POOL = "inbound_pool" -const COMPONENT_INBOUND_WORKER = "inbound_worker" - -const COMPONENT_CONNECTION = "connection" - -// Constructors - -func NewOutboundPoolLogger(handler slog.Handler, poolID string) *slog.Logger { - return newLogger(handler, - KEY_MODULE, MODULE_NAME, - KEY_COMPONENT, COMPONENT_OUTBOUND_POOL, - KEY_POOL_ID, poolID, - ) -} - -func NewOutboundWorkerLogger(handler slog.Handler, poolID string, peerID string) *slog.Logger { - return newLogger(handler, - KEY_MODULE, MODULE_NAME, - KEY_COMPONENT, COMPONENT_OUTBOUND_WORKER, - KEY_POOL_ID, poolID, - KEY_PEER_ID, peerID, - ) -} - -func NewInboundPoolLogger(handler slog.Handler, poolID string) *slog.Logger { - return newLogger(handler, - KEY_MODULE, MODULE_NAME, - KEY_COMPONENT, COMPONENT_INBOUND_POOL, - KEY_POOL_ID, poolID, - ) -} - -func NewInboundWorkerLogger(handler slog.Handler, poolID string, peerID string) *slog.Logger { - return newLogger(handler, - KEY_MODULE, MODULE_NAME, - KEY_COMPONENT, COMPONENT_INBOUND_WORKER, - KEY_POOL_ID, poolID, - KEY_PEER_ID, peerID, - ) -} - -func NewConnectionLogger(handler slog.Handler, poolID string, peerID string) *slog.Logger { - return newLogger(handler, - KEY_MODULE, MODULE_NAME, - KEY_COMPONENT, COMPONENT_CONNECTION, - KEY_POOL_ID, poolID, - KEY_PEER_ID, peerID, - ) -} - -// Helpers - -func newLogger(handler slog.Handler, attrs ...any) *slog.Logger { - return slog.New(handler).With(attrs...) -} - -// Handlers - -type ForcedLevelHandler struct { - level slog.Level - next slog.Handler -} - -func NewForcedLevelHandler(level slog.Level, next slog.Handler) slog.Handler { - return &ForcedLevelHandler{ - level: level, - next: next, - } -} - -func (h *ForcedLevelHandler) Enabled(_ context.Context, l slog.Level) bool { - return l >= h.level -} - -func (h *ForcedLevelHandler) Handle(ctx context.Context, r slog.Record) error { - return h.next.Handle(ctx, r) -} - -func (h *ForcedLevelHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - return &ForcedLevelHandler{level: h.level, next: h.next.WithAttrs(attrs)} -} - -func (h *ForcedLevelHandler) WithGroup(name string) slog.Handler { - return &ForcedLevelHandler{level: h.level, next: h.next.WithGroup(name)} -} - -func WrapOrDefault(level *slog.Level, handler slog.Handler) slog.Handler { - if level != nil { - return NewForcedLevelHandler(*level, handler) - } - return handler -} diff --git a/logging/logging_test.go b/logging/logging_test.go deleted file mode 100644 index dd90ea4..0000000 --- a/logging/logging_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package logging - -import ( - "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - // "github.com/stretchr/testify/assert" - "log/slog" - "testing" -) - -// Helpers - -func log(level slog.Level, msg string, attrs map[string]any) honeybeetest.ExpectedLog { - return honeybeetest.ExpectedLog{Level: level, Msg: msg, Attrs: attrs} -} - -// Tests - -func TestOutboundLogger(t *testing.T) { - const POOL_ID = "pool-1" - const PEER_ID = "wss://test" - - handler := honeybeetest.NewMockSlogHandler() - poolLogger := NewOutboundPoolLogger(handler, POOL_ID) - workerLogger := NewOutboundWorkerLogger(handler, POOL_ID, PEER_ID) - connLogger := NewConnectionLogger(handler, POOL_ID, PEER_ID) - - poolLogger.Info("test") - workerLogger.Info("test") - connLogger.Info("test") - - honeybeetest.Eventually(t, func() bool { - return len(handler.GetRecords()) == 3 - }, "expected a log record") - - records := handler.GetRecords() - - honeybeetest.AssertAttributePresent(t, records[0], KEY_MODULE, MODULE_NAME) - honeybeetest.AssertAttributePresent(t, records[0], KEY_COMPONENT, COMPONENT_OUTBOUND_POOL) - honeybeetest.AssertAttributePresent(t, records[0], KEY_POOL_ID, POOL_ID) - - honeybeetest.AssertAttributePresent(t, records[1], KEY_MODULE, MODULE_NAME) - honeybeetest.AssertAttributePresent(t, records[1], KEY_COMPONENT, COMPONENT_OUTBOUND_WORKER) - honeybeetest.AssertAttributePresent(t, records[1], KEY_POOL_ID, POOL_ID) - honeybeetest.AssertAttributePresent(t, records[1], KEY_PEER_ID, PEER_ID) - - honeybeetest.AssertAttributePresent(t, records[2], KEY_MODULE, MODULE_NAME) - honeybeetest.AssertAttributePresent(t, records[2], KEY_COMPONENT, COMPONENT_CONNECTION) - honeybeetest.AssertAttributePresent(t, records[2], KEY_POOL_ID, POOL_ID) - honeybeetest.AssertAttributePresent(t, records[2], KEY_PEER_ID, PEER_ID) -} - -func TestInboundLogger(t *testing.T) { - const POOL_ID = "pool-1" - const PEER_ID = "peer-1" - - handler := honeybeetest.NewMockSlogHandler() - poolLogger := NewInboundPoolLogger(handler, POOL_ID) - workerLogger := NewInboundWorkerLogger(handler, POOL_ID, PEER_ID) - connLogger := NewConnectionLogger(handler, POOL_ID, PEER_ID) - - poolLogger.Info("test") - workerLogger.Info("test") - connLogger.Info("test") - - honeybeetest.Eventually(t, func() bool { - return len(handler.GetRecords()) == 3 - }, "expected a log record") - - records := handler.GetRecords() - - honeybeetest.AssertAttributePresent(t, records[0], KEY_MODULE, MODULE_NAME) - honeybeetest.AssertAttributePresent(t, records[0], KEY_COMPONENT, COMPONENT_INBOUND_POOL) - honeybeetest.AssertAttributePresent(t, records[0], KEY_POOL_ID, POOL_ID) - - honeybeetest.AssertAttributePresent(t, records[1], KEY_MODULE, MODULE_NAME) - honeybeetest.AssertAttributePresent(t, records[1], KEY_COMPONENT, COMPONENT_INBOUND_WORKER) - honeybeetest.AssertAttributePresent(t, records[1], KEY_POOL_ID, POOL_ID) - honeybeetest.AssertAttributePresent(t, records[1], KEY_PEER_ID, PEER_ID) - - honeybeetest.AssertAttributePresent(t, records[2], KEY_MODULE, MODULE_NAME) - honeybeetest.AssertAttributePresent(t, records[2], KEY_COMPONENT, COMPONENT_CONNECTION) - honeybeetest.AssertAttributePresent(t, records[2], KEY_POOL_ID, POOL_ID) - honeybeetest.AssertAttributePresent(t, records[2], KEY_PEER_ID, PEER_ID) -} diff --git a/pool.go b/pool.go index 7d44e3a..f406a8d 100644 --- a/pool.go +++ b/pool.go @@ -2,10 +2,11 @@ package honeybee import ( "context" - "git.wisehodl.dev/jay/go-honeybee/logging" + "log/slog" + "git.wisehodl.dev/jay/go-honeybee/transport" "git.wisehodl.dev/jay/go-honeybee/types" - "log/slog" + component "git.wisehodl.dev/jay/go-mana-component" "sync" "sync/atomic" "time" @@ -50,13 +51,11 @@ type PeerStats struct { } type PoolPlugin struct { - ID string Inbox chan<- types.InboxMessage Events chan<- PoolEvent InboxCounter *atomic.Uint64 Dialer types.Dialer ConnectionConfig *transport.ConnectionConfig - Handler slog.Handler } // Pool @@ -70,8 +69,6 @@ type Pool struct { ctx context.Context cancel context.CancelFunc - id string - peers map[string]*Peer inbox chan types.InboxMessage events chan PoolEvent @@ -89,12 +86,8 @@ type Pool struct { closed bool } -func NewPool(ctx context.Context, id string, config *PoolConfig, handler slog.Handler, +func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler, ) (*Pool, error) { - if id == "" { - return nil, ErrInvalidPoolID - } - if config == nil { config = GetDefaultPoolConfig() } @@ -104,8 +97,8 @@ func NewPool(ctx context.Context, id string, config *PoolConfig, handler slog.Ha // deadlocks. if config.WorkerFactory == nil { config.WorkerFactory = func( - ctx context.Context, id string, logger *slog.Logger) (Worker, error) { - return NewWorker(ctx, id, config.WorkerConfig, logger) + ctx context.Context, id string, handler slog.Handler) (Worker, error) { + return NewWorker(ctx, id, config.WorkerConfig, handler) } } @@ -113,18 +106,17 @@ func NewPool(ctx context.Context, id string, config *PoolConfig, handler slog.Ha return nil, err } - pctx, cancel := context.WithCancel(ctx) + pctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool")) var logger *slog.Logger - if handler != nil && config.LoggingEnabled { - logger = logging.NewOutboundPoolLogger( - logging.WrapOrDefault(config.LogLevel, handler), id) + if handler != nil { + c := component.FromContext(pctx) + logger = slog.New(handler).With(slog.Any("component", c)) } return &Pool{ ctx: pctx, cancel: cancel, - id: id, peers: make(map[string]*Peer), inbox: make(chan types.InboxMessage, config.InboxBufferSize), events: make(chan PoolEvent, config.EventsBufferSize), @@ -254,35 +246,23 @@ func (p *Pool) Connect(id string) error { return NewPoolError(ErrPeerExists) } - var logger *slog.Logger - if p.handler != nil && p.config.WorkerConfig != nil { - if p.config.WorkerConfig.LoggingEnabled { - logger = logging.NewOutboundWorkerLogger( - logging.WrapOrDefault(p.config.WorkerConfig.LogLevel, p.handler), p.id, id) - } - } - // The worker factory must be non-blocking to avoid deadlocks - worker, err := p.config.WorkerFactory(p.ctx, id, logger) + worker, err := p.config.WorkerFactory(p.ctx, id, p.handler) if err != nil { return err } pool := PoolPlugin{ - ID: p.id, Inbox: p.inbox, Events: p.events, InboxCounter: p.inboxCounter, Dialer: p.dialer, ConnectionConfig: p.config.ConnectionConfig, - Handler: p.handler, } - p.wg.Add(1) - go func() { + p.wg.Go(func() { worker.Start(pool) - p.wg.Done() - }() + }) p.peers[id] = &Peer{id: id, worker: worker} diff --git a/pool_test.go b/pool_test.go index e4a9de6..7c57979 100644 --- a/pool_test.go +++ b/pool_test.go @@ -15,7 +15,7 @@ import ( func setupPool(t *testing.T) (*Pool, *honeybeetest.MockDialer) { t.Helper() - pool, err := NewPool(context.Background(), "pool-1", nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) dialer := &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { @@ -45,11 +45,6 @@ func expectEvent( // Tests -func TestPoolID(t *testing.T) { - _, err := NewPool(context.Background(), "", nil, nil) - assert.ErrorIs(t, err, ErrInvalidPoolID) -} - func TestPoolConnect(t *testing.T) { t.Run("successfully adds connection", func(t *testing.T) { pool, _ := setupPool(t) @@ -90,7 +85,7 @@ func TestPoolConnect(t *testing.T) { func TestPoolClose(t *testing.T) { t.Run("channels close after pool close", func(t *testing.T) { - pool, _ := NewPool(context.Background(), "pool-1", nil, nil) + pool, _ := NewPool(context.Background(), nil, nil) pool.Close() _, ok := <-pool.Inbox() assert.False(t, ok) @@ -99,7 +94,7 @@ func TestPoolClose(t *testing.T) { }) t.Run("connect after close returns error", func(t *testing.T) { - pool, _ := NewPool(context.Background(), "pool-1", nil, nil) + pool, _ := NewPool(context.Background(), nil, nil) pool.Close() err := pool.Connect("wss://test") assert.ErrorIs(t, err, ErrPoolClosed) @@ -157,7 +152,7 @@ func TestPoolSend(t *testing.T) { }, } - pool, err := NewPool(context.Background(), "pool-1", nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) pool.dialer = mockDialer diff --git a/transport/config.go b/transport/config.go index 3720543..6e71047 100644 --- a/transport/config.go +++ b/transport/config.go @@ -1,7 +1,6 @@ package transport import ( - "log/slog" "net/http" "time" ) @@ -15,8 +14,6 @@ type ConnectionConfig struct { PingInterval time.Duration IncomingBufferSize int ErrorsBufferSize int - LoggingEnabled bool - LogLevel *slog.Level Retry *RetryConfig } @@ -50,8 +47,6 @@ func GetDefaultConnectionConfig() *ConnectionConfig { PingInterval: 20 * time.Second, IncomingBufferSize: 100, ErrorsBufferSize: 10, - LoggingEnabled: true, - LogLevel: nil, Retry: GetDefaultRetryConfig(), } } @@ -216,21 +211,6 @@ func WithErrorsBufferSize(value int) ConnectionOption { } } -func WithLoggingEnabled(value bool) ConnectionOption { - return func(c *ConnectionConfig) error { - c.LoggingEnabled = value - return nil - } -} - -func WithLogLevel(level slog.Level) ConnectionOption { - return func(c *ConnectionConfig) error { - l := level - c.LogLevel = &l - return nil - } -} - func WithoutRetry() ConnectionOption { return func(c *ConnectionConfig) error { c.Retry = nil diff --git a/transport/config_test.go b/transport/config_test.go index 0b79fde..06f6d95 100644 --- a/transport/config_test.go +++ b/transport/config_test.go @@ -2,7 +2,6 @@ package transport import ( "github.com/stretchr/testify/assert" - "log/slog" "net/http" "testing" "time" @@ -36,8 +35,6 @@ func TestDefaultConnectionConfig(t *testing.T) { PingInterval: 20 * time.Second, IncomingBufferSize: 100, ErrorsBufferSize: 10, - LoggingEnabled: true, - LogLevel: nil, Retry: GetDefaultRetryConfig(), }) } @@ -61,8 +58,6 @@ func TestApplyConnectionOptions(t *testing.T) { conf, WithIncomingBufferSize(256), WithErrorsBufferSize(100), - WithLoggingEnabled(false), - WithLogLevel(slog.LevelError), WithRetryMaxRetries(0), WithRetryInitialDelay(3*time.Second), WithRetryJitterFactor(0.5), @@ -71,8 +66,6 @@ func TestApplyConnectionOptions(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 256, conf.IncomingBufferSize) assert.Equal(t, 100, conf.ErrorsBufferSize) - assert.False(t, conf.LoggingEnabled) - assert.Equal(t, slog.LevelError, *conf.LogLevel) assert.Equal(t, 0, conf.Retry.MaxRetries) assert.Equal(t, 3*time.Second, conf.Retry.InitialDelay) assert.Equal(t, 0.5, conf.Retry.JitterFactor) diff --git a/transport/connection.go b/transport/connection.go index 9999758..996fa95 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -12,6 +12,7 @@ import ( "time" "git.wisehodl.dev/jay/go-honeybee/types" + component "git.wisehodl.dev/jay/go-mana-component" "github.com/gorilla/websocket" ) @@ -74,7 +75,7 @@ type Connection struct { cleanupOnce sync.Once } -func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { +func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig, handler slog.Handler) (*Connection, error) { if config == nil { config = GetDefaultConnectionConfig() } @@ -88,6 +89,18 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) return nil, err } + if component.FromContext(ctx) == nil { + ctx = component.MustNew(ctx, "honeybee", "connection") + } else { + ctx = component.MustExtend(ctx, "connection") + } + + var logger *slog.Logger + if handler != nil { + c := component.FromContext(ctx) + logger = slog.New(handler).With(slog.Any("component", c)) + } + conn := &Connection{ url: url, dialer: NewDialer(), @@ -108,7 +121,7 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) } func NewConnectionFromSocket( - socket types.Socket, config *ConnectionConfig, logger *slog.Logger, + ctx context.Context, socket types.Socket, config *ConnectionConfig, handler slog.Handler, ) (*Connection, error) { if socket == nil { return nil, NewConnectionError(ErrNilSocket) @@ -122,6 +135,18 @@ func NewConnectionFromSocket( return nil, err } + if component.FromContext(ctx) == nil { + ctx = component.MustNew(ctx, "honeybee", "connection") + } else { + ctx = component.MustExtend(ctx, "connection") + } + + var logger *slog.Logger + if handler != nil { + c := component.FromContext(ctx) + logger = slog.New(handler).With(slog.Any("component", c)) + } + conn := &Connection{ url: nil, dialer: nil, @@ -293,9 +318,7 @@ func (c *Connection) shutdownLogComplete() { } func (c *Connection) startReader() { - c.wg.Add(1) - go func() { - defer c.wg.Done() + c.wg.Go(func() { defer c.shutdownInternal() for { @@ -362,7 +385,7 @@ func (c *Connection) startReader() { } } - }() + }) } func (c *Connection) setupPongHandler() { @@ -381,9 +404,7 @@ func (c *Connection) startPinger() { return } - c.wg.Add(1) - go func() { - defer c.wg.Done() + c.wg.Go(func() { defer c.shutdownInternal() // Calculate 10% jitter window @@ -404,7 +425,7 @@ func (c *Connection) startPinger() { } } } - }() + }) } diff --git a/transport/connection_close_test.go b/transport/connection_close_test.go index 89a591b..ce6e3af 100644 --- a/transport/connection_close_test.go +++ b/transport/connection_close_test.go @@ -2,6 +2,7 @@ package transport import ( "bytes" + "context" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "github.com/gorilla/websocket" @@ -11,7 +12,7 @@ import ( func TestDisconnectedConnectionClose(t *testing.T) { t.Run("close succeeds on disconnected connection", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) assert.Equal(t, StateDisconnected, conn.State()) @@ -20,7 +21,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("close is idempotent", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) conn.Close() @@ -29,7 +30,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("close with nil socket", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) assert.Nil(t, conn.socket) @@ -44,7 +45,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { return expectedErr } - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) conn.socket = mockSocket @@ -53,7 +54,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("channels close after close", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) conn.Close() @@ -66,7 +67,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("send fails after close", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) conn.Close() diff --git a/transport/connection_goroutine_test.go b/transport/connection_goroutine_test.go index b7cef5f..e1ea04b 100644 --- a/transport/connection_goroutine_test.go +++ b/transport/connection_goroutine_test.go @@ -1,6 +1,7 @@ package transport import ( + "context" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" @@ -66,7 +67,7 @@ func TestStartReader(t *testing.T) { return 0, nil, io.EOF } - conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, nil) assert.NoError(t, err) honeybeetest.Eventually(t, func() bool { diff --git a/transport/connection_send_test.go b/transport/connection_send_test.go index fdda050..8f7c096 100644 --- a/transport/connection_send_test.go +++ b/transport/connection_send_test.go @@ -1,6 +1,7 @@ package transport import ( + "context" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "github.com/gorilla/websocket" @@ -62,12 +63,12 @@ func TestConnectionSend(t *testing.T) { defer close(done) var wg sync.WaitGroup - for i := 0; i < 5; i++ { + for i := range 5 { wg.Add(1) go func(id int) { defer wg.Done() - for j := 0; j < 10; j++ { - data := []byte(fmt.Sprintf("msg-%d-%d", id, j)) + for j := range 10 { + data := fmt.Appendf(nil, "msg-%d-%d", id, j) for { // send and retry until success err := conn.Send(data) @@ -129,7 +130,7 @@ func TestConnectionSend(t *testing.T) { return nil } - conn, err := NewConnectionFromSocket(mockSocket, config, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -175,7 +176,7 @@ func TestConnectionSend(t *testing.T) { return nil } - conn, err := NewConnectionFromSocket(mockSocket, config, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -208,7 +209,7 @@ func TestConnectionSend(t *testing.T) { return fmt.Errorf("test error") } - conn, err := NewConnectionFromSocket(mockSocket, config, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -228,7 +229,7 @@ func TestConnectionSend(t *testing.T) { return writeErr } - conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, nil) assert.NoError(t, err) defer conn.Close() diff --git a/transport/connection_test.go b/transport/connection_test.go index 5b70110..b2c3de1 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -39,11 +39,11 @@ func TestConnectionStateString(t *testing.T) { func TestConnectionState(t *testing.T) { // Test initial state - conn, _ := NewConnection("ws://test", nil, nil) + conn, _ := NewConnection(context.Background(), "ws://test", nil, nil) assert.Equal(t, StateDisconnected, conn.State()) // Test state after FromSocket (should be Connected) - conn2, _ := NewConnectionFromSocket(honeybeetest.NewMockSocket(), nil, nil) + conn2, _ := NewConnectionFromSocket(context.Background(), honeybeetest.NewMockSocket(), nil, nil) assert.Equal(t, StateConnected, conn2.State()) // Test state after close @@ -94,7 +94,7 @@ func TestNewConnection(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - conn, err := NewConnection(tc.url, tc.config, nil) + conn, err := NewConnection(context.Background(), tc.url, tc.config, nil) if tc.wantErr { assert.Error(t, err) @@ -194,7 +194,7 @@ func TestNewConnectionFromSocket(t *testing.T) { } } - conn, err := NewConnectionFromSocket(tc.socket, tc.config, nil) + conn, err := NewConnectionFromSocket(context.Background(), tc.socket, tc.config, nil) if tc.wantErr { assert.Error(t, err) @@ -236,7 +236,7 @@ func TestNewConnectionFromSocket(t *testing.T) { func TestConnect(t *testing.T) { t.Run("connect fails when socket already present", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) conn.socket = honeybeetest.NewMockSocket() @@ -248,7 +248,7 @@ func TestConnect(t *testing.T) { }) t.Run("connect fails when connection closed", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) conn.Close() @@ -260,7 +260,7 @@ func TestConnect(t *testing.T) { }) t.Run("connect succeeds and starts goroutines", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) outgoingData := make(chan honeybeetest.MockOutgoingData, 10) @@ -306,7 +306,7 @@ func TestConnect(t *testing.T) { JitterFactor: 0.0, }, } - conn, err := NewConnection("ws://test", config, nil) + conn, err := NewConnection(context.Background(), "ws://test", config, nil) assert.NoError(t, err) attemptCount := 0 @@ -338,7 +338,7 @@ func TestConnect(t *testing.T) { JitterFactor: 0.0, }, } - conn, err := NewConnection("ws://test", config, nil) + conn, err := NewConnection(context.Background(), "ws://test", config, nil) assert.NoError(t, err) mockDialer := &honeybeetest.MockDialer{ @@ -355,7 +355,7 @@ func TestConnect(t *testing.T) { }) t.Run("state transitions during connect", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) assert.Equal(t, StateDisconnected, conn.State()) @@ -383,7 +383,7 @@ func TestConnect(t *testing.T) { return nil }, } - conn, err := NewConnection("ws://test", config, nil) + conn, err := NewConnection(context.Background(), "ws://test", config, nil) assert.NoError(t, err) mockSocket := honeybeetest.NewMockSocket() @@ -408,7 +408,7 @@ func TestConnect(t *testing.T) { t.Run("passes headers when configured", func(t *testing.T) { header := http.Header{"X-Custom": []string{"val"}} conf, _ := NewConnectionConfig(WithRequestHeader(header)) - conn, _ := NewConnection("ws://test", conf, nil) + conn, _ := NewConnection(context.Background(), "ws://test", conf, nil) dialCalled := false conn.dialer = &honeybeetest.MockDialer{ @@ -436,7 +436,7 @@ func TestConnectContextCancellation(t *testing.T) { JitterFactor: 0.0, }, } - conn, err := NewConnection("ws://test", config, nil) + conn, err := NewConnection(context.Background(), "ws://test", config, nil) assert.NoError(t, err) dialCount := atomic.Int32{} @@ -475,7 +475,7 @@ func TestConnectContextCancellation(t *testing.T) { // Connection method tests func TestConnectionIncoming(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) incoming := conn.Incoming() @@ -498,7 +498,7 @@ func TestConnectionErrors(t *testing.T) { } } - conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, nil) assert.NoError(t, err) defer conn.Close() @@ -521,7 +521,7 @@ func TestConnectionErrors(t *testing.T) { } } - conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, nil) assert.NoError(t, err) defer conn.Close() @@ -541,7 +541,7 @@ func TestConnectionErrors(t *testing.T) { return 0, nil, io.EOF } - conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, nil) assert.NoError(t, err) defer conn.Close() @@ -573,7 +573,7 @@ func TestConnectionHeartbeat(t *testing.T) { ) assert.NoError(t, err) - conn, _ := NewConnectionFromSocket(socket, conf, nil) + conn, _ := NewConnectionFromSocket(context.Background(), socket, conf, nil) defer conn.Close() honeybeetest.Eventually(t, @@ -586,7 +586,7 @@ func TestConnectionHeartbeat(t *testing.T) { socket, _, _ := honeybeetest.SetupTestSocket(t) socket.SetPongHandlerFunc = func(h func(string) error) { handler = h } - conn, _ := NewConnectionFromSocket(socket, nil, nil) + conn, _ := NewConnectionFromSocket(context.Background(), socket, nil, nil) defer conn.Close() honeybeetest.Eventually(t, func() bool { @@ -620,7 +620,7 @@ func setupTestConnection(t *testing.T) ( socket, incoming, outgoing = honeybeetest.SetupTestSocket(t) var err error - conn, err = NewConnectionFromSocket(socket, nil, nil) + conn, err = NewConnectionFromSocket(context.Background(), socket, nil, nil) assert.NoError(t, err) return } diff --git a/transport/logging_test.go b/transport/logging_test.go index 2a6e0e5..47e4c29 100644 --- a/transport/logging_test.go +++ b/transport/logging_test.go @@ -8,6 +8,7 @@ import ( "net/http" "testing" "time" + // slog used for ExpectedLog level constants "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/types" @@ -26,9 +27,8 @@ func log(level slog.Level, msg string, attrs map[string]any) honeybeetest.Expect func TestConnectLogging(t *testing.T) { t.Run("success", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) - conn, err := NewConnection("ws://test", nil, logger) + conn, err := NewConnection(context.Background(), "ws://test", nil, mockHandler) assert.NoError(t, err) mockSocket := honeybeetest.NewMockSocket() @@ -57,7 +57,6 @@ func TestConnectLogging(t *testing.T) { t.Run("max retries failure", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) config := &ConnectionConfig{ Retry: &RetryConfig{ @@ -68,7 +67,7 @@ func TestConnectLogging(t *testing.T) { }, } - conn, err := NewConnection("ws://test", config, logger) + conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler) assert.NoError(t, err) dialErr := fmt.Errorf("dial error") @@ -100,7 +99,6 @@ func TestConnectLogging(t *testing.T) { t.Run("success after retry", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) config := &ConnectionConfig{ Retry: &RetryConfig{ @@ -111,7 +109,7 @@ func TestConnectLogging(t *testing.T) { }, } - conn, err := NewConnection("ws://test", config, logger) + conn, err := NewConnection(context.Background(), "ws://test", config, mockHandler) assert.NoError(t, err) attemptCount := 0 @@ -151,10 +149,9 @@ func TestConnectLogging(t *testing.T) { func TestCloseLogging(t *testing.T) { t.Run("normal close", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) mockSocket := honeybeetest.NewMockSocket() - conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, mockHandler) assert.NoError(t, err) conn.Close() @@ -176,7 +173,6 @@ func TestCloseLogging(t *testing.T) { t.Run("close with socket error", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) closeErr := fmt.Errorf("close error") mockSocket := honeybeetest.NewMockSocket() @@ -184,7 +180,7 @@ func TestCloseLogging(t *testing.T) { return closeErr } - conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, mockHandler) assert.NoError(t, err) conn.Close() @@ -208,7 +204,6 @@ func TestCloseLogging(t *testing.T) { func TestReaderLogging(t *testing.T) { t.Run("clean close by peer", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) mockSocket := honeybeetest.NewMockSocket() mockSocket.ReadMessageFunc = func() (int, []byte, error) { @@ -218,7 +213,7 @@ func TestReaderLogging(t *testing.T) { } } - conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, mockHandler) assert.NoError(t, err) defer conn.Close() @@ -236,7 +231,6 @@ func TestReaderLogging(t *testing.T) { t.Run("unexpected close", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) mockSocket := honeybeetest.NewMockSocket() mockSocket.ReadMessageFunc = func() (int, []byte, error) { @@ -246,7 +240,7 @@ func TestReaderLogging(t *testing.T) { } } - conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, mockHandler) assert.NoError(t, err) defer conn.Close() @@ -264,14 +258,13 @@ func TestReaderLogging(t *testing.T) { t.Run("read error", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) mockSocket := honeybeetest.NewMockSocket() mockSocket.ReadMessageFunc = func() (int, []byte, error) { return 0, nil, io.EOF } - conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, mockHandler) assert.NoError(t, err) defer conn.Close() @@ -285,7 +278,6 @@ func TestReaderLogging(t *testing.T) { func TestWriterLogging(t *testing.T) { t.Run("write deadline error", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond} @@ -295,7 +287,7 @@ func TestWriterLogging(t *testing.T) { return deadlineErr } - conn, err := NewConnectionFromSocket(mockSocket, config, logger) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, config, mockHandler) assert.NoError(t, err) err = conn.Send([]byte("test")) @@ -317,7 +309,6 @@ func TestWriterLogging(t *testing.T) { t.Run("write message error", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - logger := slog.New(mockHandler) writeErr := fmt.Errorf("write error") mockSocket := honeybeetest.NewMockSocket() @@ -325,7 +316,7 @@ func TestWriterLogging(t *testing.T) { return writeErr } - conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + conn, err := NewConnectionFromSocket(context.Background(), mockSocket, nil, mockHandler) assert.NoError(t, err) err = conn.Send([]byte("test")) @@ -350,7 +341,7 @@ func TestLoggingDisabled(t *testing.T) { t.Run("nil logger produces no logs", func(t *testing.T) { mockHandler := honeybeetest.NewMockSlogHandler() - conn, err := NewConnection("ws://test", nil, nil) + conn, err := NewConnection(context.Background(), "ws://test", nil, nil) assert.NoError(t, err) mockSocket := honeybeetest.NewMockSocket() diff --git a/transport/retry.go b/transport/retry.go index f12e2b0..a888a62 100644 --- a/transport/retry.go +++ b/transport/retry.go @@ -59,22 +59,16 @@ func (r *RetryManager) CalculateDelay() time.Duration { } // Exponential backoff: InitialDelay * 2^(attempts-1) - shift := r.retryCount - 1 - if shift > 62 { - shift = 62 - } // prevent overflow + shift := min(r.retryCount-1, 62) // prevent overflow backoffMultiplier := float64(int64(1) << shift) baseDelay := float64(r.config.InitialDelay) * backoffMultiplier // Apply jitter: delay * (1 + jitterFactor * (random - 0.5)) random := rand.Float64() jitterMultiplier := 1 + r.config.JitterFactor*(random-0.5) - delay := time.Duration(baseDelay * jitterMultiplier) - - // Cap at MaxDelay - if delay > r.config.MaxDelay { - delay = r.config.MaxDelay - } + delay := min( + // Cap at MaxDelay + time.Duration(baseDelay*jitterMultiplier), r.config.MaxDelay) return delay } diff --git a/transport/watchdog_test.go b/transport/watchdog_test.go index 824b6db..f7c8a23 100644 --- a/transport/watchdog_test.go +++ b/transport/watchdog_test.go @@ -12,15 +12,14 @@ import ( func TestIdleWatchdog(t *testing.T) { t.Run("heartbeat resets timer, onTimeout not called", func(t *testing.T) { activity := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() called := atomic.Bool{} go IdleWatchdog( ctx, activity, 200*time.Millisecond, func() { called.Store(true) }, ) - for i := 0; i < 5; i++ { + for range 5 { time.Sleep(20 * time.Millisecond) activity <- struct{}{} } @@ -32,8 +31,7 @@ func TestIdleWatchdog(t *testing.T) { t.Run("timeout fires onTimeout exactly once", func(t *testing.T) { activity := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() count := atomic.Int32{} done := make(chan struct{}) @@ -123,7 +121,7 @@ func TestIdleWatchdog(t *testing.T) { }() // these must not block - for i := 0; i < 5; i++ { + for range 5 { activity <- struct{}{} } diff --git a/worker.go b/worker.go index c055cc0..bffc816 100644 --- a/worker.go +++ b/worker.go @@ -2,13 +2,14 @@ package honeybee import ( "context" - "git.wisehodl.dev/jay/go-honeybee/logging" - "git.wisehodl.dev/jay/go-honeybee/transport" - "git.wisehodl.dev/jay/go-honeybee/types" "log/slog" "sync" "sync/atomic" "time" + + "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" + component "git.wisehodl.dev/jay/go-mana-component" ) // Worker @@ -42,17 +43,18 @@ type DefaultWorker struct { outgoingCount *atomic.Uint64 restartCount *atomic.Uint64 - config *WorkerConfig - ctx context.Context - cancel context.CancelFunc - logger *slog.Logger + config *WorkerConfig + ctx context.Context + cancel context.CancelFunc + handler slog.Handler + logger *slog.Logger } func NewWorker( ctx context.Context, id string, config *WorkerConfig, - logger *slog.Logger, + handler slog.Handler, ) (*DefaultWorker, error) { if config == nil { config = GetDefaultWorkerConfig() @@ -61,6 +63,18 @@ func NewWorker( return nil, err } + if component.FromContext(ctx) == nil { + ctx = component.MustNew(ctx, "honeybee", "worker") + } else { + ctx = component.MustExtend(ctx, "worker") + } + + var logger *slog.Logger + if handler != nil { + c := component.FromContext(ctx) + logger = slog.New(handler).With(slog.Any("component", c), slog.String("peer_id", id)) + } + wctx, wcancel := context.WithCancel(ctx) w := &DefaultWorker{ id: id, @@ -71,6 +85,7 @@ func NewWorker( restartCount: &atomic.Uint64{}, ctx: wctx, cancel: wcancel, + handler: handler, logger: logger, } @@ -91,7 +106,7 @@ func (w *DefaultWorker) Start(pool PoolPlugin) { go func() { defer wg.Done() - RunDialer(w.id, w.ctx, pool, dial, newConn, w.logger) + RunDialer(w.id, w.ctx, pool, dial, newConn, w.handler, w.logger) }() go func() { @@ -447,14 +462,9 @@ func connect( id string, ctx context.Context, pool PoolPlugin, + handler slog.Handler, ) (*transport.Connection, error) { - var logger *slog.Logger - if pool.Handler != nil && pool.ConnectionConfig.LoggingEnabled { - logger = logging.NewConnectionLogger( - logging.WrapOrDefault(pool.ConnectionConfig.LogLevel, pool.Handler), pool.ID, id) - } - - conn, err := transport.NewConnection(id, pool.ConnectionConfig, logger) + conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler) if err != nil { return nil, err } @@ -471,6 +481,7 @@ func RunDialer( dial <-chan struct{}, newConn chan<- *transport.Connection, + handler slog.Handler, logger *slog.Logger, ) { for { @@ -482,7 +493,7 @@ func RunDialer( logger.Debug("dialer: dialing") } // dial a new connection - conn, err := connect(id, ctx, pool) + conn, err := connect(id, ctx, pool, handler) // send error if dial failed and continue if err != nil { diff --git a/worker_dialer_test.go b/worker_dialer_test.go index 1b26a91..428b4fe 100644 --- a/worker_dialer_test.go +++ b/worker_dialer_test.go @@ -19,8 +19,7 @@ func TestRunDialer(t *testing.T) { url := "wss://test" dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() mockSocket := honeybeetest.NewMockSocket() pool := PoolPlugin{ @@ -31,7 +30,7 @@ func TestRunDialer(t *testing.T) { }, } - go RunDialer(url, ctx, pool, dial, newConn, nil) + go RunDialer(url, ctx, pool, dial, newConn, nil, nil) dial <- struct{}{} honeybeetest.Eventually(t, func() bool { @@ -49,8 +48,7 @@ func TestRunDialer(t *testing.T) { url := "wss://test" dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() gate := make(chan struct{}) dialCount := atomic.Int32{} @@ -71,14 +69,14 @@ func TestRunDialer(t *testing.T) { ConnectionConfig: connConfig, } - go RunDialer(url, ctx, pool, dial, newConn, nil) + go RunDialer(url, ctx, pool, dial, newConn, nil, nil) dial <- struct{}{} // wait for dial to start blocking on gate <-started // flood dial while dialer is blocked - for i := 0; i < 5; i++ { + for range 5 { select { case dial <- struct{}{}: default: @@ -114,8 +112,7 @@ func TestRunDialer(t *testing.T) { url := "wss://test" dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // use atomic counter to fail first dial and pass second dialCount := atomic.Int32{} @@ -137,7 +134,7 @@ func TestRunDialer(t *testing.T) { ConnectionConfig: connConfig, } - go RunDialer(url, ctx, pool, dial, newConn, nil) + go RunDialer(url, ctx, pool, dial, newConn, nil, nil) dial <- struct{}{} dial <- struct{}{} @@ -161,7 +158,7 @@ func TestRunDialer(t *testing.T) { done := make(chan struct{}) go func() { - RunDialer(url, ctx, pool, dial, newConn, nil) + RunDialer(url, ctx, pool, dial, newConn, nil, nil) close(done) }() @@ -198,7 +195,7 @@ func TestRunDialer(t *testing.T) { done := make(chan struct{}) go func() { - RunDialer(url, ctx, pool, dial, newConn, nil) + RunDialer(url, ctx, pool, dial, newConn, nil, nil) close(done) }() diff --git a/worker_keepalive_test.go b/worker_keepalive_test.go index 54985da..225d31d 100644 --- a/worker_keepalive_test.go +++ b/worker_keepalive_test.go @@ -12,13 +12,12 @@ func TestRunKeepalive(t *testing.T) { heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) timeout := 200 * time.Millisecond - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() go RunKeepalive(ctx, heartbeat, keepalive, timeout, nil) // send heartbeats faster than the timeout - for i := 0; i < 5; i++ { + for range 5 { time.Sleep(20 * time.Millisecond) heartbeat <- struct{}{} } @@ -38,8 +37,7 @@ func TestRunKeepalive(t *testing.T) { heartbeat := make(chan struct{}, 1) keepalive := make(chan struct{}, 1) timeout := 20 * time.Millisecond - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() go RunKeepalive(ctx, heartbeat, keepalive, timeout, nil) @@ -80,13 +78,12 @@ func TestRunKeepalive(t *testing.T) { t.Run("disabled keepalive drains heartbeats without blocking", func(t *testing.T) { heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() go RunKeepalive(ctx, heartbeat, keepalive, 0, nil) // these must not block - for i := 0; i < 5; i++ { + for range 5 { heartbeat <- struct{}{} } diff --git a/worker_send_test.go b/worker_send_test.go index 4785a68..91dada4 100644 --- a/worker_send_test.go +++ b/worker_send_test.go @@ -81,8 +81,8 @@ func TestWorkerSend(t *testing.T) { }() const count = 3 - for i := 0; i < count; i++ { - err := w.Send([]byte(fmt.Sprintf("msg-%d", i))) + for i := range count { + err := w.Send(fmt.Appendf(nil, "msg-%d", i)) assert.NoError(t, err) } diff --git a/worker_session_inner_test.go b/worker_session_inner_test.go index 7ed810b..5a71b11 100644 --- a/worker_session_inner_test.go +++ b/worker_session_inner_test.go @@ -68,10 +68,10 @@ func TestRunReader(t *testing.T) { go RunReader("wss://test", ctx, cancel, conn, inbox, heartbeat, nil) const count = 3 - for i := 0; i < count; i++ { + for i := range count { incomingData <- honeybeetest.MockIncomingData{ MsgType: websocket.TextMessage, - Data: []byte(fmt.Sprintf("msg-%d", i)), + Data: fmt.Appendf(nil, "msg-%d", i), } } @@ -150,12 +150,11 @@ func TestHeartbeatForwarder(t *testing.T) { var pongHandler func(string) error socket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h } - conn, err := transport.NewConnectionFromSocket(socket, nil, nil) + conn, err := transport.NewConnectionFromSocket(context.Background(), socket, nil, nil) assert.NoError(t, err) heartbeat := make(chan struct{}, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() go RunHeartbeatForwarder(ctx, conn, heartbeat, nil) diff --git a/worker_session_test.go b/worker_session_test.go index 147cbdd..b5c7192 100644 --- a/worker_session_test.go +++ b/worker_session_test.go @@ -140,7 +140,7 @@ func TestRunSessionDial(t *testing.T) { // drain initial dial expectDial(t, v.dial) - for i := 0; i < 3; i++ { + for range 3 { v.keepalive <- struct{}{} expectDial(t, v.dial) } diff --git a/worker_start_test.go b/worker_start_test.go index f77ca36..e598f57 100644 --- a/worker_start_test.go +++ b/worker_start_test.go @@ -65,11 +65,9 @@ func TestWorkerStart(t *testing.T) { pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup - wg.Add(1) - go func() { + wg.Go(func() { w.Start(pool) - wg.Done() - }() + }) honeybeetest.Eventually(t, func() bool { select { @@ -91,11 +89,9 @@ func TestWorkerStart(t *testing.T) { pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup - wg.Add(1) - go func() { + wg.Go(func() { w.Start(pool) - wg.Done() - }() + }) honeybeetest.Eventually(t, func() bool { select { @@ -144,11 +140,9 @@ func TestWorkerStart(t *testing.T) { pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup - wg.Add(1) - go func() { + wg.Go(func() { w.Start(pool) - wg.Done() - }() + }) honeybeetest.Eventually(t, func() bool { select { @@ -184,11 +178,9 @@ func TestWorkerStart(t *testing.T) { pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup - wg.Add(1) - go func() { + wg.Go(func() { w.Start(pool) - wg.Done() - }() + }) honeybeetest.Eventually(t, func() bool { select { @@ -230,11 +222,9 @@ func TestWorkerStart(t *testing.T) { pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup - wg.Add(1) - go func() { + wg.Go(func() { w.Start(pool) - wg.Done() - }() + }) honeybeetest.Eventually(t, func() bool { select { @@ -278,11 +268,9 @@ func TestWorkerStart(t *testing.T) { pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup - wg.Add(1) - go func() { + wg.Go(func() { w.Start(pool) - wg.Done() - }() + }) honeybeetest.Eventually(t, func() bool { select {