cleanup and refactors

This commit is contained in:
Jay
2026-05-20 22:49:25 -04:00
parent cda6d286ab
commit f1afca7921
10 changed files with 628 additions and 496 deletions
+12
View File
@@ -5,6 +5,12 @@ import (
"time"
)
// ----------------------------------------------------------------------------
// Connection Config
// ----------------------------------------------------------------------------
// Types
type CloseHandler func(code int, text string) error
type ConnectionConfig struct {
@@ -26,6 +32,8 @@ type RetryConfig struct {
type ConnectionOption func(*ConnectionConfig) error
// Constructors
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
conf := GetDefaultConnectionConfig()
if err := applyConnectionOptions(conf, options...); err != nil {
@@ -69,6 +77,8 @@ func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOptio
return nil
}
// Validation
func ValidateConnectionConfig(config *ConnectionConfig) error {
err := validateWriteTimeout(config.WriteTimeout)
if err != nil {
@@ -153,6 +163,8 @@ func validateJitterFactor(value float64) error {
return nil
}
// Options
func WithCloseHandler(handler CloseHandler) ConnectionOption {
return func(c *ConnectionConfig) error {
c.CloseHandler = handler
+263 -230
View File
@@ -12,10 +12,14 @@ import (
"time"
"git.wisehodl.dev/jay/go-honeybee/types"
component "git.wisehodl.dev/jay/go-mana-component"
"git.wisehodl.dev/jay/go-mana-component"
"github.com/gorilla/websocket"
)
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
type ConnectionState int
const (
@@ -49,6 +53,14 @@ type ConnectionStats struct {
TotalHeartbeats uint64
}
// ----------------------------------------------------------------------------
// Connection
// ----------------------------------------------------------------------------
// ---------------------------/
// Constructors
// -------------------------/
type Connection struct {
url *url.URL
dialer types.Dialer
@@ -95,18 +107,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
ctx = component.MustExtend(ctx, "connection")
}
var logger *slog.Logger
if handler != nil {
c := component.FromContext(ctx)
logger = slog.New(handler).With(slog.Any("component", c))
}
conn := &Connection{
url: url,
dialer: NewDialer(),
socket: nil,
config: config,
logger: logger,
incoming: make(chan []byte, config.IncomingBufferSize),
heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize),
@@ -117,6 +122,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
done: make(chan struct{}),
}
if handler != nil {
comp := component.FromContext(ctx)
conn.logger = slog.New(handler).With(slog.Any("component", comp))
}
return conn, nil
}
@@ -141,18 +151,11 @@ func NewConnectionFromSocket(
ctx = component.MustExtend(ctx, "connection")
}
var logger *slog.Logger
if handler != nil {
c := component.FromContext(ctx)
logger = slog.New(handler).With(slog.Any("component", c))
}
conn := &Connection{
url: nil,
dialer: nil,
socket: socket,
config: config,
logger: logger,
incoming: make(chan []byte, config.IncomingBufferSize),
heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize),
@@ -163,17 +166,31 @@ func NewConnectionFromSocket(
done: make(chan struct{}),
}
if handler != nil {
comp := component.FromContext(ctx)
conn.logger = slog.New(handler).With(slog.Any("component", comp))
}
// initialize
if config.CloseHandler != nil {
socket.SetCloseHandler(config.CloseHandler)
}
conn.setupPongHandler()
conn.startPinger()
conn.startReader()
if conn.config.PingInterval > 0 {
conn.wg.Go(conn.startPinger)
}
conn.wg.Go(conn.startReader)
return conn, nil
}
// ---------------------------/
// Methods
// -------------------------/
func (c *Connection) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
@@ -186,17 +203,20 @@ func (c *Connection) Connect(ctx context.Context) error {
return NewConnectionError(ErrConnectionClosed)
}
// begin connecting
if c.logger != nil {
c.logger.Debug("connecting")
}
c.state = StateConnecting
// obtain socket
retryMgr := NewRetryManager(c.config.Retry)
socket, _, err := AcquireSocket(
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
if err != nil {
// socket acquisition failed
c.state = StateDisconnected
if c.logger != nil {
c.logger.Error("connection failed", "error", err)
@@ -204,231 +224,32 @@ func (c *Connection) Connect(ctx context.Context) error {
return NewConnectionError(err)
}
// got socket
c.socket = socket
c.state = StateConnected
// initialize
if c.config.CloseHandler != nil {
c.socket.SetCloseHandler(c.config.CloseHandler)
}
c.setupPongHandler()
if c.config.PingInterval > 0 {
c.wg.Go(c.startPinger)
}
c.wg.Go(c.startReader)
// connected
c.state = StateConnected
if c.logger != nil {
c.logger.Info("connected")
}
c.setupPongHandler()
c.startPinger()
c.startReader()
return nil
}
func (c *Connection) Close() {
c.shutdownExternal()
}
func (c *Connection) shutdownExternal() {
err := c.shutdownSetClosed(true)
if err != nil {
return
}
c.shutdownInner()
c.shutdownCleanup()
}
func (c *Connection) shutdownInternal() {
err := c.shutdownSetClosed(false)
if err != nil {
return
}
c.shutdownInner()
// defer final cleanup to allow this function to return
// otherwise, a deadlock occurs where startReader triggers a shutdown and
// must wait for itself to exit.
go func() {
c.shutdownCleanup()
}()
}
func (c *Connection) shutdownInner() {
c.shutdownSignalDone()
c.shutdownLogStart()
c.shutdownCloseSocket()
}
func (c *Connection) shutdownCleanup() {
c.cleanupOnce.Do(func() {
c.wg.Wait()
c.shutdownCloseChannels()
c.shutdownLogComplete()
})
}
func (c *Connection) shutdownSetClosed(wait bool) error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return NewConnectionError(ErrConnectionClosed)
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
return nil
}
func (c *Connection) shutdownSignalDone() {
c.doneOnce.Do(func() {
close(c.done)
})
}
func (c *Connection) shutdownLogStart() {
if c.logger != nil {
c.logger.Info("closing")
}
}
func (c *Connection) shutdownCloseSocket() {
if c.socket != nil {
// force unblock of any network operations immediately
expired := time.Now().Add(-1 * time.Minute)
c.socket.SetReadDeadline(expired)
c.socket.SetWriteDeadline(expired)
// close socket
err := c.socket.Close()
if err != nil && c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
}
}
func (c *Connection) shutdownCloseChannels() {
close(c.incoming)
close(c.errors)
}
func (c *Connection) shutdownLogComplete() {
if c.logger != nil {
c.logger.Info("closed")
}
}
func (c *Connection) startReader() {
c.wg.Go(func() {
defer c.shutdownInternal()
for {
select {
case <-c.done:
return
default:
messageType, data, err := c.socket.ReadMessage()
if err != nil {
var wrappedErr error
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
if c.logger != nil {
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
default:
if c.logger != nil {
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
}
} else {
isLocalClose := false
select {
case <-c.done:
isLocalClose = true
default:
}
if c.logger != nil {
if isLocalClose {
c.logger.Debug("read loop terminated", "error", err)
} else {
c.logger.Error("read error", "error", err)
}
}
wrappedErr = fmt.Errorf("%w: %w", ErrReadError, err)
}
select {
case <-c.done:
case c.errors <- wrappedErr:
}
return
}
if messageType == websocket.TextMessage ||
messageType == websocket.BinaryMessage {
select {
case <-c.done:
return
case c.incoming <- data:
c.incomingCount.Add(1)
}
}
}
}
})
}
func (c *Connection) setupPongHandler() {
c.socket.SetPongHandler(func(appData string) error {
select {
case c.heartbeat <- struct{}{}:
c.heartbeatCount.Add(1)
default:
}
return nil
})
}
func (c *Connection) startPinger() {
if c.config.PingInterval <= 0 {
return
}
c.wg.Go(func() {
defer c.shutdownInternal()
// Calculate 10% jitter window
jitter := c.config.PingInterval / 10
for {
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
next := c.config.PingInterval + offset
timer := time.NewTimer(next)
select {
case <-c.done:
timer.Stop()
return
case <-timer.C:
deadline := time.Now().Add(c.config.WriteTimeout)
if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
return
}
}
}
})
}
func (c *Connection) Send(data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
@@ -437,6 +258,7 @@ func (c *Connection) Send(data []byte) error {
return NewConnectionError(ErrConnectionClosed)
}
// setup
if c.config.WriteTimeout > 0 {
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
if c.logger != nil {
@@ -446,7 +268,10 @@ func (c *Connection) Send(data []byte) error {
}
}
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
// send
err := c.socket.WriteMessage(websocket.TextMessage, data)
if err != nil {
if c.logger != nil {
c.logger.Error("write error", "error", err)
}
@@ -489,3 +314,211 @@ func (c *Connection) Stats() ConnectionStats {
func (c *Connection) SetDialer(d types.Dialer) {
c.dialer = d
}
// ---------------------------/
// Reader loop
// -------------------------/
func (c *Connection) startReader() {
defer c.shutdownInternal()
for {
select {
case <-c.done:
return
default:
messageType, data, err := c.socket.ReadMessage()
if err != nil {
select {
case <-c.done:
case c.errors <- c.classifyCloseError(err):
}
return
}
if messageType == websocket.TextMessage ||
messageType == websocket.BinaryMessage {
select {
case <-c.done:
return
case c.incoming <- data:
c.incomingCount.Add(1)
}
}
}
}
}
func (c *Connection) classifyCloseError(err error) error {
var classifiedError error
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
if c.logger != nil {
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
default:
if c.logger != nil {
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
}
} else {
isLocalClose := false
select {
case <-c.done:
isLocalClose = true
default:
}
if c.logger != nil {
if isLocalClose {
c.logger.Debug("read loop terminated", "error", err)
} else {
c.logger.Error("read error", "error", err)
}
}
classifiedError = fmt.Errorf("%w: %w", ErrReadError, err)
}
return classifiedError
}
// ---------------------------/
// Heartbeat Handling
// -------------------------/
func (c *Connection) setupPongHandler() {
c.socket.SetPongHandler(func(appData string) error {
select {
case c.heartbeat <- struct{}{}:
c.heartbeatCount.Add(1)
default:
}
return nil
})
}
func (c *Connection) startPinger() {
defer c.shutdownInternal()
// Calculate 10% jitter window
jitter := c.config.PingInterval / 10
for {
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
next := c.config.PingInterval + offset
timer := time.NewTimer(next)
select {
case <-c.done:
timer.Stop()
return
case <-timer.C:
deadline := time.Now().Add(c.config.WriteTimeout)
err := c.socket.WriteControl(websocket.PingMessage, nil, deadline)
if err != nil {
return
}
}
}
}
// ---------------------------/
// Shutdown
// -------------------------/
func (c *Connection) Close() {
c.shutdownExternal()
}
func (c *Connection) shutdownExternal() {
// set closed
c.mu.Lock()
if c.closed {
// idempotent shutdown
c.mu.Unlock()
return
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
// perform shutdown
c.shutdownInner()
c.shutdownCleanup()
}
// shutdownInternal defers final cleanup to allow it to return.
// Otherwise, a deadlock occurs where startReader triggers a shutdown and
// must wait for itself to exit.
func (c *Connection) shutdownInternal() {
// set closed
c.mu.Lock()
if c.closed {
// idempotent shutdown
c.mu.Unlock()
return
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
// perform shutdown
c.shutdownInner()
// defer cleanup to avoid deadlock
go func() {
c.shutdownCleanup()
}()
}
func (c *Connection) shutdownInner() {
c.doneOnce.Do(func() {
close(c.done)
})
if c.logger != nil {
c.logger.Info("closing")
}
if c.socket != nil {
// force unblock of any network operations immediately
expired := time.Now().Add(-1 * time.Minute)
c.socket.SetReadDeadline(expired)
c.socket.SetWriteDeadline(expired)
// close socket
err := c.socket.Close()
if err != nil && c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
}
}
func (c *Connection) shutdownCleanup() {
c.cleanupOnce.Do(func() {
c.wg.Wait()
close(c.incoming)
close(c.errors)
if c.logger != nil {
c.logger.Info("closed")
}
})
}
+4
View File
@@ -69,6 +69,7 @@ func AcquireSocket(
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
}
// dial
socket, resp, err := dialer.DialContext(ctx, url, header)
if err == nil {
if logger != nil {
@@ -77,7 +78,9 @@ func AcquireSocket(
return socket, resp, nil
}
// dial failed, retry
if !retryMgr.ShouldRetry() {
// retry policy expired
if logger != nil {
logger.Error("dial failed, max retries reached",
"error", err,
@@ -95,6 +98,7 @@ func AcquireSocket(
"next_delay", delay)
}
// context cancellable backoff
select {
case <-time.After(delay):
case <-ctx.Done():