From b84daa1f5b3e2f1b46150b1e3116ea3b520e8be3 Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 14 Apr 2026 22:12:17 -0400 Subject: [PATCH] Made connection closing non-blocking. --- connection.go | 79 ++++++++++++++++++++-------------------- connection_close_test.go | 50 +++++++++++++------------ logging_test.go | 11 +++--- mocks_test.go | 12 +++++- 4 files changed, 81 insertions(+), 71 deletions(-) diff --git a/connection.go b/connection.go index 7862fa3..f6507e6 100644 --- a/connection.go +++ b/connection.go @@ -50,7 +50,6 @@ type Connection struct { state ConnectionState wg sync.WaitGroup - once sync.Once closed bool mu sync.RWMutex } @@ -186,7 +185,7 @@ func (c *Connection) startReader() { case c.errors <- fmt.Errorf("failed to set read deadline: %w", err): case <-c.done: } - c.Close() + c.shutdown() return } } @@ -199,7 +198,7 @@ func (c *Connection) startReader() { case c.errors <- err: case <-c.done: } - c.Close() + c.shutdown() return } @@ -208,7 +207,7 @@ func (c *Connection) startReader() { select { case c.incoming <- data: case <-c.done: - c.Close() + c.shutdown() return } } @@ -238,7 +237,7 @@ func (c *Connection) startWriter() { case c.errors <- fmt.Errorf("failed to set write deadline: %w", err): case <-c.done: } - c.Close() + c.shutdown() return } } @@ -251,7 +250,7 @@ func (c *Connection) startWriter() { case c.errors <- err: case <-c.done: } - c.Close() + c.shutdown() return } } @@ -271,6 +270,8 @@ func (c *Connection) Send(data []byte) error { select { case c.outgoing <- data: return nil + case <-c.done: + return errors.NewConnectionError("connection closing") default: return errors.NewConnectionError("outgoing queue full") } @@ -284,52 +285,52 @@ func (c *Connection) Errors() <-chan error { return c.errors } -// Close shuts down the connection and waits for goroutines to exit. -// If the underlying socket blocks indefinitely on read or write operations, -// Close will also block. This is expected behavior - hung sockets require -// external intervention (timeouts, process termination, etc). -func (c *Connection) Close() error { +func (c *Connection) shutdown() { c.mu.Lock() - alreadyClosed := c.closed - currentState := c.state - if !alreadyClosed { - if c.logger != nil { - c.logger.Info("closing", "state", currentState.String()) - } - c.closed = true - c.state = StateClosed - close(c.done) + if c.closed { + c.mu.Unlock() + return } + if c.logger != nil { + c.logger.Info("closing", "state", c.state.String()) + } + c.closed = true + c.state = StateClosed socket := c.socket + close(c.done) c.mu.Unlock() - if alreadyClosed { - return nil - } + go func() { + if socket != nil { + // force immediate timeout of any blocked network I/O + expired := time.Now().Add(-1 * time.Minute) + socket.SetReadDeadline(expired) + socket.SetWriteDeadline(expired) + err := socket.Close() - var err error - if socket != nil { - err = socket.Close() - if err != nil { - if c.logger != nil { - c.logger.Error("socket close failed", "error", err) - } - } else { - if c.logger != nil { - c.logger.Info("closed") + if err != nil { + if c.logger != nil { + c.logger.Error("socket close failed", "error", err) + } + } else { + if c.logger != nil { + c.logger.Info("closed") + } } } - } - c.wg.Wait() + c.wg.Wait() + close(c.incoming) + close(c.outgoing) + close(c.errors) + }() - close(c.incoming) - close(c.outgoing) - close(c.errors) +} - return err +func (c *Connection) Close() { + c.shutdown() } func (c *Connection) State() ConnectionState { diff --git a/connection_close_test.go b/connection_close_test.go index b9588a2..8a9fb01 100644 --- a/connection_close_test.go +++ b/connection_close_test.go @@ -15,8 +15,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { assert.NoError(t, err) assert.Equal(t, StateDisconnected, conn.State()) - err = conn.Close() - assert.NoError(t, err) + conn.Close() assert.Equal(t, StateClosed, conn.State()) }) @@ -24,12 +23,8 @@ func TestDisconnectedConnectionClose(t *testing.T) { conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) - err = conn.Close() - assert.NoError(t, err) - - // Second close should succeed without error - err = conn.Close() - assert.NoError(t, err) + conn.Close() + conn.Close() assert.Equal(t, StateClosed, conn.State()) }) @@ -38,12 +33,11 @@ func TestDisconnectedConnectionClose(t *testing.T) { assert.NoError(t, err) assert.Nil(t, conn.socket) - err = conn.Close() - assert.NoError(t, err) + conn.Close() assert.Equal(t, StateClosed, conn.State()) }) - t.Run("socket close error propagates", func(t *testing.T) { + t.Run("socket close error does not propagate", func(t *testing.T) { expectedErr := fmt.Errorf("socket close failed") mockSocket := NewMockSocket() mockSocket.CloseFunc = func() error { @@ -54,8 +48,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { assert.NoError(t, err) conn.socket = mockSocket - err = conn.Close() - assert.Equal(t, expectedErr, err) + conn.Close() assert.Equal(t, StateClosed, conn.State()) }) @@ -63,8 +56,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) - err = conn.Close() - assert.NoError(t, err) + conn.Close() // Verify incoming channel closed select { @@ -95,8 +87,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) - err = conn.Close() - assert.NoError(t, err) + conn.Close() err = conn.Send([]byte("test")) assert.Error(t, err) @@ -112,8 +103,7 @@ func TestConnectedConnectionClose(t *testing.T) { // Wait for reader to block time.Sleep(10 * time.Millisecond) - err := conn.Close() - assert.NoError(t, err) + conn.Close() assert.Equal(t, StateClosed, conn.State()) close(incomingData) @@ -126,13 +116,19 @@ func TestConnectedConnectionClose(t *testing.T) { conn.Send([]byte("message")) } - err := conn.Close() - assert.NoError(t, err) + conn.Close() - err = conn.Send([]byte("late")) + err := conn.Send([]byte("late")) assert.Error(t, err, "Send should fail after close") assert.ErrorContains(t, err, "connection closed") + // wait for background closures + select { + case <-conn.Errors(): + case <-time.After(500 * time.Millisecond): + t.Fatal("timed out waiting for cleanup") + } + close(outgoingData) }) @@ -149,8 +145,14 @@ func TestConnectedConnectionClose(t *testing.T) { time.Sleep(10 * time.Millisecond) - err := conn.Close() - assert.NoError(t, err) + conn.Close() + + // wait for background closures + select { + case <-conn.Errors(): + case <-time.After(500 * time.Millisecond): + t.Fatal("timed out waiting for cleanup") + } close(incomingData) close(outgoingData) diff --git a/logging_test.go b/logging_test.go index 8e1b888..7225359 100644 --- a/logging_test.go +++ b/logging_test.go @@ -285,9 +285,9 @@ func TestCloseLogging(t *testing.T) { conn, err := NewConnectionFromSocket(mockSocket, nil, logger) assert.NoError(t, err) - err = conn.Close() - assert.NoError(t, err) + conn.Close() + time.Sleep(10 * time.Millisecond) records := mockHandler.GetRecords() expected := []expectedLog{ @@ -311,9 +311,9 @@ func TestCloseLogging(t *testing.T) { conn, err := NewConnectionFromSocket(mockSocket, nil, logger) assert.NoError(t, err) - err = conn.Close() - assert.Error(t, err) + conn.Close() + time.Sleep(10 * time.Millisecond) records := mockHandler.GetRecords() expected := []expectedLog{ @@ -461,8 +461,7 @@ func TestLoggingDisabled(t *testing.T) { err = conn.Connect() assert.NoError(t, err) - err = conn.Close() - assert.NoError(t, err) + conn.Close() records := mockHandler.GetRecords() assert.Empty(t, records) diff --git a/mocks_test.go b/mocks_test.go index 46a8a77..d8569ac 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -2,6 +2,7 @@ package honeybee import ( "context" + "fmt" "github.com/stretchr/testify/assert" "io" "log/slog" @@ -119,11 +120,18 @@ func setupTestConnection(t *testing.T, config *Config) ( // Wire WriteMessage to push to outgoingData channel mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { select { - case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: case <-mockSocket.closed: return io.EOF + default: + select { + case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: + return nil + case <-mockSocket.closed: + return io.EOF + default: + return fmt.Errorf("mock outgoing chanel unavailable") + } } - return nil } var err error