Made connection closing non-blocking.
This commit is contained in:
@@ -50,7 +50,6 @@ type Connection struct {
|
|||||||
state ConnectionState
|
state ConnectionState
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
once sync.Once
|
|
||||||
closed bool
|
closed bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
@@ -186,7 +185,7 @@ func (c *Connection) startReader() {
|
|||||||
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
|
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
}
|
}
|
||||||
c.Close()
|
c.shutdown()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,7 +198,7 @@ func (c *Connection) startReader() {
|
|||||||
case c.errors <- err:
|
case c.errors <- err:
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
}
|
}
|
||||||
c.Close()
|
c.shutdown()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,7 +207,7 @@ func (c *Connection) startReader() {
|
|||||||
select {
|
select {
|
||||||
case c.incoming <- data:
|
case c.incoming <- data:
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
c.Close()
|
c.shutdown()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -238,7 +237,7 @@ func (c *Connection) startWriter() {
|
|||||||
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
|
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
}
|
}
|
||||||
c.Close()
|
c.shutdown()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -251,7 +250,7 @@ func (c *Connection) startWriter() {
|
|||||||
case c.errors <- err:
|
case c.errors <- err:
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
}
|
}
|
||||||
c.Close()
|
c.shutdown()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -271,6 +270,8 @@ func (c *Connection) Send(data []byte) error {
|
|||||||
select {
|
select {
|
||||||
case c.outgoing <- data:
|
case c.outgoing <- data:
|
||||||
return nil
|
return nil
|
||||||
|
case <-c.done:
|
||||||
|
return errors.NewConnectionError("connection closing")
|
||||||
default:
|
default:
|
||||||
return errors.NewConnectionError("outgoing queue full")
|
return errors.NewConnectionError("outgoing queue full")
|
||||||
}
|
}
|
||||||
@@ -284,52 +285,52 @@ func (c *Connection) Errors() <-chan error {
|
|||||||
return c.errors
|
return c.errors
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close shuts down the connection and waits for goroutines to exit.
|
func (c *Connection) shutdown() {
|
||||||
// If the underlying socket blocks indefinitely on read or write operations,
|
|
||||||
// Close will also block. This is expected behavior - hung sockets require
|
|
||||||
// external intervention (timeouts, process termination, etc).
|
|
||||||
func (c *Connection) Close() error {
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
|
||||||
alreadyClosed := c.closed
|
if c.closed {
|
||||||
currentState := c.state
|
c.mu.Unlock()
|
||||||
if !alreadyClosed {
|
return
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Info("closing", "state", currentState.String())
|
|
||||||
}
|
|
||||||
c.closed = true
|
|
||||||
c.state = StateClosed
|
|
||||||
close(c.done)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("closing", "state", c.state.String())
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
c.state = StateClosed
|
||||||
socket := c.socket
|
socket := c.socket
|
||||||
|
close(c.done)
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if alreadyClosed {
|
go func() {
|
||||||
return nil
|
if socket != nil {
|
||||||
}
|
// force immediate timeout of any blocked network I/O
|
||||||
|
expired := time.Now().Add(-1 * time.Minute)
|
||||||
|
socket.SetReadDeadline(expired)
|
||||||
|
socket.SetWriteDeadline(expired)
|
||||||
|
err := socket.Close()
|
||||||
|
|
||||||
var err error
|
if err != nil {
|
||||||
if socket != nil {
|
if c.logger != nil {
|
||||||
err = socket.Close()
|
c.logger.Error("socket close failed", "error", err)
|
||||||
if err != nil {
|
}
|
||||||
if c.logger != nil {
|
} else {
|
||||||
c.logger.Error("socket close failed", "error", err)
|
if c.logger != nil {
|
||||||
}
|
c.logger.Info("closed")
|
||||||
} else {
|
}
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Info("closed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
c.wg.Wait()
|
c.wg.Wait()
|
||||||
|
close(c.incoming)
|
||||||
|
close(c.outgoing)
|
||||||
|
close(c.errors)
|
||||||
|
}()
|
||||||
|
|
||||||
close(c.incoming)
|
}
|
||||||
close(c.outgoing)
|
|
||||||
close(c.errors)
|
|
||||||
|
|
||||||
return err
|
func (c *Connection) Close() {
|
||||||
|
c.shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) State() ConnectionState {
|
func (c *Connection) State() ConnectionState {
|
||||||
|
|||||||
@@ -15,8 +15,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, StateDisconnected, conn.State())
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, StateClosed, conn.State())
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -24,12 +23,8 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
conn, err := NewConnection("ws://test", nil, nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
conn.Close()
|
||||||
|
|
||||||
// Second close should succeed without error
|
|
||||||
err = conn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, StateClosed, conn.State())
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -38,12 +33,11 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Nil(t, conn.socket)
|
assert.Nil(t, conn.socket)
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, StateClosed, conn.State())
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("socket close error propagates", func(t *testing.T) {
|
t.Run("socket close error does not propagate", func(t *testing.T) {
|
||||||
expectedErr := fmt.Errorf("socket close failed")
|
expectedErr := fmt.Errorf("socket close failed")
|
||||||
mockSocket := NewMockSocket()
|
mockSocket := NewMockSocket()
|
||||||
mockSocket.CloseFunc = func() error {
|
mockSocket.CloseFunc = func() error {
|
||||||
@@ -54,8 +48,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
conn.socket = mockSocket
|
conn.socket = mockSocket
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.Equal(t, expectedErr, err)
|
|
||||||
assert.Equal(t, StateClosed, conn.State())
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -63,8 +56,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
conn, err := NewConnection("ws://test", nil, nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify incoming channel closed
|
// Verify incoming channel closed
|
||||||
select {
|
select {
|
||||||
@@ -95,8 +87,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
conn, err := NewConnection("ws://test", nil, nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = conn.Send([]byte("test"))
|
err = conn.Send([]byte("test"))
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -112,8 +103,7 @@ func TestConnectedConnectionClose(t *testing.T) {
|
|||||||
// Wait for reader to block
|
// Wait for reader to block
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
err := conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, StateClosed, conn.State())
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
|
||||||
close(incomingData)
|
close(incomingData)
|
||||||
@@ -126,13 +116,19 @@ func TestConnectedConnectionClose(t *testing.T) {
|
|||||||
conn.Send([]byte("message"))
|
conn.Send([]byte("message"))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = conn.Send([]byte("late"))
|
err := conn.Send([]byte("late"))
|
||||||
assert.Error(t, err, "Send should fail after close")
|
assert.Error(t, err, "Send should fail after close")
|
||||||
assert.ErrorContains(t, err, "connection closed")
|
assert.ErrorContains(t, err, "connection closed")
|
||||||
|
|
||||||
|
// wait for background closures
|
||||||
|
select {
|
||||||
|
case <-conn.Errors():
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("timed out waiting for cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
close(outgoingData)
|
close(outgoingData)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -149,8 +145,14 @@ func TestConnectedConnectionClose(t *testing.T) {
|
|||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
err := conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
// wait for background closures
|
||||||
|
select {
|
||||||
|
case <-conn.Errors():
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("timed out waiting for cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
close(incomingData)
|
close(incomingData)
|
||||||
close(outgoingData)
|
close(outgoingData)
|
||||||
|
|||||||
@@ -285,9 +285,9 @@ func TestCloseLogging(t *testing.T) {
|
|||||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
records := mockHandler.GetRecords()
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
expected := []expectedLog{
|
expected := []expectedLog{
|
||||||
@@ -311,9 +311,9 @@ func TestCloseLogging(t *testing.T) {
|
|||||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
records := mockHandler.GetRecords()
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
expected := []expectedLog{
|
expected := []expectedLog{
|
||||||
@@ -461,8 +461,7 @@ func TestLoggingDisabled(t *testing.T) {
|
|||||||
err = conn.Connect()
|
err = conn.Connect()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
conn.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
records := mockHandler.GetRecords()
|
records := mockHandler.GetRecords()
|
||||||
assert.Empty(t, records)
|
assert.Empty(t, records)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package honeybee
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -119,11 +120,18 @@ func setupTestConnection(t *testing.T, config *Config) (
|
|||||||
// Wire WriteMessage to push to outgoingData channel
|
// Wire WriteMessage to push to outgoingData channel
|
||||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||||
select {
|
select {
|
||||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
|
||||||
case <-mockSocket.closed:
|
case <-mockSocket.closed:
|
||||||
return io.EOF
|
return io.EOF
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||||
|
return nil
|
||||||
|
case <-mockSocket.closed:
|
||||||
|
return io.EOF
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
Reference in New Issue
Block a user