From f1afca792130e062a2aa6e4fda75e9fc053a9e56 Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 20 May 2026 22:49:25 -0400 Subject: [PATCH] cleanup and refactors --- config.go | 30 ++- honeybeetest/helpers.go | 10 + honeybeetest/mocks.go | 10 +- pool.go | 44 ++-- transport/config.go | 12 + transport/connection.go | 493 +++++++++++++++++++++------------------- transport/socket.go | 4 + worker.go | 286 +++++++++++++---------- worker_send_test.go | 117 ---------- worker_test.go | 118 +++++++++- 10 files changed, 628 insertions(+), 496 deletions(-) delete mode 100644 worker_send_test.go diff --git a/config.go b/config.go index 0cc2436..396b1e7 100644 --- a/config.go +++ b/config.go @@ -1,21 +1,15 @@ package honeybee import ( - "context" "git.wisehodl.dev/jay/go-honeybee/transport" - "log/slog" "time" ) -// Types - -type WorkerFactory func( - ctx context.Context, - id string, - handler slog.Handler, -) (Worker, error) - +// ---------------------------------------------------------------------------- // Pool Config +// ---------------------------------------------------------------------------- + +// Types type PoolConfig struct { InboxBufferSize int @@ -27,6 +21,8 @@ type PoolConfig struct { type PoolOption func(*PoolConfig) error +// Constructor + func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) { conf := GetDefaultPoolConfig() if err := applyPoolOptions(conf, options...); err != nil { @@ -57,6 +53,8 @@ func applyPoolOptions(config *PoolConfig, options ...PoolOption) error { return nil } +// Validation + func ValidatePoolConfig(config *PoolConfig) error { var err error @@ -84,6 +82,8 @@ func validateBufferSize(value int) error { return nil } +// Options + func WithInboxBufferSize(value int) PoolOption { return func(c *PoolConfig) error { if err := validateBufferSize(value); err != nil { @@ -133,7 +133,11 @@ func WithWorkerFactory(wf WorkerFactory) PoolOption { } } +// ---------------------------------------------------------------------------- // Worker Config +// ---------------------------------------------------------------------------- + +// Types type WorkerConfig struct { KeepaliveTimeout time.Duration @@ -142,6 +146,8 @@ type WorkerConfig struct { type WorkerOption func(*WorkerConfig) error +// Constructor + func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) { conf := GetDefaultWorkerConfig() if err := applyWorkerOptions(conf, options...); err != nil { @@ -169,6 +175,8 @@ func applyWorkerOptions(config *WorkerConfig, options ...WorkerOption) error { return nil } +// Validation + func ValidateWorkerConfig(config *WorkerConfig) error { err := validateKeepaliveTimeout(config.KeepaliveTimeout) if err != nil { @@ -192,6 +200,8 @@ func validateReconnectDelay(value time.Duration) error { return nil } +// Options + // When KeepaliveTimeout is set to zero, keepalive timeouts are disabled. func WithKeepaliveTimeout(value time.Duration) WorkerOption { return func(c *WorkerConfig) error { diff --git a/honeybeetest/helpers.go b/honeybeetest/helpers.go index ba94fd0..0e3e3bd 100644 --- a/honeybeetest/helpers.go +++ b/honeybeetest/helpers.go @@ -10,7 +10,9 @@ import ( "time" ) +// ---------------------------------------------------------------------------- // Constants +// ---------------------------------------------------------------------------- const ( TestTimeout = 2 * time.Second @@ -18,7 +20,9 @@ const ( NegativeTestTimeout = 100 * time.Millisecond ) +// ---------------------------------------------------------------------------- // Types +// ---------------------------------------------------------------------------- type MockIncomingData struct { MsgType int @@ -37,7 +41,9 @@ type ExpectedLog struct { Attrs map[string]any } +// ---------------------------------------------------------------------------- // Setup +// ---------------------------------------------------------------------------- func SetupTestSocket(t *testing.T) ( socket *MockSocket, @@ -81,7 +87,9 @@ func SetupTestSocket(t *testing.T) ( return } +// ---------------------------------------------------------------------------- // Helpers +// ---------------------------------------------------------------------------- func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) { t.Helper() @@ -126,7 +134,9 @@ func Never(t *testing.T, condition func() bool, msg string) { assert.Never(t, condition, NegativeTestTimeout, TestTick, msg) } +// ---------------------------------------------------------------------------- // Logging Helpers +// ---------------------------------------------------------------------------- func AssertLogSequence(t *testing.T, records []slog.Record, expected []ExpectedLog) { t.Helper() diff --git a/honeybeetest/mocks.go b/honeybeetest/mocks.go index ffb28e9..17c0ce6 100644 --- a/honeybeetest/mocks.go +++ b/honeybeetest/mocks.go @@ -9,12 +9,16 @@ import ( "time" ) -// Re-exported types for consumer convenience +// ---------------------------------------------------------------------------- +// Re-exports +// ---------------------------------------------------------------------------- type Socket = types.Socket type Dialer = types.Dialer +// ---------------------------------------------------------------------------- // Dialer Mocks +// ---------------------------------------------------------------------------- type MockDialer struct { DialContextFunc func( @@ -28,7 +32,9 @@ func (m *MockDialer) DialContext( return m.DialContextFunc(ctx, url, h) } +// ---------------------------------------------------------------------------- // Socket Mocks +// ---------------------------------------------------------------------------- type MockSocket struct { WriteMessageFunc func(int, []byte) error @@ -93,7 +99,9 @@ func (m *MockSocket) SetPongHandler(h func(s string) error) { m.SetPongHandlerFunc(h) } +// ---------------------------------------------------------------------------- // Logging mocks +// ---------------------------------------------------------------------------- type MockSlogHandler struct { records *[]slog.Record diff --git a/pool.go b/pool.go index f406a8d..fdb39b8 100644 --- a/pool.go +++ b/pool.go @@ -6,7 +6,7 @@ import ( "git.wisehodl.dev/jay/go-honeybee/transport" "git.wisehodl.dev/jay/go-honeybee/types" - component "git.wisehodl.dev/jay/go-mana-component" + "git.wisehodl.dev/jay/go-mana-component" "sync" "sync/atomic" "time" @@ -19,7 +19,9 @@ type Dialer = types.Dialer var NormalizeURL = transport.NormalizeURL +// ---------------------------------------------------------------------------- // Types +// ---------------------------------------------------------------------------- type PoolEventKind string @@ -58,7 +60,9 @@ type PoolPlugin struct { ConnectionConfig *transport.ConnectionConfig } +// ---------------------------------------------------------------------------- // Pool +// ---------------------------------------------------------------------------- type Peer struct { id string @@ -66,24 +70,23 @@ type Peer struct { } type Pool struct { - ctx context.Context - cancel context.CancelFunc - peers map[string]*Peer inbox chan types.InboxMessage events chan PoolEvent - - inboxCounter *atomic.Uint64 - outgoingCount *atomic.Uint64 + closed bool dialer types.Dialer config *PoolConfig handler slog.Handler logger *slog.Logger + ctx context.Context + cancel context.CancelFunc mu sync.RWMutex wg sync.WaitGroup - closed bool + + inboxCounter *atomic.Uint64 + outgoingCount *atomic.Uint64 } func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler, @@ -106,26 +109,29 @@ func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler, return nil, err } - pctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool")) + ctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool")) var logger *slog.Logger if handler != nil { - c := component.FromContext(pctx) + c := component.FromContext(ctx) logger = slog.New(handler).With(slog.Any("component", c)) } return &Pool{ - ctx: pctx, - cancel: cancel, - peers: make(map[string]*Peer), - inbox: make(chan types.InboxMessage, config.InboxBufferSize), - events: make(chan PoolEvent, config.EventsBufferSize), + peers: make(map[string]*Peer), + inbox: make(chan types.InboxMessage, config.InboxBufferSize), + events: make(chan PoolEvent, config.EventsBufferSize), + + dialer: transport.NewDialer(), + config: config, + handler: handler, + logger: logger, + + ctx: ctx, + cancel: cancel, + inboxCounter: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{}, - dialer: transport.NewDialer(), - config: config, - handler: handler, - logger: logger, }, nil } diff --git a/transport/config.go b/transport/config.go index 6e71047..2242dce 100644 --- a/transport/config.go +++ b/transport/config.go @@ -5,6 +5,12 @@ import ( "time" ) +// ---------------------------------------------------------------------------- +// Connection Config +// ---------------------------------------------------------------------------- + +// Types + type CloseHandler func(code int, text string) error type ConnectionConfig struct { @@ -26,6 +32,8 @@ type RetryConfig struct { type ConnectionOption func(*ConnectionConfig) error +// Constructors + func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) { conf := GetDefaultConnectionConfig() if err := applyConnectionOptions(conf, options...); err != nil { @@ -69,6 +77,8 @@ func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOptio return nil } +// Validation + func ValidateConnectionConfig(config *ConnectionConfig) error { err := validateWriteTimeout(config.WriteTimeout) if err != nil { @@ -153,6 +163,8 @@ func validateJitterFactor(value float64) error { return nil } +// Options + func WithCloseHandler(handler CloseHandler) ConnectionOption { return func(c *ConnectionConfig) error { c.CloseHandler = handler diff --git a/transport/connection.go b/transport/connection.go index 996fa95..e70fd79 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -12,10 +12,14 @@ import ( "time" "git.wisehodl.dev/jay/go-honeybee/types" - component "git.wisehodl.dev/jay/go-mana-component" + "git.wisehodl.dev/jay/go-mana-component" "github.com/gorilla/websocket" ) +// ---------------------------------------------------------------------------- +// Types +// ---------------------------------------------------------------------------- + type ConnectionState int const ( @@ -49,6 +53,14 @@ type ConnectionStats struct { TotalHeartbeats uint64 } +// ---------------------------------------------------------------------------- +// Connection +// ---------------------------------------------------------------------------- + +// ---------------------------/ +// Constructors +// -------------------------/ + type Connection struct { url *url.URL dialer types.Dialer @@ -95,18 +107,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig, ctx = component.MustExtend(ctx, "connection") } - var logger *slog.Logger - if handler != nil { - c := component.FromContext(ctx) - logger = slog.New(handler).With(slog.Any("component", c)) - } - conn := &Connection{ url: url, dialer: NewDialer(), socket: nil, config: config, - logger: logger, incoming: make(chan []byte, config.IncomingBufferSize), heartbeat: make(chan struct{}, 1), errors: make(chan error, config.ErrorsBufferSize), @@ -117,6 +122,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig, done: make(chan struct{}), } + if handler != nil { + comp := component.FromContext(ctx) + conn.logger = slog.New(handler).With(slog.Any("component", comp)) + } + return conn, nil } @@ -141,18 +151,11 @@ func NewConnectionFromSocket( ctx = component.MustExtend(ctx, "connection") } - var logger *slog.Logger - if handler != nil { - c := component.FromContext(ctx) - logger = slog.New(handler).With(slog.Any("component", c)) - } - conn := &Connection{ url: nil, dialer: nil, socket: socket, config: config, - logger: logger, incoming: make(chan []byte, config.IncomingBufferSize), heartbeat: make(chan struct{}, 1), errors: make(chan error, config.ErrorsBufferSize), @@ -163,17 +166,31 @@ func NewConnectionFromSocket( done: make(chan struct{}), } + if handler != nil { + comp := component.FromContext(ctx) + conn.logger = slog.New(handler).With(slog.Any("component", comp)) + } + + // initialize if config.CloseHandler != nil { socket.SetCloseHandler(config.CloseHandler) } conn.setupPongHandler() - conn.startPinger() - conn.startReader() + + if conn.config.PingInterval > 0 { + conn.wg.Go(conn.startPinger) + } + + conn.wg.Go(conn.startReader) return conn, nil } +// ---------------------------/ +// Methods +// -------------------------/ + func (c *Connection) Connect(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() @@ -186,17 +203,20 @@ func (c *Connection) Connect(ctx context.Context) error { return NewConnectionError(ErrConnectionClosed) } + // begin connecting if c.logger != nil { c.logger.Debug("connecting") } c.state = StateConnecting + // obtain socket retryMgr := NewRetryManager(c.config.Retry) socket, _, err := AcquireSocket( ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger) if err != nil { + // socket acquisition failed c.state = StateDisconnected if c.logger != nil { c.logger.Error("connection failed", "error", err) @@ -204,231 +224,32 @@ func (c *Connection) Connect(ctx context.Context) error { return NewConnectionError(err) } + // got socket c.socket = socket - c.state = StateConnected + // initialize if c.config.CloseHandler != nil { c.socket.SetCloseHandler(c.config.CloseHandler) } + c.setupPongHandler() + + if c.config.PingInterval > 0 { + c.wg.Go(c.startPinger) + } + + c.wg.Go(c.startReader) + + // connected + c.state = StateConnected + if c.logger != nil { c.logger.Info("connected") } - c.setupPongHandler() - c.startPinger() - c.startReader() - return nil } -func (c *Connection) Close() { - c.shutdownExternal() -} - -func (c *Connection) shutdownExternal() { - err := c.shutdownSetClosed(true) - if err != nil { - return - } - - c.shutdownInner() - c.shutdownCleanup() -} - -func (c *Connection) shutdownInternal() { - err := c.shutdownSetClosed(false) - if err != nil { - return - } - - c.shutdownInner() - - // defer final cleanup to allow this function to return - // otherwise, a deadlock occurs where startReader triggers a shutdown and - // must wait for itself to exit. - go func() { - c.shutdownCleanup() - }() -} - -func (c *Connection) shutdownInner() { - c.shutdownSignalDone() - c.shutdownLogStart() - c.shutdownCloseSocket() -} - -func (c *Connection) shutdownCleanup() { - c.cleanupOnce.Do(func() { - c.wg.Wait() - c.shutdownCloseChannels() - c.shutdownLogComplete() - }) -} - -func (c *Connection) shutdownSetClosed(wait bool) error { - c.mu.Lock() - if c.closed { - c.mu.Unlock() - return NewConnectionError(ErrConnectionClosed) - } - c.closed = true - c.state = StateClosed - c.mu.Unlock() - return nil -} - -func (c *Connection) shutdownSignalDone() { - c.doneOnce.Do(func() { - close(c.done) - }) -} - -func (c *Connection) shutdownLogStart() { - if c.logger != nil { - c.logger.Info("closing") - } -} - -func (c *Connection) shutdownCloseSocket() { - if c.socket != nil { - // force unblock of any network operations immediately - expired := time.Now().Add(-1 * time.Minute) - c.socket.SetReadDeadline(expired) - c.socket.SetWriteDeadline(expired) - - // close socket - err := c.socket.Close() - - if err != nil && c.logger != nil { - c.logger.Error("socket close failed", "error", err) - } - } -} - -func (c *Connection) shutdownCloseChannels() { - close(c.incoming) - close(c.errors) -} - -func (c *Connection) shutdownLogComplete() { - if c.logger != nil { - c.logger.Info("closed") - } -} - -func (c *Connection) startReader() { - c.wg.Go(func() { - defer c.shutdownInternal() - - for { - select { - case <-c.done: - return - default: - messageType, data, err := c.socket.ReadMessage() - if err != nil { - var wrappedErr error - var closeErr *websocket.CloseError - if errors.As(err, &closeErr) { - switch closeErr.Code { - case websocket.CloseNormalClosure, websocket.CloseGoingAway: - if c.logger != nil { - c.logger.Info("connection closed by peer", - "code", closeErr.Code, - "text", closeErr.Text, - ) - } - wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedClean, err) - default: - if c.logger != nil { - c.logger.Error("unexpected close", - "code", closeErr.Code, - "text", closeErr.Text, - ) - } - wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err) - } - } else { - isLocalClose := false - select { - case <-c.done: - isLocalClose = true - default: - } - if c.logger != nil { - if isLocalClose { - c.logger.Debug("read loop terminated", "error", err) - } else { - c.logger.Error("read error", "error", err) - } - } - wrappedErr = fmt.Errorf("%w: %w", ErrReadError, err) - } - - select { - case <-c.done: - case c.errors <- wrappedErr: - } - return - } - - if messageType == websocket.TextMessage || - messageType == websocket.BinaryMessage { - select { - case <-c.done: - return - case c.incoming <- data: - c.incomingCount.Add(1) - } - } - - } - } - }) -} - -func (c *Connection) setupPongHandler() { - c.socket.SetPongHandler(func(appData string) error { - select { - case c.heartbeat <- struct{}{}: - c.heartbeatCount.Add(1) - default: - } - return nil - }) -} - -func (c *Connection) startPinger() { - if c.config.PingInterval <= 0 { - return - } - - c.wg.Go(func() { - defer c.shutdownInternal() - - // Calculate 10% jitter window - jitter := c.config.PingInterval / 10 - - for { - offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter - next := c.config.PingInterval + offset - timer := time.NewTimer(next) - select { - case <-c.done: - timer.Stop() - return - case <-timer.C: - deadline := time.Now().Add(c.config.WriteTimeout) - if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil { - return - } - } - } - }) - -} - func (c *Connection) Send(data []byte) error { c.writeMu.Lock() defer c.writeMu.Unlock() @@ -437,6 +258,7 @@ func (c *Connection) Send(data []byte) error { return NewConnectionError(ErrConnectionClosed) } + // setup if c.config.WriteTimeout > 0 { if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { if c.logger != nil { @@ -446,7 +268,10 @@ func (c *Connection) Send(data []byte) error { } } - if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil { + // send + err := c.socket.WriteMessage(websocket.TextMessage, data) + + if err != nil { if c.logger != nil { c.logger.Error("write error", "error", err) } @@ -489,3 +314,211 @@ func (c *Connection) Stats() ConnectionStats { func (c *Connection) SetDialer(d types.Dialer) { c.dialer = d } + +// ---------------------------/ +// Reader loop +// -------------------------/ + +func (c *Connection) startReader() { + defer c.shutdownInternal() + + for { + select { + case <-c.done: + return + default: + messageType, data, err := c.socket.ReadMessage() + if err != nil { + select { + case <-c.done: + case c.errors <- c.classifyCloseError(err): + } + return + } + + if messageType == websocket.TextMessage || + messageType == websocket.BinaryMessage { + select { + case <-c.done: + return + case c.incoming <- data: + c.incomingCount.Add(1) + } + } + } + } +} + +func (c *Connection) classifyCloseError(err error) error { + var classifiedError error + var closeErr *websocket.CloseError + + if errors.As(err, &closeErr) { + switch closeErr.Code { + case websocket.CloseNormalClosure, websocket.CloseGoingAway: + if c.logger != nil { + c.logger.Info("connection closed by peer", + "code", closeErr.Code, + "text", closeErr.Text, + ) + } + classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedClean, err) + + default: + if c.logger != nil { + c.logger.Error("unexpected close", + "code", closeErr.Code, + "text", closeErr.Text, + ) + } + classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err) + } + + } else { + isLocalClose := false + + select { + case <-c.done: + isLocalClose = true + default: + } + + if c.logger != nil { + if isLocalClose { + c.logger.Debug("read loop terminated", "error", err) + } else { + c.logger.Error("read error", "error", err) + } + } + + classifiedError = fmt.Errorf("%w: %w", ErrReadError, err) + } + + return classifiedError +} + +// ---------------------------/ +// Heartbeat Handling +// -------------------------/ + +func (c *Connection) setupPongHandler() { + c.socket.SetPongHandler(func(appData string) error { + select { + case c.heartbeat <- struct{}{}: + c.heartbeatCount.Add(1) + default: + } + return nil + }) +} + +func (c *Connection) startPinger() { + defer c.shutdownInternal() + + // Calculate 10% jitter window + jitter := c.config.PingInterval / 10 + + for { + offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter + next := c.config.PingInterval + offset + timer := time.NewTimer(next) + select { + case <-c.done: + timer.Stop() + return + case <-timer.C: + deadline := time.Now().Add(c.config.WriteTimeout) + err := c.socket.WriteControl(websocket.PingMessage, nil, deadline) + if err != nil { + return + } + } + } + +} + +// ---------------------------/ +// Shutdown +// -------------------------/ + +func (c *Connection) Close() { + c.shutdownExternal() +} + +func (c *Connection) shutdownExternal() { + // set closed + c.mu.Lock() + if c.closed { + // idempotent shutdown + c.mu.Unlock() + return + } + c.closed = true + c.state = StateClosed + c.mu.Unlock() + + // perform shutdown + c.shutdownInner() + c.shutdownCleanup() +} + +// shutdownInternal defers final cleanup to allow it to return. +// Otherwise, a deadlock occurs where startReader triggers a shutdown and +// must wait for itself to exit. +func (c *Connection) shutdownInternal() { + // set closed + c.mu.Lock() + if c.closed { + // idempotent shutdown + c.mu.Unlock() + return + } + c.closed = true + c.state = StateClosed + c.mu.Unlock() + + // perform shutdown + c.shutdownInner() + + // defer cleanup to avoid deadlock + go func() { + c.shutdownCleanup() + }() +} + +func (c *Connection) shutdownInner() { + c.doneOnce.Do(func() { + close(c.done) + }) + + if c.logger != nil { + c.logger.Info("closing") + } + + if c.socket != nil { + // force unblock of any network operations immediately + expired := time.Now().Add(-1 * time.Minute) + c.socket.SetReadDeadline(expired) + c.socket.SetWriteDeadline(expired) + + // close socket + err := c.socket.Close() + + if err != nil && c.logger != nil { + c.logger.Error("socket close failed", "error", err) + } + } +} + +func (c *Connection) shutdownCleanup() { + c.cleanupOnce.Do(func() { + c.wg.Wait() + + close(c.incoming) + close(c.errors) + + if c.logger != nil { + c.logger.Info("closed") + } + }) +} diff --git a/transport/socket.go b/transport/socket.go index ffa3a13..4c2fe33 100644 --- a/transport/socket.go +++ b/transport/socket.go @@ -69,6 +69,7 @@ func AcquireSocket( logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1) } + // dial socket, resp, err := dialer.DialContext(ctx, url, header) if err == nil { if logger != nil { @@ -77,7 +78,9 @@ func AcquireSocket( return socket, resp, nil } + // dial failed, retry if !retryMgr.ShouldRetry() { + // retry policy expired if logger != nil { logger.Error("dial failed, max retries reached", "error", err, @@ -95,6 +98,7 @@ func AcquireSocket( "next_delay", delay) } + // context cancellable backoff select { case <-time.After(delay): case <-ctx.Done(): diff --git a/worker.go b/worker.go index 31c4b5a..01ab7af 100644 --- a/worker.go +++ b/worker.go @@ -9,10 +9,22 @@ import ( "git.wisehodl.dev/jay/go-honeybee/transport" "git.wisehodl.dev/jay/go-honeybee/types" - component "git.wisehodl.dev/jay/go-mana-component" + "git.wisehodl.dev/jay/go-mana-component" ) +// ---------------------------------------------------------------------------- // Worker +// ---------------------------------------------------------------------------- + +// ---------------------------/ +// Types +// -------------------------/ + +type WorkerFactory func( + ctx context.Context, + id string, + handler slog.Handler, +) (Worker, error) type Worker interface { Start(pool PoolPlugin) @@ -37,19 +49,23 @@ type DefaultWorker struct { id string conn atomic.Pointer[transport.Connection] - heartbeat chan struct{} + sendHeartbeat chan struct{} + + ctx context.Context + cancel context.CancelFunc + config *WorkerConfig + handler slog.Handler + logger *slog.Logger processedCount *atomic.Uint64 outgoingCount *atomic.Uint64 restartCount *atomic.Uint64 - - config *WorkerConfig - ctx context.Context - cancel context.CancelFunc - handler slog.Handler - logger *slog.Logger } +// ---------------------------/ +// Constructor +// -------------------------/ + func NewWorker( ctx context.Context, id string, @@ -77,21 +93,28 @@ func NewWorker( wctx, wcancel := context.WithCancel(ctx) w := &DefaultWorker{ - id: id, - config: config, - heartbeat: make(chan struct{}), + id: id, + + sendHeartbeat: make(chan struct{}), + + ctx: wctx, + cancel: wcancel, + config: config, + handler: handler, + logger: logger, + processedCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{}, restartCount: &atomic.Uint64{}, - ctx: wctx, - cancel: wcancel, - handler: handler, - logger: logger, } return w, nil } +// ---------------------------/ +// Session +// -------------------------/ + func (w *DefaultWorker) Start(pool PoolPlugin) { if w.logger != nil { w.logger.Debug("starting") @@ -114,71 +137,19 @@ func (w *DefaultWorker) Start(pool PoolPlugin) { } func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) { - newConn := make(chan *transport.Connection, 1) - - var timer *time.Timer - if w.config.KeepaliveTimeout > 0 { - if w.logger != nil { - w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout) - } - timer = time.NewTimer(w.config.KeepaliveTimeout) - defer timer.Stop() - } else { - if w.logger != nil { - w.logger.Debug("keepalive: disabled") - } - } - - resetTimer := func() { - if timer == nil { - return - } - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timer.Reset(w.config.KeepaliveTimeout) - } - - timerC := func() <-chan time.Time { - if timer == nil { - return nil - } - return timer.C - } - + // setup dialer var dialCancel context.CancelFunc + newConn := make(chan *transport.Connection, 1) + spawnDialer := func() { dialCancel = w.spawnDialer(ctx, dialCancel, newConn, pool) } - spawnDial := func() { - if dialCancel != nil { - dialCancel() - } - var dialCtx context.Context - dialCtx, dialCancel = context.WithCancel(ctx) - if w.logger != nil { - w.logger.Debug("session: requesting connection") - } - go func() { - conn, err := connect(w.id, dialCtx, pool, w.handler) - if err != nil { - if w.logger != nil { - w.logger.Warn("dialer: dial failed") - } - return - } - select { - case newConn <- conn: - case <-dialCtx.Done(): - conn.Close() - } - }() - } + // setup heartbeat + timer, timerC, heartbeat := w.setupHeartbeat() + defer timer.Stop() + // main loop for { // spawn initial dial for this reconnect cycle - spawnDial() + spawnDialer() // obtain new connection var conn *transport.Connection @@ -190,23 +161,26 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) { dialCancel() } return - case <-w.heartbeat: - resetTimer() - case <-timerC(): - if w.logger != nil { - w.logger.Info("keepalive: no activity observed") - } - timer.Reset(w.config.KeepaliveTimeout) - spawnDial() + case conn = <-newConn: if w.logger != nil { w.logger.Debug("session: connected") } break preConn + + case <-w.sendHeartbeat: + heartbeat() + + case <-timerC(): + if w.logger != nil { + w.logger.Info("keepalive: no activity observed") + } + timer.Reset(w.config.KeepaliveTimeout) + spawnDialer() } } - // set up new connection + // setup new connection w.conn.Store(conn) pool.Events <- PoolEvent{ID: w.id, Kind: EventConnected, At: time.Now()} @@ -220,14 +194,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) { select { case <-ctx.Done(): break conn_loop - case <-w.heartbeat: - resetTimer() - case <-timerC(): - if w.logger != nil { - w.logger.Info("keepalive: no activity observed") - } - timer.Reset(w.config.KeepaliveTimeout) - break conn_loop + case data, ok := <-conn.Incoming(): if !ok { if w.logger != nil { @@ -235,20 +202,34 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) { } break conn_loop } + pool.Inbox <- types.InboxMessage{ - ID: w.id, - Data: data, - ReceivedAt: time.Now(), - } - resetTimer() + ID: w.id, Data: data, ReceivedAt: time.Now()} + + pool.InboxCounter.Add(1) + w.processedCount.Add(1) + + heartbeat() + case <-conn.Heartbeat(): if w.logger != nil { w.logger.Debug("ping-pong heartbeat") } - resetTimer() + heartbeat() + + case <-w.sendHeartbeat: + heartbeat() + + case <-timerC(): + if w.logger != nil { + w.logger.Info("keepalive: no activity observed") + } + timer.Reset(w.config.KeepaliveTimeout) + break conn_loop } } + // session ended conn.Close() if w.logger != nil { @@ -272,6 +253,98 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) { } } +func (w *DefaultWorker) setupHeartbeat() ( + timer *time.Timer, timerC func() <-chan time.Time, heartbeat func(), +) { + if w.config.KeepaliveTimeout > 0 { + if w.logger != nil { + w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout) + } + timer = time.NewTimer(w.config.KeepaliveTimeout) + } else { + if w.logger != nil { + w.logger.Debug("keepalive: disabled") + } + } + + heartbeat = func() { + if timer == nil { + return + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(w.config.KeepaliveTimeout) + } + + timerC = func() <-chan time.Time { + if timer == nil { + return nil + } + return timer.C + } + + return +} + +func (w *DefaultWorker) spawnDialer( + ctx context.Context, + dialCancel context.CancelFunc, + newConn chan<- *transport.Connection, + pool PoolPlugin, +) context.CancelFunc { + if dialCancel != nil { + dialCancel() + } + + dialCtx, dialCancel := context.WithCancel(ctx) + + if w.logger != nil { + w.logger.Debug("session: requesting connection") + } + + go func() { + conn, err := connect(w.id, dialCtx, pool, w.handler) + + if err != nil { + if w.logger != nil { + w.logger.Warn("dialer: dial failed", "error", err) + } + return + } + + select { + case newConn <- conn: + case <-dialCtx.Done(): + conn.Close() + } + }() + + return dialCancel +} + +func connect( + id string, + ctx context.Context, + pool PoolPlugin, + handler slog.Handler, +) (*transport.Connection, error) { + conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler) + if err != nil { + return nil, err + } + + conn.SetDialer(pool.Dialer) + return conn, conn.Connect(ctx) +} + +// ---------------------------/ +// Methods +// -------------------------/ + func (w *DefaultWorker) Stop() { if w.logger != nil { w.logger.Debug("shutting down") @@ -291,7 +364,7 @@ func (w *DefaultWorker) Send(data []byte) error { } select { - case w.heartbeat <- struct{}{}: + case w.sendHeartbeat <- struct{}{}: case <-w.ctx.Done(): } @@ -324,18 +397,3 @@ func (w *DefaultWorker) Stats() WorkerStats { TotalSent: w.outgoingCount.Load(), } } - -func connect( - id string, - ctx context.Context, - pool PoolPlugin, - handler slog.Handler, -) (*transport.Connection, error) { - conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler) - if err != nil { - return nil, err - } - - conn.SetDialer(pool.Dialer) - return conn, conn.Connect(ctx) -} diff --git a/worker_send_test.go b/worker_send_test.go deleted file mode 100644 index 91dada4..0000000 --- a/worker_send_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package honeybee - -import ( - "context" - "fmt" - "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - "github.com/stretchr/testify/assert" - "sync/atomic" - "testing" -) - -func TestWorkerSend(t *testing.T) { - t.Run("data sent to mock socket", func(t *testing.T) { - conn, _, _, outgoingData := setupTestConnection(t) - defer conn.Close() - - ctx, cancel := context.WithCancel(context.Background()) - - heartbeat := make(chan struct{}) - heartbeatCount := atomic.Int32{} - - w := &DefaultWorker{ - ctx: ctx, - cancel: cancel, - id: "wss://test", - heartbeat: heartbeat, - outgoingCount: &atomic.Uint64{}, - } - w.conn.Store(conn) - defer w.cancel() - - go func() { - for range heartbeat { - heartbeatCount.Add(1) - } - }() - - testData := []byte("hello") - err := w.Send(testData) - assert.NoError(t, err) - - // at least one heartbeat was sent - honeybeetest.Eventually(t, func() bool { - return heartbeatCount.Load() >= 1 - }, "expected heartbeats") - - // message was sent by the socket - honeybeetest.Eventually(t, func() bool { - select { - case msg := <-outgoingData: - return string(msg.Data) == "hello" - default: - return false - } - }, "expected message") - }) - - t.Run("sends one heartbeat per successful send", func(t *testing.T) { - conn, _, _, _ := setupTestConnection(t) - defer conn.Close() - - ctx, cancel := context.WithCancel(context.Background()) - - heartbeat := make(chan struct{}) - heartbeatCount := atomic.Int32{} - - w := &DefaultWorker{ - ctx: ctx, - cancel: cancel, - id: "wss://test", - heartbeat: heartbeat, - outgoingCount: &atomic.Uint64{}, - } - w.conn.Store(conn) - defer w.cancel() - - go func() { - for range heartbeat { - heartbeatCount.Add(1) - } - }() - - const count = 3 - for i := range count { - err := w.Send(fmt.Appendf(nil, "msg-%d", i)) - assert.NoError(t, err) - } - - honeybeetest.Eventually(t, func() bool { - return heartbeatCount.Load() == count - }, "expected heartbeats") - }) - - t.Run("returns error if connection is unavailable", func(t *testing.T) { - // no connection available to worker - - ctx, cancel := context.WithCancel(context.Background()) - - heartbeat := make(chan struct{}) - - w := &DefaultWorker{ - ctx: ctx, - cancel: cancel, - id: "wss://test", - heartbeat: heartbeat, - } - defer w.cancel() - - go func() { - for range heartbeat { - } - }() - - err := w.Send([]byte("hello")) - assert.ErrorIs(t, err, ErrConnectionUnavailable) - }) -} diff --git a/worker_test.go b/worker_test.go index 64e47b7..8ba2e55 100644 --- a/worker_test.go +++ b/worker_test.go @@ -3,6 +3,7 @@ package honeybee import ( "context" "errors" + "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/transport" "git.wisehodl.dev/jay/go-honeybee/types" @@ -41,7 +42,7 @@ func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *D cancel: cancel, id: "wss://test", config: config, - heartbeat: make(chan struct{}), + sendHeartbeat: make(chan struct{}), processedCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{}, restartCount: &atomic.Uint64{}, @@ -134,7 +135,7 @@ func TestWorkerSession(t *testing.T) { cancel: cancel, id: "wss://test", config: config, - heartbeat: make(chan struct{}), + sendHeartbeat: make(chan struct{}), processedCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{}, restartCount: &atomic.Uint64{}, @@ -303,7 +304,7 @@ func TestWorkerSession(t *testing.T) { cancel: cancel, id: "wss://test", config: config, - heartbeat: make(chan struct{}), + sendHeartbeat: make(chan struct{}), processedCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{}, restartCount: &atomic.Uint64{}, @@ -365,7 +366,7 @@ func TestWorkerSession(t *testing.T) { cancel: cancel, id: "wss://test", config: config, - heartbeat: make(chan struct{}), + sendHeartbeat: make(chan struct{}), processedCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{}, restartCount: &atomic.Uint64{}, @@ -431,7 +432,7 @@ func TestWorkerSession(t *testing.T) { cancel: cancel, id: "wss://test", config: config, - heartbeat: make(chan struct{}), + sendHeartbeat: make(chan struct{}), processedCount: &atomic.Uint64{}, outgoingCount: &atomic.Uint64{}, restartCount: &atomic.Uint64{}, @@ -638,3 +639,110 @@ func TestWorkerSession(t *testing.T) { }, "expected wg to drain after parent cancel") }) } + +func TestWorkerSend(t *testing.T) { + t.Run("data sent to mock socket", func(t *testing.T) { + conn, _, _, outgoingData := setupTestConnection(t) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + heartbeat := make(chan struct{}) + heartbeatCount := atomic.Int32{} + + w := &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + sendHeartbeat: heartbeat, + outgoingCount: &atomic.Uint64{}, + } + w.conn.Store(conn) + defer w.cancel() + + go func() { + for range heartbeat { + heartbeatCount.Add(1) + } + }() + + testData := []byte("hello") + err := w.Send(testData) + assert.NoError(t, err) + + // at least one heartbeat was sent + honeybeetest.Eventually(t, func() bool { + return heartbeatCount.Load() >= 1 + }, "expected heartbeats") + + // message was sent by the socket + honeybeetest.Eventually(t, func() bool { + select { + case msg := <-outgoingData: + return string(msg.Data) == "hello" + default: + return false + } + }, "expected message") + }) + + t.Run("sends one heartbeat per successful send", func(t *testing.T) { + conn, _, _, _ := setupTestConnection(t) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + heartbeat := make(chan struct{}) + heartbeatCount := atomic.Int32{} + + w := &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + sendHeartbeat: heartbeat, + outgoingCount: &atomic.Uint64{}, + } + w.conn.Store(conn) + defer w.cancel() + + go func() { + for range heartbeat { + heartbeatCount.Add(1) + } + }() + + const count = 3 + for i := range count { + err := w.Send(fmt.Appendf(nil, "msg-%d", i)) + assert.NoError(t, err) + } + + honeybeetest.Eventually(t, func() bool { + return heartbeatCount.Load() == count + }, "expected heartbeats") + }) + + t.Run("returns error if connection is unavailable", func(t *testing.T) { + // no connection available to worker + + ctx, cancel := context.WithCancel(context.Background()) + + heartbeat := make(chan struct{}) + + w := &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + sendHeartbeat: heartbeat, + } + defer w.cancel() + + go func() { + for range heartbeat { + } + }() + + err := w.Send([]byte("hello")) + assert.ErrorIs(t, err, ErrConnectionUnavailable) + }) +}