Refactored connection shutdown logic.
This commit is contained in:
+141
-86
@@ -50,10 +50,12 @@ type Connection struct {
|
||||
|
||||
state ConnectionState
|
||||
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
writeMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
writeMu sync.Mutex
|
||||
doneOnce sync.Once
|
||||
cleanupOnce sync.Once
|
||||
}
|
||||
|
||||
func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
||||
@@ -167,51 +169,151 @@ func (c *Connection) Connect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdownExternal()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownExternal() {
|
||||
err := c.shutdownSetClosed(true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
c.shutdownCleanup()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInternal() {
|
||||
err := c.shutdownSetClosed(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
|
||||
// defer final cleanup to allow this function to return
|
||||
// otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||
// must wait for itself to exit.
|
||||
go func() {
|
||||
c.shutdownCleanup()
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInner() {
|
||||
c.shutdownSignalDone()
|
||||
c.shutdownLogStart()
|
||||
c.shutdownCloseSocket()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCleanup() {
|
||||
c.cleanupOnce.Do(func() {
|
||||
c.wg.Wait()
|
||||
c.shutdownCloseChannels()
|
||||
c.shutdownLogComplete()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSetClosed(wait bool) error {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSignalDone() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogStart() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseSocket() {
|
||||
if c.socket != nil {
|
||||
// force unblock of any network operations immediately
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
c.socket.SetReadDeadline(expired)
|
||||
c.socket.SetWriteDeadline(expired)
|
||||
|
||||
// close socket
|
||||
err := c.socket.Close()
|
||||
|
||||
if err != nil && c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseChannels() {
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogComplete() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connection closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) startReader() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
defer c.shutdownInternal()
|
||||
|
||||
for {
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
default:
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case c.errors <- err:
|
||||
case <-c.done:
|
||||
}
|
||||
c.shutdown()
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
default:
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
default:
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case c.incoming <- data:
|
||||
case <-c.done:
|
||||
c.shutdown()
|
||||
select {
|
||||
case <-c.done:
|
||||
case c.errors <- err:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.incoming <- data:
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -230,7 +332,7 @@ func (c *Connection) Send(data []byte) error {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("write deadline error", "error", err)
|
||||
}
|
||||
c.shutdown()
|
||||
c.shutdownExternal()
|
||||
return fmt.Errorf("failed to set write deadline: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -253,53 +355,6 @@ func (c *Connection) Errors() <-chan error {
|
||||
return c.errors
|
||||
}
|
||||
|
||||
func (c *Connection) shutdown() {
|
||||
c.mu.Lock()
|
||||
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing", "state", c.state.String())
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
socket := c.socket
|
||||
close(c.done)
|
||||
c.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
if socket != nil {
|
||||
// force immediate timeout of any blocked network I/O
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
socket.SetReadDeadline(expired)
|
||||
socket.SetWriteDeadline(expired)
|
||||
err := socket.Close()
|
||||
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
} else {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.wg.Wait()
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdown()
|
||||
}
|
||||
|
||||
func (c *Connection) State() ConnectionState {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
Reference in New Issue
Block a user