Simplify tests and avoid unnecessary race conditions.
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDisconnectedConnectionClose(t *testing.T) {
|
func TestDisconnectedConnectionClose(t *testing.T) {
|
||||||
@@ -58,29 +57,14 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
// Verify incoming channel closed
|
assert.Eventually(t, func() bool {
|
||||||
select {
|
select {
|
||||||
case _, ok := <-conn.incoming:
|
case _, ok := <-conn.Errors():
|
||||||
assert.False(t, ok, "incoming channel should be closed")
|
return !ok
|
||||||
case <-time.After(50 * time.Millisecond):
|
default:
|
||||||
t.Fatal("timeout waiting for incoming channel closure")
|
return false
|
||||||
}
|
}
|
||||||
|
}, testTimeout, testTick, "errors channel should close")
|
||||||
// Verify outgoing channel closed
|
|
||||||
select {
|
|
||||||
case _, ok := <-conn.outgoing:
|
|
||||||
assert.False(t, ok, "outgoing channel should be closed")
|
|
||||||
case <-time.After(50 * time.Millisecond):
|
|
||||||
t.Fatal("timeout waiting for outgoing channel closure")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify errors channel closed
|
|
||||||
select {
|
|
||||||
case _, ok := <-conn.errors:
|
|
||||||
assert.False(t, ok, "errors channel should be closed")
|
|
||||||
case <-time.After(50 * time.Millisecond):
|
|
||||||
t.Fatal("timeout waiting for errors channel closure")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("send fails after close", func(t *testing.T) {
|
t.Run("send fails after close", func(t *testing.T) {
|
||||||
@@ -115,12 +99,10 @@ func TestConnectedConnectionClose(t *testing.T) {
|
|||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
assert.Equal(t, StateClosed, conn.State())
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
|
||||||
close(incomingData)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("writer active during close exits cleanly", func(t *testing.T) {
|
t.Run("writer active during close exits cleanly", func(t *testing.T) {
|
||||||
conn, _, _, outgoingData := setupTestConnection(t, nil)
|
conn, _, _, _ := setupTestConnection(t, nil)
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
for i := 0; i < 50; i++ {
|
||||||
conn.Send([]byte("message"))
|
conn.Send([]byte("message"))
|
||||||
@@ -131,22 +113,10 @@ func TestConnectedConnectionClose(t *testing.T) {
|
|||||||
err := conn.Send([]byte("late"))
|
err := conn.Send([]byte("late"))
|
||||||
assert.Error(t, err, "Send should fail after close")
|
assert.Error(t, err, "Send should fail after close")
|
||||||
assert.ErrorContains(t, err, "connection closed")
|
assert.ErrorContains(t, err, "connection closed")
|
||||||
|
|
||||||
// wait for background closures
|
|
||||||
assert.Eventually(t, func() bool {
|
|
||||||
select {
|
|
||||||
case <-conn.Errors():
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, testTimeout, testTick)
|
|
||||||
|
|
||||||
close(outgoingData)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("both goroutines active during close", func(t *testing.T) {
|
t.Run("both goroutines active during close", func(t *testing.T) {
|
||||||
conn, _, incomingData, outgoingData := setupTestConnection(t, nil)
|
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
incomingData <- mockIncomingData{
|
incomingData <- mockIncomingData{
|
||||||
@@ -157,18 +127,5 @@ func TestConnectedConnectionClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
// wait for background closures
|
|
||||||
assert.Eventually(t, func() bool {
|
|
||||||
select {
|
|
||||||
case <-conn.Errors():
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, testTimeout, testTick)
|
|
||||||
|
|
||||||
close(incomingData)
|
|
||||||
close(outgoingData)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,14 +125,9 @@ func TestStartReader(t *testing.T) {
|
|||||||
|
|
||||||
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil}
|
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil}
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
select {
|
||||||
select {
|
case <-conn.Incoming():
|
||||||
case <-conn.Incoming():
|
}
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, testTimeout, testTick)
|
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnectionSend(t *testing.T) {
|
func TestConnectionSend(t *testing.T) {
|
||||||
@@ -63,12 +62,11 @@ func TestConnectionSend(t *testing.T) {
|
|||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Verify data appeared on outgoing channel
|
|
||||||
select {
|
select {
|
||||||
case sent := <-conn.outgoing:
|
case sent := <-conn.outgoing:
|
||||||
assert.Equal(t, tc.data, sent)
|
assert.Equal(t, tc.data, sent)
|
||||||
case <-time.After(50 * time.Millisecond):
|
default:
|
||||||
t.Fatal("timeout: data not sent to outgoing channel")
|
t.Fatal("data not sent to outgoing channel")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -290,7 +290,6 @@ func TestConnect(t *testing.T) {
|
|||||||
}, testTimeout, testTick)
|
}, testTimeout, testTick)
|
||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
close(outgoingData)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("connect retries on dial failure", func(t *testing.T) {
|
t.Run("connect retries on dial failure", func(t *testing.T) {
|
||||||
|
|||||||
@@ -128,21 +128,13 @@ func setupTestConnection(t *testing.T, config *Config) (
|
|||||||
|
|
||||||
// Wire WriteMessage to push to outgoingData channel
|
// Wire WriteMessage to push to outgoingData channel
|
||||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||||
mockSocket.mu.Lock()
|
|
||||||
defer mockSocket.mu.Unlock()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||||
|
return nil
|
||||||
case <-mockSocket.closed:
|
case <-mockSocket.closed:
|
||||||
return io.EOF
|
return io.EOF
|
||||||
default:
|
default:
|
||||||
select {
|
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
|
||||||
return nil
|
|
||||||
case <-mockSocket.closed:
|
|
||||||
return io.EOF
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("mock outgoing chanel unavailable")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user