Made connection closing non-blocking.

This commit is contained in:
Jay
2026-04-14 22:12:17 -04:00
parent 2e4f0257f5
commit b84daa1f5b
4 changed files with 81 additions and 71 deletions

View File

@@ -50,7 +50,6 @@ type Connection struct {
state ConnectionState state ConnectionState
wg sync.WaitGroup wg sync.WaitGroup
once sync.Once
closed bool closed bool
mu sync.RWMutex mu sync.RWMutex
} }
@@ -186,7 +185,7 @@ func (c *Connection) startReader() {
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err): case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
case <-c.done: case <-c.done:
} }
c.Close() c.shutdown()
return return
} }
} }
@@ -199,7 +198,7 @@ func (c *Connection) startReader() {
case c.errors <- err: case c.errors <- err:
case <-c.done: case <-c.done:
} }
c.Close() c.shutdown()
return return
} }
@@ -208,7 +207,7 @@ func (c *Connection) startReader() {
select { select {
case c.incoming <- data: case c.incoming <- data:
case <-c.done: case <-c.done:
c.Close() c.shutdown()
return return
} }
} }
@@ -238,7 +237,7 @@ func (c *Connection) startWriter() {
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err): case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
case <-c.done: case <-c.done:
} }
c.Close() c.shutdown()
return return
} }
} }
@@ -251,7 +250,7 @@ func (c *Connection) startWriter() {
case c.errors <- err: case c.errors <- err:
case <-c.done: case <-c.done:
} }
c.Close() c.shutdown()
return return
} }
} }
@@ -271,6 +270,8 @@ func (c *Connection) Send(data []byte) error {
select { select {
case c.outgoing <- data: case c.outgoing <- data:
return nil return nil
case <-c.done:
return errors.NewConnectionError("connection closing")
default: default:
return errors.NewConnectionError("outgoing queue full") return errors.NewConnectionError("outgoing queue full")
} }
@@ -284,34 +285,31 @@ func (c *Connection) Errors() <-chan error {
return c.errors return c.errors
} }
// Close shuts down the connection and waits for goroutines to exit. func (c *Connection) shutdown() {
// If the underlying socket blocks indefinitely on read or write operations,
// Close will also block. This is expected behavior - hung sockets require
// external intervention (timeouts, process termination, etc).
func (c *Connection) Close() error {
c.mu.Lock() c.mu.Lock()
alreadyClosed := c.closed if c.closed {
currentState := c.state c.mu.Unlock()
if !alreadyClosed { return
}
if c.logger != nil { if c.logger != nil {
c.logger.Info("closing", "state", currentState.String()) c.logger.Info("closing", "state", c.state.String())
} }
c.closed = true c.closed = true
c.state = StateClosed c.state = StateClosed
close(c.done)
}
socket := c.socket socket := c.socket
close(c.done)
c.mu.Unlock() c.mu.Unlock()
if alreadyClosed { go func() {
return nil
}
var err error
if socket != nil { if socket != nil {
err = socket.Close() // force immediate timeout of any blocked network I/O
expired := time.Now().Add(-1 * time.Minute)
socket.SetReadDeadline(expired)
socket.SetWriteDeadline(expired)
err := socket.Close()
if err != nil { if err != nil {
if c.logger != nil { if c.logger != nil {
c.logger.Error("socket close failed", "error", err) c.logger.Error("socket close failed", "error", err)
@@ -324,12 +322,15 @@ func (c *Connection) Close() error {
} }
c.wg.Wait() c.wg.Wait()
close(c.incoming) close(c.incoming)
close(c.outgoing) close(c.outgoing)
close(c.errors) close(c.errors)
}()
return err }
func (c *Connection) Close() {
c.shutdown()
} }
func (c *Connection) State() ConnectionState { func (c *Connection) State() ConnectionState {

View File

@@ -15,8 +15,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, StateDisconnected, conn.State()) assert.Equal(t, StateDisconnected, conn.State())
err = conn.Close() conn.Close()
assert.NoError(t, err)
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
}) })
@@ -24,12 +23,8 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil) conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
err = conn.Close() conn.Close()
assert.NoError(t, err) conn.Close()
// Second close should succeed without error
err = conn.Close()
assert.NoError(t, err)
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
}) })
@@ -38,12 +33,11 @@ func TestDisconnectedConnectionClose(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, conn.socket) assert.Nil(t, conn.socket)
err = conn.Close() conn.Close()
assert.NoError(t, err)
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
}) })
t.Run("socket close error propagates", func(t *testing.T) { t.Run("socket close error does not propagate", func(t *testing.T) {
expectedErr := fmt.Errorf("socket close failed") expectedErr := fmt.Errorf("socket close failed")
mockSocket := NewMockSocket() mockSocket := NewMockSocket()
mockSocket.CloseFunc = func() error { mockSocket.CloseFunc = func() error {
@@ -54,8 +48,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
conn.socket = mockSocket conn.socket = mockSocket
err = conn.Close() conn.Close()
assert.Equal(t, expectedErr, err)
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
}) })
@@ -63,8 +56,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil) conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
err = conn.Close() conn.Close()
assert.NoError(t, err)
// Verify incoming channel closed // Verify incoming channel closed
select { select {
@@ -95,8 +87,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil) conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
err = conn.Close() conn.Close()
assert.NoError(t, err)
err = conn.Send([]byte("test")) err = conn.Send([]byte("test"))
assert.Error(t, err) assert.Error(t, err)
@@ -112,8 +103,7 @@ func TestConnectedConnectionClose(t *testing.T) {
// Wait for reader to block // Wait for reader to block
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
err := conn.Close() conn.Close()
assert.NoError(t, err)
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
close(incomingData) close(incomingData)
@@ -126,13 +116,19 @@ func TestConnectedConnectionClose(t *testing.T) {
conn.Send([]byte("message")) conn.Send([]byte("message"))
} }
err := conn.Close() conn.Close()
assert.NoError(t, err)
err = conn.Send([]byte("late")) err := conn.Send([]byte("late"))
assert.Error(t, err, "Send should fail after close") assert.Error(t, err, "Send should fail after close")
assert.ErrorContains(t, err, "connection closed") assert.ErrorContains(t, err, "connection closed")
// wait for background closures
select {
case <-conn.Errors():
case <-time.After(500 * time.Millisecond):
t.Fatal("timed out waiting for cleanup")
}
close(outgoingData) close(outgoingData)
}) })
@@ -149,8 +145,14 @@ func TestConnectedConnectionClose(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
err := conn.Close() conn.Close()
assert.NoError(t, err)
// wait for background closures
select {
case <-conn.Errors():
case <-time.After(500 * time.Millisecond):
t.Fatal("timed out waiting for cleanup")
}
close(incomingData) close(incomingData)
close(outgoingData) close(outgoingData)

View File

@@ -285,9 +285,9 @@ func TestCloseLogging(t *testing.T) {
conn, err := NewConnectionFromSocket(mockSocket, nil, logger) conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
assert.NoError(t, err) assert.NoError(t, err)
err = conn.Close() conn.Close()
assert.NoError(t, err)
time.Sleep(10 * time.Millisecond)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []expectedLog{ expected := []expectedLog{
@@ -311,9 +311,9 @@ func TestCloseLogging(t *testing.T) {
conn, err := NewConnectionFromSocket(mockSocket, nil, logger) conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
assert.NoError(t, err) assert.NoError(t, err)
err = conn.Close() conn.Close()
assert.Error(t, err)
time.Sleep(10 * time.Millisecond)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []expectedLog{ expected := []expectedLog{
@@ -461,8 +461,7 @@ func TestLoggingDisabled(t *testing.T) {
err = conn.Connect() err = conn.Connect()
assert.NoError(t, err) assert.NoError(t, err)
err = conn.Close() conn.Close()
assert.NoError(t, err)
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
assert.Empty(t, records) assert.Empty(t, records)

View File

@@ -2,6 +2,7 @@ package honeybee
import ( import (
"context" "context"
"fmt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io" "io"
"log/slog" "log/slog"
@@ -119,11 +120,18 @@ func setupTestConnection(t *testing.T, config *Config) (
// Wire WriteMessage to push to outgoingData channel // Wire WriteMessage to push to outgoingData channel
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
select { select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
case <-mockSocket.closed: case <-mockSocket.closed:
return io.EOF return io.EOF
} default:
select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
return nil return nil
case <-mockSocket.closed:
return io.EOF
default:
return fmt.Errorf("mock outgoing chanel unavailable")
}
}
} }
var err error var err error