From cda6d286ab062accd3e07885127b7cad28700d88 Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 20 May 2026 14:01:01 -0400 Subject: [PATCH] refactor(worker): collapse session goroutines into single runSession loop Replace the five-goroutine session model (RunDialer, RunKeepalive, RunReader, RunHeartbeatForwarder, RunStopMonitor, Session) with a single DefaultWorker.runSession method containing two select loops: one pre-connection and one connected. Ephemeral dial goroutines replace RunDialer; the keepalive timer and heartbeat reset are inlined. No exported building-block symbols remain. Consolidate worker_dialer_test.go, worker_session_test.go, and worker_start_test.go into worker_test.go. Add seven new behavioral tests covering dial failure, keepalive-driven dial replacement, pre-connection stop, message delivery with timestamp, sustained activity and pong resetting the keepalive timer, keepalive-triggered reconnect, and nil connection pointer after disconnect. Update EXTEND.md and README.md to remove references to the deleted building blocks and document the single worker replacement pattern --- AGENTS.md | 4 + EXTEND.md | 22 +- README.md | 4 +- worker.go | 512 +++++++++------------------- worker_dialer_test.go | 220 ------------ worker_keepalive_test.go | 99 ------ worker_session_inner_test.go | 229 ------------- worker_session_test.go | 441 ------------------------ worker_start_test.go | 299 ---------------- worker_test.go | 640 +++++++++++++++++++++++++++++++++++ 10 files changed, 811 insertions(+), 1659 deletions(-) create mode 100644 AGENTS.md delete mode 100644 worker_dialer_test.go delete mode 100644 worker_keepalive_test.go delete mode 100644 worker_session_inner_test.go delete mode 100644 worker_session_test.go delete mode 100644 worker_start_test.go create mode 100644 worker_test.go diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..04b80f6 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,4 @@ +# go-honeybee + +## Build +- Run `go fmt` on every edited file before staging. diff --git a/EXTEND.md b/EXTEND.md index f6cc64f..804e655 100644 --- a/EXTEND.md +++ b/EXTEND.md @@ -65,27 +65,11 @@ The pool calls the factory under its write lock when `Connect` is called. The fa The factory is set via `honeybee.WithWorkerFactory` on the pool config. -### Building Blocks +### Replacing the Worker -**`RunDialer(id, ctx, pool, dial, newConn, logger)`** Listens on `dial` for connection requests. On each signal, calls `connect` to dial a new `*transport.Connection`. While a dial is in progress, drains additional `dial` signals so that at most one dial runs at a time. On failure, logs the error and waits for the next `dial` signal. On success, sends the connection on `newConn`. Exits when `ctx` is cancelled. +Satisfy the `Worker` interface and register your implementation via `honeybee.WithWorkerFactory`. Your worker is responsible for the full connection lifecycle: dialing and redialing, managing connection state, forwarding received messages to `pool.Inbox`, emitting `EventConnected` and `EventDisconnected` to `pool.Events`, and incrementing `pool.InboxCounter` for each message forwarded. -**`RunKeepalive(ctx, heartbeat, keepalive, timeout, logger)`** Monitors `heartbeat`. Resets a timer on each signal. When the timer fires, sends a signal on `keepalive` to notify the session that the connection should be replaced. When `timeout` is zero, keepalive is disabled: it drains `heartbeat` without acting and exits when `ctx` is cancelled. - -**`RunReader(id, ctx, onStop, conn, inbox, heartbeat, logger)`** Reads from `conn.Incoming()` until the channel closes or `ctx` is cancelled. Builds an `InboxMessage` inline with the peer ID and writes it directly to `inbox`. Sends a signal on `heartbeat` for each message. On exit, calls `conn.Close()` and then `onStop`. - -**`RunHeartbeatForwarder(ctx, conn, heartbeat, logger)`** Reads from `conn.Heartbeat()` and forwards each signal to `heartbeat`. Propagates pong replies into the worker's heartbeat channel so pongs reset the keepalive timer alongside data messages and sends. - -**`RunStopMonitor(ctx, onStop, conn, keepalive, logger)`** Waits for either `ctx.Done` or a signal on `keepalive`. On either, calls `conn.Close()` and then `onStop`. This is how a keepalive expiry propagates into a session tear-down. - -**`Session`** The coordination struct that ties the above blocks together for one connection lifecycle. `Session.Start` runs a loop: request a dial, wait for a connection, run `RunReader`, `RunHeartbeatForwarder`, and `RunStopMonitor` concurrently, wait for them to finish, emit `EventDisconnected`, sleep for `ReconnectDelay`, then repeat. `Session` is exported so it can be embedded or used directly in a custom worker. - -### Replacement Patterns - -**Swap one block.** The most common case is replacing `RunReader` to intercept or annotate inbound messages, or replacing `RunKeepalive` with a different activity metric. Reuse `Session` for the connection lifecycle and substitute the one goroutine you need to change. - -**Replace the session loop.** Construct your own loop using `RunDialer`, `RunKeepalive`, and the session-level blocks. This gives you control over reconnection logic, back-off behavior, or multi-connection strategies while keeping the lower-level I/O blocks intact. - -**Implement from scratch.** Satisfy the `Worker` interface directly. You are responsible for dialing, managing connection state, forwarding messages to `pool.Inbox`, emitting `honeybee.EventConnected` and `honeybee.EventDisconnected` to `pool.Events`, and incrementing `pool.InboxCounter`. +`DefaultWorker`'s source is the authoritative reference for how those responsibilities are met. ## Factory Constraints diff --git a/README.md b/README.md index 75f1ff4..b7b76a5 100644 --- a/README.md +++ b/README.md @@ -266,9 +266,9 @@ connStats := conn.Stats() // conn is a *transport.Connection ## Extending Pools -The pool owns peer registration, event plumbing, and lifecycle. The worker owns what happens on the wire. The default worker can be replaced entirely or composed from the exported `Run*` building blocks that Honeybee provides. +The pool owns peer registration, event plumbing, and lifecycle. The worker owns what happens on the wire. The default worker can be replaced entirely via `WorkerFactory`. -See EXTEND.md for the worker interface contract, the `PoolPlugin` fields, and the available building blocks for the pool worker. +See EXTEND.md for the worker interface contract, the `PoolPlugin` fields, and extension patterns. ## Configuration diff --git a/worker.go b/worker.go index bffc816..31c4b5a 100644 --- a/worker.go +++ b/worker.go @@ -97,39 +97,10 @@ func (w *DefaultWorker) Start(pool PoolPlugin) { w.logger.Debug("starting") } - dial := make(chan struct{}, 1) - newConn := make(chan *transport.Connection, 1) - keepalive := make(chan struct{}, 1) - var wg sync.WaitGroup - wg.Add(3) - - go func() { - defer wg.Done() - RunDialer(w.id, w.ctx, pool, dial, newConn, w.handler, w.logger) - }() - - go func() { - defer wg.Done() - RunKeepalive(w.ctx, w.heartbeat, keepalive, w.config.KeepaliveTimeout, w.logger) - }() - - go func() { - defer wg.Done() - session := &Session{ - id: w.id, - connPtr: &w.conn, - poolInbox: pool.Inbox, - heartbeat: w.heartbeat, - dial: dial, - keepalive: keepalive, - newConn: newConn, - reconnectDelay: w.config.ReconnectDelay, - restartCount: w.restartCount, - logger: w.logger, - } - session.Start(w.ctx, pool) - }() + wg.Go(func() { + w.runSession(w.ctx, pool) + }) if w.logger != nil { w.logger.Info("started") @@ -142,6 +113,165 @@ func (w *DefaultWorker) Start(pool PoolPlugin) { } } +func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) { + newConn := make(chan *transport.Connection, 1) + + var timer *time.Timer + if w.config.KeepaliveTimeout > 0 { + if w.logger != nil { + w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout) + } + timer = time.NewTimer(w.config.KeepaliveTimeout) + defer timer.Stop() + } else { + if w.logger != nil { + w.logger.Debug("keepalive: disabled") + } + } + + resetTimer := func() { + if timer == nil { + return + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(w.config.KeepaliveTimeout) + } + + timerC := func() <-chan time.Time { + if timer == nil { + return nil + } + return timer.C + } + + var dialCancel context.CancelFunc + + spawnDial := func() { + if dialCancel != nil { + dialCancel() + } + var dialCtx context.Context + dialCtx, dialCancel = context.WithCancel(ctx) + if w.logger != nil { + w.logger.Debug("session: requesting connection") + } + go func() { + conn, err := connect(w.id, dialCtx, pool, w.handler) + if err != nil { + if w.logger != nil { + w.logger.Warn("dialer: dial failed") + } + return + } + select { + case newConn <- conn: + case <-dialCtx.Done(): + conn.Close() + } + }() + } + + for { + // spawn initial dial for this reconnect cycle + spawnDial() + + // obtain new connection + var conn *transport.Connection + preConn: + for { + select { + case <-ctx.Done(): + if dialCancel != nil { + dialCancel() + } + return + case <-w.heartbeat: + resetTimer() + case <-timerC(): + if w.logger != nil { + w.logger.Info("keepalive: no activity observed") + } + timer.Reset(w.config.KeepaliveTimeout) + spawnDial() + case conn = <-newConn: + if w.logger != nil { + w.logger.Debug("session: connected") + } + break preConn + } + } + + // set up new connection + w.conn.Store(conn) + pool.Events <- PoolEvent{ID: w.id, Kind: EventConnected, At: time.Now()} + + if w.logger != nil { + w.logger.Info("session: started") + } + + // run session loop + conn_loop: + for { + select { + case <-ctx.Done(): + break conn_loop + case <-w.heartbeat: + resetTimer() + case <-timerC(): + if w.logger != nil { + w.logger.Info("keepalive: no activity observed") + } + timer.Reset(w.config.KeepaliveTimeout) + break conn_loop + case data, ok := <-conn.Incoming(): + if !ok { + if w.logger != nil { + w.logger.Debug("reader: disconnected") + } + break conn_loop + } + pool.Inbox <- types.InboxMessage{ + ID: w.id, + Data: data, + ReceivedAt: time.Now(), + } + resetTimer() + case <-conn.Heartbeat(): + if w.logger != nil { + w.logger.Debug("ping-pong heartbeat") + } + resetTimer() + } + } + + conn.Close() + + if w.logger != nil { + w.logger.Info("session: ended") + } + + // tear down connection + w.conn.Store(nil) + pool.Events <- PoolEvent{ID: w.id, Kind: EventDisconnected, At: time.Now()} + + // exit if worker is shutting down + select { + case <-ctx.Done(): + return + default: + } + + // refresh session + time.Sleep(w.config.ReconnectDelay) + w.restartCount.Add(1) + } +} + func (w *DefaultWorker) Stop() { if w.logger != nil { w.logger.Debug("shutting down") @@ -195,269 +325,6 @@ func (w *DefaultWorker) Stats() WorkerStats { } } -type Session struct { - id string - connPtr *atomic.Pointer[transport.Connection] - - poolInbox chan<- types.InboxMessage - heartbeat chan<- struct{} - dial chan<- struct{} - - keepalive <-chan struct{} - newConn <-chan *transport.Connection - - reconnectDelay time.Duration - restartCount *atomic.Uint64 - - logger *slog.Logger -} - -func (s *Session) Start( - ctx context.Context, - pool PoolPlugin, -) { - for { - if s.logger != nil { - s.logger.Debug("session: requesting connection") - } - - // request new connection - select { - case s.dial <- struct{}{}: - default: - } - - // obtain new connection - var conn *transport.Connection - preConn: - for { - select { - case <-ctx.Done(): - return - case <-s.keepalive: - select { - case s.dial <- struct{}{}: - if s.logger != nil { - s.logger.Debug("session: requesting connection") - } - - default: - } - case conn = <-s.newConn: - if s.logger != nil { - s.logger.Debug("session: connected") - } - break preConn - } - } - - // set up new connection - s.connPtr.Store(conn) - pool.Events <- PoolEvent{ID: s.id, Kind: EventConnected, At: time.Now()} - - // set up session context - sctx, scancel := context.WithCancel(ctx) - onStop := func() { scancel() } - - // start session - var wg sync.WaitGroup - wg.Add(3) - go func() { - defer wg.Done() - RunReader(s.id, sctx, onStop, conn, s.poolInbox, s.heartbeat, s.logger) - }() - go func() { - defer wg.Done() - RunHeartbeatForwarder(sctx, conn, s.heartbeat, s.logger) - }() - go func() { - defer wg.Done() - RunStopMonitor(sctx, onStop, conn, s.keepalive, s.logger) - }() - - if s.logger != nil { - s.logger.Info("session: started") - } - - // complete session - wg.Wait() - - if s.logger != nil { - s.logger.Info("session: ended") - } - - // tear down connection - s.connPtr.Store(nil) - pool.Events <- PoolEvent{ID: s.id, Kind: EventDisconnected, At: time.Now()} - - // exit if worker is shutting down - select { - case <-ctx.Done(): - return - default: - } - - // refresh session - time.Sleep(s.reconnectDelay) - s.restartCount.Add(1) - } - -} - -func RunReader( - id string, - ctx context.Context, - onStop func(), - conn *transport.Connection, - poolInbox chan<- types.InboxMessage, - heartbeat chan<- struct{}, - logger *slog.Logger, -) { - defer func() { - if logger != nil { - logger.Debug("reader: stopping") - } - - conn.Close() - onStop() - }() - - for { - select { - case <-ctx.Done(): - return - case data, ok := <-conn.Incoming(): - if !ok { - // connection has closed - if logger != nil { - logger.Debug("reader: disconnected") - } - return - } - - // send message forward - poolInbox <- types.InboxMessage{ - ID: id, - Data: data, - ReceivedAt: time.Now(), - } - - // send heartbeat - select { - case heartbeat <- struct{}{}: - case <-ctx.Done(): - return - } - } - } -} - -func RunHeartbeatForwarder( - ctx context.Context, - conn *transport.Connection, - heartbeat chan<- struct{}, - logger *slog.Logger, -) { - for { - select { - case <-ctx.Done(): - return - case <-conn.Heartbeat(): - select { - case heartbeat <- struct{}{}: - if logger != nil { - logger.Debug("ping-pong heartbeat") - } - case <-ctx.Done(): - return - } - } - } -} - -func RunStopMonitor( - ctx context.Context, - onStop func(), - conn *transport.Connection, - keepalive <-chan struct{}, - logger *slog.Logger, -) { - defer func() { - if logger != nil { - logger.Debug("stop monitor: stopping") - } - - conn.Close() - onStop() - }() - - select { - case <-ctx.Done(): - case <-keepalive: - if logger != nil { - logger.Debug("stop monitor: stopping: keepalive") - } - } -} - -func RunKeepalive( - ctx context.Context, - heartbeat <-chan struct{}, - keepalive chan<- struct{}, - timeout time.Duration, - logger *slog.Logger, -) { - // disable keepalive timeout if not configured - if timeout <= 0 { - if logger != nil { - logger.Debug("keepalive: disabled") - } - // drain heartbeats - // wait for cancel and exit - for { - select { - case <-heartbeat: - case <-ctx.Done(): - return - } - } - } - - if logger != nil { - logger.Debug("keepalive: enabled", "timeout", timeout) - } - - timer := time.NewTimer(timeout) - defer timer.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-heartbeat: - // drain the timer channel and reset - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timer.Reset(timeout) - // timer completed - case <-timer.C: - // send keepalive signal, then reset the timer - if logger != nil { - logger.Info("keepalive: no activity observed") - } - select { - case keepalive <- struct{}{}: - default: - } - timer.Reset(timeout) - } - } -} - func connect( id string, ctx context.Context, @@ -472,58 +339,3 @@ func connect( conn.SetDialer(pool.Dialer) return conn, conn.Connect(ctx) } - -func RunDialer( - id string, - ctx context.Context, - pool PoolPlugin, - - dial <-chan struct{}, - newConn chan<- *transport.Connection, - - handler slog.Handler, - logger *slog.Logger, -) { - for { - select { - case <-ctx.Done(): - return - case <-dial: - if logger != nil { - logger.Debug("dialer: dialing") - } - // dial a new connection - conn, err := connect(id, ctx, pool, handler) - - // send error if dial failed and continue - if err != nil { - if logger != nil { - logger.Warn("dialer: dial failed") - } - continue - } - - if logger != nil { - logger.Debug("dialer: connected") - } - - // drain any redundant signals that arrived during the dial - for { - select { - case <-dial: - default: - goto drained - } - } - drained: - - // send the new connection or close and exit - select { - case newConn <- conn: - case <-ctx.Done(): - conn.Close() - return - } - } - } -} diff --git a/worker_dialer_test.go b/worker_dialer_test.go deleted file mode 100644 index 428b4fe..0000000 --- a/worker_dialer_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package honeybee - -import ( - "context" - "fmt" - "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - "git.wisehodl.dev/jay/go-honeybee/transport" - "git.wisehodl.dev/jay/go-honeybee/types" - "github.com/stretchr/testify/assert" - "net/http" - "sync" - "sync/atomic" - "testing" - "time" -) - -func TestRunDialer(t *testing.T) { - t.Run("successful dial delivers connection to newConn", func(t *testing.T) { - url := "wss://test" - dial := make(chan struct{}, 1) - newConn := make(chan *transport.Connection, 1) - ctx := t.Context() - - mockSocket := honeybeetest.NewMockSocket() - pool := PoolPlugin{ - Dialer: &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return mockSocket, nil, nil - }, - }, - } - - go RunDialer(url, ctx, pool, dial, newConn, nil, nil) - dial <- struct{}{} - - honeybeetest.Eventually(t, func() bool { - select { - case <-newConn: - return true - default: - return false - } - }, "expected new connection") - }) - - t.Run("concurrent dial signals are drained; only one connection produced.", - func(t *testing.T) { - url := "wss://test" - dial := make(chan struct{}, 1) - newConn := make(chan *transport.Connection, 1) - ctx := t.Context() - - gate := make(chan struct{}) - dialCount := atomic.Int32{} - - mockSocket := honeybeetest.NewMockSocket() - connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry - started := make(chan struct{}) - startOnce := sync.Once{} - pool := PoolPlugin{ - Dialer: &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - dialCount.Add(1) - startOnce.Do(func() { close(started) }) - <-gate - return mockSocket, nil, nil - }, - }, - ConnectionConfig: connConfig, - } - - go RunDialer(url, ctx, pool, dial, newConn, nil, nil) - dial <- struct{}{} - - // wait for dial to start blocking on gate - <-started - - // flood dial while dialer is blocked - for range 5 { - select { - case dial <- struct{}{}: - default: - } - } - - close(gate) - - // connection is cleared to connect - honeybeetest.Eventually(t, func() bool { - select { - case <-newConn: - return true - default: - return false - } - }, "expected new connection") - - // number of dials < number of dial requests - honeybeetest.Never(t, func() bool { - return dialCount.Load() >= 5 - }, "expected fewer dials than requests") - - // dial channel still writable - select { - case dial <- struct{}{}: - default: - t.Fatal("dial channel should still accept sends") - } - }) - - t.Run("dial failure emits error, succeeds on next signal", func(t *testing.T) { - url := "wss://test" - dial := make(chan struct{}, 1) - newConn := make(chan *transport.Connection, 1) - ctx := t.Context() - - // use atomic counter to fail first dial and pass second - dialCount := atomic.Int32{} - mockSocket := honeybeetest.NewMockSocket() - connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry - pool := PoolPlugin{ - Dialer: &honeybeetest.MockDialer{ - DialContextFunc: func( - context.Context, string, http.Header, - ) (types.Socket, *http.Response, error) { - if dialCount.Add(1) == 1 { - // fail first - return nil, nil, fmt.Errorf("dial failed") - } - // pass second - return mockSocket, nil, nil - }, - }, - ConnectionConfig: connConfig, - } - - go RunDialer(url, ctx, pool, dial, newConn, nil, nil) - dial <- struct{}{} - dial <- struct{}{} - - honeybeetest.Eventually(t, func() bool { - select { - case <-newConn: - return true - default: - return false - } - }, "expected new connection") - }) - - t.Run("exits on context cancellation", func(t *testing.T) { - url := "wss://test" - dial := make(chan struct{}, 1) - newConn := make(chan *transport.Connection, 1) - ctx, cancel := context.WithCancel(context.Background()) - - pool := PoolPlugin{} - - done := make(chan struct{}) - go func() { - RunDialer(url, ctx, pool, dial, newConn, nil, nil) - close(done) - }() - - cancel() - - honeybeetest.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, "expected done signal") - }) - - t.Run("context cancelled during in-progress dial exits without delivering connection", func(t *testing.T) { - url := "wss://test" - dial := make(chan struct{}, 1) - newConn := make(chan *transport.Connection, 1) - ctx, cancel := context.WithCancel(context.Background()) - - pool := PoolPlugin{ - ConnectionConfig: &transport.ConnectionConfig{Retry: nil}, - Dialer: &honeybeetest.MockDialer{ - DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { - // block until context is cancelled - select { - case <-ctx.Done(): - return nil, nil, ctx.Err() - } - }, - }, - } - - done := make(chan struct{}) - go func() { - RunDialer(url, ctx, pool, dial, newConn, nil, nil) - close(done) - }() - - dial <- struct{}{} - - // wait for dialer to block - time.Sleep(20 * time.Millisecond) - cancel() - - honeybeetest.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, "expected done signal") - - // no connection was sent - assert.Empty(t, newConn) - }) -} diff --git a/worker_keepalive_test.go b/worker_keepalive_test.go deleted file mode 100644 index 225d31d..0000000 --- a/worker_keepalive_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package honeybee - -import ( - "context" - "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - "testing" - "time" -) - -func TestRunKeepalive(t *testing.T) { - t.Run("heartbeat resets timer, no keepalive signal fired", func(t *testing.T) { - heartbeat := make(chan struct{}) - keepalive := make(chan struct{}, 1) - timeout := 200 * time.Millisecond - ctx := t.Context() - - go RunKeepalive(ctx, heartbeat, keepalive, timeout, nil) - - // send heartbeats faster than the timeout - for range 5 { - time.Sleep(20 * time.Millisecond) - heartbeat <- struct{}{} - } - - // because the timer is being reset, keepalive signal should not be sent - honeybeetest.Never(t, func() bool { - select { - case <-keepalive: - return true - default: - return false - } - }, "unexpected keepalive signal") - }) - - t.Run("keepalive timeout fires signal", func(t *testing.T) { - heartbeat := make(chan struct{}, 1) - keepalive := make(chan struct{}, 1) - timeout := 20 * time.Millisecond - ctx := t.Context() - - go RunKeepalive(ctx, heartbeat, keepalive, timeout, nil) - - // send no heartbeats, wait for timeout and keepalive signal - honeybeetest.Eventually(t, func() bool { - select { - case <-keepalive: - return true - default: - return false - } - }, "expected keepalive signal") - }) - - t.Run("exits on context cancellation", func(t *testing.T) { - heartbeat := make(chan struct{}, 1) - keepalive := make(chan struct{}, 1) - timeout := 20 * time.Second - ctx, cancel := context.WithCancel(context.Background()) - - done := make(chan struct{}) - go func() { - RunKeepalive(ctx, heartbeat, keepalive, timeout, nil) - close(done) - }() - - cancel() - honeybeetest.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, "expected done signal") - }) - - t.Run("disabled keepalive drains heartbeats without blocking", func(t *testing.T) { - heartbeat := make(chan struct{}) - keepalive := make(chan struct{}, 1) - ctx := t.Context() - - go RunKeepalive(ctx, heartbeat, keepalive, 0, nil) - - // these must not block - for range 5 { - heartbeat <- struct{}{} - } - - honeybeetest.Never(t, func() bool { - select { - case <-keepalive: - return true - default: - return false - } - }, "keepalive signal should not fire when disabled") - }) -} diff --git a/worker_session_inner_test.go b/worker_session_inner_test.go deleted file mode 100644 index 5a71b11..0000000 --- a/worker_session_inner_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package honeybee - -import ( - "context" - "fmt" - "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - "git.wisehodl.dev/jay/go-honeybee/transport" - "git.wisehodl.dev/jay/go-honeybee/types" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "io" - "sync/atomic" - "testing" - "time" -) - -func TestRunReader(t *testing.T) { - t.Run("message arrives with correct data and non-zero receivedAt", func(t *testing.T) { - conn, _, incomingData, _ := setupTestConnection(t) - defer conn.Close() - - inbox := make(chan types.InboxMessage, 1) - heartbeat := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for range heartbeat { - } - }() - go RunReader("wss://test", ctx, cancel, conn, inbox, heartbeat, nil) - - before := time.Now() - incomingData <- honeybeetest.MockIncomingData{ - MsgType: websocket.TextMessage, - Data: []byte("hello"), - } - - honeybeetest.Eventually(t, func() bool { - select { - case msg := <-inbox: - return string(msg.Data) == "hello" && msg.ReceivedAt.After(before) - default: - return false - } - }, "expected message") - }) - - t.Run("heartbeat receives one signal per message", func(t *testing.T) { - conn, _, incomingData, _ := setupTestConnection(t) - defer conn.Close() - - inbox := make(chan types.InboxMessage, 10) - heartbeat := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - received := atomic.Int32{} - go func() { - for range heartbeat { - received.Add(1) - } - }() - go func() { - for range inbox { - } - }() - go RunReader("wss://test", ctx, cancel, conn, inbox, heartbeat, nil) - - const count = 3 - for i := range count { - incomingData <- honeybeetest.MockIncomingData{ - MsgType: websocket.TextMessage, - Data: fmt.Appendf(nil, "msg-%d", i), - } - } - - honeybeetest.Eventually(t, func() bool { - return received.Load() == count - }, fmt.Sprintf("expected %d messages", count)) - }) - - t.Run("incoming channel close calls conn.Close and onStop", func(t *testing.T) { - conn, _, incomingData, _ := setupTestConnection(t) - - inbox := make(chan types.InboxMessage, 1) - heartbeat := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for range heartbeat { - } - }() - go func() { - for range inbox { - } - }() - go RunReader("wss://test", ctx, cancel, conn, inbox, heartbeat, nil) - - // induce connection closure via reader - incomingData <- honeybeetest.MockIncomingData{Err: io.EOF} - - err := <-conn.Errors() - assert.ErrorIs(t, err, io.EOF) - - honeybeetest.Eventually(t, func() bool { - return conn.State() == transport.StateClosed - }, "expected closed state") - - honeybeetest.Eventually(t, func() bool { - select { - case <-ctx.Done(): - return true - default: - return false - } - }, "expected context to cancel") - }) - - t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { - conn, _, _, _ := setupTestConnection(t) - - inbox := make(chan types.InboxMessage, 1) - heartbeat := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - - go RunReader("wss://test", ctx, cancel, conn, inbox, heartbeat, nil) - - cancel() - - honeybeetest.Eventually(t, func() bool { - return conn.State() == transport.StateClosed - }, "expected closed state") - - honeybeetest.Eventually(t, func() bool { - select { - case <-ctx.Done(): - return true - default: - return false - } - }, "expected context to cancel") - }) -} - -func TestHeartbeatForwarder(t *testing.T) { - t.Run("connection level heartbeat propagates", func(t *testing.T) { - socket, _, _ := honeybeetest.SetupTestSocket(t) - var pongHandler func(string) error - socket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h } - - conn, err := transport.NewConnectionFromSocket(context.Background(), socket, nil, nil) - assert.NoError(t, err) - - heartbeat := make(chan struct{}, 1) - ctx := t.Context() - - go RunHeartbeatForwarder(ctx, conn, heartbeat, nil) - - honeybeetest.Eventually(t, func() bool { - return pongHandler != nil - }, "expected Connection to register PongHandler") - - if pongHandler == nil { - t.Fatal("pong handler was never set") - } - - pongHandler("") // Trigger pong - - select { - case <-heartbeat: - case <-time.After(time.Second): - t.Fatal("pong did not propagate to worker heartbeat") - } - }) -} - -func TestRunStopMonitor(t *testing.T) { - t.Run("keepalive signal calls conn.Close and cancel", func(t *testing.T) { - conn, _, _, _ := setupTestConnection(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - keepalive := make(chan struct{}, 1) - - go RunStopMonitor(ctx, cancel, conn, keepalive, nil) - - keepalive <- struct{}{} - - honeybeetest.Eventually(t, func() bool { - return conn.State() == transport.StateClosed - }, "expected closed state") - - honeybeetest.Eventually(t, func() bool { - select { - case <-ctx.Done(): - return true - default: - return false - } - }, "expected context to cancel") - }) - - t.Run("ctx.Done calls conn.Close and cancel", func(t *testing.T) { - conn, _, _, _ := setupTestConnection(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - keepalive := make(chan struct{}) - - go RunStopMonitor(ctx, cancel, conn, keepalive, nil) - - cancel() - - honeybeetest.Eventually(t, func() bool { - return conn.State() == transport.StateClosed - }, "expected closed state") - - honeybeetest.Eventually(t, func() bool { - select { - case <-ctx.Done(): - return true - default: - return false - } - }, "expected context to cancel") - }) -} diff --git a/worker_session_test.go b/worker_session_test.go deleted file mode 100644 index b5c7192..0000000 --- a/worker_session_test.go +++ /dev/null @@ -1,441 +0,0 @@ -package honeybee - -import ( - "context" - "fmt" - "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - "git.wisehodl.dev/jay/go-honeybee/transport" - "git.wisehodl.dev/jay/go-honeybee/types" - "sync/atomic" - "testing" -) - -func drainEvent(t *testing.T, events <-chan PoolEvent, kind PoolEventKind) { - t.Helper() - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == kind - default: - return false - } - }, fmt.Sprintf("expected %s event", kind)) -} - -type testVars struct { - id string - - dial chan struct{} - keepalive chan struct{} - heartbeat chan struct{} - newConn chan *transport.Connection - poolInbox chan types.InboxMessage - - conn *transport.Connection - mockSocket *honeybeetest.MockSocket - incomingData chan honeybeetest.MockIncomingData - outgoingData chan honeybeetest.MockOutgoingData - - connPtr *atomic.Pointer[transport.Connection] -} - -func setup(t *testing.T) ( - ctx context.Context, - cancel context.CancelFunc, - vars testVars, -) { - t.Helper() - ctx, cancel = context.WithCancel(context.Background()) - conn, mockSocket, incomingData, outgoingData := setupTestConnection(t) - vars = testVars{ - id: "wss://test", - dial: make(chan struct{}, 1), - keepalive: make(chan struct{}, 1), - heartbeat: make(chan struct{}, 1), - newConn: make(chan *transport.Connection, 1), - poolInbox: make(chan types.InboxMessage, 256), - conn: conn, - mockSocket: mockSocket, - incomingData: incomingData, - outgoingData: outgoingData, - connPtr: &atomic.Pointer[transport.Connection]{}, - } - return -} - -func expectDial(t *testing.T, dial <-chan struct{}) { - t.Helper() - honeybeetest.Eventually(t, func() bool { - select { - case <-dial: - return true - default: - return false - } - }, "expected dial signal") -} - -func TestRunSessionDial(t *testing.T) { - t.Run("fires dial immediately on entry", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - pool := PoolPlugin{Events: make(chan PoolEvent, 10)} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - go session.Start(ctx, pool) - - expectDial(t, v.dial) - }) - - t.Run("keepalive fires dial", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - pool := PoolPlugin{Events: make(chan PoolEvent, 10)} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - go session.Start(ctx, pool) - - // drain initial dial - expectDial(t, v.dial) - - v.keepalive <- struct{}{} - expectDial(t, v.dial) - }) - - t.Run("multiple keepalive signals each fire dial", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - pool := PoolPlugin{Events: make(chan PoolEvent, 10)} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - go session.Start(ctx, pool) - - // drain initial dial - expectDial(t, v.dial) - - for range 3 { - v.keepalive <- struct{}{} - expectDial(t, v.dial) - } - }) -} - -func TestRunSessionConnect(t *testing.T) { - t.Run("connection pointer set after newConn received", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - pool := PoolPlugin{Events: make(chan PoolEvent, 10)} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - go session.Start(ctx, pool) - - v.newConn <- v.conn - - honeybeetest.Eventually(t, func() bool { - return v.connPtr.Load() != nil - }, "expected connection pointer to be set") - }) - - t.Run("EventConnected emitted", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - go session.Start(ctx, pool) - - v.newConn <- v.conn - - honeybeetest.Eventually(t, func() bool { - select { - case event := <-events: - return event.ID == v.id && event.Kind == EventConnected - default: - return false - } - }, "expected EventConnected") - }) -} - -func TestRunSessionDisconnect(t *testing.T) { - t.Run("EventDisconnected emitted on connection close", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - restartCount: &atomic.Uint64{}, - } - - go session.Start(ctx, pool) - - v.newConn <- v.conn - drainEvent(t, events, EventConnected) - - close(v.incomingData) - - drainEvent(t, events, EventDisconnected) - }) - - t.Run("connection pointer cleared after disconnect", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - restartCount: &atomic.Uint64{}, - } - - go session.Start(ctx, pool) - - v.newConn <- v.conn - drainEvent(t, events, EventConnected) - - close(v.incomingData) - drainEvent(t, events, EventDisconnected) - - honeybeetest.Eventually(t, func() bool { - return v.connPtr.Load() == nil - }, "expected connection pointer to be nil") - }) - - t.Run("dial fires again after disconnect", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - restartCount: &atomic.Uint64{}, - } - - go session.Start(ctx, pool) - - v.newConn <- v.conn - drainEvent(t, events, EventConnected) - - // drain the initial dial signal before disconnecting - <-v.dial - - close(v.incomingData) - drainEvent(t, events, EventDisconnected) - - honeybeetest.Eventually(t, func() bool { - select { - case <-v.dial: - return true - default: - return false - } - }, "expected dial signal after disconnect") - }) - - t.Run("second connection cycle emits EventConnected", func(t *testing.T) { - ctx, cancel, v := setup(t) - defer cancel() - - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - restartCount: &atomic.Uint64{}, - } - - go session.Start(ctx, pool) - - v.newConn <- v.conn - drainEvent(t, events, EventConnected) - - close(v.incomingData) - drainEvent(t, events, EventDisconnected) - - conn2, _, _, _ := setupTestConnection(t) - v.newConn <- conn2 - drainEvent(t, events, EventConnected) - }) -} - -func TestRunSessionCancellation(t *testing.T) { - t.Run("ctx cancelled pre-connection exits without emitting events", func(t *testing.T) { - ctx, cancel, v := setup(t) - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - done := make(chan struct{}) - go func() { - defer close(done) - session.Start(ctx, pool) - }() - - cancel() - - honeybeetest.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, "expected runSession to exit") - - honeybeetest.Never(t, func() bool { - select { - case <-events: - return true - default: - return false - } - }, "expected no events emitted") - }) - - t.Run("ctx cancelled post-connection emits EventDisconnected", func(t *testing.T) { - ctx, cancel, v := setup(t) - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - done := make(chan struct{}) - go func() { - defer close(done) - session.Start(ctx, pool) - }() - - v.newConn <- v.conn - drainEvent(t, events, EventConnected) - - cancel() - drainEvent(t, events, EventDisconnected) - - honeybeetest.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, "expected runSession to exit") - }) - - t.Run("ctx cancelled post-connection clears connection pointer", func(t *testing.T) { - ctx, cancel, v := setup(t) - events := make(chan PoolEvent, 10) - pool := PoolPlugin{Events: events} - session := &Session{ - id: v.id, - connPtr: v.connPtr, - poolInbox: v.poolInbox, - heartbeat: v.heartbeat, - dial: v.dial, - keepalive: v.keepalive, - newConn: v.newConn, - } - - done := make(chan struct{}) - go func() { - defer close(done) - session.Start(ctx, pool) - }() - - v.newConn <- v.conn - drainEvent(t, events, EventConnected) - - cancel() - drainEvent(t, events, EventDisconnected) - - honeybeetest.Eventually(t, func() bool { - return v.connPtr.Load() == nil - }, "expected connection pointer to be nil") - }) -} diff --git a/worker_start_test.go b/worker_start_test.go deleted file mode 100644 index e598f57..0000000 --- a/worker_start_test.go +++ /dev/null @@ -1,299 +0,0 @@ -package honeybee - -import ( - "context" - "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - "git.wisehodl.dev/jay/go-honeybee/types" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "net/http" - "sync" - "sync/atomic" - "testing" - "time" -) - -func makeWorkerContext(t *testing.T) ( - inbox chan types.InboxMessage, - events chan PoolEvent, - pool PoolPlugin, -) { - t.Helper() - inbox = make(chan types.InboxMessage, 256) - events = make(chan PoolEvent, 10) - pool = PoolPlugin{ - Inbox: inbox, - Events: events, - InboxCounter: &atomic.Uint64{}, - } - return -} - -func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *DefaultWorker { - t.Helper() - config, _ := NewWorkerConfig( - WithReconnectDelay(0 * time.Second), - ) - return &DefaultWorker{ - ctx: ctx, - cancel: cancel, - id: "wss://test", - config: config, - heartbeat: make(chan struct{}), - processedCount: &atomic.Uint64{}, - outgoingCount: &atomic.Uint64{}, - restartCount: &atomic.Uint64{}, - } -} - -func mockDialer(socket *honeybeetest.MockSocket) *honeybeetest.MockDialer { - return &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return socket, nil, nil - }, - } -} - -func TestWorkerStart(t *testing.T) { - t.Run("EventConnected emitted after dial succeeds", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - w := makeWorker(t, ctx, cancel) - _, events, pool := makeWorkerContext(t) - mockSocket := honeybeetest.NewMockSocket() - pool.Dialer = mockDialer(mockSocket) - - var wg sync.WaitGroup - wg.Go(func() { - w.Start(pool) - }) - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.ID == w.id && e.Kind == EventConnected - default: - return false - } - }, "expected EventConnected") - }) - - t.Run("Send delivers data to socket", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - w := makeWorker(t, ctx, cancel) - _, events, pool := makeWorkerContext(t) - _, mockSocket, _, outgoingData := setupTestConnection(t) - pool.Dialer = mockDialer(mockSocket) - - var wg sync.WaitGroup - wg.Go(func() { - w.Start(pool) - }) - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventConnected - default: - return false - } - }, "expected EventConnected") - - err := w.Send([]byte("hello")) - assert.NoError(t, err) - - honeybeetest.Eventually(t, func() bool { - select { - case msg := <-outgoingData: - return string(msg.Data) == "hello" - default: - return false - } - }, "expected data on socket") - }) - - t.Run("socket data arrives on Inbox", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - w := makeWorker(t, ctx, cancel) - inbox, events, pool := makeWorkerContext(t) - - incomingData := make(chan honeybeetest.MockIncomingData, 10) - mockSocket := honeybeetest.NewMockSocket() - - mockSocket.CloseFunc = func() error { - mockSocket.Once.Do(func() { close(mockSocket.Closed) }) - return nil - } - - mockSocket.ReadMessageFunc = func() (int, []byte, error) { - select { - case data := <-incomingData: - return data.MsgType, data.Data, data.Err - } - } - - pool.Dialer = mockDialer(mockSocket) - - var wg sync.WaitGroup - wg.Go(func() { - w.Start(pool) - }) - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventConnected - default: - return false - } - }, "expected EventConnected") - - incomingData <- honeybeetest.MockIncomingData{ - MsgType: websocket.TextMessage, - Data: []byte("hello"), - } - - honeybeetest.Eventually(t, func() bool { - select { - case msg := <-inbox: - return msg.ID == w.id && string(msg.Data) == "hello" - default: - return false - } - }, "expected message on Inbox") - }) - - t.Run("socket close produces EventDisconnected then EventConnected", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - w := makeWorker(t, ctx, cancel) - _, events, pool := makeWorkerContext(t) - _, mockSocket, incomingData, _ := setupTestConnection(t) - pool.Dialer = mockDialer(mockSocket) - - var wg sync.WaitGroup - wg.Go(func() { - w.Start(pool) - }) - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventConnected - default: - return false - } - }, "expected EventConnected") - - close(incomingData) - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventDisconnected - default: - return false - } - }, "expected EventDisconnected") - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventConnected - default: - return false - } - }, "expected second EventConnected") - }) - - t.Run("Stop produces EventDisconnected and wg drains", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - w := makeWorker(t, ctx, cancel) - _, events, pool := makeWorkerContext(t) - mockSocket := honeybeetest.NewMockSocket() - pool.Dialer = mockDialer(mockSocket) - - var wg sync.WaitGroup - wg.Go(func() { - w.Start(pool) - }) - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventConnected - default: - return false - } - }, "expected EventConnected") - - w.Stop() - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventDisconnected - default: - return false - } - }, "expected EventDisconnected") - - done := make(chan struct{}) - go func() { wg.Wait(); close(done) }() - honeybeetest.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, "expected wg to drain") - }) - - t.Run("parent context cancel exits cleanly and wg drains", func(t *testing.T) { - parentCtx, parentCancel := context.WithCancel(context.Background()) - workerCtx, workerCancel := context.WithCancel(parentCtx) - - w := makeWorker(t, workerCtx, workerCancel) - _, events, pool := makeWorkerContext(t) - mockSocket := honeybeetest.NewMockSocket() - pool.Dialer = mockDialer(mockSocket) - - var wg sync.WaitGroup - wg.Go(func() { - w.Start(pool) - }) - - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.Kind == EventConnected - default: - return false - } - }, "expected EventConnected") - - // drain events after parent cancel — we don't assert what they are, - // only that the worker exits - parentCancel() - - done := make(chan struct{}) - go func() { wg.Wait(); close(done) }() - honeybeetest.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, "expected wg to drain after parent cancel") - }) -} diff --git a/worker_test.go b/worker_test.go new file mode 100644 index 0000000..64e47b7 --- /dev/null +++ b/worker_test.go @@ -0,0 +1,640 @@ +package honeybee + +import ( + "context" + "errors" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" +) + +func makeWorkerContext(t *testing.T) ( + inbox chan types.InboxMessage, + events chan PoolEvent, + pool PoolPlugin, +) { + t.Helper() + inbox = make(chan types.InboxMessage, 256) + events = make(chan PoolEvent, 10) + pool = PoolPlugin{ + Inbox: inbox, + Events: events, + InboxCounter: &atomic.Uint64{}, + } + return +} + +func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *DefaultWorker { + t.Helper() + config, _ := NewWorkerConfig( + WithReconnectDelay(0 * time.Second), + ) + return &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: config, + heartbeat: make(chan struct{}), + processedCount: &atomic.Uint64{}, + outgoingCount: &atomic.Uint64{}, + restartCount: &atomic.Uint64{}, + } +} + +func mockDialer(socket *honeybeetest.MockSocket) *honeybeetest.MockDialer { + return &honeybeetest.MockDialer{ + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { + return socket, nil, nil + }, + } +} + +func TestWorkerSession(t *testing.T) { + t.Run("EventConnected emitted after dial succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, pool := makeWorkerContext(t) + mockSocket := honeybeetest.NewMockSocket() + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { + w.Start(pool) + }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.ID == w.id && e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + }) + + t.Run("dial failure exhausted - session stays alive, no events emitted", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, pool := makeWorkerContext(t) + + pool.Dialer = &honeybeetest.MockDialer{ + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { + return nil, nil, errors.New("connection refused") + }, + } + cc, _ := transport.NewConnectionConfig(transport.WithoutRetry()) + pool.ConnectionConfig = cc + + var wg sync.WaitGroup + wg.Go(func() { w.Start(pool) }) + + honeybeetest.Never(t, func() bool { + select { + case <-events: + return true + default: + return false + } + }, "expected no events when dial fails") + + // worker goroutine is still running + assert.False(t, func() bool { + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + return true + case <-time.After(50 * time.Millisecond): + return false + } + }(), "expected worker to still be running after dial failure") + }) + + t.Run("keepalive fires before connection - dial is cancelled and replaced", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + config, _ := NewWorkerConfig( + WithReconnectDelay(0), + WithKeepaliveTimeout(20*time.Millisecond), + ) + w := &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: config, + heartbeat: make(chan struct{}), + processedCount: &atomic.Uint64{}, + outgoingCount: &atomic.Uint64{}, + restartCount: &atomic.Uint64{}, + } + _, _, pool := makeWorkerContext(t) + + var dialCount atomic.Uint64 + pool.Dialer = &honeybeetest.MockDialer{ + DialContextFunc: func(dialCtx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { + dialCount.Add(1) + <-dialCtx.Done() + return nil, nil, dialCtx.Err() + }, + } + cc, _ := transport.NewConnectionConfig(transport.WithoutRetry()) + pool.ConnectionConfig = cc + + var wg sync.WaitGroup + wg.Go(func() { w.Start(pool) }) + + // keepalive fires after 20ms; a second dial goroutine must be spawned + honeybeetest.Eventually(t, func() bool { + return dialCount.Load() >= 2 + }, "expected at least two dial attempts after keepalive fired") + }) + + t.Run("Stop before connection established - exits cleanly, no events", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, pool := makeWorkerContext(t) + + pool.Dialer = &honeybeetest.MockDialer{ + DialContextFunc: func(dialCtx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { + <-dialCtx.Done() + return nil, nil, dialCtx.Err() + }, + } + cc, _ := transport.NewConnectionConfig(transport.WithoutRetry()) + pool.ConnectionConfig = cc + + var wg sync.WaitGroup + wg.Go(func() { w.Start(pool) }) + + w.Stop() + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + honeybeetest.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, "expected Start to return after Stop") + + assert.Empty(t, events, "expected no events when stopped before connection") + }) + + t.Run("Send delivers data to socket", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, pool := makeWorkerContext(t) + _, mockSocket, _, outgoingData := setupTestConnection(t) + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { + w.Start(pool) + }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + err := w.Send([]byte("hello")) + assert.NoError(t, err) + + honeybeetest.Eventually(t, func() bool { + select { + case msg := <-outgoingData: + return string(msg.Data) == "hello" + default: + return false + } + }, "expected data on socket") + }) + + t.Run("socket data arrives on Inbox", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + inbox, events, pool := makeWorkerContext(t) + + incomingData := make(chan honeybeetest.MockIncomingData, 10) + mockSocket := honeybeetest.NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.Once.Do(func() { close(mockSocket.Closed) }) + return nil + } + + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + select { + case data := <-incomingData: + return data.MsgType, data.Data, data.Err + } + } + + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { + w.Start(pool) + }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + incomingData <- honeybeetest.MockIncomingData{ + MsgType: websocket.TextMessage, + Data: []byte("hello"), + } + + var received types.InboxMessage + honeybeetest.Eventually(t, func() bool { + select { + case msg := <-inbox: + received = msg + return true + default: + return false + } + }, "expected message on Inbox") + assert.Equal(t, w.id, received.ID) + assert.Equal(t, []byte("hello"), received.Data) + assert.False(t, received.ReceivedAt.IsZero(), "expected non-zero ReceivedAt") + }) + + t.Run("sustained incoming messages reset keepalive - no disconnect", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + config, _ := NewWorkerConfig( + WithReconnectDelay(0), + WithKeepaliveTimeout(60*time.Millisecond), + ) + w := &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: config, + heartbeat: make(chan struct{}), + processedCount: &atomic.Uint64{}, + outgoingCount: &atomic.Uint64{}, + restartCount: &atomic.Uint64{}, + } + _, events, pool := makeWorkerContext(t) + _, mockSocket, incomingData, _ := setupTestConnection(t) + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { w.Start(pool) }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + // send messages every 20ms for 100ms — well within the 60ms timeout each cycle + go func() { + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + select { + case incomingData <- honeybeetest.MockIncomingData{MsgType: websocket.TextMessage, Data: []byte("ping")}: + case <-ctx.Done(): + return + } + case <-ctx.Done(): + return + } + } + }() + + honeybeetest.Never(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected no EventDisconnected while messages are arriving") + }) + + t.Run("pong heartbeat resets keepalive - no disconnect", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + config, _ := NewWorkerConfig( + WithReconnectDelay(0), + WithKeepaliveTimeout(60*time.Millisecond), + ) + w := &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: config, + heartbeat: make(chan struct{}), + processedCount: &atomic.Uint64{}, + outgoingCount: &atomic.Uint64{}, + restartCount: &atomic.Uint64{}, + } + _, events, pool := makeWorkerContext(t) + + // socket whose pong handler fires every 20ms; no incoming messages + var pongHandler func(string) error + mockSocket, incomingData, _ := honeybeetest.SetupTestSocket(t) + mockSocket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h } + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { w.Start(pool) }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + // fire pong every 20ms — well within the 60ms keepalive window + go func() { + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if pongHandler != nil { + _ = pongHandler("") + } + case <-ctx.Done(): + return + } + } + }() + + honeybeetest.Never(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected no EventDisconnected while pongs are arriving") + + _ = incomingData // kept open to prevent reader EOF + }) + + t.Run("keepalive fires while connected - EventDisconnected emitted and redial begins", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + config, _ := NewWorkerConfig( + WithReconnectDelay(0), + WithKeepaliveTimeout(30*time.Millisecond), + ) + w := &DefaultWorker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: config, + heartbeat: make(chan struct{}), + processedCount: &atomic.Uint64{}, + outgoingCount: &atomic.Uint64{}, + restartCount: &atomic.Uint64{}, + } + _, events, pool := makeWorkerContext(t) + _, mockSocket, _, _ := setupTestConnection(t) + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { w.Start(pool) }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + // no activity — keepalive fires after 30ms + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected EventDisconnected after keepalive timeout") + + // session must redial — a second EventConnected follows + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected after redial") + }) + + t.Run("socket close produces EventDisconnected then EventConnected", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, pool := makeWorkerContext(t) + _, mockSocket, incomingData, _ := setupTestConnection(t) + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { + w.Start(pool) + }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + close(incomingData) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected EventDisconnected") + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected second EventConnected") + }) + + t.Run("connection pointer is nil between disconnect and reconnect", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, pool := makeWorkerContext(t) + _, mockSocket, incomingData, _ := setupTestConnection(t) + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { w.Start(pool) }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + close(incomingData) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected EventDisconnected") + + // conn.Store(nil) happens before EventDisconnected is sent + assert.Nil(t, w.conn.Load(), "expected connection pointer to be nil after disconnect") + }) + + t.Run("Stop produces EventDisconnected and wg drains", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, pool := makeWorkerContext(t) + mockSocket := honeybeetest.NewMockSocket() + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { + w.Start(pool) + }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + w.Stop() + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected EventDisconnected") + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + honeybeetest.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, "expected wg to drain") + }) + + t.Run("parent context cancel exits cleanly and wg drains", func(t *testing.T) { + parentCtx, parentCancel := context.WithCancel(context.Background()) + workerCtx, workerCancel := context.WithCancel(parentCtx) + + w := makeWorker(t, workerCtx, workerCancel) + _, events, pool := makeWorkerContext(t) + mockSocket := honeybeetest.NewMockSocket() + pool.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Go(func() { + w.Start(pool) + }) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + // drain events after parent cancel — we don't assert what they are, + // only that the worker exits + parentCancel() + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + honeybeetest.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, "expected wg to drain after parent cancel") + }) +}