Refactored connection shutdown logic.
This commit is contained in:
@@ -51,12 +51,3 @@ func setupWorkerTestConnection(t *testing.T) (
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func connClosed(conn *transport.Connection) bool {
|
|
||||||
select {
|
|
||||||
case _, ok := <-conn.Errors():
|
|
||||||
return !ok
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRunSession(t *testing.T) {
|
func TestRunSessionDial(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,14 +127,19 @@ func TestRunReader(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
go w.runReader(conn, messages, sessionDone, onStop)
|
go w.runReader(conn, messages, sessionDone, onStop)
|
||||||
|
|
||||||
// simulate remote close
|
// induce connection closure via reader
|
||||||
incomingData <- honeybeetest.MockIncomingData{Err: io.EOF}
|
incomingData <- honeybeetest.MockIncomingData{Err: io.EOF}
|
||||||
|
|
||||||
|
err := <-conn.Errors()
|
||||||
|
assert.Equal(t, io.EOF, err)
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
return connClosed(conn)
|
return conn.State() == transport.StateClosed
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
|
||||||
assert.True(t, onStopCalled.Load())
|
assert.Eventually(t, func() bool {
|
||||||
|
return onStopCalled.Load()
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
|
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
|
||||||
@@ -157,10 +162,12 @@ func TestRunReader(t *testing.T) {
|
|||||||
close(sessionDone)
|
close(sessionDone)
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
return connClosed(conn)
|
return conn.State() == transport.StateClosed
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
|
||||||
assert.True(t, onStopCalled.Load())
|
assert.Eventually(t, func() bool {
|
||||||
|
return onStopCalled.Load()
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,10 +188,12 @@ func TestRunStopMonitor(t *testing.T) {
|
|||||||
keepalive <- struct{}{}
|
keepalive <- struct{}{}
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
return connClosed(conn)
|
return conn.State() == transport.StateClosed
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
|
||||||
assert.True(t, onStopCalled.Load())
|
assert.Eventually(t, func() bool {
|
||||||
|
return onStopCalled.Load()
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) {
|
t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) {
|
||||||
@@ -202,10 +211,12 @@ func TestRunStopMonitor(t *testing.T) {
|
|||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
return connClosed(conn)
|
return conn.State() == transport.StateClosed
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
|
||||||
assert.True(t, onStopCalled.Load())
|
assert.Eventually(t, func() bool {
|
||||||
|
return onStopCalled.Load()
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
|
t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) {
|
||||||
@@ -224,10 +235,12 @@ func TestRunStopMonitor(t *testing.T) {
|
|||||||
close(sessionDone)
|
close(sessionDone)
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
return connClosed(conn)
|
return conn.State() == transport.StateClosed
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
|
||||||
assert.True(t, onStopCalled.Load())
|
assert.Eventually(t, func() bool {
|
||||||
|
return onStopCalled.Load()
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+141
-86
@@ -50,10 +50,12 @@ type Connection struct {
|
|||||||
|
|
||||||
state ConnectionState
|
state ConnectionState
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
closed bool
|
closed bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
writeMu sync.Mutex
|
writeMu sync.Mutex
|
||||||
|
doneOnce sync.Once
|
||||||
|
cleanupOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
||||||
@@ -167,51 +169,151 @@ func (c *Connection) Connect(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Close() {
|
||||||
|
c.shutdownExternal()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownExternal() {
|
||||||
|
err := c.shutdownSetClosed(true)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.shutdownInner()
|
||||||
|
c.shutdownCleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownInternal() {
|
||||||
|
err := c.shutdownSetClosed(false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.shutdownInner()
|
||||||
|
|
||||||
|
// defer final cleanup to allow this function to return
|
||||||
|
// otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||||
|
// must wait for itself to exit.
|
||||||
|
go func() {
|
||||||
|
c.shutdownCleanup()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownInner() {
|
||||||
|
c.shutdownSignalDone()
|
||||||
|
c.shutdownLogStart()
|
||||||
|
c.shutdownCloseSocket()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownCleanup() {
|
||||||
|
c.cleanupOnce.Do(func() {
|
||||||
|
c.wg.Wait()
|
||||||
|
c.shutdownCloseChannels()
|
||||||
|
c.shutdownLogComplete()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownSetClosed(wait bool) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return ErrConnectionClosed
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
c.state = StateClosed
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownSignalDone() {
|
||||||
|
c.doneOnce.Do(func() {
|
||||||
|
close(c.done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownLogStart() {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("shutting down")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownCloseSocket() {
|
||||||
|
if c.socket != nil {
|
||||||
|
// force unblock of any network operations immediately
|
||||||
|
expired := time.Now().Add(-1 * time.Minute)
|
||||||
|
c.socket.SetReadDeadline(expired)
|
||||||
|
c.socket.SetWriteDeadline(expired)
|
||||||
|
|
||||||
|
// close socket
|
||||||
|
err := c.socket.Close()
|
||||||
|
|
||||||
|
if err != nil && c.logger != nil {
|
||||||
|
c.logger.Error("socket close failed", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownCloseChannels() {
|
||||||
|
close(c.incoming)
|
||||||
|
close(c.errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownLogComplete() {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("connection closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connection) startReader() {
|
func (c *Connection) startReader() {
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer c.wg.Done()
|
defer c.wg.Done()
|
||||||
|
defer c.shutdownInternal()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
messageType, data, err := c.socket.ReadMessage()
|
select {
|
||||||
if err != nil {
|
case <-c.done:
|
||||||
if c.logger != nil {
|
|
||||||
var closeErr *websocket.CloseError
|
|
||||||
if errors.As(err, &closeErr) {
|
|
||||||
switch closeErr.Code {
|
|
||||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
|
||||||
c.logger.Info("connection closed by peer",
|
|
||||||
"code", closeErr.Code,
|
|
||||||
"text", closeErr.Text,
|
|
||||||
)
|
|
||||||
default:
|
|
||||||
c.logger.Error("unexpected close",
|
|
||||||
"code", closeErr.Code,
|
|
||||||
"text", closeErr.Text,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.logger.Error("read error", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case c.errors <- err:
|
|
||||||
case <-c.done:
|
|
||||||
}
|
|
||||||
c.shutdown()
|
|
||||||
return
|
return
|
||||||
}
|
default:
|
||||||
|
messageType, data, err := c.socket.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
var closeErr *websocket.CloseError
|
||||||
|
if errors.As(err, &closeErr) {
|
||||||
|
switch closeErr.Code {
|
||||||
|
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||||
|
c.logger.Info("connection closed by peer",
|
||||||
|
"code", closeErr.Code,
|
||||||
|
"text", closeErr.Text,
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
c.logger.Error("unexpected close",
|
||||||
|
"code", closeErr.Code,
|
||||||
|
"text", closeErr.Text,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.logger.Error("read error", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if messageType == websocket.TextMessage ||
|
select {
|
||||||
messageType == websocket.BinaryMessage {
|
case <-c.done:
|
||||||
select {
|
case c.errors <- err:
|
||||||
case c.incoming <- data:
|
}
|
||||||
case <-c.done:
|
|
||||||
c.shutdown()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
if messageType == websocket.TextMessage ||
|
||||||
|
messageType == websocket.BinaryMessage {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case c.incoming <- data:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -230,7 +332,7 @@ func (c *Connection) Send(data []byte) error {
|
|||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
c.logger.Error("write deadline error", "error", err)
|
c.logger.Error("write deadline error", "error", err)
|
||||||
}
|
}
|
||||||
c.shutdown()
|
c.shutdownExternal()
|
||||||
return fmt.Errorf("failed to set write deadline: %w", err)
|
return fmt.Errorf("failed to set write deadline: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -253,53 +355,6 @@ func (c *Connection) Errors() <-chan error {
|
|||||||
return c.errors
|
return c.errors
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) shutdown() {
|
|
||||||
c.mu.Lock()
|
|
||||||
|
|
||||||
if c.closed {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Info("closing", "state", c.state.String())
|
|
||||||
}
|
|
||||||
c.closed = true
|
|
||||||
c.state = StateClosed
|
|
||||||
socket := c.socket
|
|
||||||
close(c.done)
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if socket != nil {
|
|
||||||
// force immediate timeout of any blocked network I/O
|
|
||||||
expired := time.Now().Add(-1 * time.Minute)
|
|
||||||
socket.SetReadDeadline(expired)
|
|
||||||
socket.SetWriteDeadline(expired)
|
|
||||||
err := socket.Close()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Error("socket close failed", "error", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Info("closed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.wg.Wait()
|
|
||||||
close(c.incoming)
|
|
||||||
close(c.errors)
|
|
||||||
}()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) Close() {
|
|
||||||
c.shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) State() ConnectionState {
|
func (c *Connection) State() ConnectionState {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
|
|||||||
@@ -58,15 +58,11 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.True(t, conn.closed)
|
||||||
select {
|
_, ok := <-conn.incoming
|
||||||
case _, ok := <-conn.Errors():
|
assert.False(t, ok)
|
||||||
return !ok
|
_, ok = <-conn.errors
|
||||||
default:
|
assert.False(t, ok)
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
|
|
||||||
"errors channel should close")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("send fails after close", func(t *testing.T) {
|
t.Run("send fails after close", func(t *testing.T) {
|
||||||
@@ -103,32 +99,4 @@ func TestConnectedConnectionClose(t *testing.T) {
|
|||||||
conn.Close()
|
conn.Close()
|
||||||
assert.Equal(t, StateClosed, conn.State())
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("writer active during close exits cleanly", func(t *testing.T) {
|
|
||||||
conn, _, _, _ := setupTestConnection(t, nil)
|
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
conn.Send([]byte("message"))
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Close()
|
|
||||||
|
|
||||||
err := conn.Send([]byte("late"))
|
|
||||||
assert.Error(t, err, "Send should fail after close")
|
|
||||||
assert.ErrorContains(t, err, "connection closed")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("both goroutines active during close", func(t *testing.T) {
|
|
||||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
incomingData <- honeybeetest.MockIncomingData{
|
|
||||||
MsgType: websocket.TextMessage,
|
|
||||||
Data: []byte(fmt.Sprintf("in-%d", i)),
|
|
||||||
}
|
|
||||||
conn.Send([]byte(fmt.Sprintf("out-%d", i)))
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Close()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,23 +62,12 @@ func TestStartReader(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
readErr := fmt.Errorf("read failed")
|
|
||||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||||
return 0, nil, readErr
|
return 0, nil, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
|
conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
|
||||||
select {
|
|
||||||
case err := <-conn.Errors():
|
|
||||||
return err == readErr
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
return conn.State() == StateClosed
|
return conn.State() == StateClosed
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ func TestConnectionSend(t *testing.T) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-outgoingData:
|
case msg := <-outgoingData:
|
||||||
fmt.Printf("got message %s\n", string(msg.Data))
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
messages = append(messages, string(msg.Data))
|
messages = append(messages, string(msg.Data))
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
@@ -69,7 +68,6 @@ func TestConnectionSend(t *testing.T) {
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for j := 0; j < 10; j++ {
|
for j := 0; j < 10; j++ {
|
||||||
data := []byte(fmt.Sprintf("msg-%d-%d", id, j))
|
data := []byte(fmt.Sprintf("msg-%d-%d", id, j))
|
||||||
fmt.Printf("sending message %s\n", string(data))
|
|
||||||
for {
|
for {
|
||||||
// send and retry until success
|
// send and retry until success
|
||||||
err := conn.Send(data)
|
err := conn.Send(data)
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ func TestCloseLogging(t *testing.T) {
|
|||||||
records := mockHandler.GetRecords()
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
expected := []expectedLog{
|
expected := []expectedLog{
|
||||||
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
{slog.LevelInfo, "shutting down", map[string]any{}},
|
||||||
{slog.LevelInfo, "closed", map[string]any{}},
|
{slog.LevelInfo, "closed", map[string]any{}},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,7 +311,7 @@ func TestCloseLogging(t *testing.T) {
|
|||||||
records := mockHandler.GetRecords()
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
expected := []expectedLog{
|
expected := []expectedLog{
|
||||||
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
{slog.LevelInfo, "shutting down", map[string]any{}},
|
||||||
{slog.LevelError, "socket close failed", map[string]any{"error": closeErr}},
|
{slog.LevelError, "socket close failed", map[string]any{"error": closeErr}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user