Refactored connection shutdown logic.

This commit is contained in:
Jay
2026-04-19 09:29:12 -04:00
parent 72f0793047
commit 6998ccf701
7 changed files with 175 additions and 161 deletions
+141 -86
View File
@@ -50,10 +50,12 @@ type Connection struct {
state ConnectionState
wg sync.WaitGroup
closed bool
mu sync.RWMutex
writeMu sync.Mutex
wg sync.WaitGroup
closed bool
mu sync.RWMutex
writeMu sync.Mutex
doneOnce sync.Once
cleanupOnce sync.Once
}
func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
@@ -167,51 +169,151 @@ func (c *Connection) Connect(ctx context.Context) error {
return nil
}
func (c *Connection) Close() {
c.shutdownExternal()
}
func (c *Connection) shutdownExternal() {
err := c.shutdownSetClosed(true)
if err != nil {
return
}
c.shutdownInner()
c.shutdownCleanup()
}
func (c *Connection) shutdownInternal() {
err := c.shutdownSetClosed(false)
if err != nil {
return
}
c.shutdownInner()
// defer final cleanup to allow this function to return
// otherwise, a deadlock occurs where startReader triggers a shutdown and
// must wait for itself to exit.
go func() {
c.shutdownCleanup()
}()
}
func (c *Connection) shutdownInner() {
c.shutdownSignalDone()
c.shutdownLogStart()
c.shutdownCloseSocket()
}
func (c *Connection) shutdownCleanup() {
c.cleanupOnce.Do(func() {
c.wg.Wait()
c.shutdownCloseChannels()
c.shutdownLogComplete()
})
}
func (c *Connection) shutdownSetClosed(wait bool) error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return ErrConnectionClosed
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
return nil
}
func (c *Connection) shutdownSignalDone() {
c.doneOnce.Do(func() {
close(c.done)
})
}
func (c *Connection) shutdownLogStart() {
if c.logger != nil {
c.logger.Info("shutting down")
}
}
func (c *Connection) shutdownCloseSocket() {
if c.socket != nil {
// force unblock of any network operations immediately
expired := time.Now().Add(-1 * time.Minute)
c.socket.SetReadDeadline(expired)
c.socket.SetWriteDeadline(expired)
// close socket
err := c.socket.Close()
if err != nil && c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
}
}
func (c *Connection) shutdownCloseChannels() {
close(c.incoming)
close(c.errors)
}
func (c *Connection) shutdownLogComplete() {
if c.logger != nil {
c.logger.Info("connection closed")
}
}
func (c *Connection) startReader() {
c.wg.Add(1)
go func() {
defer c.wg.Done()
defer c.shutdownInternal()
for {
messageType, data, err := c.socket.ReadMessage()
if err != nil {
if c.logger != nil {
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
default:
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
} else {
c.logger.Error("read error", "error", err)
}
}
select {
case c.errors <- err:
case <-c.done:
}
c.shutdown()
select {
case <-c.done:
return
}
default:
messageType, data, err := c.socket.ReadMessage()
if err != nil {
if c.logger != nil {
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
default:
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
} else {
c.logger.Error("read error", "error", err)
}
}
if messageType == websocket.TextMessage ||
messageType == websocket.BinaryMessage {
select {
case c.incoming <- data:
case <-c.done:
c.shutdown()
select {
case <-c.done:
case c.errors <- err:
}
return
}
}
if messageType == websocket.TextMessage ||
messageType == websocket.BinaryMessage {
select {
case <-c.done:
return
case c.incoming <- data:
}
}
}
}
}()
@@ -230,7 +332,7 @@ func (c *Connection) Send(data []byte) error {
if c.logger != nil {
c.logger.Error("write deadline error", "error", err)
}
c.shutdown()
c.shutdownExternal()
return fmt.Errorf("failed to set write deadline: %w", err)
}
}
@@ -253,53 +355,6 @@ func (c *Connection) Errors() <-chan error {
return c.errors
}
func (c *Connection) shutdown() {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
if c.logger != nil {
c.logger.Info("closing", "state", c.state.String())
}
c.closed = true
c.state = StateClosed
socket := c.socket
close(c.done)
c.mu.Unlock()
go func() {
if socket != nil {
// force immediate timeout of any blocked network I/O
expired := time.Now().Add(-1 * time.Minute)
socket.SetReadDeadline(expired)
socket.SetWriteDeadline(expired)
err := socket.Close()
if err != nil {
if c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
} else {
if c.logger != nil {
c.logger.Info("closed")
}
}
}
c.wg.Wait()
close(c.incoming)
close(c.errors)
}()
}
func (c *Connection) Close() {
c.shutdown()
}
func (c *Connection) State() ConnectionState {
c.mu.RLock()
defer c.mu.RUnlock()