Simplify tests and avoid unnecessary race conditions.

This commit is contained in:
Jay
2026-04-15 13:47:34 -04:00
parent b128a021de
commit d002c19889
5 changed files with 18 additions and 77 deletions

View File

@@ -6,7 +6,6 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time"
) )
func TestDisconnectedConnectionClose(t *testing.T) { func TestDisconnectedConnectionClose(t *testing.T) {
@@ -58,29 +57,14 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn.Close() conn.Close()
// Verify incoming channel closed assert.Eventually(t, func() bool {
select { select {
case _, ok := <-conn.incoming: case _, ok := <-conn.Errors():
assert.False(t, ok, "incoming channel should be closed") return !ok
case <-time.After(50 * time.Millisecond): default:
t.Fatal("timeout waiting for incoming channel closure") return false
} }
}, testTimeout, testTick, "errors channel should close")
// Verify outgoing channel closed
select {
case _, ok := <-conn.outgoing:
assert.False(t, ok, "outgoing channel should be closed")
case <-time.After(50 * time.Millisecond):
t.Fatal("timeout waiting for outgoing channel closure")
}
// Verify errors channel closed
select {
case _, ok := <-conn.errors:
assert.False(t, ok, "errors channel should be closed")
case <-time.After(50 * time.Millisecond):
t.Fatal("timeout waiting for errors channel closure")
}
}) })
t.Run("send fails after close", func(t *testing.T) { t.Run("send fails after close", func(t *testing.T) {
@@ -115,12 +99,10 @@ func TestConnectedConnectionClose(t *testing.T) {
conn.Close() conn.Close()
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
close(incomingData)
}) })
t.Run("writer active during close exits cleanly", func(t *testing.T) { t.Run("writer active during close exits cleanly", func(t *testing.T) {
conn, _, _, outgoingData := setupTestConnection(t, nil) conn, _, _, _ := setupTestConnection(t, nil)
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
conn.Send([]byte("message")) conn.Send([]byte("message"))
@@ -131,22 +113,10 @@ func TestConnectedConnectionClose(t *testing.T) {
err := conn.Send([]byte("late")) err := conn.Send([]byte("late"))
assert.Error(t, err, "Send should fail after close") assert.Error(t, err, "Send should fail after close")
assert.ErrorContains(t, err, "connection closed") assert.ErrorContains(t, err, "connection closed")
// wait for background closures
assert.Eventually(t, func() bool {
select {
case <-conn.Errors():
return true
default:
return false
}
}, testTimeout, testTick)
close(outgoingData)
}) })
t.Run("both goroutines active during close", func(t *testing.T) { t.Run("both goroutines active during close", func(t *testing.T) {
conn, _, incomingData, outgoingData := setupTestConnection(t, nil) conn, _, incomingData, _ := setupTestConnection(t, nil)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
incomingData <- mockIncomingData{ incomingData <- mockIncomingData{
@@ -157,18 +127,5 @@ func TestConnectedConnectionClose(t *testing.T) {
} }
conn.Close() conn.Close()
// wait for background closures
assert.Eventually(t, func() bool {
select {
case <-conn.Errors():
return true
default:
return false
}
}, testTimeout, testTick)
close(incomingData)
close(outgoingData)
}) })
} }

View File

@@ -125,14 +125,9 @@ func TestStartReader(t *testing.T) {
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil} incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil}
assert.Eventually(t, func() bool { select {
select { case <-conn.Incoming():
case <-conn.Incoming(): }
return true
default:
return false
}
}, testTimeout, testTick)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
select { select {

View File

@@ -5,7 +5,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"sync" "sync"
"testing" "testing"
"time"
) )
func TestConnectionSend(t *testing.T) { func TestConnectionSend(t *testing.T) {
@@ -63,12 +62,11 @@ func TestConnectionSend(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Verify data appeared on outgoing channel
select { select {
case sent := <-conn.outgoing: case sent := <-conn.outgoing:
assert.Equal(t, tc.data, sent) assert.Equal(t, tc.data, sent)
case <-time.After(50 * time.Millisecond): default:
t.Fatal("timeout: data not sent to outgoing channel") t.Fatal("data not sent to outgoing channel")
} }
}) })
} }

View File

@@ -290,7 +290,6 @@ func TestConnect(t *testing.T) {
}, testTimeout, testTick) }, testTimeout, testTick)
conn.Close() conn.Close()
close(outgoingData)
}) })
t.Run("connect retries on dial failure", func(t *testing.T) { t.Run("connect retries on dial failure", func(t *testing.T) {

View File

@@ -128,21 +128,13 @@ func setupTestConnection(t *testing.T, config *Config) (
// Wire WriteMessage to push to outgoingData channel // Wire WriteMessage to push to outgoingData channel
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
mockSocket.mu.Lock()
defer mockSocket.mu.Unlock()
select { select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
return nil
case <-mockSocket.closed: case <-mockSocket.closed:
return io.EOF return io.EOF
default: default:
select { return fmt.Errorf("mock outgoing chanel unavailable")
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
return nil
case <-mockSocket.closed:
return io.EOF
default:
return fmt.Errorf("mock outgoing chanel unavailable")
}
} }
} }