Refactor async assertions, remove manual sleeps and timeouts.

This commit is contained in:
Jay
2026-04-15 12:05:08 -04:00
parent fdae43e715
commit b128a021de
5 changed files with 189 additions and 132 deletions

View File

@@ -1,8 +1,8 @@
package honeybee package honeybee
import ( import (
"bytes"
"fmt" "fmt"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
@@ -100,8 +100,18 @@ func TestConnectedConnectionClose(t *testing.T) {
t.Run("blocked on ReadMessage, unblocks on closed", func(t *testing.T) { t.Run("blocked on ReadMessage, unblocks on closed", func(t *testing.T) {
conn, _, incomingData, _ := setupTestConnection(t, nil) conn, _, incomingData, _ := setupTestConnection(t, nil)
// Wait for reader to block // Send a message to ensure reader loop is blocking
time.Sleep(10 * time.Millisecond) canary := []byte("canary")
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: canary}
assert.Eventually(t, func() bool {
select {
case msg := <-conn.Incoming():
return bytes.Equal(msg, canary)
default:
return false
}
}, testTimeout, testTick)
conn.Close() conn.Close()
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
@@ -123,11 +133,14 @@ func TestConnectedConnectionClose(t *testing.T) {
assert.ErrorContains(t, err, "connection closed") assert.ErrorContains(t, err, "connection closed")
// wait for background closures // wait for background closures
select { assert.Eventually(t, func() bool {
case <-conn.Errors(): select {
case <-time.After(500 * time.Millisecond): case <-conn.Errors():
t.Fatal("timed out waiting for cleanup") return true
} default:
return false
}
}, testTimeout, testTick)
close(outgoingData) close(outgoingData)
}) })
@@ -143,16 +156,17 @@ func TestConnectedConnectionClose(t *testing.T) {
conn.Send([]byte(fmt.Sprintf("out-%d", i))) conn.Send([]byte(fmt.Sprintf("out-%d", i)))
} }
time.Sleep(10 * time.Millisecond)
conn.Close() conn.Close()
// wait for background closures // wait for background closures
select { assert.Eventually(t, func() bool {
case <-conn.Errors(): select {
case <-time.After(500 * time.Millisecond): case <-conn.Errors():
t.Fatal("timed out waiting for cleanup") return true
} default:
return false
}
}, testTimeout, testTick)
close(incomingData) close(incomingData)
close(outgoingData) close(outgoingData)

View File

@@ -1,10 +1,12 @@
package honeybee package honeybee
import ( import (
"bytes"
"fmt" "fmt"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io" "io"
"strings"
"testing" "testing"
"time" "time"
) )
@@ -78,12 +80,15 @@ func TestStartReader(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer conn.Close() defer conn.Close()
select { assert.Never(t, func() bool {
case <-deadlineCalled: select {
t.Fatal("SetReadDeadline should not be called when timeout is zero") case <-deadlineCalled:
case <-time.After(100 * time.Millisecond): return true
} default:
return false
}
}, negativeTestTimeout, testTick,
"SetReadDeadline should not be called when timeout is zero")
}) })
t.Run("read timeout sets deadline when positive", func(t *testing.T) { t.Run("read timeout sets deadline when positive", func(t *testing.T) {
@@ -120,17 +125,24 @@ func TestStartReader(t *testing.T) {
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil} incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil}
select { assert.Eventually(t, func() bool {
case <-conn.Incoming(): select {
case <-time.After(100 * time.Millisecond): case <-conn.Incoming():
} return true
default:
return false
}
}, testTimeout, testTick)
select { assert.Eventually(t, func() bool {
case _, ok := <-deadlineCalled: select {
assert.True(t, ok, "SetReadDeadline should be called when timeout is positive") case <-deadlineCalled:
case <-time.After(100 * time.Millisecond): return true
t.Fatal("SetReadDeadline was never called") default:
} return false
}
}, testTimeout, testTick,
"SetWriteDeadline should be called when timeout is positive")
}) })
t.Run("reader exits on deadline error", func(t *testing.T) { t.Run("reader exits on deadline error", func(t *testing.T) {
@@ -153,16 +165,19 @@ func TestStartReader(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer conn.Close() defer conn.Close()
select { assert.Eventually(t, func() bool {
case err := <-conn.Errors(): select {
assert.ErrorContains(t, err, "failed to set read deadline") case err := <-conn.Errors():
case <-time.After(100 * time.Millisecond): return err != nil &&
t.Fatal("timeout waiting for deadline error") strings.Contains(err.Error(), "failed to set read deadline")
} default:
return false
time.Sleep(10 * time.Millisecond) }
assert.Equal(t, StateClosed, conn.State()) }, testTimeout, testTick)
assert.Eventually(t, func() bool {
return conn.State() == StateClosed
}, testTimeout, testTick)
}) })
t.Run("reader exits on socket read error", func(t *testing.T) { t.Run("reader exits on socket read error", func(t *testing.T) {
@@ -184,16 +199,18 @@ func TestStartReader(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer conn.Close() defer conn.Close()
select { assert.Eventually(t, func() bool {
case err := <-conn.Errors(): select {
assert.Equal(t, readErr, err) case err := <-conn.Errors():
case <-time.After(100 * time.Millisecond): return err == readErr
t.Fatal("timeout waiting for read error") default:
} return false
}
time.Sleep(10 * time.Millisecond) }, testTimeout, testTick)
assert.Equal(t, StateClosed, conn.State())
assert.Eventually(t, func() bool {
return conn.State() == StateClosed
}, testTimeout, testTick)
}) })
} }
@@ -263,13 +280,15 @@ func TestStartWriter(t *testing.T) {
err = conn.Send([]byte("test")) err = conn.Send([]byte("test"))
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(20 * time.Millisecond) assert.Never(t, func() bool {
select {
select { case <-deadlineCalled:
case <-deadlineCalled: return true
t.Fatal("SetWriteDeadline should not be called when timeout is zero") default:
case <-time.After(100 * time.Millisecond): return false
} }
}, negativeTestTimeout, testTick,
"SetWriteDeadline should not be called when timeout is zero")
}) })
t.Run("write timeout sets deadline when positive", func(t *testing.T) { t.Run("write timeout sets deadline when positive", func(t *testing.T) {
@@ -307,14 +326,15 @@ func TestStartWriter(t *testing.T) {
err = conn.Send([]byte("test")) err = conn.Send([]byte("test"))
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(20 * time.Millisecond) assert.Eventually(t, func() bool {
select {
select { case <-deadlineCalled:
case _, ok := <-deadlineCalled: return true
assert.True(t, ok, "SetWriteDeadline should be called when timeout is positive") default:
case <-time.After(100 * time.Millisecond): return false
t.Fatal("SetWriteDeadline was never called") }
} }, testTimeout, testTick,
"SetWriteDeadline should be called when timeout is positive")
}) })
t.Run("writer exits on deadline error", func(t *testing.T) { t.Run("writer exits on deadline error", func(t *testing.T) {
@@ -340,15 +360,19 @@ func TestStartWriter(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer conn.Close() defer conn.Close()
select { assert.Eventually(t, func() bool {
case err := <-conn.Errors(): select {
assert.ErrorContains(t, err, "failed to set write deadline") case err := <-conn.Errors():
case <-time.After(100 * time.Millisecond): return err != nil &&
t.Fatal("timeout waiting for deadline error") strings.Contains(err.Error(), "failed to set write deadline")
} default:
return false
}
}, testTimeout, testTick)
time.Sleep(10 * time.Millisecond) assert.Eventually(t, func() bool {
assert.Equal(t, StateClosed, conn.State()) return conn.State() == StateClosed
}, testTimeout, testTick)
}) })
t.Run("writer exits on socket write error", func(t *testing.T) { t.Run("writer exits on socket write error", func(t *testing.T) {
@@ -366,15 +390,18 @@ func TestStartWriter(t *testing.T) {
err = conn.Send([]byte("test")) err = conn.Send([]byte("test"))
assert.NoError(t, err) assert.NoError(t, err)
select { assert.Eventually(t, func() bool {
case err := <-conn.Errors(): select {
assert.Equal(t, writeErr, err) case err := <-conn.Errors():
case <-time.After(100 * time.Millisecond): return err == writeErr
t.Fatal("timeout waiting for write error") default:
} return false
}
}, testTimeout, testTick)
time.Sleep(10 * time.Millisecond) assert.Eventually(t, func() bool {
assert.Equal(t, StateClosed, conn.State()) return conn.State() == StateClosed
}, testTimeout, testTick)
}) })
} }
@@ -382,23 +409,33 @@ func TestStartWriter(t *testing.T) {
func expectIncoming(t *testing.T, conn *Connection, expected []byte) { func expectIncoming(t *testing.T, conn *Connection, expected []byte) {
t.Helper() t.Helper()
assert.Eventually(t, func() bool {
select { select {
case received := <-conn.Incoming(): case received := <-conn.Incoming():
assert.Equal(t, expected, received) return bytes.Equal(received, expected)
case <-time.After(100 * time.Millisecond): default:
t.Fatal("timeout waiting for message") return false
} }
}, testTimeout, testTick)
} }
func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) { func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) {
t.Helper() t.Helper()
select { var call mockOutgoingData
case call := <-outgoingData: found := assert.Eventually(t, func() bool {
select {
case received := <-outgoingData:
call = received
return true
default:
return false
}
}, testTimeout, testTick)
if found {
assert.Equal(t, msgType, call.msgType) assert.Equal(t, msgType, call.msgType)
assert.Equal(t, expected, call.data) assert.Equal(t, expected, call.data)
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for write")
} }
} }

View File

@@ -1,6 +1,7 @@
package honeybee package honeybee
import ( import (
"bytes"
"fmt" "fmt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"net/http" "net/http"
@@ -279,14 +280,14 @@ func TestConnect(t *testing.T) {
testData := []byte("test") testData := []byte("test")
conn.Send(testData) conn.Send(testData)
time.Sleep(10 * time.Millisecond) assert.Eventually(t, func() bool {
select {
select { case msg := <-outgoingData:
case msg := <-outgoingData: return bytes.Equal(msg.data, testData)
assert.Equal(t, testData, msg.data) default:
case <-time.After(100 * time.Millisecond): return false
t.Fatal("timeout waiting for message write") }
} }, testTimeout, testTick)
conn.Close() conn.Close()
close(outgoingData) close(outgoingData)

View File

@@ -57,25 +57,6 @@ func assertLogSequence(t *testing.T, records []slog.Record, expected []expectedL
} }
} }
func assertLogRecord(
t *testing.T,
records []slog.Record,
level slog.Level,
msgSnippet string,
expectedAttrs ...slog.Attr,
) {
t.Helper()
record := findLogRecord(records, level, msgSnippet)
if record == nil {
t.Fatalf("no log record found with level %v and message containing %q", level, msgSnippet)
}
for _, expectedAttr := range expectedAttrs {
assertAttributePresent(t, *record, expectedAttr.Key, expectedAttr.Value.Any())
}
}
func findLogRecord( func findLogRecord(
records []slog.Record, level slog.Level, msgSnippet string, records []slog.Record, level slog.Level, msgSnippet string,
) *slog.Record { ) *slog.Record {
@@ -287,7 +268,11 @@ func TestCloseLogging(t *testing.T) {
conn.Close() conn.Close()
time.Sleep(10 * time.Millisecond) assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil
}, testTimeout, testTick)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []expectedLog{ expected := []expectedLog{
@@ -313,7 +298,11 @@ func TestCloseLogging(t *testing.T) {
conn.Close() conn.Close()
time.Sleep(10 * time.Millisecond) assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "socket close failed") != nil
}, testTimeout, testTick)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []expectedLog{ expected := []expectedLog{
@@ -341,12 +330,13 @@ func TestReaderLogging(t *testing.T) {
conn, err := NewConnectionFromSocket(mockSocket, config, logger) conn, err := NewConnectionFromSocket(mockSocket, config, logger)
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(50 * time.Millisecond) assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "read deadline error") != nil
}, testTimeout, testTick)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
assertLogRecord(t, records, slog.LevelError, "read deadline error")
record := findLogRecord(records, slog.LevelError, "read deadline error") record := findLogRecord(records, slog.LevelError, "read deadline error")
assert.NotNil(t, record) assert.NotNil(t, record)
assertAttributePresent(t, *record, "error", deadlineErr) assertAttributePresent(t, *record, "error", deadlineErr)
@@ -367,12 +357,13 @@ func TestReaderLogging(t *testing.T) {
conn, err := NewConnectionFromSocket(mockSocket, nil, logger) conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(50 * time.Millisecond) assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "read error") != nil
}, testTimeout, testTick)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
assertLogRecord(t, records, slog.LevelError, "read error")
record := findLogRecord(records, slog.LevelError, "read error") record := findLogRecord(records, slog.LevelError, "read error")
assert.NotNil(t, record) assert.NotNil(t, record)
assertAttributePresent(t, *record, "error", readErr) assertAttributePresent(t, *record, "error", readErr)
@@ -400,12 +391,13 @@ func TestWriterLogging(t *testing.T) {
err = conn.Send([]byte("test")) err = conn.Send([]byte("test"))
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(50 * time.Millisecond) assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "write deadline error") != nil
}, testTimeout, testTick)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
assertLogRecord(t, records, slog.LevelError, "write deadline error")
record := findLogRecord(records, slog.LevelError, "write deadline error") record := findLogRecord(records, slog.LevelError, "write deadline error")
assert.NotNil(t, record) assert.NotNil(t, record)
assertAttributePresent(t, *record, "error", deadlineErr) assertAttributePresent(t, *record, "error", deadlineErr)
@@ -429,12 +421,13 @@ func TestWriterLogging(t *testing.T) {
err = conn.Send([]byte("test")) err = conn.Send([]byte("test"))
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(50 * time.Millisecond) assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "write error") != nil
}, testTimeout, testTick)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
assertLogRecord(t, records, slog.LevelError, "write error")
record := findLogRecord(records, slog.LevelError, "write error") record := findLogRecord(records, slog.LevelError, "write error")
assert.NotNil(t, record) assert.NotNil(t, record)
assertAttributePresent(t, *record, "error", writeErr) assertAttributePresent(t, *record, "error", writeErr)

View File

@@ -12,6 +12,14 @@ import (
"time" "time"
) )
// Test Constants
const (
testTimeout = 2 * time.Second
testTick = 10 * time.Millisecond
negativeTestTimeout = 100 * time.Millisecond
)
// Dialer Mocks // Dialer Mocks
type MockDialer struct { type MockDialer struct {
@@ -33,6 +41,7 @@ type MockSocket struct {
SetCloseHandlerFunc func(func(int, string) error) SetCloseHandlerFunc func(func(int, string) error)
closed chan struct{} closed chan struct{}
once sync.Once once sync.Once
mu sync.Mutex
} }
func NewMockSocket() *MockSocket { func NewMockSocket() *MockSocket {
@@ -119,6 +128,9 @@ func setupTestConnection(t *testing.T, config *Config) (
// Wire WriteMessage to push to outgoingData channel // Wire WriteMessage to push to outgoingData channel
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
mockSocket.mu.Lock()
defer mockSocket.mu.Unlock()
select { select {
case <-mockSocket.closed: case <-mockSocket.closed:
return io.EOF return io.EOF