cleanup and refactors

This commit is contained in:
Jay
2026-05-20 22:49:25 -04:00
parent cda6d286ab
commit f1afca7921
10 changed files with 628 additions and 496 deletions
+20 -10
View File
@@ -1,21 +1,15 @@
package honeybee
import (
"context"
"git.wisehodl.dev/jay/go-honeybee/transport"
"log/slog"
"time"
)
// Types
type WorkerFactory func(
ctx context.Context,
id string,
handler slog.Handler,
) (Worker, error)
// ----------------------------------------------------------------------------
// Pool Config
// ----------------------------------------------------------------------------
// Types
type PoolConfig struct {
InboxBufferSize int
@@ -27,6 +21,8 @@ type PoolConfig struct {
type PoolOption func(*PoolConfig) error
// Constructor
func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) {
conf := GetDefaultPoolConfig()
if err := applyPoolOptions(conf, options...); err != nil {
@@ -57,6 +53,8 @@ func applyPoolOptions(config *PoolConfig, options ...PoolOption) error {
return nil
}
// Validation
func ValidatePoolConfig(config *PoolConfig) error {
var err error
@@ -84,6 +82,8 @@ func validateBufferSize(value int) error {
return nil
}
// Options
func WithInboxBufferSize(value int) PoolOption {
return func(c *PoolConfig) error {
if err := validateBufferSize(value); err != nil {
@@ -133,7 +133,11 @@ func WithWorkerFactory(wf WorkerFactory) PoolOption {
}
}
// ----------------------------------------------------------------------------
// Worker Config
// ----------------------------------------------------------------------------
// Types
type WorkerConfig struct {
KeepaliveTimeout time.Duration
@@ -142,6 +146,8 @@ type WorkerConfig struct {
type WorkerOption func(*WorkerConfig) error
// Constructor
func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) {
conf := GetDefaultWorkerConfig()
if err := applyWorkerOptions(conf, options...); err != nil {
@@ -169,6 +175,8 @@ func applyWorkerOptions(config *WorkerConfig, options ...WorkerOption) error {
return nil
}
// Validation
func ValidateWorkerConfig(config *WorkerConfig) error {
err := validateKeepaliveTimeout(config.KeepaliveTimeout)
if err != nil {
@@ -192,6 +200,8 @@ func validateReconnectDelay(value time.Duration) error {
return nil
}
// Options
// When KeepaliveTimeout is set to zero, keepalive timeouts are disabled.
func WithKeepaliveTimeout(value time.Duration) WorkerOption {
return func(c *WorkerConfig) error {
+10
View File
@@ -10,7 +10,9 @@ import (
"time"
)
// ----------------------------------------------------------------------------
// Constants
// ----------------------------------------------------------------------------
const (
TestTimeout = 2 * time.Second
@@ -18,7 +20,9 @@ const (
NegativeTestTimeout = 100 * time.Millisecond
)
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
type MockIncomingData struct {
MsgType int
@@ -37,7 +41,9 @@ type ExpectedLog struct {
Attrs map[string]any
}
// ----------------------------------------------------------------------------
// Setup
// ----------------------------------------------------------------------------
func SetupTestSocket(t *testing.T) (
socket *MockSocket,
@@ -81,7 +87,9 @@ func SetupTestSocket(t *testing.T) (
return
}
// ----------------------------------------------------------------------------
// Helpers
// ----------------------------------------------------------------------------
func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) {
t.Helper()
@@ -126,7 +134,9 @@ func Never(t *testing.T, condition func() bool, msg string) {
assert.Never(t, condition, NegativeTestTimeout, TestTick, msg)
}
// ----------------------------------------------------------------------------
// Logging Helpers
// ----------------------------------------------------------------------------
func AssertLogSequence(t *testing.T, records []slog.Record, expected []ExpectedLog) {
t.Helper()
+9 -1
View File
@@ -9,12 +9,16 @@ import (
"time"
)
// Re-exported types for consumer convenience
// ----------------------------------------------------------------------------
// Re-exports
// ----------------------------------------------------------------------------
type Socket = types.Socket
type Dialer = types.Dialer
// ----------------------------------------------------------------------------
// Dialer Mocks
// ----------------------------------------------------------------------------
type MockDialer struct {
DialContextFunc func(
@@ -28,7 +32,9 @@ func (m *MockDialer) DialContext(
return m.DialContextFunc(ctx, url, h)
}
// ----------------------------------------------------------------------------
// Socket Mocks
// ----------------------------------------------------------------------------
type MockSocket struct {
WriteMessageFunc func(int, []byte) error
@@ -93,7 +99,9 @@ func (m *MockSocket) SetPongHandler(h func(s string) error) {
m.SetPongHandlerFunc(h)
}
// ----------------------------------------------------------------------------
// Logging mocks
// ----------------------------------------------------------------------------
type MockSlogHandler struct {
records *[]slog.Record
+20 -14
View File
@@ -6,7 +6,7 @@ import (
"git.wisehodl.dev/jay/go-honeybee/transport"
"git.wisehodl.dev/jay/go-honeybee/types"
component "git.wisehodl.dev/jay/go-mana-component"
"git.wisehodl.dev/jay/go-mana-component"
"sync"
"sync/atomic"
"time"
@@ -19,7 +19,9 @@ type Dialer = types.Dialer
var NormalizeURL = transport.NormalizeURL
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
type PoolEventKind string
@@ -58,7 +60,9 @@ type PoolPlugin struct {
ConnectionConfig *transport.ConnectionConfig
}
// ----------------------------------------------------------------------------
// Pool
// ----------------------------------------------------------------------------
type Peer struct {
id string
@@ -66,24 +70,23 @@ type Peer struct {
}
type Pool struct {
ctx context.Context
cancel context.CancelFunc
peers map[string]*Peer
inbox chan types.InboxMessage
events chan PoolEvent
inboxCounter *atomic.Uint64
outgoingCount *atomic.Uint64
closed bool
dialer types.Dialer
config *PoolConfig
handler slog.Handler
logger *slog.Logger
ctx context.Context
cancel context.CancelFunc
mu sync.RWMutex
wg sync.WaitGroup
closed bool
inboxCounter *atomic.Uint64
outgoingCount *atomic.Uint64
}
func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
@@ -106,26 +109,29 @@ func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
return nil, err
}
pctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool"))
ctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool"))
var logger *slog.Logger
if handler != nil {
c := component.FromContext(pctx)
c := component.FromContext(ctx)
logger = slog.New(handler).With(slog.Any("component", c))
}
return &Pool{
ctx: pctx,
cancel: cancel,
peers: make(map[string]*Peer),
inbox: make(chan types.InboxMessage, config.InboxBufferSize),
events: make(chan PoolEvent, config.EventsBufferSize),
inboxCounter: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
dialer: transport.NewDialer(),
config: config,
handler: handler,
logger: logger,
ctx: ctx,
cancel: cancel,
inboxCounter: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
}, nil
}
+12
View File
@@ -5,6 +5,12 @@ import (
"time"
)
// ----------------------------------------------------------------------------
// Connection Config
// ----------------------------------------------------------------------------
// Types
type CloseHandler func(code int, text string) error
type ConnectionConfig struct {
@@ -26,6 +32,8 @@ type RetryConfig struct {
type ConnectionOption func(*ConnectionConfig) error
// Constructors
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
conf := GetDefaultConnectionConfig()
if err := applyConnectionOptions(conf, options...); err != nil {
@@ -69,6 +77,8 @@ func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOptio
return nil
}
// Validation
func ValidateConnectionConfig(config *ConnectionConfig) error {
err := validateWriteTimeout(config.WriteTimeout)
if err != nil {
@@ -153,6 +163,8 @@ func validateJitterFactor(value float64) error {
return nil
}
// Options
func WithCloseHandler(handler CloseHandler) ConnectionOption {
return func(c *ConnectionConfig) error {
c.CloseHandler = handler
+263 -230
View File
@@ -12,10 +12,14 @@ import (
"time"
"git.wisehodl.dev/jay/go-honeybee/types"
component "git.wisehodl.dev/jay/go-mana-component"
"git.wisehodl.dev/jay/go-mana-component"
"github.com/gorilla/websocket"
)
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
type ConnectionState int
const (
@@ -49,6 +53,14 @@ type ConnectionStats struct {
TotalHeartbeats uint64
}
// ----------------------------------------------------------------------------
// Connection
// ----------------------------------------------------------------------------
// ---------------------------/
// Constructors
// -------------------------/
type Connection struct {
url *url.URL
dialer types.Dialer
@@ -95,18 +107,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
ctx = component.MustExtend(ctx, "connection")
}
var logger *slog.Logger
if handler != nil {
c := component.FromContext(ctx)
logger = slog.New(handler).With(slog.Any("component", c))
}
conn := &Connection{
url: url,
dialer: NewDialer(),
socket: nil,
config: config,
logger: logger,
incoming: make(chan []byte, config.IncomingBufferSize),
heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize),
@@ -117,6 +122,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
done: make(chan struct{}),
}
if handler != nil {
comp := component.FromContext(ctx)
conn.logger = slog.New(handler).With(slog.Any("component", comp))
}
return conn, nil
}
@@ -141,18 +151,11 @@ func NewConnectionFromSocket(
ctx = component.MustExtend(ctx, "connection")
}
var logger *slog.Logger
if handler != nil {
c := component.FromContext(ctx)
logger = slog.New(handler).With(slog.Any("component", c))
}
conn := &Connection{
url: nil,
dialer: nil,
socket: socket,
config: config,
logger: logger,
incoming: make(chan []byte, config.IncomingBufferSize),
heartbeat: make(chan struct{}, 1),
errors: make(chan error, config.ErrorsBufferSize),
@@ -163,17 +166,31 @@ func NewConnectionFromSocket(
done: make(chan struct{}),
}
if handler != nil {
comp := component.FromContext(ctx)
conn.logger = slog.New(handler).With(slog.Any("component", comp))
}
// initialize
if config.CloseHandler != nil {
socket.SetCloseHandler(config.CloseHandler)
}
conn.setupPongHandler()
conn.startPinger()
conn.startReader()
if conn.config.PingInterval > 0 {
conn.wg.Go(conn.startPinger)
}
conn.wg.Go(conn.startReader)
return conn, nil
}
// ---------------------------/
// Methods
// -------------------------/
func (c *Connection) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
@@ -186,17 +203,20 @@ func (c *Connection) Connect(ctx context.Context) error {
return NewConnectionError(ErrConnectionClosed)
}
// begin connecting
if c.logger != nil {
c.logger.Debug("connecting")
}
c.state = StateConnecting
// obtain socket
retryMgr := NewRetryManager(c.config.Retry)
socket, _, err := AcquireSocket(
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
if err != nil {
// socket acquisition failed
c.state = StateDisconnected
if c.logger != nil {
c.logger.Error("connection failed", "error", err)
@@ -204,231 +224,32 @@ func (c *Connection) Connect(ctx context.Context) error {
return NewConnectionError(err)
}
// got socket
c.socket = socket
c.state = StateConnected
// initialize
if c.config.CloseHandler != nil {
c.socket.SetCloseHandler(c.config.CloseHandler)
}
c.setupPongHandler()
if c.config.PingInterval > 0 {
c.wg.Go(c.startPinger)
}
c.wg.Go(c.startReader)
// connected
c.state = StateConnected
if c.logger != nil {
c.logger.Info("connected")
}
c.setupPongHandler()
c.startPinger()
c.startReader()
return nil
}
func (c *Connection) Close() {
c.shutdownExternal()
}
func (c *Connection) shutdownExternal() {
err := c.shutdownSetClosed(true)
if err != nil {
return
}
c.shutdownInner()
c.shutdownCleanup()
}
func (c *Connection) shutdownInternal() {
err := c.shutdownSetClosed(false)
if err != nil {
return
}
c.shutdownInner()
// defer final cleanup to allow this function to return
// otherwise, a deadlock occurs where startReader triggers a shutdown and
// must wait for itself to exit.
go func() {
c.shutdownCleanup()
}()
}
func (c *Connection) shutdownInner() {
c.shutdownSignalDone()
c.shutdownLogStart()
c.shutdownCloseSocket()
}
func (c *Connection) shutdownCleanup() {
c.cleanupOnce.Do(func() {
c.wg.Wait()
c.shutdownCloseChannels()
c.shutdownLogComplete()
})
}
func (c *Connection) shutdownSetClosed(wait bool) error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return NewConnectionError(ErrConnectionClosed)
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
return nil
}
func (c *Connection) shutdownSignalDone() {
c.doneOnce.Do(func() {
close(c.done)
})
}
func (c *Connection) shutdownLogStart() {
if c.logger != nil {
c.logger.Info("closing")
}
}
func (c *Connection) shutdownCloseSocket() {
if c.socket != nil {
// force unblock of any network operations immediately
expired := time.Now().Add(-1 * time.Minute)
c.socket.SetReadDeadline(expired)
c.socket.SetWriteDeadline(expired)
// close socket
err := c.socket.Close()
if err != nil && c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
}
}
func (c *Connection) shutdownCloseChannels() {
close(c.incoming)
close(c.errors)
}
func (c *Connection) shutdownLogComplete() {
if c.logger != nil {
c.logger.Info("closed")
}
}
func (c *Connection) startReader() {
c.wg.Go(func() {
defer c.shutdownInternal()
for {
select {
case <-c.done:
return
default:
messageType, data, err := c.socket.ReadMessage()
if err != nil {
var wrappedErr error
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
if c.logger != nil {
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
default:
if c.logger != nil {
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
}
} else {
isLocalClose := false
select {
case <-c.done:
isLocalClose = true
default:
}
if c.logger != nil {
if isLocalClose {
c.logger.Debug("read loop terminated", "error", err)
} else {
c.logger.Error("read error", "error", err)
}
}
wrappedErr = fmt.Errorf("%w: %w", ErrReadError, err)
}
select {
case <-c.done:
case c.errors <- wrappedErr:
}
return
}
if messageType == websocket.TextMessage ||
messageType == websocket.BinaryMessage {
select {
case <-c.done:
return
case c.incoming <- data:
c.incomingCount.Add(1)
}
}
}
}
})
}
func (c *Connection) setupPongHandler() {
c.socket.SetPongHandler(func(appData string) error {
select {
case c.heartbeat <- struct{}{}:
c.heartbeatCount.Add(1)
default:
}
return nil
})
}
func (c *Connection) startPinger() {
if c.config.PingInterval <= 0 {
return
}
c.wg.Go(func() {
defer c.shutdownInternal()
// Calculate 10% jitter window
jitter := c.config.PingInterval / 10
for {
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
next := c.config.PingInterval + offset
timer := time.NewTimer(next)
select {
case <-c.done:
timer.Stop()
return
case <-timer.C:
deadline := time.Now().Add(c.config.WriteTimeout)
if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
return
}
}
}
})
}
func (c *Connection) Send(data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
@@ -437,6 +258,7 @@ func (c *Connection) Send(data []byte) error {
return NewConnectionError(ErrConnectionClosed)
}
// setup
if c.config.WriteTimeout > 0 {
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
if c.logger != nil {
@@ -446,7 +268,10 @@ func (c *Connection) Send(data []byte) error {
}
}
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
// send
err := c.socket.WriteMessage(websocket.TextMessage, data)
if err != nil {
if c.logger != nil {
c.logger.Error("write error", "error", err)
}
@@ -489,3 +314,211 @@ func (c *Connection) Stats() ConnectionStats {
func (c *Connection) SetDialer(d types.Dialer) {
c.dialer = d
}
// ---------------------------/
// Reader loop
// -------------------------/
func (c *Connection) startReader() {
defer c.shutdownInternal()
for {
select {
case <-c.done:
return
default:
messageType, data, err := c.socket.ReadMessage()
if err != nil {
select {
case <-c.done:
case c.errors <- c.classifyCloseError(err):
}
return
}
if messageType == websocket.TextMessage ||
messageType == websocket.BinaryMessage {
select {
case <-c.done:
return
case c.incoming <- data:
c.incomingCount.Add(1)
}
}
}
}
}
func (c *Connection) classifyCloseError(err error) error {
var classifiedError error
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
if c.logger != nil {
c.logger.Info("connection closed by peer",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
default:
if c.logger != nil {
c.logger.Error("unexpected close",
"code", closeErr.Code,
"text", closeErr.Text,
)
}
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
}
} else {
isLocalClose := false
select {
case <-c.done:
isLocalClose = true
default:
}
if c.logger != nil {
if isLocalClose {
c.logger.Debug("read loop terminated", "error", err)
} else {
c.logger.Error("read error", "error", err)
}
}
classifiedError = fmt.Errorf("%w: %w", ErrReadError, err)
}
return classifiedError
}
// ---------------------------/
// Heartbeat Handling
// -------------------------/
func (c *Connection) setupPongHandler() {
c.socket.SetPongHandler(func(appData string) error {
select {
case c.heartbeat <- struct{}{}:
c.heartbeatCount.Add(1)
default:
}
return nil
})
}
func (c *Connection) startPinger() {
defer c.shutdownInternal()
// Calculate 10% jitter window
jitter := c.config.PingInterval / 10
for {
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
next := c.config.PingInterval + offset
timer := time.NewTimer(next)
select {
case <-c.done:
timer.Stop()
return
case <-timer.C:
deadline := time.Now().Add(c.config.WriteTimeout)
err := c.socket.WriteControl(websocket.PingMessage, nil, deadline)
if err != nil {
return
}
}
}
}
// ---------------------------/
// Shutdown
// -------------------------/
func (c *Connection) Close() {
c.shutdownExternal()
}
func (c *Connection) shutdownExternal() {
// set closed
c.mu.Lock()
if c.closed {
// idempotent shutdown
c.mu.Unlock()
return
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
// perform shutdown
c.shutdownInner()
c.shutdownCleanup()
}
// shutdownInternal defers final cleanup to allow it to return.
// Otherwise, a deadlock occurs where startReader triggers a shutdown and
// must wait for itself to exit.
func (c *Connection) shutdownInternal() {
// set closed
c.mu.Lock()
if c.closed {
// idempotent shutdown
c.mu.Unlock()
return
}
c.closed = true
c.state = StateClosed
c.mu.Unlock()
// perform shutdown
c.shutdownInner()
// defer cleanup to avoid deadlock
go func() {
c.shutdownCleanup()
}()
}
func (c *Connection) shutdownInner() {
c.doneOnce.Do(func() {
close(c.done)
})
if c.logger != nil {
c.logger.Info("closing")
}
if c.socket != nil {
// force unblock of any network operations immediately
expired := time.Now().Add(-1 * time.Minute)
c.socket.SetReadDeadline(expired)
c.socket.SetWriteDeadline(expired)
// close socket
err := c.socket.Close()
if err != nil && c.logger != nil {
c.logger.Error("socket close failed", "error", err)
}
}
}
func (c *Connection) shutdownCleanup() {
c.cleanupOnce.Do(func() {
c.wg.Wait()
close(c.incoming)
close(c.errors)
if c.logger != nil {
c.logger.Info("closed")
}
})
}
+4
View File
@@ -69,6 +69,7 @@ func AcquireSocket(
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
}
// dial
socket, resp, err := dialer.DialContext(ctx, url, header)
if err == nil {
if logger != nil {
@@ -77,7 +78,9 @@ func AcquireSocket(
return socket, resp, nil
}
// dial failed, retry
if !retryMgr.ShouldRetry() {
// retry policy expired
if logger != nil {
logger.Error("dial failed, max retries reached",
"error", err,
@@ -95,6 +98,7 @@ func AcquireSocket(
"next_delay", delay)
}
// context cancellable backoff
select {
case <-time.After(delay):
case <-ctx.Done():
+169 -111
View File
@@ -9,10 +9,22 @@ import (
"git.wisehodl.dev/jay/go-honeybee/transport"
"git.wisehodl.dev/jay/go-honeybee/types"
component "git.wisehodl.dev/jay/go-mana-component"
"git.wisehodl.dev/jay/go-mana-component"
)
// ----------------------------------------------------------------------------
// Worker
// ----------------------------------------------------------------------------
// ---------------------------/
// Types
// -------------------------/
type WorkerFactory func(
ctx context.Context,
id string,
handler slog.Handler,
) (Worker, error)
type Worker interface {
Start(pool PoolPlugin)
@@ -37,19 +49,23 @@ type DefaultWorker struct {
id string
conn atomic.Pointer[transport.Connection]
heartbeat chan struct{}
sendHeartbeat chan struct{}
ctx context.Context
cancel context.CancelFunc
config *WorkerConfig
handler slog.Handler
logger *slog.Logger
processedCount *atomic.Uint64
outgoingCount *atomic.Uint64
restartCount *atomic.Uint64
config *WorkerConfig
ctx context.Context
cancel context.CancelFunc
handler slog.Handler
logger *slog.Logger
}
// ---------------------------/
// Constructor
// -------------------------/
func NewWorker(
ctx context.Context,
id string,
@@ -78,20 +94,27 @@ func NewWorker(
wctx, wcancel := context.WithCancel(ctx)
w := &DefaultWorker{
id: id,
sendHeartbeat: make(chan struct{}),
ctx: wctx,
cancel: wcancel,
config: config,
heartbeat: make(chan struct{}),
handler: handler,
logger: logger,
processedCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
restartCount: &atomic.Uint64{},
ctx: wctx,
cancel: wcancel,
handler: handler,
logger: logger,
}
return w, nil
}
// ---------------------------/
// Session
// -------------------------/
func (w *DefaultWorker) Start(pool PoolPlugin) {
if w.logger != nil {
w.logger.Debug("starting")
@@ -114,71 +137,19 @@ func (w *DefaultWorker) Start(pool PoolPlugin) {
}
func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
newConn := make(chan *transport.Connection, 1)
var timer *time.Timer
if w.config.KeepaliveTimeout > 0 {
if w.logger != nil {
w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout)
}
timer = time.NewTimer(w.config.KeepaliveTimeout)
defer timer.Stop()
} else {
if w.logger != nil {
w.logger.Debug("keepalive: disabled")
}
}
resetTimer := func() {
if timer == nil {
return
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(w.config.KeepaliveTimeout)
}
timerC := func() <-chan time.Time {
if timer == nil {
return nil
}
return timer.C
}
// setup dialer
var dialCancel context.CancelFunc
newConn := make(chan *transport.Connection, 1)
spawnDialer := func() { dialCancel = w.spawnDialer(ctx, dialCancel, newConn, pool) }
spawnDial := func() {
if dialCancel != nil {
dialCancel()
}
var dialCtx context.Context
dialCtx, dialCancel = context.WithCancel(ctx)
if w.logger != nil {
w.logger.Debug("session: requesting connection")
}
go func() {
conn, err := connect(w.id, dialCtx, pool, w.handler)
if err != nil {
if w.logger != nil {
w.logger.Warn("dialer: dial failed")
}
return
}
select {
case newConn <- conn:
case <-dialCtx.Done():
conn.Close()
}
}()
}
// setup heartbeat
timer, timerC, heartbeat := w.setupHeartbeat()
defer timer.Stop()
// main loop
for {
// spawn initial dial for this reconnect cycle
spawnDial()
spawnDialer()
// obtain new connection
var conn *transport.Connection
@@ -190,19 +161,22 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
dialCancel()
}
return
case <-w.heartbeat:
resetTimer()
case <-timerC():
if w.logger != nil {
w.logger.Info("keepalive: no activity observed")
}
timer.Reset(w.config.KeepaliveTimeout)
spawnDial()
case conn = <-newConn:
if w.logger != nil {
w.logger.Debug("session: connected")
}
break preConn
case <-w.sendHeartbeat:
heartbeat()
case <-timerC():
if w.logger != nil {
w.logger.Info("keepalive: no activity observed")
}
timer.Reset(w.config.KeepaliveTimeout)
spawnDialer()
}
}
@@ -220,14 +194,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
select {
case <-ctx.Done():
break conn_loop
case <-w.heartbeat:
resetTimer()
case <-timerC():
if w.logger != nil {
w.logger.Info("keepalive: no activity observed")
}
timer.Reset(w.config.KeepaliveTimeout)
break conn_loop
case data, ok := <-conn.Incoming():
if !ok {
if w.logger != nil {
@@ -235,20 +202,34 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
}
break conn_loop
}
pool.Inbox <- types.InboxMessage{
ID: w.id,
Data: data,
ReceivedAt: time.Now(),
}
resetTimer()
ID: w.id, Data: data, ReceivedAt: time.Now()}
pool.InboxCounter.Add(1)
w.processedCount.Add(1)
heartbeat()
case <-conn.Heartbeat():
if w.logger != nil {
w.logger.Debug("ping-pong heartbeat")
}
resetTimer()
heartbeat()
case <-w.sendHeartbeat:
heartbeat()
case <-timerC():
if w.logger != nil {
w.logger.Info("keepalive: no activity observed")
}
timer.Reset(w.config.KeepaliveTimeout)
break conn_loop
}
}
// session ended
conn.Close()
if w.logger != nil {
@@ -272,6 +253,98 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
}
}
func (w *DefaultWorker) setupHeartbeat() (
timer *time.Timer, timerC func() <-chan time.Time, heartbeat func(),
) {
if w.config.KeepaliveTimeout > 0 {
if w.logger != nil {
w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout)
}
timer = time.NewTimer(w.config.KeepaliveTimeout)
} else {
if w.logger != nil {
w.logger.Debug("keepalive: disabled")
}
}
heartbeat = func() {
if timer == nil {
return
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(w.config.KeepaliveTimeout)
}
timerC = func() <-chan time.Time {
if timer == nil {
return nil
}
return timer.C
}
return
}
func (w *DefaultWorker) spawnDialer(
ctx context.Context,
dialCancel context.CancelFunc,
newConn chan<- *transport.Connection,
pool PoolPlugin,
) context.CancelFunc {
if dialCancel != nil {
dialCancel()
}
dialCtx, dialCancel := context.WithCancel(ctx)
if w.logger != nil {
w.logger.Debug("session: requesting connection")
}
go func() {
conn, err := connect(w.id, dialCtx, pool, w.handler)
if err != nil {
if w.logger != nil {
w.logger.Warn("dialer: dial failed", "error", err)
}
return
}
select {
case newConn <- conn:
case <-dialCtx.Done():
conn.Close()
}
}()
return dialCancel
}
func connect(
id string,
ctx context.Context,
pool PoolPlugin,
handler slog.Handler,
) (*transport.Connection, error) {
conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler)
if err != nil {
return nil, err
}
conn.SetDialer(pool.Dialer)
return conn, conn.Connect(ctx)
}
// ---------------------------/
// Methods
// -------------------------/
func (w *DefaultWorker) Stop() {
if w.logger != nil {
w.logger.Debug("shutting down")
@@ -291,7 +364,7 @@ func (w *DefaultWorker) Send(data []byte) error {
}
select {
case w.heartbeat <- struct{}{}:
case w.sendHeartbeat <- struct{}{}:
case <-w.ctx.Done():
}
@@ -324,18 +397,3 @@ func (w *DefaultWorker) Stats() WorkerStats {
TotalSent: w.outgoingCount.Load(),
}
}
func connect(
id string,
ctx context.Context,
pool PoolPlugin,
handler slog.Handler,
) (*transport.Connection, error) {
conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler)
if err != nil {
return nil, err
}
conn.SetDialer(pool.Dialer)
return conn, conn.Connect(ctx)
}
-117
View File
@@ -1,117 +0,0 @@
package honeybee
import (
"context"
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"github.com/stretchr/testify/assert"
"sync/atomic"
"testing"
)
func TestWorkerSend(t *testing.T) {
t.Run("data sent to mock socket", func(t *testing.T) {
conn, _, _, outgoingData := setupTestConnection(t)
defer conn.Close()
ctx, cancel := context.WithCancel(context.Background())
heartbeat := make(chan struct{})
heartbeatCount := atomic.Int32{}
w := &DefaultWorker{
ctx: ctx,
cancel: cancel,
id: "wss://test",
heartbeat: heartbeat,
outgoingCount: &atomic.Uint64{},
}
w.conn.Store(conn)
defer w.cancel()
go func() {
for range heartbeat {
heartbeatCount.Add(1)
}
}()
testData := []byte("hello")
err := w.Send(testData)
assert.NoError(t, err)
// at least one heartbeat was sent
honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() >= 1
}, "expected heartbeats")
// message was sent by the socket
honeybeetest.Eventually(t, func() bool {
select {
case msg := <-outgoingData:
return string(msg.Data) == "hello"
default:
return false
}
}, "expected message")
})
t.Run("sends one heartbeat per successful send", func(t *testing.T) {
conn, _, _, _ := setupTestConnection(t)
defer conn.Close()
ctx, cancel := context.WithCancel(context.Background())
heartbeat := make(chan struct{})
heartbeatCount := atomic.Int32{}
w := &DefaultWorker{
ctx: ctx,
cancel: cancel,
id: "wss://test",
heartbeat: heartbeat,
outgoingCount: &atomic.Uint64{},
}
w.conn.Store(conn)
defer w.cancel()
go func() {
for range heartbeat {
heartbeatCount.Add(1)
}
}()
const count = 3
for i := range count {
err := w.Send(fmt.Appendf(nil, "msg-%d", i))
assert.NoError(t, err)
}
honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() == count
}, "expected heartbeats")
})
t.Run("returns error if connection is unavailable", func(t *testing.T) {
// no connection available to worker
ctx, cancel := context.WithCancel(context.Background())
heartbeat := make(chan struct{})
w := &DefaultWorker{
ctx: ctx,
cancel: cancel,
id: "wss://test",
heartbeat: heartbeat,
}
defer w.cancel()
go func() {
for range heartbeat {
}
}()
err := w.Send([]byte("hello"))
assert.ErrorIs(t, err, ErrConnectionUnavailable)
})
}
+113 -5
View File
@@ -3,6 +3,7 @@ package honeybee
import (
"context"
"errors"
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/transport"
"git.wisehodl.dev/jay/go-honeybee/types"
@@ -41,7 +42,7 @@ func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *D
cancel: cancel,
id: "wss://test",
config: config,
heartbeat: make(chan struct{}),
sendHeartbeat: make(chan struct{}),
processedCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
restartCount: &atomic.Uint64{},
@@ -134,7 +135,7 @@ func TestWorkerSession(t *testing.T) {
cancel: cancel,
id: "wss://test",
config: config,
heartbeat: make(chan struct{}),
sendHeartbeat: make(chan struct{}),
processedCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
restartCount: &atomic.Uint64{},
@@ -303,7 +304,7 @@ func TestWorkerSession(t *testing.T) {
cancel: cancel,
id: "wss://test",
config: config,
heartbeat: make(chan struct{}),
sendHeartbeat: make(chan struct{}),
processedCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
restartCount: &atomic.Uint64{},
@@ -365,7 +366,7 @@ func TestWorkerSession(t *testing.T) {
cancel: cancel,
id: "wss://test",
config: config,
heartbeat: make(chan struct{}),
sendHeartbeat: make(chan struct{}),
processedCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
restartCount: &atomic.Uint64{},
@@ -431,7 +432,7 @@ func TestWorkerSession(t *testing.T) {
cancel: cancel,
id: "wss://test",
config: config,
heartbeat: make(chan struct{}),
sendHeartbeat: make(chan struct{}),
processedCount: &atomic.Uint64{},
outgoingCount: &atomic.Uint64{},
restartCount: &atomic.Uint64{},
@@ -638,3 +639,110 @@ func TestWorkerSession(t *testing.T) {
}, "expected wg to drain after parent cancel")
})
}
func TestWorkerSend(t *testing.T) {
t.Run("data sent to mock socket", func(t *testing.T) {
conn, _, _, outgoingData := setupTestConnection(t)
defer conn.Close()
ctx, cancel := context.WithCancel(context.Background())
heartbeat := make(chan struct{})
heartbeatCount := atomic.Int32{}
w := &DefaultWorker{
ctx: ctx,
cancel: cancel,
id: "wss://test",
sendHeartbeat: heartbeat,
outgoingCount: &atomic.Uint64{},
}
w.conn.Store(conn)
defer w.cancel()
go func() {
for range heartbeat {
heartbeatCount.Add(1)
}
}()
testData := []byte("hello")
err := w.Send(testData)
assert.NoError(t, err)
// at least one heartbeat was sent
honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() >= 1
}, "expected heartbeats")
// message was sent by the socket
honeybeetest.Eventually(t, func() bool {
select {
case msg := <-outgoingData:
return string(msg.Data) == "hello"
default:
return false
}
}, "expected message")
})
t.Run("sends one heartbeat per successful send", func(t *testing.T) {
conn, _, _, _ := setupTestConnection(t)
defer conn.Close()
ctx, cancel := context.WithCancel(context.Background())
heartbeat := make(chan struct{})
heartbeatCount := atomic.Int32{}
w := &DefaultWorker{
ctx: ctx,
cancel: cancel,
id: "wss://test",
sendHeartbeat: heartbeat,
outgoingCount: &atomic.Uint64{},
}
w.conn.Store(conn)
defer w.cancel()
go func() {
for range heartbeat {
heartbeatCount.Add(1)
}
}()
const count = 3
for i := range count {
err := w.Send(fmt.Appendf(nil, "msg-%d", i))
assert.NoError(t, err)
}
honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() == count
}, "expected heartbeats")
})
t.Run("returns error if connection is unavailable", func(t *testing.T) {
// no connection available to worker
ctx, cancel := context.WithCancel(context.Background())
heartbeat := make(chan struct{})
w := &DefaultWorker{
ctx: ctx,
cancel: cancel,
id: "wss://test",
sendHeartbeat: heartbeat,
}
defer w.cancel()
go func() {
for range heartbeat {
}
}()
err := w.Send([]byte("hello"))
assert.ErrorIs(t, err, ErrConnectionUnavailable)
})
}