Refactored connection shutdown logic.
This commit is contained in:
@@ -51,12 +51,3 @@ func setupWorkerTestConnection(t *testing.T) (
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
func connClosed(conn *transport.Connection) bool {
|
||||
select {
|
||||
case _, ok := <-conn.Errors():
|
||||
return !ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRunSession(t *testing.T) {
|
||||
func TestRunSessionDial(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
@@ -127,14 +127,19 @@ func TestRunReader(t *testing.T) {
|
||||
}()
|
||||
go w.runReader(conn, messages, sessionDone, onStop)
|
||||
|
||||
// simulate remote close
|
||||
// induce connection closure via reader
|
||||
incomingData <- honeybeetest.MockIncomingData{Err: io.EOF}
|
||||
|
||||
err := <-conn.Errors()
|
||||
assert.Equal(t, io.EOF, err)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return connClosed(conn)
|
||||
return conn.State() == transport.StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.True(t, onStopCalled.Load())
|
||||
assert.Eventually(t, func() bool {
|
||||
return onStopCalled.Load()
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
|
||||
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
|
||||
@@ -157,10 +162,12 @@ func TestRunReader(t *testing.T) {
|
||||
close(sessionDone)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return connClosed(conn)
|
||||
return conn.State() == transport.StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.True(t, onStopCalled.Load())
|
||||
assert.Eventually(t, func() bool {
|
||||
return onStopCalled.Load()
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -181,10 +188,12 @@ func TestRunStopMonitor(t *testing.T) {
|
||||
keepalive <- struct{}{}
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return connClosed(conn)
|
||||
return conn.State() == transport.StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.True(t, onStopCalled.Load())
|
||||
assert.Eventually(t, func() bool {
|
||||
return onStopCalled.Load()
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
|
||||
t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) {
|
||||
@@ -202,10 +211,12 @@ func TestRunStopMonitor(t *testing.T) {
|
||||
cancel()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return connClosed(conn)
|
||||
return conn.State() == transport.StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.True(t, onStopCalled.Load())
|
||||
assert.Eventually(t, func() bool {
|
||||
return onStopCalled.Load()
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
|
||||
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
|
||||
@@ -224,10 +235,12 @@ func TestRunStopMonitor(t *testing.T) {
|
||||
close(sessionDone)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return connClosed(conn)
|
||||
return conn.State() == transport.StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.True(t, onStopCalled.Load())
|
||||
assert.Eventually(t, func() bool {
|
||||
return onStopCalled.Load()
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+141
-86
@@ -50,10 +50,12 @@ type Connection struct {
|
||||
|
||||
state ConnectionState
|
||||
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
writeMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
writeMu sync.Mutex
|
||||
doneOnce sync.Once
|
||||
cleanupOnce sync.Once
|
||||
}
|
||||
|
||||
func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
||||
@@ -167,51 +169,151 @@ func (c *Connection) Connect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdownExternal()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownExternal() {
|
||||
err := c.shutdownSetClosed(true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
c.shutdownCleanup()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInternal() {
|
||||
err := c.shutdownSetClosed(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
|
||||
// defer final cleanup to allow this function to return
|
||||
// otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||
// must wait for itself to exit.
|
||||
go func() {
|
||||
c.shutdownCleanup()
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInner() {
|
||||
c.shutdownSignalDone()
|
||||
c.shutdownLogStart()
|
||||
c.shutdownCloseSocket()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCleanup() {
|
||||
c.cleanupOnce.Do(func() {
|
||||
c.wg.Wait()
|
||||
c.shutdownCloseChannels()
|
||||
c.shutdownLogComplete()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSetClosed(wait bool) error {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSignalDone() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogStart() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseSocket() {
|
||||
if c.socket != nil {
|
||||
// force unblock of any network operations immediately
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
c.socket.SetReadDeadline(expired)
|
||||
c.socket.SetWriteDeadline(expired)
|
||||
|
||||
// close socket
|
||||
err := c.socket.Close()
|
||||
|
||||
if err != nil && c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseChannels() {
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogComplete() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connection closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) startReader() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
defer c.shutdownInternal()
|
||||
|
||||
for {
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
default:
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case c.errors <- err:
|
||||
case <-c.done:
|
||||
}
|
||||
c.shutdown()
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
default:
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
default:
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case c.incoming <- data:
|
||||
case <-c.done:
|
||||
c.shutdown()
|
||||
select {
|
||||
case <-c.done:
|
||||
case c.errors <- err:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.incoming <- data:
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -230,7 +332,7 @@ func (c *Connection) Send(data []byte) error {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("write deadline error", "error", err)
|
||||
}
|
||||
c.shutdown()
|
||||
c.shutdownExternal()
|
||||
return fmt.Errorf("failed to set write deadline: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -253,53 +355,6 @@ func (c *Connection) Errors() <-chan error {
|
||||
return c.errors
|
||||
}
|
||||
|
||||
func (c *Connection) shutdown() {
|
||||
c.mu.Lock()
|
||||
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing", "state", c.state.String())
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
socket := c.socket
|
||||
close(c.done)
|
||||
c.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
if socket != nil {
|
||||
// force immediate timeout of any blocked network I/O
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
socket.SetReadDeadline(expired)
|
||||
socket.SetWriteDeadline(expired)
|
||||
err := socket.Close()
|
||||
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
} else {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.wg.Wait()
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdown()
|
||||
}
|
||||
|
||||
func (c *Connection) State() ConnectionState {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
@@ -58,15 +58,11 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
|
||||
conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case _, ok := <-conn.Errors():
|
||||
return !ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
|
||||
"errors channel should close")
|
||||
assert.True(t, conn.closed)
|
||||
_, ok := <-conn.incoming
|
||||
assert.False(t, ok)
|
||||
_, ok = <-conn.errors
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("send fails after close", func(t *testing.T) {
|
||||
@@ -103,32 +99,4 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("writer active during close exits cleanly", func(t *testing.T) {
|
||||
conn, _, _, _ := setupTestConnection(t, nil)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
conn.Send([]byte("message"))
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
|
||||
err := conn.Send([]byte("late"))
|
||||
assert.Error(t, err, "Send should fail after close")
|
||||
assert.ErrorContains(t, err, "connection closed")
|
||||
})
|
||||
|
||||
t.Run("both goroutines active during close", func(t *testing.T) {
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage,
|
||||
Data: []byte(fmt.Sprintf("in-%d", i)),
|
||||
}
|
||||
conn.Send([]byte(fmt.Sprintf("out-%d", i)))
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -62,23 +62,12 @@ func TestStartReader(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
readErr := fmt.Errorf("read failed")
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, readErr
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case err := <-conn.Errors():
|
||||
return err == readErr
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return conn.State() == StateClosed
|
||||
|
||||
@@ -50,7 +50,6 @@ func TestConnectionSend(t *testing.T) {
|
||||
for {
|
||||
select {
|
||||
case msg := <-outgoingData:
|
||||
fmt.Printf("got message %s\n", string(msg.Data))
|
||||
mu.Lock()
|
||||
messages = append(messages, string(msg.Data))
|
||||
mu.Unlock()
|
||||
@@ -69,7 +68,6 @@ func TestConnectionSend(t *testing.T) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
data := []byte(fmt.Sprintf("msg-%d-%d", id, j))
|
||||
fmt.Printf("sending message %s\n", string(data))
|
||||
for {
|
||||
// send and retry until success
|
||||
err := conn.Send(data)
|
||||
|
||||
@@ -281,7 +281,7 @@ func TestCloseLogging(t *testing.T) {
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
||||
{slog.LevelInfo, "shutting down", map[string]any{}},
|
||||
{slog.LevelInfo, "closed", map[string]any{}},
|
||||
}
|
||||
|
||||
@@ -311,7 +311,7 @@ func TestCloseLogging(t *testing.T) {
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
||||
{slog.LevelInfo, "shutting down", map[string]any{}},
|
||||
{slog.LevelError, "socket close failed", map[string]any{"error": closeErr}},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user