cleanup and refactors
This commit is contained in:
@@ -5,6 +5,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Connection Config
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Types
|
||||
|
||||
type CloseHandler func(code int, text string) error
|
||||
|
||||
type ConnectionConfig struct {
|
||||
@@ -26,6 +32,8 @@ type RetryConfig struct {
|
||||
|
||||
type ConnectionOption func(*ConnectionConfig) error
|
||||
|
||||
// Constructors
|
||||
|
||||
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
|
||||
conf := GetDefaultConnectionConfig()
|
||||
if err := applyConnectionOptions(conf, options...); err != nil {
|
||||
@@ -69,6 +77,8 @@ func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOptio
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validation
|
||||
|
||||
func ValidateConnectionConfig(config *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(config.WriteTimeout)
|
||||
if err != nil {
|
||||
@@ -153,6 +163,8 @@ func validateJitterFactor(value float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Options
|
||||
|
||||
func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.CloseHandler = handler
|
||||
|
||||
+263
-230
@@ -12,10 +12,14 @@ import (
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
component "git.wisehodl.dev/jay/go-mana-component"
|
||||
"git.wisehodl.dev/jay/go-mana-component"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type ConnectionState int
|
||||
|
||||
const (
|
||||
@@ -49,6 +53,14 @@ type ConnectionStats struct {
|
||||
TotalHeartbeats uint64
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Connection
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// ---------------------------/
|
||||
// Constructors
|
||||
// -------------------------/
|
||||
|
||||
type Connection struct {
|
||||
url *url.URL
|
||||
dialer types.Dialer
|
||||
@@ -95,18 +107,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
|
||||
ctx = component.MustExtend(ctx, "connection")
|
||||
}
|
||||
|
||||
var logger *slog.Logger
|
||||
if handler != nil {
|
||||
c := component.FromContext(ctx)
|
||||
logger = slog.New(handler).With(slog.Any("component", c))
|
||||
}
|
||||
|
||||
conn := &Connection{
|
||||
url: url,
|
||||
dialer: NewDialer(),
|
||||
socket: nil,
|
||||
config: config,
|
||||
logger: logger,
|
||||
incoming: make(chan []byte, config.IncomingBufferSize),
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
errors: make(chan error, config.ErrorsBufferSize),
|
||||
@@ -117,6 +122,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if handler != nil {
|
||||
comp := component.FromContext(ctx)
|
||||
conn.logger = slog.New(handler).With(slog.Any("component", comp))
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -141,18 +151,11 @@ func NewConnectionFromSocket(
|
||||
ctx = component.MustExtend(ctx, "connection")
|
||||
}
|
||||
|
||||
var logger *slog.Logger
|
||||
if handler != nil {
|
||||
c := component.FromContext(ctx)
|
||||
logger = slog.New(handler).With(slog.Any("component", c))
|
||||
}
|
||||
|
||||
conn := &Connection{
|
||||
url: nil,
|
||||
dialer: nil,
|
||||
socket: socket,
|
||||
config: config,
|
||||
logger: logger,
|
||||
incoming: make(chan []byte, config.IncomingBufferSize),
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
errors: make(chan error, config.ErrorsBufferSize),
|
||||
@@ -163,17 +166,31 @@ func NewConnectionFromSocket(
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if handler != nil {
|
||||
comp := component.FromContext(ctx)
|
||||
conn.logger = slog.New(handler).With(slog.Any("component", comp))
|
||||
}
|
||||
|
||||
// initialize
|
||||
if config.CloseHandler != nil {
|
||||
socket.SetCloseHandler(config.CloseHandler)
|
||||
}
|
||||
|
||||
conn.setupPongHandler()
|
||||
conn.startPinger()
|
||||
conn.startReader()
|
||||
|
||||
if conn.config.PingInterval > 0 {
|
||||
conn.wg.Go(conn.startPinger)
|
||||
}
|
||||
|
||||
conn.wg.Go(conn.startReader)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Methods
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -186,17 +203,20 @@ func (c *Connection) Connect(ctx context.Context) error {
|
||||
return NewConnectionError(ErrConnectionClosed)
|
||||
}
|
||||
|
||||
// begin connecting
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("connecting")
|
||||
}
|
||||
|
||||
c.state = StateConnecting
|
||||
|
||||
// obtain socket
|
||||
retryMgr := NewRetryManager(c.config.Retry)
|
||||
socket, _, err := AcquireSocket(
|
||||
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
|
||||
|
||||
if err != nil {
|
||||
// socket acquisition failed
|
||||
c.state = StateDisconnected
|
||||
if c.logger != nil {
|
||||
c.logger.Error("connection failed", "error", err)
|
||||
@@ -204,231 +224,32 @@ func (c *Connection) Connect(ctx context.Context) error {
|
||||
return NewConnectionError(err)
|
||||
}
|
||||
|
||||
// got socket
|
||||
c.socket = socket
|
||||
c.state = StateConnected
|
||||
|
||||
// initialize
|
||||
if c.config.CloseHandler != nil {
|
||||
c.socket.SetCloseHandler(c.config.CloseHandler)
|
||||
}
|
||||
|
||||
c.setupPongHandler()
|
||||
|
||||
if c.config.PingInterval > 0 {
|
||||
c.wg.Go(c.startPinger)
|
||||
}
|
||||
|
||||
c.wg.Go(c.startReader)
|
||||
|
||||
// connected
|
||||
c.state = StateConnected
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connected")
|
||||
}
|
||||
|
||||
c.setupPongHandler()
|
||||
c.startPinger()
|
||||
c.startReader()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdownExternal()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownExternal() {
|
||||
err := c.shutdownSetClosed(true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
c.shutdownCleanup()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInternal() {
|
||||
err := c.shutdownSetClosed(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
|
||||
// defer final cleanup to allow this function to return
|
||||
// otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||
// must wait for itself to exit.
|
||||
go func() {
|
||||
c.shutdownCleanup()
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInner() {
|
||||
c.shutdownSignalDone()
|
||||
c.shutdownLogStart()
|
||||
c.shutdownCloseSocket()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCleanup() {
|
||||
c.cleanupOnce.Do(func() {
|
||||
c.wg.Wait()
|
||||
c.shutdownCloseChannels()
|
||||
c.shutdownLogComplete()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSetClosed(wait bool) error {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return NewConnectionError(ErrConnectionClosed)
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSignalDone() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogStart() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseSocket() {
|
||||
if c.socket != nil {
|
||||
// force unblock of any network operations immediately
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
c.socket.SetReadDeadline(expired)
|
||||
c.socket.SetWriteDeadline(expired)
|
||||
|
||||
// close socket
|
||||
err := c.socket.Close()
|
||||
|
||||
if err != nil && c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseChannels() {
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogComplete() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) startReader() {
|
||||
c.wg.Go(func() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
var wrappedErr error
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
|
||||
default:
|
||||
if c.logger != nil {
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
|
||||
}
|
||||
} else {
|
||||
isLocalClose := false
|
||||
select {
|
||||
case <-c.done:
|
||||
isLocalClose = true
|
||||
default:
|
||||
}
|
||||
if c.logger != nil {
|
||||
if isLocalClose {
|
||||
c.logger.Debug("read loop terminated", "error", err)
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
wrappedErr = fmt.Errorf("%w: %w", ErrReadError, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
case c.errors <- wrappedErr:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.incoming <- data:
|
||||
c.incomingCount.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) setupPongHandler() {
|
||||
c.socket.SetPongHandler(func(appData string) error {
|
||||
select {
|
||||
case c.heartbeat <- struct{}{}:
|
||||
c.heartbeatCount.Add(1)
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) startPinger() {
|
||||
if c.config.PingInterval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
c.wg.Go(func() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
// Calculate 10% jitter window
|
||||
jitter := c.config.PingInterval / 10
|
||||
|
||||
for {
|
||||
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
|
||||
next := c.config.PingInterval + offset
|
||||
timer := time.NewTimer(next)
|
||||
select {
|
||||
case <-c.done:
|
||||
timer.Stop()
|
||||
return
|
||||
case <-timer.C:
|
||||
deadline := time.Now().Add(c.config.WriteTimeout)
|
||||
if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Send(data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
@@ -437,6 +258,7 @@ func (c *Connection) Send(data []byte) error {
|
||||
return NewConnectionError(ErrConnectionClosed)
|
||||
}
|
||||
|
||||
// setup
|
||||
if c.config.WriteTimeout > 0 {
|
||||
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||
if c.logger != nil {
|
||||
@@ -446,7 +268,10 @@ func (c *Connection) Send(data []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
// send
|
||||
err := c.socket.WriteMessage(websocket.TextMessage, data)
|
||||
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("write error", "error", err)
|
||||
}
|
||||
@@ -489,3 +314,211 @@ func (c *Connection) Stats() ConnectionStats {
|
||||
func (c *Connection) SetDialer(d types.Dialer) {
|
||||
c.dialer = d
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Reader loop
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) startReader() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-c.done:
|
||||
case c.errors <- c.classifyCloseError(err):
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.incoming <- data:
|
||||
c.incomingCount.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) classifyCloseError(err error) error {
|
||||
var classifiedError error
|
||||
var closeErr *websocket.CloseError
|
||||
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
|
||||
|
||||
default:
|
||||
if c.logger != nil {
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
|
||||
}
|
||||
|
||||
} else {
|
||||
isLocalClose := false
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
isLocalClose = true
|
||||
default:
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
if isLocalClose {
|
||||
c.logger.Debug("read loop terminated", "error", err)
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
classifiedError = fmt.Errorf("%w: %w", ErrReadError, err)
|
||||
}
|
||||
|
||||
return classifiedError
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Heartbeat Handling
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) setupPongHandler() {
|
||||
c.socket.SetPongHandler(func(appData string) error {
|
||||
select {
|
||||
case c.heartbeat <- struct{}{}:
|
||||
c.heartbeatCount.Add(1)
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) startPinger() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
// Calculate 10% jitter window
|
||||
jitter := c.config.PingInterval / 10
|
||||
|
||||
for {
|
||||
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
|
||||
next := c.config.PingInterval + offset
|
||||
timer := time.NewTimer(next)
|
||||
select {
|
||||
case <-c.done:
|
||||
timer.Stop()
|
||||
return
|
||||
case <-timer.C:
|
||||
deadline := time.Now().Add(c.config.WriteTimeout)
|
||||
err := c.socket.WriteControl(websocket.PingMessage, nil, deadline)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Shutdown
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdownExternal()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownExternal() {
|
||||
// set closed
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
// idempotent shutdown
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
|
||||
// perform shutdown
|
||||
c.shutdownInner()
|
||||
c.shutdownCleanup()
|
||||
}
|
||||
|
||||
// shutdownInternal defers final cleanup to allow it to return.
|
||||
// Otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||
// must wait for itself to exit.
|
||||
func (c *Connection) shutdownInternal() {
|
||||
// set closed
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
// idempotent shutdown
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
|
||||
// perform shutdown
|
||||
c.shutdownInner()
|
||||
|
||||
// defer cleanup to avoid deadlock
|
||||
go func() {
|
||||
c.shutdownCleanup()
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInner() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing")
|
||||
}
|
||||
|
||||
if c.socket != nil {
|
||||
// force unblock of any network operations immediately
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
c.socket.SetReadDeadline(expired)
|
||||
c.socket.SetWriteDeadline(expired)
|
||||
|
||||
// close socket
|
||||
err := c.socket.Close()
|
||||
|
||||
if err != nil && c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCleanup() {
|
||||
c.cleanupOnce.Do(func() {
|
||||
c.wg.Wait()
|
||||
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -69,6 +69,7 @@ func AcquireSocket(
|
||||
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
|
||||
}
|
||||
|
||||
// dial
|
||||
socket, resp, err := dialer.DialContext(ctx, url, header)
|
||||
if err == nil {
|
||||
if logger != nil {
|
||||
@@ -77,7 +78,9 @@ func AcquireSocket(
|
||||
return socket, resp, nil
|
||||
}
|
||||
|
||||
// dial failed, retry
|
||||
if !retryMgr.ShouldRetry() {
|
||||
// retry policy expired
|
||||
if logger != nil {
|
||||
logger.Error("dial failed, max retries reached",
|
||||
"error", err,
|
||||
@@ -95,6 +98,7 @@ func AcquireSocket(
|
||||
"next_delay", delay)
|
||||
}
|
||||
|
||||
// context cancellable backoff
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
|
||||
Reference in New Issue
Block a user