From 45b1a31cbbb42b6c9bdbe9143afc074bd42addf6 Mon Sep 17 00:00:00 2001 From: Jay Date: Sun, 19 Apr 2026 09:58:41 -0400 Subject: [PATCH] Wrote session dial tests. Use new helpers. --- honeybeetest/helpers.go | 10 ++ initiatorpool/pool_test.go | 10 +- initiatorpool/worker_test.go | 179 ++++++++++++++++++------- transport/connection_close_test.go | 4 +- transport/connection_goroutine_test.go | 4 +- transport/connection_send_test.go | 19 ++- transport/connection_test.go | 8 +- transport/logging_test.go | 28 ++-- transport/socket_test.go | 4 +- 9 files changed, 175 insertions(+), 91 deletions(-) diff --git a/honeybeetest/helpers.go b/honeybeetest/helpers.go index 4caa402..994aa1c 100644 --- a/honeybeetest/helpers.go +++ b/honeybeetest/helpers.go @@ -62,3 +62,13 @@ func ExpectWrite(t *testing.T, outgoingData chan MockOutgoingData, msgType int, assert.Equal(t, expected, call.Data) } } + +func Eventually(t *testing.T, condition func() bool, msg string) { + t.Helper() + assert.Eventually(t, condition, TestTimeout, TestTick, msg) +} + +func Never(t *testing.T, condition func() bool, msg string) { + t.Helper() + assert.Never(t, condition, NegativeTestTimeout, TestTick, msg) +} diff --git a/initiatorpool/pool_test.go b/initiatorpool/pool_test.go index b887a0b..f6b18da 100644 --- a/initiatorpool/pool_test.go +++ b/initiatorpool/pool_test.go @@ -31,14 +31,14 @@ func _TestPoolConnect(t *testing.T) { err = pool.Connect("wss://test") assert.NoError(t, err) - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case event := <-pool.events: return event.ID == "wss://test" && event.Kind == EventConnected default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected event") _, exists := pool.peers["wss://test"] assert.True(t, exists) @@ -214,14 +214,12 @@ func expectEvent( expectedKind PoolEventKind, ) { t.Helper() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case e := <-events: return e.ID == expectedURL && e.Kind == expectedKind default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick, - fmt.Sprintf("expected event: URL=%q, Kind=%q", - expectedURL, expectedKind)) + }, fmt.Sprintf("expected event: URL=%q, Kind=%q", expectedURL, expectedKind)) } diff --git a/initiatorpool/worker_test.go b/initiatorpool/worker_test.go index bd7f612..f1ddefb 100644 --- a/initiatorpool/worker_test.go +++ b/initiatorpool/worker_test.go @@ -16,7 +16,86 @@ import ( ) func TestRunSessionDial(t *testing.T) { + setup := func(t *testing.T) ( + w *Worker, + ctx context.Context, + cancel context.CancelFunc, + dial chan struct{}, + keepalive chan struct{}, + newConn chan *transport.Connection, + ) { + t.Helper() + ctx, cancel = context.WithCancel(context.Background()) + w = &Worker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: GetDefaultWorkerConfig(), + heartbeat: make(chan struct{}), + } + dial = make(chan struct{}, 1) + keepalive = make(chan struct{}, 1) + newConn = make(chan *transport.Connection, 1) + return + } + expectDial := func(t *testing.T, dial <-chan struct{}) { + t.Helper() + honeybeetest.Eventually(t, func() bool { + select { + case <-dial: + return true + default: + return false + } + }, "expected dial signal") + } + + t.Run("fires dial immediately on entry", func(t *testing.T) { + w, ctx, cancel, dial, keepalive, newConn := setup(t) + defer cancel() + + messages := make(chan receivedMessage, 1) + wctx := WorkerContext{Events: make(chan PoolEvent, 10)} + + go w.runSession(ctx, wctx, messages, dial, keepalive, newConn) + + expectDial(t, dial) + }) + + t.Run("keepalive fires dial", func(t *testing.T) { + w, ctx, cancel, dial, keepalive, newConn := setup(t) + defer cancel() + + messages := make(chan receivedMessage, 1) + wctx := WorkerContext{Events: make(chan PoolEvent, 10)} + + go w.runSession(ctx, wctx, messages, dial, keepalive, newConn) + + // drain initial dial + expectDial(t, dial) + + keepalive <- struct{}{} + expectDial(t, dial) + }) + + t.Run("multiple keepalive signals each fire dial", func(t *testing.T) { + w, ctx, cancel, dial, keepalive, newConn := setup(t) + defer cancel() + + messages := make(chan receivedMessage, 1) + wctx := WorkerContext{Events: make(chan PoolEvent, 10)} + + go w.runSession(ctx, wctx, messages, dial, keepalive, newConn) + + // drain initial dial + expectDial(t, dial) + + for i := 0; i < 3; i++ { + keepalive <- struct{}{} + expectDial(t, dial) + } + }) } func TestRunReader(t *testing.T) { @@ -49,14 +128,14 @@ func TestRunReader(t *testing.T) { Data: []byte("hello"), } - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case msg := <-messages: return string(msg.data) == "hello" && msg.receivedAt.After(before) default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected message") }) t.Run("heartbeat receives one signal per message", func(t *testing.T) { @@ -97,9 +176,9 @@ func TestRunReader(t *testing.T) { } } - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return received.Load() == count - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, fmt.Sprintf("expected %d messages", count)) }) t.Run("incoming channel close calls conn.Close and onStop", func(t *testing.T) { @@ -133,13 +212,13 @@ func TestRunReader(t *testing.T) { err := <-conn.Errors() assert.Equal(t, io.EOF, err) - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return conn.State() == transport.StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected closed state") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return onStopCalled.Load() - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected onStop to be called") }) t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { @@ -161,13 +240,13 @@ func TestRunReader(t *testing.T) { close(sessionDone) - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return conn.State() == transport.StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected closed state") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return onStopCalled.Load() - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected onStop to be called") }) } @@ -187,13 +266,13 @@ func TestRunStopMonitor(t *testing.T) { keepalive <- struct{}{} - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return conn.State() == transport.StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected closed state") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return onStopCalled.Load() - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected onStop to be called") }) t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) { @@ -210,13 +289,13 @@ func TestRunStopMonitor(t *testing.T) { cancel() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return conn.State() == transport.StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected closed state") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return onStopCalled.Load() - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected onStop to be called") }) t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { @@ -234,13 +313,13 @@ func TestRunStopMonitor(t *testing.T) { close(sessionDone) - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return conn.State() == transport.StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected closed state") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return onStopCalled.Load() - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected onStop to be called") }) } @@ -256,14 +335,14 @@ func TestRunForwarder(t *testing.T) { messages <- receivedMessage{data: []byte("hello"), receivedAt: time.Now()} - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case msg := <-inbox: return string(msg.Data) == "hello" && msg.ID == "wss://test" default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected message") }) t.Run("oldest message dropped when queue is full", func(t *testing.T) { @@ -299,14 +378,14 @@ func TestRunForwarder(t *testing.T) { // receive messages from the inbox var received []string - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case msg := <-inbox: received = append(received, string(msg.Data)) default: } return len(received) == 2 - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected messages") // first message was dropped assert.Equal(t, []string{"second", "third"}, received) @@ -327,14 +406,14 @@ func TestRunForwarder(t *testing.T) { }() cancel() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-done: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected done signal") }) } @@ -358,14 +437,14 @@ func TestRunKeepalive(t *testing.T) { } // because the timer is being reset, keepalive signal should not be sent - assert.Never(t, func() bool { + honeybeetest.Never(t, func() bool { select { case <-keepalive: return true default: return false } - }, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick) + }, "unexpected keepalive signal") }) t.Run("keepalive timeout fires signal", func(t *testing.T) { @@ -377,14 +456,14 @@ func TestRunKeepalive(t *testing.T) { go w.runKeepalive(ctx, keepalive) // send no heartbeats, wait for timeout and keepalive signal - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-keepalive: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected keepalive signal") }) t.Run("exits on context cancellation", func(t *testing.T) { @@ -399,14 +478,14 @@ func TestRunKeepalive(t *testing.T) { }() cancel() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-done: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected done signal") }) } @@ -431,14 +510,14 @@ func TestRunDialer(t *testing.T) { go w.runDialer(ctx, wctx, dial, newConn) dial <- struct{}{} - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-newConn: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected new connection") }) t.Run("concurrent dial signals are drained; only one connection produced.", @@ -483,14 +562,14 @@ func TestRunDialer(t *testing.T) { close(gate) // connection is cleared to connect - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-newConn: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected new connection") // connection was only dialed once assert.Equal(t, int32(1), dialCount.Load()) @@ -535,25 +614,25 @@ func TestRunDialer(t *testing.T) { go w.runDialer(ctx, wctx, dial, newConn) dial <- struct{}{} - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case err := <-errors: return err != nil default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected error") dial <- struct{}{} - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-newConn: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected new connection") }) t.Run("exits on context cancellation", func(t *testing.T) { @@ -572,14 +651,14 @@ func TestRunDialer(t *testing.T) { cancel() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-done: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected done signal") }) t.Run("context cancelled during in-progress dial exits without delivering connection", func(t *testing.T) { @@ -614,14 +693,14 @@ func TestRunDialer(t *testing.T) { time.Sleep(20 * time.Millisecond) cancel() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-done: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected done signal") // no connection was sent assert.Empty(t, newConn) @@ -661,14 +740,14 @@ func TestWorkerSend(t *testing.T) { assert.Equal(t, 1, int(heartbeatCount.Load())) // message was sent by the socket - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case msg := <-outgoingData: return string(msg.Data) == "hello" default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected message") }) t.Run("sends one heartbeat per successful send", func(t *testing.T) { diff --git a/transport/connection_close_test.go b/transport/connection_close_test.go index 241f75c..e6e5405 100644 --- a/transport/connection_close_test.go +++ b/transport/connection_close_test.go @@ -87,14 +87,14 @@ func TestConnectedConnectionClose(t *testing.T) { incomingData <- honeybeetest.MockIncomingData{ MsgType: websocket.TextMessage, Data: canary} - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case msg := <-conn.Incoming(): return bytes.Equal(msg, canary) default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected canary message") conn.Close() assert.Equal(t, StateClosed, conn.State()) diff --git a/transport/connection_goroutine_test.go b/transport/connection_goroutine_test.go index 8787a93..ad696d5 100644 --- a/transport/connection_goroutine_test.go +++ b/transport/connection_goroutine_test.go @@ -69,8 +69,8 @@ func TestStartReader(t *testing.T) { conn, err := NewConnectionFromSocket(mockSocket, nil, nil) assert.NoError(t, err) - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return conn.State() == StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected closed state") }) } diff --git a/transport/connection_send_test.go b/transport/connection_send_test.go index 6a17e88..7cf64f1 100644 --- a/transport/connection_send_test.go +++ b/transport/connection_send_test.go @@ -83,12 +83,11 @@ func TestConnectionSend(t *testing.T) { wg.Wait() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { mu.Lock() defer mu.Unlock() return len(messages) == 50 - }, honeybeetest.TestTimeout, honeybeetest.TestTick, - "should have received 50 messages") + }, "should have received 50 messages") }) @@ -137,15 +136,14 @@ func TestConnectionSend(t *testing.T) { err = conn.Send([]byte("test")) assert.NoError(t, err) - assert.Never(t, func() bool { + honeybeetest.Never(t, func() bool { select { case <-deadlineCalled: return true default: return false } - }, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick, - "SetWriteDeadline should not be called when timeout is zero") + }, "SetWriteDeadline should not be called when timeout is zero") }) t.Run("write timeout sets deadline when positive", func(t *testing.T) { @@ -184,15 +182,14 @@ func TestConnectionSend(t *testing.T) { err = conn.Send([]byte("test")) assert.NoError(t, err) - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case <-deadlineCalled: return true default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick, - "SetWriteDeadline should be called when timeout is positive") + }, "SetWriteDeadline should be called when timeout is positive") }) t.Run("send fails on deadline error", func(t *testing.T) { @@ -218,9 +215,9 @@ func TestConnectionSend(t *testing.T) { err = conn.Send([]byte("test")) assert.ErrorContains(t, err, "failed to set write deadline: test error") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return conn.State() == StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected closed state") }) t.Run("send fails on socket write error", func(t *testing.T) { diff --git a/transport/connection_test.go b/transport/connection_test.go index 25d0ad9..b0fb660 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -283,14 +283,14 @@ func TestConnect(t *testing.T) { testData := []byte("test") conn.Send(testData) - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { select { case msg := <-outgoingData: return bytes.Equal(msg.Data, testData) default: return false } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected message") conn.Close() }) @@ -433,9 +433,9 @@ func TestConnectContextCancellation(t *testing.T) { }() // wait for first dial - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return dialCount.Load() >= 1 - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected dial") cancel() select { diff --git a/transport/logging_test.go b/transport/logging_test.go index 9b92830..fecdf1e 100644 --- a/transport/logging_test.go +++ b/transport/logging_test.go @@ -273,10 +273,10 @@ func TestCloseLogging(t *testing.T) { conn.Close() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected log") records := mockHandler.GetRecords() @@ -303,10 +303,10 @@ func TestCloseLogging(t *testing.T) { conn.Close() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "socket close failed") != nil - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected log") records := mockHandler.GetRecords() @@ -336,10 +336,10 @@ func TestReaderLogging(t *testing.T) { assert.NoError(t, err) defer conn.Close() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") != nil - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected log") record := findLogRecord(mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") assert.NotNil(t, record) @@ -364,10 +364,10 @@ func TestReaderLogging(t *testing.T) { assert.NoError(t, err) defer conn.Close() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "unexpected close") != nil - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected log") record := findLogRecord(mockHandler.GetRecords(), slog.LevelError, "unexpected close") assert.NotNil(t, record) @@ -389,10 +389,10 @@ func TestReaderLogging(t *testing.T) { assert.NoError(t, err) defer conn.Close() - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "read error") != nil - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected log") }) } @@ -415,10 +415,10 @@ func TestWriterLogging(t *testing.T) { err = conn.Send([]byte("test")) assert.ErrorContains(t, err, "failed to set write deadline: deadline error") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "write deadline error") != nil - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected log") records := mockHandler.GetRecords() @@ -445,10 +445,10 @@ func TestWriterLogging(t *testing.T) { err = conn.Send([]byte("test")) assert.ErrorContains(t, err, "write error") - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "write error") != nil - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected log") records := mockHandler.GetRecords() diff --git a/transport/socket_test.go b/transport/socket_test.go index 1c5dc95..d2130a6 100644 --- a/transport/socket_test.go +++ b/transport/socket_test.go @@ -200,9 +200,9 @@ func TestAcquireSocketContextCancellation(t *testing.T) { }() // wait for first two dials to complete, then cancel during sleep - assert.Eventually(t, func() bool { + honeybeetest.Eventually(t, func() bool { return dialCount.Load() > 1 - }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }, "expected dials") cancel() select {