From b128a021de8bc4c8926be0222c1212bfac9d3fb7 Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 15 Apr 2026 12:05:08 -0400 Subject: [PATCH] Refactor async assertions, remove manual sleeps and timeouts. --- connection_close_test.go | 44 +++++--- connection_goroutine_test.go | 189 +++++++++++++++++++++-------------- connection_test.go | 17 ++-- logging_test.go | 59 +++++------ mocks_test.go | 12 +++ 5 files changed, 189 insertions(+), 132 deletions(-) diff --git a/connection_close_test.go b/connection_close_test.go index 8a9fb01..1684627 100644 --- a/connection_close_test.go +++ b/connection_close_test.go @@ -1,8 +1,8 @@ package honeybee import ( + "bytes" "fmt" - "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "testing" @@ -100,8 +100,18 @@ func TestConnectedConnectionClose(t *testing.T) { t.Run("blocked on ReadMessage, unblocks on closed", func(t *testing.T) { conn, _, incomingData, _ := setupTestConnection(t, nil) - // Wait for reader to block - time.Sleep(10 * time.Millisecond) + // Send a message to ensure reader loop is blocking + canary := []byte("canary") + incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: canary} + + assert.Eventually(t, func() bool { + select { + case msg := <-conn.Incoming(): + return bytes.Equal(msg, canary) + default: + return false + } + }, testTimeout, testTick) conn.Close() assert.Equal(t, StateClosed, conn.State()) @@ -123,11 +133,14 @@ func TestConnectedConnectionClose(t *testing.T) { assert.ErrorContains(t, err, "connection closed") // wait for background closures - select { - case <-conn.Errors(): - case <-time.After(500 * time.Millisecond): - t.Fatal("timed out waiting for cleanup") - } + assert.Eventually(t, func() bool { + select { + case <-conn.Errors(): + return true + default: + return false + } + }, testTimeout, testTick) close(outgoingData) }) @@ -143,16 +156,17 @@ func TestConnectedConnectionClose(t *testing.T) { conn.Send([]byte(fmt.Sprintf("out-%d", i))) } - time.Sleep(10 * time.Millisecond) - conn.Close() // wait for background closures - select { - case <-conn.Errors(): - case <-time.After(500 * time.Millisecond): - t.Fatal("timed out waiting for cleanup") - } + assert.Eventually(t, func() bool { + select { + case <-conn.Errors(): + return true + default: + return false + } + }, testTimeout, testTick) close(incomingData) close(outgoingData) diff --git a/connection_goroutine_test.go b/connection_goroutine_test.go index a9e52c0..38d689a 100644 --- a/connection_goroutine_test.go +++ b/connection_goroutine_test.go @@ -1,10 +1,12 @@ package honeybee import ( + "bytes" "fmt" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "io" + "strings" "testing" "time" ) @@ -78,12 +80,15 @@ func TestStartReader(t *testing.T) { assert.NoError(t, err) defer conn.Close() - select { - case <-deadlineCalled: - t.Fatal("SetReadDeadline should not be called when timeout is zero") - case <-time.After(100 * time.Millisecond): - } - + assert.Never(t, func() bool { + select { + case <-deadlineCalled: + return true + default: + return false + } + }, negativeTestTimeout, testTick, + "SetReadDeadline should not be called when timeout is zero") }) t.Run("read timeout sets deadline when positive", func(t *testing.T) { @@ -120,17 +125,24 @@ func TestStartReader(t *testing.T) { incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil} - select { - case <-conn.Incoming(): - case <-time.After(100 * time.Millisecond): - } + assert.Eventually(t, func() bool { + select { + case <-conn.Incoming(): + return true + default: + return false + } + }, testTimeout, testTick) - select { - case _, ok := <-deadlineCalled: - assert.True(t, ok, "SetReadDeadline should be called when timeout is positive") - case <-time.After(100 * time.Millisecond): - t.Fatal("SetReadDeadline was never called") - } + assert.Eventually(t, func() bool { + select { + case <-deadlineCalled: + return true + default: + return false + } + }, testTimeout, testTick, + "SetWriteDeadline should be called when timeout is positive") }) t.Run("reader exits on deadline error", func(t *testing.T) { @@ -153,16 +165,19 @@ func TestStartReader(t *testing.T) { assert.NoError(t, err) defer conn.Close() - select { - case err := <-conn.Errors(): - assert.ErrorContains(t, err, "failed to set read deadline") - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for deadline error") - } - - time.Sleep(10 * time.Millisecond) - assert.Equal(t, StateClosed, conn.State()) + assert.Eventually(t, func() bool { + select { + case err := <-conn.Errors(): + return err != nil && + strings.Contains(err.Error(), "failed to set read deadline") + default: + return false + } + }, testTimeout, testTick) + assert.Eventually(t, func() bool { + return conn.State() == StateClosed + }, testTimeout, testTick) }) t.Run("reader exits on socket read error", func(t *testing.T) { @@ -184,16 +199,18 @@ func TestStartReader(t *testing.T) { assert.NoError(t, err) defer conn.Close() - select { - case err := <-conn.Errors(): - assert.Equal(t, readErr, err) - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for read error") - } - - time.Sleep(10 * time.Millisecond) - assert.Equal(t, StateClosed, conn.State()) + assert.Eventually(t, func() bool { + select { + case err := <-conn.Errors(): + return err == readErr + default: + return false + } + }, testTimeout, testTick) + assert.Eventually(t, func() bool { + return conn.State() == StateClosed + }, testTimeout, testTick) }) } @@ -263,13 +280,15 @@ func TestStartWriter(t *testing.T) { err = conn.Send([]byte("test")) assert.NoError(t, err) - time.Sleep(20 * time.Millisecond) - - select { - case <-deadlineCalled: - t.Fatal("SetWriteDeadline should not be called when timeout is zero") - case <-time.After(100 * time.Millisecond): - } + assert.Never(t, func() bool { + select { + case <-deadlineCalled: + return true + default: + return false + } + }, negativeTestTimeout, testTick, + "SetWriteDeadline should not be called when timeout is zero") }) t.Run("write timeout sets deadline when positive", func(t *testing.T) { @@ -307,14 +326,15 @@ func TestStartWriter(t *testing.T) { err = conn.Send([]byte("test")) assert.NoError(t, err) - time.Sleep(20 * time.Millisecond) - - select { - case _, ok := <-deadlineCalled: - assert.True(t, ok, "SetWriteDeadline should be called when timeout is positive") - case <-time.After(100 * time.Millisecond): - t.Fatal("SetWriteDeadline was never called") - } + assert.Eventually(t, func() bool { + select { + case <-deadlineCalled: + return true + default: + return false + } + }, testTimeout, testTick, + "SetWriteDeadline should be called when timeout is positive") }) t.Run("writer exits on deadline error", func(t *testing.T) { @@ -340,15 +360,19 @@ func TestStartWriter(t *testing.T) { assert.NoError(t, err) defer conn.Close() - select { - case err := <-conn.Errors(): - assert.ErrorContains(t, err, "failed to set write deadline") - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for deadline error") - } + assert.Eventually(t, func() bool { + select { + case err := <-conn.Errors(): + return err != nil && + strings.Contains(err.Error(), "failed to set write deadline") + default: + return false + } + }, testTimeout, testTick) - time.Sleep(10 * time.Millisecond) - assert.Equal(t, StateClosed, conn.State()) + assert.Eventually(t, func() bool { + return conn.State() == StateClosed + }, testTimeout, testTick) }) t.Run("writer exits on socket write error", func(t *testing.T) { @@ -366,15 +390,18 @@ func TestStartWriter(t *testing.T) { err = conn.Send([]byte("test")) assert.NoError(t, err) - select { - case err := <-conn.Errors(): - assert.Equal(t, writeErr, err) - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for write error") - } + assert.Eventually(t, func() bool { + select { + case err := <-conn.Errors(): + return err == writeErr + default: + return false + } + }, testTimeout, testTick) - time.Sleep(10 * time.Millisecond) - assert.Equal(t, StateClosed, conn.State()) + assert.Eventually(t, func() bool { + return conn.State() == StateClosed + }, testTimeout, testTick) }) } @@ -382,23 +409,33 @@ func TestStartWriter(t *testing.T) { func expectIncoming(t *testing.T, conn *Connection, expected []byte) { t.Helper() - - select { - case received := <-conn.Incoming(): - assert.Equal(t, expected, received) - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for message") - } + assert.Eventually(t, func() bool { + select { + case received := <-conn.Incoming(): + return bytes.Equal(received, expected) + default: + return false + } + }, testTimeout, testTick) } func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) { t.Helper() - select { - case call := <-outgoingData: + var call mockOutgoingData + found := assert.Eventually(t, func() bool { + select { + case received := <-outgoingData: + call = received + return true + default: + return false + } + }, testTimeout, testTick) + + if found { + assert.Equal(t, msgType, call.msgType) assert.Equal(t, expected, call.data) - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for write") } } diff --git a/connection_test.go b/connection_test.go index b662f01..ac12b4e 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1,6 +1,7 @@ package honeybee import ( + "bytes" "fmt" "github.com/stretchr/testify/assert" "net/http" @@ -279,14 +280,14 @@ func TestConnect(t *testing.T) { testData := []byte("test") conn.Send(testData) - time.Sleep(10 * time.Millisecond) - - select { - case msg := <-outgoingData: - assert.Equal(t, testData, msg.data) - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for message write") - } + assert.Eventually(t, func() bool { + select { + case msg := <-outgoingData: + return bytes.Equal(msg.data, testData) + default: + return false + } + }, testTimeout, testTick) conn.Close() close(outgoingData) diff --git a/logging_test.go b/logging_test.go index 7225359..68056e6 100644 --- a/logging_test.go +++ b/logging_test.go @@ -57,25 +57,6 @@ func assertLogSequence(t *testing.T, records []slog.Record, expected []expectedL } } -func assertLogRecord( - t *testing.T, - records []slog.Record, - level slog.Level, - msgSnippet string, - expectedAttrs ...slog.Attr, -) { - t.Helper() - - record := findLogRecord(records, level, msgSnippet) - if record == nil { - t.Fatalf("no log record found with level %v and message containing %q", level, msgSnippet) - } - - for _, expectedAttr := range expectedAttrs { - assertAttributePresent(t, *record, expectedAttr.Key, expectedAttr.Value.Any()) - } -} - func findLogRecord( records []slog.Record, level slog.Level, msgSnippet string, ) *slog.Record { @@ -287,7 +268,11 @@ func TestCloseLogging(t *testing.T) { conn.Close() - time.Sleep(10 * time.Millisecond) + assert.Eventually(t, func() bool { + return findLogRecord( + mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil + }, testTimeout, testTick) + records := mockHandler.GetRecords() expected := []expectedLog{ @@ -313,7 +298,11 @@ func TestCloseLogging(t *testing.T) { conn.Close() - time.Sleep(10 * time.Millisecond) + assert.Eventually(t, func() bool { + return findLogRecord( + mockHandler.GetRecords(), slog.LevelError, "socket close failed") != nil + }, testTimeout, testTick) + records := mockHandler.GetRecords() expected := []expectedLog{ @@ -341,12 +330,13 @@ func TestReaderLogging(t *testing.T) { conn, err := NewConnectionFromSocket(mockSocket, config, logger) assert.NoError(t, err) - time.Sleep(50 * time.Millisecond) + assert.Eventually(t, func() bool { + return findLogRecord( + mockHandler.GetRecords(), slog.LevelError, "read deadline error") != nil + }, testTimeout, testTick) records := mockHandler.GetRecords() - assertLogRecord(t, records, slog.LevelError, "read deadline error") - record := findLogRecord(records, slog.LevelError, "read deadline error") assert.NotNil(t, record) assertAttributePresent(t, *record, "error", deadlineErr) @@ -367,12 +357,13 @@ func TestReaderLogging(t *testing.T) { conn, err := NewConnectionFromSocket(mockSocket, nil, logger) assert.NoError(t, err) - time.Sleep(50 * time.Millisecond) + assert.Eventually(t, func() bool { + return findLogRecord( + mockHandler.GetRecords(), slog.LevelError, "read error") != nil + }, testTimeout, testTick) records := mockHandler.GetRecords() - assertLogRecord(t, records, slog.LevelError, "read error") - record := findLogRecord(records, slog.LevelError, "read error") assert.NotNil(t, record) assertAttributePresent(t, *record, "error", readErr) @@ -400,12 +391,13 @@ func TestWriterLogging(t *testing.T) { err = conn.Send([]byte("test")) assert.NoError(t, err) - time.Sleep(50 * time.Millisecond) + assert.Eventually(t, func() bool { + return findLogRecord( + mockHandler.GetRecords(), slog.LevelError, "write deadline error") != nil + }, testTimeout, testTick) records := mockHandler.GetRecords() - assertLogRecord(t, records, slog.LevelError, "write deadline error") - record := findLogRecord(records, slog.LevelError, "write deadline error") assert.NotNil(t, record) assertAttributePresent(t, *record, "error", deadlineErr) @@ -429,12 +421,13 @@ func TestWriterLogging(t *testing.T) { err = conn.Send([]byte("test")) assert.NoError(t, err) - time.Sleep(50 * time.Millisecond) + assert.Eventually(t, func() bool { + return findLogRecord( + mockHandler.GetRecords(), slog.LevelError, "write error") != nil + }, testTimeout, testTick) records := mockHandler.GetRecords() - assertLogRecord(t, records, slog.LevelError, "write error") - record := findLogRecord(records, slog.LevelError, "write error") assert.NotNil(t, record) assertAttributePresent(t, *record, "error", writeErr) diff --git a/mocks_test.go b/mocks_test.go index d8569ac..5b522cd 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -12,6 +12,14 @@ import ( "time" ) +// Test Constants + +const ( + testTimeout = 2 * time.Second + testTick = 10 * time.Millisecond + negativeTestTimeout = 100 * time.Millisecond +) + // Dialer Mocks type MockDialer struct { @@ -33,6 +41,7 @@ type MockSocket struct { SetCloseHandlerFunc func(func(int, string) error) closed chan struct{} once sync.Once + mu sync.Mutex } func NewMockSocket() *MockSocket { @@ -119,6 +128,9 @@ func setupTestConnection(t *testing.T, config *Config) ( // Wire WriteMessage to push to outgoingData channel mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + mockSocket.mu.Lock() + defer mockSocket.mu.Unlock() + select { case <-mockSocket.closed: return io.EOF