diff --git a/connection_close_test.go b/connection_close_test.go index 1684627..eddb2c9 100644 --- a/connection_close_test.go +++ b/connection_close_test.go @@ -6,7 +6,6 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "testing" - "time" ) func TestDisconnectedConnectionClose(t *testing.T) { @@ -58,29 +57,14 @@ func TestDisconnectedConnectionClose(t *testing.T) { conn.Close() - // Verify incoming channel closed - select { - case _, ok := <-conn.incoming: - assert.False(t, ok, "incoming channel should be closed") - case <-time.After(50 * time.Millisecond): - t.Fatal("timeout waiting for incoming channel closure") - } - - // Verify outgoing channel closed - select { - case _, ok := <-conn.outgoing: - assert.False(t, ok, "outgoing channel should be closed") - case <-time.After(50 * time.Millisecond): - t.Fatal("timeout waiting for outgoing channel closure") - } - - // Verify errors channel closed - select { - case _, ok := <-conn.errors: - assert.False(t, ok, "errors channel should be closed") - case <-time.After(50 * time.Millisecond): - t.Fatal("timeout waiting for errors channel closure") - } + assert.Eventually(t, func() bool { + select { + case _, ok := <-conn.Errors(): + return !ok + default: + return false + } + }, testTimeout, testTick, "errors channel should close") }) t.Run("send fails after close", func(t *testing.T) { @@ -115,12 +99,10 @@ func TestConnectedConnectionClose(t *testing.T) { conn.Close() assert.Equal(t, StateClosed, conn.State()) - - close(incomingData) }) t.Run("writer active during close exits cleanly", func(t *testing.T) { - conn, _, _, outgoingData := setupTestConnection(t, nil) + conn, _, _, _ := setupTestConnection(t, nil) for i := 0; i < 50; i++ { conn.Send([]byte("message")) @@ -131,22 +113,10 @@ func TestConnectedConnectionClose(t *testing.T) { err := conn.Send([]byte("late")) assert.Error(t, err, "Send should fail after close") assert.ErrorContains(t, err, "connection closed") - - // wait for background closures - assert.Eventually(t, func() bool { - select { - case <-conn.Errors(): - return true - default: - return false - } - }, testTimeout, testTick) - - close(outgoingData) }) t.Run("both goroutines active during close", func(t *testing.T) { - conn, _, incomingData, outgoingData := setupTestConnection(t, nil) + conn, _, incomingData, _ := setupTestConnection(t, nil) for i := 0; i < 10; i++ { incomingData <- mockIncomingData{ @@ -157,18 +127,5 @@ func TestConnectedConnectionClose(t *testing.T) { } conn.Close() - - // wait for background closures - assert.Eventually(t, func() bool { - select { - case <-conn.Errors(): - return true - default: - return false - } - }, testTimeout, testTick) - - close(incomingData) - close(outgoingData) }) } diff --git a/connection_goroutine_test.go b/connection_goroutine_test.go index 38d689a..dfc4194 100644 --- a/connection_goroutine_test.go +++ b/connection_goroutine_test.go @@ -125,14 +125,9 @@ func TestStartReader(t *testing.T) { incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil} - assert.Eventually(t, func() bool { - select { - case <-conn.Incoming(): - return true - default: - return false - } - }, testTimeout, testTick) + select { + case <-conn.Incoming(): + } assert.Eventually(t, func() bool { select { diff --git a/connection_send_test.go b/connection_send_test.go index c475b6b..03b0ba0 100644 --- a/connection_send_test.go +++ b/connection_send_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/assert" "sync" "testing" - "time" ) func TestConnectionSend(t *testing.T) { @@ -63,12 +62,11 @@ func TestConnectionSend(t *testing.T) { assert.NoError(t, err) - // Verify data appeared on outgoing channel select { case sent := <-conn.outgoing: assert.Equal(t, tc.data, sent) - case <-time.After(50 * time.Millisecond): - t.Fatal("timeout: data not sent to outgoing channel") + default: + t.Fatal("data not sent to outgoing channel") } }) } diff --git a/connection_test.go b/connection_test.go index ac12b4e..6ee68bc 100644 --- a/connection_test.go +++ b/connection_test.go @@ -290,7 +290,6 @@ func TestConnect(t *testing.T) { }, testTimeout, testTick) conn.Close() - close(outgoingData) }) t.Run("connect retries on dial failure", func(t *testing.T) { diff --git a/mocks_test.go b/mocks_test.go index 5b522cd..77ee5b0 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -128,21 +128,13 @@ func setupTestConnection(t *testing.T, config *Config) ( // Wire WriteMessage to push to outgoingData channel mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { - mockSocket.mu.Lock() - defer mockSocket.mu.Unlock() - select { + case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: + return nil case <-mockSocket.closed: return io.EOF default: - select { - case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: - return nil - case <-mockSocket.closed: - return io.EOF - default: - return fmt.Errorf("mock outgoing chanel unavailable") - } + return fmt.Errorf("mock outgoing chanel unavailable") } }