Made connection closing non-blocking.
This commit is contained in:
@@ -50,7 +50,6 @@ type Connection struct {
|
||||
state ConnectionState
|
||||
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -186,7 +185,7 @@ func (c *Connection) startReader() {
|
||||
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
|
||||
case <-c.done:
|
||||
}
|
||||
c.Close()
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -199,7 +198,7 @@ func (c *Connection) startReader() {
|
||||
case c.errors <- err:
|
||||
case <-c.done:
|
||||
}
|
||||
c.Close()
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -208,7 +207,7 @@ func (c *Connection) startReader() {
|
||||
select {
|
||||
case c.incoming <- data:
|
||||
case <-c.done:
|
||||
c.Close()
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -238,7 +237,7 @@ func (c *Connection) startWriter() {
|
||||
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
|
||||
case <-c.done:
|
||||
}
|
||||
c.Close()
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -251,7 +250,7 @@ func (c *Connection) startWriter() {
|
||||
case c.errors <- err:
|
||||
case <-c.done:
|
||||
}
|
||||
c.Close()
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -271,6 +270,8 @@ func (c *Connection) Send(data []byte) error {
|
||||
select {
|
||||
case c.outgoing <- data:
|
||||
return nil
|
||||
case <-c.done:
|
||||
return errors.NewConnectionError("connection closing")
|
||||
default:
|
||||
return errors.NewConnectionError("outgoing queue full")
|
||||
}
|
||||
@@ -284,52 +285,52 @@ func (c *Connection) Errors() <-chan error {
|
||||
return c.errors
|
||||
}
|
||||
|
||||
// Close shuts down the connection and waits for goroutines to exit.
|
||||
// If the underlying socket blocks indefinitely on read or write operations,
|
||||
// Close will also block. This is expected behavior - hung sockets require
|
||||
// external intervention (timeouts, process termination, etc).
|
||||
func (c *Connection) Close() error {
|
||||
func (c *Connection) shutdown() {
|
||||
c.mu.Lock()
|
||||
|
||||
alreadyClosed := c.closed
|
||||
currentState := c.state
|
||||
if !alreadyClosed {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing", "state", currentState.String())
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
close(c.done)
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing", "state", c.state.String())
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
socket := c.socket
|
||||
close(c.done)
|
||||
c.mu.Unlock()
|
||||
|
||||
if alreadyClosed {
|
||||
return nil
|
||||
}
|
||||
go func() {
|
||||
if socket != nil {
|
||||
// force immediate timeout of any blocked network I/O
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
socket.SetReadDeadline(expired)
|
||||
socket.SetWriteDeadline(expired)
|
||||
err := socket.Close()
|
||||
|
||||
var err error
|
||||
if socket != nil {
|
||||
err = socket.Close()
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
} else {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
} else {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.wg.Wait()
|
||||
c.wg.Wait()
|
||||
close(c.incoming)
|
||||
close(c.outgoing)
|
||||
close(c.errors)
|
||||
}()
|
||||
|
||||
close(c.incoming)
|
||||
close(c.outgoing)
|
||||
close(c.errors)
|
||||
}
|
||||
|
||||
return err
|
||||
func (c *Connection) Close() {
|
||||
c.shutdown()
|
||||
}
|
||||
|
||||
func (c *Connection) State() ConnectionState {
|
||||
|
||||
@@ -15,8 +15,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
@@ -24,12 +23,8 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Second close should succeed without error
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
@@ -38,12 +33,11 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, conn.socket)
|
||||
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("socket close error propagates", func(t *testing.T) {
|
||||
t.Run("socket close error does not propagate", func(t *testing.T) {
|
||||
expectedErr := fmt.Errorf("socket close failed")
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket.CloseFunc = func() error {
|
||||
@@ -54,8 +48,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
conn.socket = mockSocket
|
||||
|
||||
err = conn.Close()
|
||||
assert.Equal(t, expectedErr, err)
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
@@ -63,8 +56,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
// Verify incoming channel closed
|
||||
select {
|
||||
@@ -95,8 +87,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.Error(t, err)
|
||||
@@ -112,8 +103,7 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
// Wait for reader to block
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
|
||||
close(incomingData)
|
||||
@@ -126,13 +116,19 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
conn.Send([]byte("message"))
|
||||
}
|
||||
|
||||
err := conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
err = conn.Send([]byte("late"))
|
||||
err := conn.Send([]byte("late"))
|
||||
assert.Error(t, err, "Send should fail after close")
|
||||
assert.ErrorContains(t, err, "connection closed")
|
||||
|
||||
// wait for background closures
|
||||
select {
|
||||
case <-conn.Errors():
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("timed out waiting for cleanup")
|
||||
}
|
||||
|
||||
close(outgoingData)
|
||||
})
|
||||
|
||||
@@ -149,8 +145,14 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
// wait for background closures
|
||||
select {
|
||||
case <-conn.Errors():
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("timed out waiting for cleanup")
|
||||
}
|
||||
|
||||
close(incomingData)
|
||||
close(outgoingData)
|
||||
|
||||
@@ -285,9 +285,9 @@ func TestCloseLogging(t *testing.T) {
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
@@ -311,9 +311,9 @@ func TestCloseLogging(t *testing.T) {
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
assert.Error(t, err)
|
||||
conn.Close()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
@@ -461,8 +461,7 @@ func TestLoggingDisabled(t *testing.T) {
|
||||
err = conn.Connect()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
assert.Empty(t, records)
|
||||
|
||||
@@ -2,6 +2,7 @@ package honeybee
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -119,11 +120,18 @@ func setupTestConnection(t *testing.T, config *Config) (
|
||||
// Wire WriteMessage to push to outgoingData channel
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||
case <-mockSocket.closed:
|
||||
return io.EOF
|
||||
default:
|
||||
select {
|
||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||
return nil
|
||||
case <-mockSocket.closed:
|
||||
return io.EOF
|
||||
default:
|
||||
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
Reference in New Issue
Block a user