Made connection closing non-blocking.

This commit is contained in:
Jay
2026-04-14 22:12:17 -04:00
parent 2e4f0257f5
commit b84daa1f5b
4 changed files with 81 additions and 71 deletions

View File

@@ -50,7 +50,6 @@ type Connection struct {
state ConnectionState
wg sync.WaitGroup
once sync.Once
closed bool
mu sync.RWMutex
}
@@ -186,7 +185,7 @@ func (c *Connection) startReader() {
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
case <-c.done:
}
c.Close()
c.shutdown()
return
}
}
@@ -199,7 +198,7 @@ func (c *Connection) startReader() {
case c.errors <- err:
case <-c.done:
}
c.Close()
c.shutdown()
return
}
@@ -208,7 +207,7 @@ func (c *Connection) startReader() {
select {
case c.incoming <- data:
case <-c.done:
c.Close()
c.shutdown()
return
}
}
@@ -238,7 +237,7 @@ func (c *Connection) startWriter() {
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
case <-c.done:
}
c.Close()
c.shutdown()
return
}
}
@@ -251,7 +250,7 @@ func (c *Connection) startWriter() {
case c.errors <- err:
case <-c.done:
}
c.Close()
c.shutdown()
return
}
}
@@ -271,6 +270,8 @@ func (c *Connection) Send(data []byte) error {
select {
case c.outgoing <- data:
return nil
case <-c.done:
return errors.NewConnectionError("connection closing")
default:
return errors.NewConnectionError("outgoing queue full")
}
@@ -284,34 +285,31 @@ func (c *Connection) Errors() <-chan error {
return c.errors
}
// Close shuts down the connection and waits for goroutines to exit.
// If the underlying socket blocks indefinitely on read or write operations,
// Close will also block. This is expected behavior - hung sockets require
// external intervention (timeouts, process termination, etc).
func (c *Connection) Close() error {
func (c *Connection) shutdown() {
c.mu.Lock()
alreadyClosed := c.closed
currentState := c.state
if !alreadyClosed {
if c.closed {
c.mu.Unlock()
return
}
if c.logger != nil {
c.logger.Info("closing", "state", currentState.String())
c.logger.Info("closing", "state", c.state.String())
}
c.closed = true
c.state = StateClosed
close(c.done)
}
socket := c.socket
close(c.done)
c.mu.Unlock()
if alreadyClosed {
return nil
}
var err error
go func() {
if socket != nil {
err = socket.Close()
// force immediate timeout of any blocked network I/O
expired := time.Now().Add(-1 * time.Minute)
socket.SetReadDeadline(expired)
socket.SetWriteDeadline(expired)
err := socket.Close()
if err != nil {
if c.logger != nil {
c.logger.Error("socket close failed", "error", err)
@@ -324,12 +322,15 @@ func (c *Connection) Close() error {
}
c.wg.Wait()
close(c.incoming)
close(c.outgoing)
close(c.errors)
}()
return err
}
func (c *Connection) Close() {
c.shutdown()
}
func (c *Connection) State() ConnectionState {

View File

@@ -15,8 +15,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, StateDisconnected, conn.State())
err = conn.Close()
assert.NoError(t, err)
conn.Close()
assert.Equal(t, StateClosed, conn.State())
})
@@ -24,12 +23,8 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err)
err = conn.Close()
assert.NoError(t, err)
// Second close should succeed without error
err = conn.Close()
assert.NoError(t, err)
conn.Close()
conn.Close()
assert.Equal(t, StateClosed, conn.State())
})
@@ -38,12 +33,11 @@ func TestDisconnectedConnectionClose(t *testing.T) {
assert.NoError(t, err)
assert.Nil(t, conn.socket)
err = conn.Close()
assert.NoError(t, err)
conn.Close()
assert.Equal(t, StateClosed, conn.State())
})
t.Run("socket close error propagates", func(t *testing.T) {
t.Run("socket close error does not propagate", func(t *testing.T) {
expectedErr := fmt.Errorf("socket close failed")
mockSocket := NewMockSocket()
mockSocket.CloseFunc = func() error {
@@ -54,8 +48,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
assert.NoError(t, err)
conn.socket = mockSocket
err = conn.Close()
assert.Equal(t, expectedErr, err)
conn.Close()
assert.Equal(t, StateClosed, conn.State())
})
@@ -63,8 +56,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err)
err = conn.Close()
assert.NoError(t, err)
conn.Close()
// Verify incoming channel closed
select {
@@ -95,8 +87,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err)
err = conn.Close()
assert.NoError(t, err)
conn.Close()
err = conn.Send([]byte("test"))
assert.Error(t, err)
@@ -112,8 +103,7 @@ func TestConnectedConnectionClose(t *testing.T) {
// Wait for reader to block
time.Sleep(10 * time.Millisecond)
err := conn.Close()
assert.NoError(t, err)
conn.Close()
assert.Equal(t, StateClosed, conn.State())
close(incomingData)
@@ -126,13 +116,19 @@ func TestConnectedConnectionClose(t *testing.T) {
conn.Send([]byte("message"))
}
err := conn.Close()
assert.NoError(t, err)
conn.Close()
err = conn.Send([]byte("late"))
err := conn.Send([]byte("late"))
assert.Error(t, err, "Send should fail after close")
assert.ErrorContains(t, err, "connection closed")
// wait for background closures
select {
case <-conn.Errors():
case <-time.After(500 * time.Millisecond):
t.Fatal("timed out waiting for cleanup")
}
close(outgoingData)
})
@@ -149,8 +145,14 @@ func TestConnectedConnectionClose(t *testing.T) {
time.Sleep(10 * time.Millisecond)
err := conn.Close()
assert.NoError(t, err)
conn.Close()
// wait for background closures
select {
case <-conn.Errors():
case <-time.After(500 * time.Millisecond):
t.Fatal("timed out waiting for cleanup")
}
close(incomingData)
close(outgoingData)

View File

@@ -285,9 +285,9 @@ func TestCloseLogging(t *testing.T) {
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
assert.NoError(t, err)
err = conn.Close()
assert.NoError(t, err)
conn.Close()
time.Sleep(10 * time.Millisecond)
records := mockHandler.GetRecords()
expected := []expectedLog{
@@ -311,9 +311,9 @@ func TestCloseLogging(t *testing.T) {
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
assert.NoError(t, err)
err = conn.Close()
assert.Error(t, err)
conn.Close()
time.Sleep(10 * time.Millisecond)
records := mockHandler.GetRecords()
expected := []expectedLog{
@@ -461,8 +461,7 @@ func TestLoggingDisabled(t *testing.T) {
err = conn.Connect()
assert.NoError(t, err)
err = conn.Close()
assert.NoError(t, err)
conn.Close()
records := mockHandler.GetRecords()
assert.Empty(t, records)

View File

@@ -2,6 +2,7 @@ package honeybee
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"io"
"log/slog"
@@ -119,11 +120,18 @@ func setupTestConnection(t *testing.T, config *Config) (
// Wire WriteMessage to push to outgoingData channel
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
case <-mockSocket.closed:
return io.EOF
}
default:
select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
return nil
case <-mockSocket.closed:
return io.EOF
default:
return fmt.Errorf("mock outgoing chanel unavailable")
}
}
}
var err error