diff --git a/initiatorpool/helper_test.go b/initiatorpool/helper_test.go index a275bae..85bebcd 100644 --- a/initiatorpool/helper_test.go +++ b/initiatorpool/helper_test.go @@ -51,12 +51,3 @@ func setupWorkerTestConnection(t *testing.T) ( assert.NoError(t, err) return } - -func connClosed(conn *transport.Connection) bool { - select { - case _, ok := <-conn.Errors(): - return !ok - default: - return false - } -} diff --git a/initiatorpool/worker_test.go b/initiatorpool/worker_test.go index 139aa14..bd7f612 100644 --- a/initiatorpool/worker_test.go +++ b/initiatorpool/worker_test.go @@ -15,7 +15,7 @@ import ( "time" ) -func TestRunSession(t *testing.T) { +func TestRunSessionDial(t *testing.T) { } @@ -127,14 +127,19 @@ func TestRunReader(t *testing.T) { }() go w.runReader(conn, messages, sessionDone, onStop) - // simulate remote close + // induce connection closure via reader incomingData <- honeybeetest.MockIncomingData{Err: io.EOF} + err := <-conn.Errors() + assert.Equal(t, io.EOF, err) + assert.Eventually(t, func() bool { - return connClosed(conn) + return conn.State() == transport.StateClosed }, honeybeetest.TestTimeout, honeybeetest.TestTick) - assert.True(t, onStopCalled.Load()) + assert.Eventually(t, func() bool { + return onStopCalled.Load() + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { @@ -157,10 +162,12 @@ func TestRunReader(t *testing.T) { close(sessionDone) assert.Eventually(t, func() bool { - return connClosed(conn) + return conn.State() == transport.StateClosed }, honeybeetest.TestTimeout, honeybeetest.TestTick) - assert.True(t, onStopCalled.Load()) + assert.Eventually(t, func() bool { + return onStopCalled.Load() + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) } @@ -181,10 +188,12 @@ func TestRunStopMonitor(t *testing.T) { keepalive <- struct{}{} assert.Eventually(t, func() bool { - return connClosed(conn) + return conn.State() == transport.StateClosed }, honeybeetest.TestTimeout, honeybeetest.TestTick) - assert.True(t, onStopCalled.Load()) + assert.Eventually(t, func() bool { + return onStopCalled.Load() + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) { @@ -202,10 +211,12 @@ func TestRunStopMonitor(t *testing.T) { cancel() assert.Eventually(t, func() bool { - return connClosed(conn) + return conn.State() == transport.StateClosed }, honeybeetest.TestTimeout, honeybeetest.TestTick) - assert.True(t, onStopCalled.Load()) + assert.Eventually(t, func() bool { + return onStopCalled.Load() + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { @@ -224,10 +235,12 @@ func TestRunStopMonitor(t *testing.T) { close(sessionDone) assert.Eventually(t, func() bool { - return connClosed(conn) + return conn.State() == transport.StateClosed }, honeybeetest.TestTimeout, honeybeetest.TestTick) - assert.True(t, onStopCalled.Load()) + assert.Eventually(t, func() bool { + return onStopCalled.Load() + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) } diff --git a/transport/connection.go b/transport/connection.go index 4ede803..c2e0c75 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -50,10 +50,12 @@ type Connection struct { state ConnectionState - wg sync.WaitGroup - closed bool - mu sync.RWMutex - writeMu sync.Mutex + wg sync.WaitGroup + closed bool + mu sync.RWMutex + writeMu sync.Mutex + doneOnce sync.Once + cleanupOnce sync.Once } func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { @@ -167,51 +169,151 @@ func (c *Connection) Connect(ctx context.Context) error { return nil } +func (c *Connection) Close() { + c.shutdownExternal() +} + +func (c *Connection) shutdownExternal() { + err := c.shutdownSetClosed(true) + if err != nil { + return + } + + c.shutdownInner() + c.shutdownCleanup() +} + +func (c *Connection) shutdownInternal() { + err := c.shutdownSetClosed(false) + if err != nil { + return + } + + c.shutdownInner() + + // defer final cleanup to allow this function to return + // otherwise, a deadlock occurs where startReader triggers a shutdown and + // must wait for itself to exit. + go func() { + c.shutdownCleanup() + }() +} + +func (c *Connection) shutdownInner() { + c.shutdownSignalDone() + c.shutdownLogStart() + c.shutdownCloseSocket() +} + +func (c *Connection) shutdownCleanup() { + c.cleanupOnce.Do(func() { + c.wg.Wait() + c.shutdownCloseChannels() + c.shutdownLogComplete() + }) +} + +func (c *Connection) shutdownSetClosed(wait bool) error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return ErrConnectionClosed + } + c.closed = true + c.state = StateClosed + c.mu.Unlock() + return nil +} + +func (c *Connection) shutdownSignalDone() { + c.doneOnce.Do(func() { + close(c.done) + }) +} + +func (c *Connection) shutdownLogStart() { + if c.logger != nil { + c.logger.Info("shutting down") + } +} + +func (c *Connection) shutdownCloseSocket() { + if c.socket != nil { + // force unblock of any network operations immediately + expired := time.Now().Add(-1 * time.Minute) + c.socket.SetReadDeadline(expired) + c.socket.SetWriteDeadline(expired) + + // close socket + err := c.socket.Close() + + if err != nil && c.logger != nil { + c.logger.Error("socket close failed", "error", err) + } + } +} + +func (c *Connection) shutdownCloseChannels() { + close(c.incoming) + close(c.errors) +} + +func (c *Connection) shutdownLogComplete() { + if c.logger != nil { + c.logger.Info("connection closed") + } +} + func (c *Connection) startReader() { c.wg.Add(1) go func() { defer c.wg.Done() + defer c.shutdownInternal() for { - messageType, data, err := c.socket.ReadMessage() - if err != nil { - if c.logger != nil { - var closeErr *websocket.CloseError - if errors.As(err, &closeErr) { - switch closeErr.Code { - case websocket.CloseNormalClosure, websocket.CloseGoingAway: - c.logger.Info("connection closed by peer", - "code", closeErr.Code, - "text", closeErr.Text, - ) - default: - c.logger.Error("unexpected close", - "code", closeErr.Code, - "text", closeErr.Text, - ) - } - } else { - c.logger.Error("read error", "error", err) - } - } - select { - case c.errors <- err: - case <-c.done: - } - c.shutdown() + select { + case <-c.done: return - } + default: + messageType, data, err := c.socket.ReadMessage() + if err != nil { + if c.logger != nil { + var closeErr *websocket.CloseError + if errors.As(err, &closeErr) { + switch closeErr.Code { + case websocket.CloseNormalClosure, websocket.CloseGoingAway: + c.logger.Info("connection closed by peer", + "code", closeErr.Code, + "text", closeErr.Text, + ) + default: + c.logger.Error("unexpected close", + "code", closeErr.Code, + "text", closeErr.Text, + ) + } + } else { + c.logger.Error("read error", "error", err) + } + } - if messageType == websocket.TextMessage || - messageType == websocket.BinaryMessage { - select { - case c.incoming <- data: - case <-c.done: - c.shutdown() + select { + case <-c.done: + case c.errors <- err: + } return } - } + if messageType == websocket.TextMessage || + messageType == websocket.BinaryMessage { + select { + case <-c.done: + return + case c.incoming <- data: + } + } + + } } }() @@ -230,7 +332,7 @@ func (c *Connection) Send(data []byte) error { if c.logger != nil { c.logger.Error("write deadline error", "error", err) } - c.shutdown() + c.shutdownExternal() return fmt.Errorf("failed to set write deadline: %w", err) } } @@ -253,53 +355,6 @@ func (c *Connection) Errors() <-chan error { return c.errors } -func (c *Connection) shutdown() { - c.mu.Lock() - - if c.closed { - c.mu.Unlock() - return - } - - if c.logger != nil { - c.logger.Info("closing", "state", c.state.String()) - } - c.closed = true - c.state = StateClosed - socket := c.socket - close(c.done) - c.mu.Unlock() - - go func() { - if socket != nil { - // force immediate timeout of any blocked network I/O - expired := time.Now().Add(-1 * time.Minute) - socket.SetReadDeadline(expired) - socket.SetWriteDeadline(expired) - err := socket.Close() - - if err != nil { - if c.logger != nil { - c.logger.Error("socket close failed", "error", err) - } - } else { - if c.logger != nil { - c.logger.Info("closed") - } - } - } - - c.wg.Wait() - close(c.incoming) - close(c.errors) - }() - -} - -func (c *Connection) Close() { - c.shutdown() -} - func (c *Connection) State() ConnectionState { c.mu.RLock() defer c.mu.RUnlock() diff --git a/transport/connection_close_test.go b/transport/connection_close_test.go index fa44a99..241f75c 100644 --- a/transport/connection_close_test.go +++ b/transport/connection_close_test.go @@ -58,15 +58,11 @@ func TestDisconnectedConnectionClose(t *testing.T) { conn.Close() - assert.Eventually(t, func() bool { - select { - case _, ok := <-conn.Errors(): - return !ok - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick, - "errors channel should close") + assert.True(t, conn.closed) + _, ok := <-conn.incoming + assert.False(t, ok) + _, ok = <-conn.errors + assert.False(t, ok) }) t.Run("send fails after close", func(t *testing.T) { @@ -103,32 +99,4 @@ func TestConnectedConnectionClose(t *testing.T) { conn.Close() assert.Equal(t, StateClosed, conn.State()) }) - - t.Run("writer active during close exits cleanly", func(t *testing.T) { - conn, _, _, _ := setupTestConnection(t, nil) - - for i := 0; i < 50; i++ { - conn.Send([]byte("message")) - } - - conn.Close() - - err := conn.Send([]byte("late")) - assert.Error(t, err, "Send should fail after close") - assert.ErrorContains(t, err, "connection closed") - }) - - t.Run("both goroutines active during close", func(t *testing.T) { - conn, _, incomingData, _ := setupTestConnection(t, nil) - - for i := 0; i < 10; i++ { - incomingData <- honeybeetest.MockIncomingData{ - MsgType: websocket.TextMessage, - Data: []byte(fmt.Sprintf("in-%d", i)), - } - conn.Send([]byte(fmt.Sprintf("out-%d", i))) - } - - conn.Close() - }) } diff --git a/transport/connection_goroutine_test.go b/transport/connection_goroutine_test.go index 8203c63..8787a93 100644 --- a/transport/connection_goroutine_test.go +++ b/transport/connection_goroutine_test.go @@ -1,10 +1,10 @@ package transport import ( - "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "io" "testing" ) @@ -62,23 +62,12 @@ func TestStartReader(t *testing.T) { return nil } - readErr := fmt.Errorf("read failed") mockSocket.ReadMessageFunc = func() (int, []byte, error) { - return 0, nil, readErr + return 0, nil, io.EOF } conn, err := NewConnectionFromSocket(mockSocket, nil, nil) assert.NoError(t, err) - defer conn.Close() - - assert.Eventually(t, func() bool { - select { - case err := <-conn.Errors(): - return err == readErr - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) assert.Eventually(t, func() bool { return conn.State() == StateClosed diff --git a/transport/connection_send_test.go b/transport/connection_send_test.go index 211a024..6a17e88 100644 --- a/transport/connection_send_test.go +++ b/transport/connection_send_test.go @@ -50,7 +50,6 @@ func TestConnectionSend(t *testing.T) { for { select { case msg := <-outgoingData: - fmt.Printf("got message %s\n", string(msg.Data)) mu.Lock() messages = append(messages, string(msg.Data)) mu.Unlock() @@ -69,7 +68,6 @@ func TestConnectionSend(t *testing.T) { defer wg.Done() for j := 0; j < 10; j++ { data := []byte(fmt.Sprintf("msg-%d-%d", id, j)) - fmt.Printf("sending message %s\n", string(data)) for { // send and retry until success err := conn.Send(data) diff --git a/transport/logging_test.go b/transport/logging_test.go index a103668..9b92830 100644 --- a/transport/logging_test.go +++ b/transport/logging_test.go @@ -281,7 +281,7 @@ func TestCloseLogging(t *testing.T) { records := mockHandler.GetRecords() expected := []expectedLog{ - {slog.LevelInfo, "closing", map[string]any{"state": "connected"}}, + {slog.LevelInfo, "shutting down", map[string]any{}}, {slog.LevelInfo, "closed", map[string]any{}}, } @@ -311,7 +311,7 @@ func TestCloseLogging(t *testing.T) { records := mockHandler.GetRecords() expected := []expectedLog{ - {slog.LevelInfo, "closing", map[string]any{"state": "connected"}}, + {slog.LevelInfo, "shutting down", map[string]any{}}, {slog.LevelError, "socket close failed", map[string]any{"error": closeErr}}, }