Refactored connection shutdown logic.

This commit is contained in:
Jay
2026-04-19 09:29:12 -04:00
parent 72f0793047
commit 6998ccf701
7 changed files with 175 additions and 161 deletions
-9
View File
@@ -51,12 +51,3 @@ func setupWorkerTestConnection(t *testing.T) (
assert.NoError(t, err) assert.NoError(t, err)
return return
} }
func connClosed(conn *transport.Connection) bool {
select {
case _, ok := <-conn.Errors():
return !ok
default:
return false
}
}
+25 -12
View File
@@ -15,7 +15,7 @@ import (
"time" "time"
) )
func TestRunSession(t *testing.T) { func TestRunSessionDial(t *testing.T) {
} }
@@ -127,14 +127,19 @@ func TestRunReader(t *testing.T) {
}() }()
go w.runReader(conn, messages, sessionDone, onStop) go w.runReader(conn, messages, sessionDone, onStop)
// simulate remote close // induce connection closure via reader
incomingData <- honeybeetest.MockIncomingData{Err: io.EOF} incomingData <- honeybeetest.MockIncomingData{Err: io.EOF}
err := <-conn.Errors()
assert.Equal(t, io.EOF, err)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return connClosed(conn) return conn.State() == transport.StateClosed
}, honeybeetest.TestTimeout, honeybeetest.TestTick) }, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.True(t, onStopCalled.Load()) assert.Eventually(t, func() bool {
return onStopCalled.Load()
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
}) })
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
@@ -157,10 +162,12 @@ func TestRunReader(t *testing.T) {
close(sessionDone) close(sessionDone)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return connClosed(conn) return conn.State() == transport.StateClosed
}, honeybeetest.TestTimeout, honeybeetest.TestTick) }, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.True(t, onStopCalled.Load()) assert.Eventually(t, func() bool {
return onStopCalled.Load()
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
}) })
} }
@@ -181,10 +188,12 @@ func TestRunStopMonitor(t *testing.T) {
keepalive <- struct{}{} keepalive <- struct{}{}
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return connClosed(conn) return conn.State() == transport.StateClosed
}, honeybeetest.TestTimeout, honeybeetest.TestTick) }, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.True(t, onStopCalled.Load()) assert.Eventually(t, func() bool {
return onStopCalled.Load()
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
}) })
t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) { t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) {
@@ -202,10 +211,12 @@ func TestRunStopMonitor(t *testing.T) {
cancel() cancel()
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return connClosed(conn) return conn.State() == transport.StateClosed
}, honeybeetest.TestTimeout, honeybeetest.TestTick) }, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.True(t, onStopCalled.Load()) assert.Eventually(t, func() bool {
return onStopCalled.Load()
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
}) })
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
@@ -224,10 +235,12 @@ func TestRunStopMonitor(t *testing.T) {
close(sessionDone) close(sessionDone)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return connClosed(conn) return conn.State() == transport.StateClosed
}, honeybeetest.TestTimeout, honeybeetest.TestTick) }, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.True(t, onStopCalled.Load()) assert.Eventually(t, func() bool {
return onStopCalled.Load()
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
}) })
} }
+141 -86
View File
@@ -50,10 +50,12 @@ type Connection struct {
state ConnectionState state ConnectionState
wg sync.WaitGroup wg sync.WaitGroup
closed bool closed bool
mu sync.RWMutex mu sync.RWMutex
writeMu sync.Mutex writeMu sync.Mutex
doneOnce sync.Once
cleanupOnce sync.Once
} }
func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
@@ -167,51 +169,151 @@ func (c *Connection) Connect(ctx context.Context) error {
return nil return nil
} }
func (c *Connection) Close() {
c.shutdownExternal()
}
func (c *Connection) shutdownExternal() {
err := c.shutdownSetClosed(true)
if err != nil {
return
}
c.shutdownInner()
c.shutdownCleanup()
}
func (c *Connection) shutdownInternal() {
err := c.shutdownSetClosed(false)
if err != nil {
return
}
c.shutdownInner()
// defer final cleanup to allow this function to return
// otherwise, a deadlock occurs where startReader triggers a shutdown and
// must wait for itself to exit.
go func() {
c.shutdownCleanup()
}()
}
func (c *Connection) shutdownInner() {
c.shutdownSignalDone()
c.shutdownLogStart()
c.shutdownCloseSocket()
}
func (c *Connection) shutdownCleanup() {
c.cleanupOnce.Do(func() {
c.wg.Wait()
c.shutdownCloseChannels()
c.shutdownLogComplete()
})
}
func (c *Connection) shutdownSetClosed(wait bool) error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return ErrConnectionClosed
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
return nil
}
func (c *Connection) shutdownSignalDone() {
c.doneOnce.Do(func() {
close(c.done)
})
}
func (c *Connection) shutdownLogStart() {
if c.logger != nil {
c.logger.Info("shutting down")
}
}
func (c *Connection) shutdownCloseSocket() {
if c.socket != nil {
// force unblock of any network operations immediately
expired := time.Now().Add(-1 * time.Minute)
c.socket.SetReadDeadline(expired)
c.socket.SetWriteDeadline(expired)
// close socket
err := c.socket.Close()
if err != nil && c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
}
}
func (c *Connection) shutdownCloseChannels() {
close(c.incoming)
close(c.errors)
}
func (c *Connection) shutdownLogComplete() {
if c.logger != nil {
c.logger.Info("connection closed")
}
}
func (c *Connection) startReader() { func (c *Connection) startReader() {
c.wg.Add(1) c.wg.Add(1)
go func() { go func() {
defer c.wg.Done() defer c.wg.Done()
defer c.shutdownInternal()
for { for {
messageType, data, err := c.socket.ReadMessage() select {
if err != nil { case <-c.done:
if c.logger != nil {
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
default:
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
} else {
c.logger.Error("read error", "error", err)
}
}
select {
case c.errors <- err:
case <-c.done:
}
c.shutdown()
return return
} default:
messageType, data, err := c.socket.ReadMessage()
if err != nil {
if c.logger != nil {
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
default:
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
} else {
c.logger.Error("read error", "error", err)
}
}
if messageType == websocket.TextMessage || select {
messageType == websocket.BinaryMessage { case <-c.done:
select { case c.errors <- err:
case c.incoming <- data: }
case <-c.done:
c.shutdown()
return return
} }
}
if messageType == websocket.TextMessage ||
messageType == websocket.BinaryMessage {
select {
case <-c.done:
return
case c.incoming <- data:
}
}
}
} }
}() }()
@@ -230,7 +332,7 @@ func (c *Connection) Send(data []byte) error {
if c.logger != nil { if c.logger != nil {
c.logger.Error("write deadline error", "error", err) c.logger.Error("write deadline error", "error", err)
} }
c.shutdown() c.shutdownExternal()
return fmt.Errorf("failed to set write deadline: %w", err) return fmt.Errorf("failed to set write deadline: %w", err)
} }
} }
@@ -253,53 +355,6 @@ func (c *Connection) Errors() <-chan error {
return c.errors return c.errors
} }
func (c *Connection) shutdown() {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
if c.logger != nil {
c.logger.Info("closing", "state", c.state.String())
}
c.closed = true
c.state = StateClosed
socket := c.socket
close(c.done)
c.mu.Unlock()
go func() {
if socket != nil {
// force immediate timeout of any blocked network I/O
expired := time.Now().Add(-1 * time.Minute)
socket.SetReadDeadline(expired)
socket.SetWriteDeadline(expired)
err := socket.Close()
if err != nil {
if c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
} else {
if c.logger != nil {
c.logger.Info("closed")
}
}
}
c.wg.Wait()
close(c.incoming)
close(c.errors)
}()
}
func (c *Connection) Close() {
c.shutdown()
}
func (c *Connection) State() ConnectionState { func (c *Connection) State() ConnectionState {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
+5 -37
View File
@@ -58,15 +58,11 @@ func TestDisconnectedConnectionClose(t *testing.T) {
conn.Close() conn.Close()
assert.Eventually(t, func() bool { assert.True(t, conn.closed)
select { _, ok := <-conn.incoming
case _, ok := <-conn.Errors(): assert.False(t, ok)
return !ok _, ok = <-conn.errors
default: assert.False(t, ok)
return false
}
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
"errors channel should close")
}) })
t.Run("send fails after close", func(t *testing.T) { t.Run("send fails after close", func(t *testing.T) {
@@ -103,32 +99,4 @@ func TestConnectedConnectionClose(t *testing.T) {
conn.Close() conn.Close()
assert.Equal(t, StateClosed, conn.State()) assert.Equal(t, StateClosed, conn.State())
}) })
t.Run("writer active during close exits cleanly", func(t *testing.T) {
conn, _, _, _ := setupTestConnection(t, nil)
for i := 0; i < 50; i++ {
conn.Send([]byte("message"))
}
conn.Close()
err := conn.Send([]byte("late"))
assert.Error(t, err, "Send should fail after close")
assert.ErrorContains(t, err, "connection closed")
})
t.Run("both goroutines active during close", func(t *testing.T) {
conn, _, incomingData, _ := setupTestConnection(t, nil)
for i := 0; i < 10; i++ {
incomingData <- honeybeetest.MockIncomingData{
MsgType: websocket.TextMessage,
Data: []byte(fmt.Sprintf("in-%d", i)),
}
conn.Send([]byte(fmt.Sprintf("out-%d", i)))
}
conn.Close()
})
} }
+2 -13
View File
@@ -1,10 +1,10 @@
package transport package transport
import ( import (
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io"
"testing" "testing"
) )
@@ -62,23 +62,12 @@ func TestStartReader(t *testing.T) {
return nil return nil
} }
readErr := fmt.Errorf("read failed")
mockSocket.ReadMessageFunc = func() (int, []byte, error) { mockSocket.ReadMessageFunc = func() (int, []byte, error) {
return 0, nil, readErr return 0, nil, io.EOF
} }
conn, err := NewConnectionFromSocket(mockSocket, nil, nil) conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
defer conn.Close()
assert.Eventually(t, func() bool {
select {
case err := <-conn.Errors():
return err == readErr
default:
return false
}
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return conn.State() == StateClosed return conn.State() == StateClosed
-2
View File
@@ -50,7 +50,6 @@ func TestConnectionSend(t *testing.T) {
for { for {
select { select {
case msg := <-outgoingData: case msg := <-outgoingData:
fmt.Printf("got message %s\n", string(msg.Data))
mu.Lock() mu.Lock()
messages = append(messages, string(msg.Data)) messages = append(messages, string(msg.Data))
mu.Unlock() mu.Unlock()
@@ -69,7 +68,6 @@ func TestConnectionSend(t *testing.T) {
defer wg.Done() defer wg.Done()
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
data := []byte(fmt.Sprintf("msg-%d-%d", id, j)) data := []byte(fmt.Sprintf("msg-%d-%d", id, j))
fmt.Printf("sending message %s\n", string(data))
for { for {
// send and retry until success // send and retry until success
err := conn.Send(data) err := conn.Send(data)
+2 -2
View File
@@ -281,7 +281,7 @@ func TestCloseLogging(t *testing.T) {
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []expectedLog{ expected := []expectedLog{
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}}, {slog.LevelInfo, "shutting down", map[string]any{}},
{slog.LevelInfo, "closed", map[string]any{}}, {slog.LevelInfo, "closed", map[string]any{}},
} }
@@ -311,7 +311,7 @@ func TestCloseLogging(t *testing.T) {
records := mockHandler.GetRecords() records := mockHandler.GetRecords()
expected := []expectedLog{ expected := []expectedLog{
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}}, {slog.LevelInfo, "shutting down", map[string]any{}},
{slog.LevelError, "socket close failed", map[string]any{"error": closeErr}}, {slog.LevelError, "socket close failed", map[string]any{"error": closeErr}},
} }