Simplify tests and avoid unnecessary race conditions.
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
@@ -58,29 +57,14 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
|
||||
conn.Close()
|
||||
|
||||
// Verify incoming channel closed
|
||||
select {
|
||||
case _, ok := <-conn.incoming:
|
||||
assert.False(t, ok, "incoming channel should be closed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
t.Fatal("timeout waiting for incoming channel closure")
|
||||
}
|
||||
|
||||
// Verify outgoing channel closed
|
||||
select {
|
||||
case _, ok := <-conn.outgoing:
|
||||
assert.False(t, ok, "outgoing channel should be closed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
t.Fatal("timeout waiting for outgoing channel closure")
|
||||
}
|
||||
|
||||
// Verify errors channel closed
|
||||
select {
|
||||
case _, ok := <-conn.errors:
|
||||
assert.False(t, ok, "errors channel should be closed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
t.Fatal("timeout waiting for errors channel closure")
|
||||
}
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case _, ok := <-conn.Errors():
|
||||
return !ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick, "errors channel should close")
|
||||
})
|
||||
|
||||
t.Run("send fails after close", func(t *testing.T) {
|
||||
@@ -115,12 +99,10 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
|
||||
close(incomingData)
|
||||
})
|
||||
|
||||
t.Run("writer active during close exits cleanly", func(t *testing.T) {
|
||||
conn, _, _, outgoingData := setupTestConnection(t, nil)
|
||||
conn, _, _, _ := setupTestConnection(t, nil)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
conn.Send([]byte("message"))
|
||||
@@ -131,22 +113,10 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
err := conn.Send([]byte("late"))
|
||||
assert.Error(t, err, "Send should fail after close")
|
||||
assert.ErrorContains(t, err, "connection closed")
|
||||
|
||||
// wait for background closures
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-conn.Errors():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
|
||||
close(outgoingData)
|
||||
})
|
||||
|
||||
t.Run("both goroutines active during close", func(t *testing.T) {
|
||||
conn, _, incomingData, outgoingData := setupTestConnection(t, nil)
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
incomingData <- mockIncomingData{
|
||||
@@ -157,18 +127,5 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
|
||||
// wait for background closures
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-conn.Errors():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
|
||||
close(incomingData)
|
||||
close(outgoingData)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -125,14 +125,9 @@ func TestStartReader(t *testing.T) {
|
||||
|
||||
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil}
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-conn.Incoming():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
select {
|
||||
case <-conn.Incoming():
|
||||
}
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConnectionSend(t *testing.T) {
|
||||
@@ -63,12 +62,11 @@ func TestConnectionSend(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify data appeared on outgoing channel
|
||||
select {
|
||||
case sent := <-conn.outgoing:
|
||||
assert.Equal(t, tc.data, sent)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
t.Fatal("timeout: data not sent to outgoing channel")
|
||||
default:
|
||||
t.Fatal("data not sent to outgoing channel")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -290,7 +290,6 @@ func TestConnect(t *testing.T) {
|
||||
}, testTimeout, testTick)
|
||||
|
||||
conn.Close()
|
||||
close(outgoingData)
|
||||
})
|
||||
|
||||
t.Run("connect retries on dial failure", func(t *testing.T) {
|
||||
|
||||
@@ -128,21 +128,13 @@ func setupTestConnection(t *testing.T, config *Config) (
|
||||
|
||||
// Wire WriteMessage to push to outgoingData channel
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
mockSocket.mu.Lock()
|
||||
defer mockSocket.mu.Unlock()
|
||||
|
||||
select {
|
||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||
return nil
|
||||
case <-mockSocket.closed:
|
||||
return io.EOF
|
||||
default:
|
||||
select {
|
||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||
return nil
|
||||
case <-mockSocket.closed:
|
||||
return io.EOF
|
||||
default:
|
||||
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||
}
|
||||
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user