cleanup and refactors
This commit is contained in:
@@ -1,21 +1,15 @@
|
||||
package honeybee
|
||||
|
||||
import (
|
||||
"context"
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Types
|
||||
|
||||
type WorkerFactory func(
|
||||
ctx context.Context,
|
||||
id string,
|
||||
handler slog.Handler,
|
||||
) (Worker, error)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pool Config
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Types
|
||||
|
||||
type PoolConfig struct {
|
||||
InboxBufferSize int
|
||||
@@ -27,6 +21,8 @@ type PoolConfig struct {
|
||||
|
||||
type PoolOption func(*PoolConfig) error
|
||||
|
||||
// Constructor
|
||||
|
||||
func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) {
|
||||
conf := GetDefaultPoolConfig()
|
||||
if err := applyPoolOptions(conf, options...); err != nil {
|
||||
@@ -57,6 +53,8 @@ func applyPoolOptions(config *PoolConfig, options ...PoolOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validation
|
||||
|
||||
func ValidatePoolConfig(config *PoolConfig) error {
|
||||
var err error
|
||||
|
||||
@@ -84,6 +82,8 @@ func validateBufferSize(value int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Options
|
||||
|
||||
func WithInboxBufferSize(value int) PoolOption {
|
||||
return func(c *PoolConfig) error {
|
||||
if err := validateBufferSize(value); err != nil {
|
||||
@@ -133,7 +133,11 @@ func WithWorkerFactory(wf WorkerFactory) PoolOption {
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Worker Config
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Types
|
||||
|
||||
type WorkerConfig struct {
|
||||
KeepaliveTimeout time.Duration
|
||||
@@ -142,6 +146,8 @@ type WorkerConfig struct {
|
||||
|
||||
type WorkerOption func(*WorkerConfig) error
|
||||
|
||||
// Constructor
|
||||
|
||||
func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) {
|
||||
conf := GetDefaultWorkerConfig()
|
||||
if err := applyWorkerOptions(conf, options...); err != nil {
|
||||
@@ -169,6 +175,8 @@ func applyWorkerOptions(config *WorkerConfig, options ...WorkerOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validation
|
||||
|
||||
func ValidateWorkerConfig(config *WorkerConfig) error {
|
||||
err := validateKeepaliveTimeout(config.KeepaliveTimeout)
|
||||
if err != nil {
|
||||
@@ -192,6 +200,8 @@ func validateReconnectDelay(value time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Options
|
||||
|
||||
// When KeepaliveTimeout is set to zero, keepalive timeouts are disabled.
|
||||
func WithKeepaliveTimeout(value time.Duration) WorkerOption {
|
||||
return func(c *WorkerConfig) error {
|
||||
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Constants
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
const (
|
||||
TestTimeout = 2 * time.Second
|
||||
@@ -18,7 +20,9 @@ const (
|
||||
NegativeTestTimeout = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type MockIncomingData struct {
|
||||
MsgType int
|
||||
@@ -37,7 +41,9 @@ type ExpectedLog struct {
|
||||
Attrs map[string]any
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Setup
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func SetupTestSocket(t *testing.T) (
|
||||
socket *MockSocket,
|
||||
@@ -81,7 +87,9 @@ func SetupTestSocket(t *testing.T) (
|
||||
return
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) {
|
||||
t.Helper()
|
||||
@@ -126,7 +134,9 @@ func Never(t *testing.T, condition func() bool, msg string) {
|
||||
assert.Never(t, condition, NegativeTestTimeout, TestTick, msg)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Logging Helpers
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func AssertLogSequence(t *testing.T, records []slog.Record, expected []ExpectedLog) {
|
||||
t.Helper()
|
||||
|
||||
@@ -9,12 +9,16 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Re-exported types for consumer convenience
|
||||
// ----------------------------------------------------------------------------
|
||||
// Re-exports
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Socket = types.Socket
|
||||
type Dialer = types.Dialer
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Dialer Mocks
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type MockDialer struct {
|
||||
DialContextFunc func(
|
||||
@@ -28,7 +32,9 @@ func (m *MockDialer) DialContext(
|
||||
return m.DialContextFunc(ctx, url, h)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Socket Mocks
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type MockSocket struct {
|
||||
WriteMessageFunc func(int, []byte) error
|
||||
@@ -93,7 +99,9 @@ func (m *MockSocket) SetPongHandler(h func(s string) error) {
|
||||
m.SetPongHandlerFunc(h)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Logging mocks
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type MockSlogHandler struct {
|
||||
records *[]slog.Record
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
component "git.wisehodl.dev/jay/go-mana-component"
|
||||
"git.wisehodl.dev/jay/go-mana-component"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -19,7 +19,9 @@ type Dialer = types.Dialer
|
||||
|
||||
var NormalizeURL = transport.NormalizeURL
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type PoolEventKind string
|
||||
|
||||
@@ -58,7 +60,9 @@ type PoolPlugin struct {
|
||||
ConnectionConfig *transport.ConnectionConfig
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pool
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Peer struct {
|
||||
id string
|
||||
@@ -66,24 +70,23 @@ type Peer struct {
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
peers map[string]*Peer
|
||||
inbox chan types.InboxMessage
|
||||
events chan PoolEvent
|
||||
|
||||
inboxCounter *atomic.Uint64
|
||||
outgoingCount *atomic.Uint64
|
||||
closed bool
|
||||
|
||||
dialer types.Dialer
|
||||
config *PoolConfig
|
||||
handler slog.Handler
|
||||
logger *slog.Logger
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
|
||||
inboxCounter *atomic.Uint64
|
||||
outgoingCount *atomic.Uint64
|
||||
}
|
||||
|
||||
func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
|
||||
@@ -106,26 +109,29 @@ func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool"))
|
||||
ctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool"))
|
||||
|
||||
var logger *slog.Logger
|
||||
if handler != nil {
|
||||
c := component.FromContext(pctx)
|
||||
c := component.FromContext(ctx)
|
||||
logger = slog.New(handler).With(slog.Any("component", c))
|
||||
}
|
||||
|
||||
return &Pool{
|
||||
ctx: pctx,
|
||||
cancel: cancel,
|
||||
peers: make(map[string]*Peer),
|
||||
inbox: make(chan types.InboxMessage, config.InboxBufferSize),
|
||||
events: make(chan PoolEvent, config.EventsBufferSize),
|
||||
peers: make(map[string]*Peer),
|
||||
inbox: make(chan types.InboxMessage, config.InboxBufferSize),
|
||||
events: make(chan PoolEvent, config.EventsBufferSize),
|
||||
|
||||
dialer: transport.NewDialer(),
|
||||
config: config,
|
||||
handler: handler,
|
||||
logger: logger,
|
||||
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
|
||||
inboxCounter: &atomic.Uint64{},
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
dialer: transport.NewDialer(),
|
||||
config: config,
|
||||
handler: handler,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Connection Config
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Types
|
||||
|
||||
type CloseHandler func(code int, text string) error
|
||||
|
||||
type ConnectionConfig struct {
|
||||
@@ -26,6 +32,8 @@ type RetryConfig struct {
|
||||
|
||||
type ConnectionOption func(*ConnectionConfig) error
|
||||
|
||||
// Constructors
|
||||
|
||||
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
|
||||
conf := GetDefaultConnectionConfig()
|
||||
if err := applyConnectionOptions(conf, options...); err != nil {
|
||||
@@ -69,6 +77,8 @@ func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOptio
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validation
|
||||
|
||||
func ValidateConnectionConfig(config *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(config.WriteTimeout)
|
||||
if err != nil {
|
||||
@@ -153,6 +163,8 @@ func validateJitterFactor(value float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Options
|
||||
|
||||
func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.CloseHandler = handler
|
||||
|
||||
+263
-230
@@ -12,10 +12,14 @@ import (
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
component "git.wisehodl.dev/jay/go-mana-component"
|
||||
"git.wisehodl.dev/jay/go-mana-component"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type ConnectionState int
|
||||
|
||||
const (
|
||||
@@ -49,6 +53,14 @@ type ConnectionStats struct {
|
||||
TotalHeartbeats uint64
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Connection
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// ---------------------------/
|
||||
// Constructors
|
||||
// -------------------------/
|
||||
|
||||
type Connection struct {
|
||||
url *url.URL
|
||||
dialer types.Dialer
|
||||
@@ -95,18 +107,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
|
||||
ctx = component.MustExtend(ctx, "connection")
|
||||
}
|
||||
|
||||
var logger *slog.Logger
|
||||
if handler != nil {
|
||||
c := component.FromContext(ctx)
|
||||
logger = slog.New(handler).With(slog.Any("component", c))
|
||||
}
|
||||
|
||||
conn := &Connection{
|
||||
url: url,
|
||||
dialer: NewDialer(),
|
||||
socket: nil,
|
||||
config: config,
|
||||
logger: logger,
|
||||
incoming: make(chan []byte, config.IncomingBufferSize),
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
errors: make(chan error, config.ErrorsBufferSize),
|
||||
@@ -117,6 +122,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if handler != nil {
|
||||
comp := component.FromContext(ctx)
|
||||
conn.logger = slog.New(handler).With(slog.Any("component", comp))
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -141,18 +151,11 @@ func NewConnectionFromSocket(
|
||||
ctx = component.MustExtend(ctx, "connection")
|
||||
}
|
||||
|
||||
var logger *slog.Logger
|
||||
if handler != nil {
|
||||
c := component.FromContext(ctx)
|
||||
logger = slog.New(handler).With(slog.Any("component", c))
|
||||
}
|
||||
|
||||
conn := &Connection{
|
||||
url: nil,
|
||||
dialer: nil,
|
||||
socket: socket,
|
||||
config: config,
|
||||
logger: logger,
|
||||
incoming: make(chan []byte, config.IncomingBufferSize),
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
errors: make(chan error, config.ErrorsBufferSize),
|
||||
@@ -163,17 +166,31 @@ func NewConnectionFromSocket(
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if handler != nil {
|
||||
comp := component.FromContext(ctx)
|
||||
conn.logger = slog.New(handler).With(slog.Any("component", comp))
|
||||
}
|
||||
|
||||
// initialize
|
||||
if config.CloseHandler != nil {
|
||||
socket.SetCloseHandler(config.CloseHandler)
|
||||
}
|
||||
|
||||
conn.setupPongHandler()
|
||||
conn.startPinger()
|
||||
conn.startReader()
|
||||
|
||||
if conn.config.PingInterval > 0 {
|
||||
conn.wg.Go(conn.startPinger)
|
||||
}
|
||||
|
||||
conn.wg.Go(conn.startReader)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Methods
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -186,17 +203,20 @@ func (c *Connection) Connect(ctx context.Context) error {
|
||||
return NewConnectionError(ErrConnectionClosed)
|
||||
}
|
||||
|
||||
// begin connecting
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("connecting")
|
||||
}
|
||||
|
||||
c.state = StateConnecting
|
||||
|
||||
// obtain socket
|
||||
retryMgr := NewRetryManager(c.config.Retry)
|
||||
socket, _, err := AcquireSocket(
|
||||
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
|
||||
|
||||
if err != nil {
|
||||
// socket acquisition failed
|
||||
c.state = StateDisconnected
|
||||
if c.logger != nil {
|
||||
c.logger.Error("connection failed", "error", err)
|
||||
@@ -204,231 +224,32 @@ func (c *Connection) Connect(ctx context.Context) error {
|
||||
return NewConnectionError(err)
|
||||
}
|
||||
|
||||
// got socket
|
||||
c.socket = socket
|
||||
c.state = StateConnected
|
||||
|
||||
// initialize
|
||||
if c.config.CloseHandler != nil {
|
||||
c.socket.SetCloseHandler(c.config.CloseHandler)
|
||||
}
|
||||
|
||||
c.setupPongHandler()
|
||||
|
||||
if c.config.PingInterval > 0 {
|
||||
c.wg.Go(c.startPinger)
|
||||
}
|
||||
|
||||
c.wg.Go(c.startReader)
|
||||
|
||||
// connected
|
||||
c.state = StateConnected
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connected")
|
||||
}
|
||||
|
||||
c.setupPongHandler()
|
||||
c.startPinger()
|
||||
c.startReader()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdownExternal()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownExternal() {
|
||||
err := c.shutdownSetClosed(true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
c.shutdownCleanup()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInternal() {
|
||||
err := c.shutdownSetClosed(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.shutdownInner()
|
||||
|
||||
// defer final cleanup to allow this function to return
|
||||
// otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||
// must wait for itself to exit.
|
||||
go func() {
|
||||
c.shutdownCleanup()
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInner() {
|
||||
c.shutdownSignalDone()
|
||||
c.shutdownLogStart()
|
||||
c.shutdownCloseSocket()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCleanup() {
|
||||
c.cleanupOnce.Do(func() {
|
||||
c.wg.Wait()
|
||||
c.shutdownCloseChannels()
|
||||
c.shutdownLogComplete()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSetClosed(wait bool) error {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return NewConnectionError(ErrConnectionClosed)
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownSignalDone() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogStart() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseSocket() {
|
||||
if c.socket != nil {
|
||||
// force unblock of any network operations immediately
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
c.socket.SetReadDeadline(expired)
|
||||
c.socket.SetWriteDeadline(expired)
|
||||
|
||||
// close socket
|
||||
err := c.socket.Close()
|
||||
|
||||
if err != nil && c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCloseChannels() {
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownLogComplete() {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) startReader() {
|
||||
c.wg.Go(func() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
var wrappedErr error
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
|
||||
default:
|
||||
if c.logger != nil {
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
|
||||
}
|
||||
} else {
|
||||
isLocalClose := false
|
||||
select {
|
||||
case <-c.done:
|
||||
isLocalClose = true
|
||||
default:
|
||||
}
|
||||
if c.logger != nil {
|
||||
if isLocalClose {
|
||||
c.logger.Debug("read loop terminated", "error", err)
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
wrappedErr = fmt.Errorf("%w: %w", ErrReadError, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
case c.errors <- wrappedErr:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.incoming <- data:
|
||||
c.incomingCount.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) setupPongHandler() {
|
||||
c.socket.SetPongHandler(func(appData string) error {
|
||||
select {
|
||||
case c.heartbeat <- struct{}{}:
|
||||
c.heartbeatCount.Add(1)
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) startPinger() {
|
||||
if c.config.PingInterval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
c.wg.Go(func() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
// Calculate 10% jitter window
|
||||
jitter := c.config.PingInterval / 10
|
||||
|
||||
for {
|
||||
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
|
||||
next := c.config.PingInterval + offset
|
||||
timer := time.NewTimer(next)
|
||||
select {
|
||||
case <-c.done:
|
||||
timer.Stop()
|
||||
return
|
||||
case <-timer.C:
|
||||
deadline := time.Now().Add(c.config.WriteTimeout)
|
||||
if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Send(data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
@@ -437,6 +258,7 @@ func (c *Connection) Send(data []byte) error {
|
||||
return NewConnectionError(ErrConnectionClosed)
|
||||
}
|
||||
|
||||
// setup
|
||||
if c.config.WriteTimeout > 0 {
|
||||
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||
if c.logger != nil {
|
||||
@@ -446,7 +268,10 @@ func (c *Connection) Send(data []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
// send
|
||||
err := c.socket.WriteMessage(websocket.TextMessage, data)
|
||||
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("write error", "error", err)
|
||||
}
|
||||
@@ -489,3 +314,211 @@ func (c *Connection) Stats() ConnectionStats {
|
||||
func (c *Connection) SetDialer(d types.Dialer) {
|
||||
c.dialer = d
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Reader loop
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) startReader() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-c.done:
|
||||
case c.errors <- c.classifyCloseError(err):
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case c.incoming <- data:
|
||||
c.incomingCount.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) classifyCloseError(err error) error {
|
||||
var classifiedError error
|
||||
var closeErr *websocket.CloseError
|
||||
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
|
||||
|
||||
default:
|
||||
if c.logger != nil {
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
|
||||
}
|
||||
|
||||
} else {
|
||||
isLocalClose := false
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
isLocalClose = true
|
||||
default:
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
if isLocalClose {
|
||||
c.logger.Debug("read loop terminated", "error", err)
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
classifiedError = fmt.Errorf("%w: %w", ErrReadError, err)
|
||||
}
|
||||
|
||||
return classifiedError
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Heartbeat Handling
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) setupPongHandler() {
|
||||
c.socket.SetPongHandler(func(appData string) error {
|
||||
select {
|
||||
case c.heartbeat <- struct{}{}:
|
||||
c.heartbeatCount.Add(1)
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Connection) startPinger() {
|
||||
defer c.shutdownInternal()
|
||||
|
||||
// Calculate 10% jitter window
|
||||
jitter := c.config.PingInterval / 10
|
||||
|
||||
for {
|
||||
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
|
||||
next := c.config.PingInterval + offset
|
||||
timer := time.NewTimer(next)
|
||||
select {
|
||||
case <-c.done:
|
||||
timer.Stop()
|
||||
return
|
||||
case <-timer.C:
|
||||
deadline := time.Now().Add(c.config.WriteTimeout)
|
||||
err := c.socket.WriteControl(websocket.PingMessage, nil, deadline)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Shutdown
|
||||
// -------------------------/
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdownExternal()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownExternal() {
|
||||
// set closed
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
// idempotent shutdown
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
|
||||
// perform shutdown
|
||||
c.shutdownInner()
|
||||
c.shutdownCleanup()
|
||||
}
|
||||
|
||||
// shutdownInternal defers final cleanup to allow it to return.
|
||||
// Otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||
// must wait for itself to exit.
|
||||
func (c *Connection) shutdownInternal() {
|
||||
// set closed
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
// idempotent shutdown
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
c.mu.Unlock()
|
||||
|
||||
// perform shutdown
|
||||
c.shutdownInner()
|
||||
|
||||
// defer cleanup to avoid deadlock
|
||||
go func() {
|
||||
c.shutdownCleanup()
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownInner() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing")
|
||||
}
|
||||
|
||||
if c.socket != nil {
|
||||
// force unblock of any network operations immediately
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
c.socket.SetReadDeadline(expired)
|
||||
c.socket.SetWriteDeadline(expired)
|
||||
|
||||
// close socket
|
||||
err := c.socket.Close()
|
||||
|
||||
if err != nil && c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) shutdownCleanup() {
|
||||
c.cleanupOnce.Do(func() {
|
||||
c.wg.Wait()
|
||||
|
||||
close(c.incoming)
|
||||
close(c.errors)
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -69,6 +69,7 @@ func AcquireSocket(
|
||||
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
|
||||
}
|
||||
|
||||
// dial
|
||||
socket, resp, err := dialer.DialContext(ctx, url, header)
|
||||
if err == nil {
|
||||
if logger != nil {
|
||||
@@ -77,7 +78,9 @@ func AcquireSocket(
|
||||
return socket, resp, nil
|
||||
}
|
||||
|
||||
// dial failed, retry
|
||||
if !retryMgr.ShouldRetry() {
|
||||
// retry policy expired
|
||||
if logger != nil {
|
||||
logger.Error("dial failed, max retries reached",
|
||||
"error", err,
|
||||
@@ -95,6 +98,7 @@ func AcquireSocket(
|
||||
"next_delay", delay)
|
||||
}
|
||||
|
||||
// context cancellable backoff
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -9,10 +9,22 @@ import (
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
component "git.wisehodl.dev/jay/go-mana-component"
|
||||
"git.wisehodl.dev/jay/go-mana-component"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Worker
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// ---------------------------/
|
||||
// Types
|
||||
// -------------------------/
|
||||
|
||||
type WorkerFactory func(
|
||||
ctx context.Context,
|
||||
id string,
|
||||
handler slog.Handler,
|
||||
) (Worker, error)
|
||||
|
||||
type Worker interface {
|
||||
Start(pool PoolPlugin)
|
||||
@@ -37,19 +49,23 @@ type DefaultWorker struct {
|
||||
id string
|
||||
conn atomic.Pointer[transport.Connection]
|
||||
|
||||
heartbeat chan struct{}
|
||||
sendHeartbeat chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *WorkerConfig
|
||||
handler slog.Handler
|
||||
logger *slog.Logger
|
||||
|
||||
processedCount *atomic.Uint64
|
||||
outgoingCount *atomic.Uint64
|
||||
restartCount *atomic.Uint64
|
||||
|
||||
config *WorkerConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
handler slog.Handler
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Constructor
|
||||
// -------------------------/
|
||||
|
||||
func NewWorker(
|
||||
ctx context.Context,
|
||||
id string,
|
||||
@@ -77,21 +93,28 @@ func NewWorker(
|
||||
|
||||
wctx, wcancel := context.WithCancel(ctx)
|
||||
w := &DefaultWorker{
|
||||
id: id,
|
||||
config: config,
|
||||
heartbeat: make(chan struct{}),
|
||||
id: id,
|
||||
|
||||
sendHeartbeat: make(chan struct{}),
|
||||
|
||||
ctx: wctx,
|
||||
cancel: wcancel,
|
||||
config: config,
|
||||
handler: handler,
|
||||
logger: logger,
|
||||
|
||||
processedCount: &atomic.Uint64{},
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
restartCount: &atomic.Uint64{},
|
||||
ctx: wctx,
|
||||
cancel: wcancel,
|
||||
handler: handler,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Session
|
||||
// -------------------------/
|
||||
|
||||
func (w *DefaultWorker) Start(pool PoolPlugin) {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("starting")
|
||||
@@ -114,71 +137,19 @@ func (w *DefaultWorker) Start(pool PoolPlugin) {
|
||||
}
|
||||
|
||||
func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
||||
newConn := make(chan *transport.Connection, 1)
|
||||
|
||||
var timer *time.Timer
|
||||
if w.config.KeepaliveTimeout > 0 {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout)
|
||||
}
|
||||
timer = time.NewTimer(w.config.KeepaliveTimeout)
|
||||
defer timer.Stop()
|
||||
} else {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("keepalive: disabled")
|
||||
}
|
||||
}
|
||||
|
||||
resetTimer := func() {
|
||||
if timer == nil {
|
||||
return
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(w.config.KeepaliveTimeout)
|
||||
}
|
||||
|
||||
timerC := func() <-chan time.Time {
|
||||
if timer == nil {
|
||||
return nil
|
||||
}
|
||||
return timer.C
|
||||
}
|
||||
|
||||
// setup dialer
|
||||
var dialCancel context.CancelFunc
|
||||
newConn := make(chan *transport.Connection, 1)
|
||||
spawnDialer := func() { dialCancel = w.spawnDialer(ctx, dialCancel, newConn, pool) }
|
||||
|
||||
spawnDial := func() {
|
||||
if dialCancel != nil {
|
||||
dialCancel()
|
||||
}
|
||||
var dialCtx context.Context
|
||||
dialCtx, dialCancel = context.WithCancel(ctx)
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("session: requesting connection")
|
||||
}
|
||||
go func() {
|
||||
conn, err := connect(w.id, dialCtx, pool, w.handler)
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Warn("dialer: dial failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case newConn <- conn:
|
||||
case <-dialCtx.Done():
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
// setup heartbeat
|
||||
timer, timerC, heartbeat := w.setupHeartbeat()
|
||||
defer timer.Stop()
|
||||
|
||||
// main loop
|
||||
for {
|
||||
// spawn initial dial for this reconnect cycle
|
||||
spawnDial()
|
||||
spawnDialer()
|
||||
|
||||
// obtain new connection
|
||||
var conn *transport.Connection
|
||||
@@ -190,23 +161,26 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
||||
dialCancel()
|
||||
}
|
||||
return
|
||||
case <-w.heartbeat:
|
||||
resetTimer()
|
||||
case <-timerC():
|
||||
if w.logger != nil {
|
||||
w.logger.Info("keepalive: no activity observed")
|
||||
}
|
||||
timer.Reset(w.config.KeepaliveTimeout)
|
||||
spawnDial()
|
||||
|
||||
case conn = <-newConn:
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("session: connected")
|
||||
}
|
||||
break preConn
|
||||
|
||||
case <-w.sendHeartbeat:
|
||||
heartbeat()
|
||||
|
||||
case <-timerC():
|
||||
if w.logger != nil {
|
||||
w.logger.Info("keepalive: no activity observed")
|
||||
}
|
||||
timer.Reset(w.config.KeepaliveTimeout)
|
||||
spawnDialer()
|
||||
}
|
||||
}
|
||||
|
||||
// set up new connection
|
||||
// setup new connection
|
||||
w.conn.Store(conn)
|
||||
pool.Events <- PoolEvent{ID: w.id, Kind: EventConnected, At: time.Now()}
|
||||
|
||||
@@ -220,14 +194,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break conn_loop
|
||||
case <-w.heartbeat:
|
||||
resetTimer()
|
||||
case <-timerC():
|
||||
if w.logger != nil {
|
||||
w.logger.Info("keepalive: no activity observed")
|
||||
}
|
||||
timer.Reset(w.config.KeepaliveTimeout)
|
||||
break conn_loop
|
||||
|
||||
case data, ok := <-conn.Incoming():
|
||||
if !ok {
|
||||
if w.logger != nil {
|
||||
@@ -235,20 +202,34 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
||||
}
|
||||
break conn_loop
|
||||
}
|
||||
|
||||
pool.Inbox <- types.InboxMessage{
|
||||
ID: w.id,
|
||||
Data: data,
|
||||
ReceivedAt: time.Now(),
|
||||
}
|
||||
resetTimer()
|
||||
ID: w.id, Data: data, ReceivedAt: time.Now()}
|
||||
|
||||
pool.InboxCounter.Add(1)
|
||||
w.processedCount.Add(1)
|
||||
|
||||
heartbeat()
|
||||
|
||||
case <-conn.Heartbeat():
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("ping-pong heartbeat")
|
||||
}
|
||||
resetTimer()
|
||||
heartbeat()
|
||||
|
||||
case <-w.sendHeartbeat:
|
||||
heartbeat()
|
||||
|
||||
case <-timerC():
|
||||
if w.logger != nil {
|
||||
w.logger.Info("keepalive: no activity observed")
|
||||
}
|
||||
timer.Reset(w.config.KeepaliveTimeout)
|
||||
break conn_loop
|
||||
}
|
||||
}
|
||||
|
||||
// session ended
|
||||
conn.Close()
|
||||
|
||||
if w.logger != nil {
|
||||
@@ -272,6 +253,98 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *DefaultWorker) setupHeartbeat() (
|
||||
timer *time.Timer, timerC func() <-chan time.Time, heartbeat func(),
|
||||
) {
|
||||
if w.config.KeepaliveTimeout > 0 {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout)
|
||||
}
|
||||
timer = time.NewTimer(w.config.KeepaliveTimeout)
|
||||
} else {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("keepalive: disabled")
|
||||
}
|
||||
}
|
||||
|
||||
heartbeat = func() {
|
||||
if timer == nil {
|
||||
return
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(w.config.KeepaliveTimeout)
|
||||
}
|
||||
|
||||
timerC = func() <-chan time.Time {
|
||||
if timer == nil {
|
||||
return nil
|
||||
}
|
||||
return timer.C
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (w *DefaultWorker) spawnDialer(
|
||||
ctx context.Context,
|
||||
dialCancel context.CancelFunc,
|
||||
newConn chan<- *transport.Connection,
|
||||
pool PoolPlugin,
|
||||
) context.CancelFunc {
|
||||
if dialCancel != nil {
|
||||
dialCancel()
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithCancel(ctx)
|
||||
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("session: requesting connection")
|
||||
}
|
||||
|
||||
go func() {
|
||||
conn, err := connect(w.id, dialCtx, pool, w.handler)
|
||||
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Warn("dialer: dial failed", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case newConn <- conn:
|
||||
case <-dialCtx.Done():
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
return dialCancel
|
||||
}
|
||||
|
||||
func connect(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
pool PoolPlugin,
|
||||
handler slog.Handler,
|
||||
) (*transport.Connection, error) {
|
||||
conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.SetDialer(pool.Dialer)
|
||||
return conn, conn.Connect(ctx)
|
||||
}
|
||||
|
||||
// ---------------------------/
|
||||
// Methods
|
||||
// -------------------------/
|
||||
|
||||
func (w *DefaultWorker) Stop() {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("shutting down")
|
||||
@@ -291,7 +364,7 @@ func (w *DefaultWorker) Send(data []byte) error {
|
||||
}
|
||||
|
||||
select {
|
||||
case w.heartbeat <- struct{}{}:
|
||||
case w.sendHeartbeat <- struct{}{}:
|
||||
case <-w.ctx.Done():
|
||||
}
|
||||
|
||||
@@ -324,18 +397,3 @@ func (w *DefaultWorker) Stats() WorkerStats {
|
||||
TotalSent: w.outgoingCount.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func connect(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
pool PoolPlugin,
|
||||
handler slog.Handler,
|
||||
) (*transport.Connection, error) {
|
||||
conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.SetDialer(pool.Dialer)
|
||||
return conn, conn.Connect(ctx)
|
||||
}
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
package honeybee
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWorkerSend(t *testing.T) {
|
||||
t.Run("data sent to mock socket", func(t *testing.T) {
|
||||
conn, _, _, outgoingData := setupTestConnection(t)
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
heartbeat := make(chan struct{})
|
||||
heartbeatCount := atomic.Int32{}
|
||||
|
||||
w := &DefaultWorker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
heartbeat: heartbeat,
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
}
|
||||
w.conn.Store(conn)
|
||||
defer w.cancel()
|
||||
|
||||
go func() {
|
||||
for range heartbeat {
|
||||
heartbeatCount.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
testData := []byte("hello")
|
||||
err := w.Send(testData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// at least one heartbeat was sent
|
||||
honeybeetest.Eventually(t, func() bool {
|
||||
return heartbeatCount.Load() >= 1
|
||||
}, "expected heartbeats")
|
||||
|
||||
// message was sent by the socket
|
||||
honeybeetest.Eventually(t, func() bool {
|
||||
select {
|
||||
case msg := <-outgoingData:
|
||||
return string(msg.Data) == "hello"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, "expected message")
|
||||
})
|
||||
|
||||
t.Run("sends one heartbeat per successful send", func(t *testing.T) {
|
||||
conn, _, _, _ := setupTestConnection(t)
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
heartbeat := make(chan struct{})
|
||||
heartbeatCount := atomic.Int32{}
|
||||
|
||||
w := &DefaultWorker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
heartbeat: heartbeat,
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
}
|
||||
w.conn.Store(conn)
|
||||
defer w.cancel()
|
||||
|
||||
go func() {
|
||||
for range heartbeat {
|
||||
heartbeatCount.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
const count = 3
|
||||
for i := range count {
|
||||
err := w.Send(fmt.Appendf(nil, "msg-%d", i))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
honeybeetest.Eventually(t, func() bool {
|
||||
return heartbeatCount.Load() == count
|
||||
}, "expected heartbeats")
|
||||
})
|
||||
|
||||
t.Run("returns error if connection is unavailable", func(t *testing.T) {
|
||||
// no connection available to worker
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
heartbeat := make(chan struct{})
|
||||
|
||||
w := &DefaultWorker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
heartbeat: heartbeat,
|
||||
}
|
||||
defer w.cancel()
|
||||
|
||||
go func() {
|
||||
for range heartbeat {
|
||||
}
|
||||
}()
|
||||
|
||||
err := w.Send([]byte("hello"))
|
||||
assert.ErrorIs(t, err, ErrConnectionUnavailable)
|
||||
})
|
||||
}
|
||||
+113
-5
@@ -3,6 +3,7 @@ package honeybee
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
@@ -41,7 +42,7 @@ func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *D
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
config: config,
|
||||
heartbeat: make(chan struct{}),
|
||||
sendHeartbeat: make(chan struct{}),
|
||||
processedCount: &atomic.Uint64{},
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
restartCount: &atomic.Uint64{},
|
||||
@@ -134,7 +135,7 @@ func TestWorkerSession(t *testing.T) {
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
config: config,
|
||||
heartbeat: make(chan struct{}),
|
||||
sendHeartbeat: make(chan struct{}),
|
||||
processedCount: &atomic.Uint64{},
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
restartCount: &atomic.Uint64{},
|
||||
@@ -303,7 +304,7 @@ func TestWorkerSession(t *testing.T) {
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
config: config,
|
||||
heartbeat: make(chan struct{}),
|
||||
sendHeartbeat: make(chan struct{}),
|
||||
processedCount: &atomic.Uint64{},
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
restartCount: &atomic.Uint64{},
|
||||
@@ -365,7 +366,7 @@ func TestWorkerSession(t *testing.T) {
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
config: config,
|
||||
heartbeat: make(chan struct{}),
|
||||
sendHeartbeat: make(chan struct{}),
|
||||
processedCount: &atomic.Uint64{},
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
restartCount: &atomic.Uint64{},
|
||||
@@ -431,7 +432,7 @@ func TestWorkerSession(t *testing.T) {
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
config: config,
|
||||
heartbeat: make(chan struct{}),
|
||||
sendHeartbeat: make(chan struct{}),
|
||||
processedCount: &atomic.Uint64{},
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
restartCount: &atomic.Uint64{},
|
||||
@@ -638,3 +639,110 @@ func TestWorkerSession(t *testing.T) {
|
||||
}, "expected wg to drain after parent cancel")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerSend(t *testing.T) {
|
||||
t.Run("data sent to mock socket", func(t *testing.T) {
|
||||
conn, _, _, outgoingData := setupTestConnection(t)
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
heartbeat := make(chan struct{})
|
||||
heartbeatCount := atomic.Int32{}
|
||||
|
||||
w := &DefaultWorker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
sendHeartbeat: heartbeat,
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
}
|
||||
w.conn.Store(conn)
|
||||
defer w.cancel()
|
||||
|
||||
go func() {
|
||||
for range heartbeat {
|
||||
heartbeatCount.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
testData := []byte("hello")
|
||||
err := w.Send(testData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// at least one heartbeat was sent
|
||||
honeybeetest.Eventually(t, func() bool {
|
||||
return heartbeatCount.Load() >= 1
|
||||
}, "expected heartbeats")
|
||||
|
||||
// message was sent by the socket
|
||||
honeybeetest.Eventually(t, func() bool {
|
||||
select {
|
||||
case msg := <-outgoingData:
|
||||
return string(msg.Data) == "hello"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, "expected message")
|
||||
})
|
||||
|
||||
t.Run("sends one heartbeat per successful send", func(t *testing.T) {
|
||||
conn, _, _, _ := setupTestConnection(t)
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
heartbeat := make(chan struct{})
|
||||
heartbeatCount := atomic.Int32{}
|
||||
|
||||
w := &DefaultWorker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
sendHeartbeat: heartbeat,
|
||||
outgoingCount: &atomic.Uint64{},
|
||||
}
|
||||
w.conn.Store(conn)
|
||||
defer w.cancel()
|
||||
|
||||
go func() {
|
||||
for range heartbeat {
|
||||
heartbeatCount.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
const count = 3
|
||||
for i := range count {
|
||||
err := w.Send(fmt.Appendf(nil, "msg-%d", i))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
honeybeetest.Eventually(t, func() bool {
|
||||
return heartbeatCount.Load() == count
|
||||
}, "expected heartbeats")
|
||||
})
|
||||
|
||||
t.Run("returns error if connection is unavailable", func(t *testing.T) {
|
||||
// no connection available to worker
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
heartbeat := make(chan struct{})
|
||||
|
||||
w := &DefaultWorker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
id: "wss://test",
|
||||
sendHeartbeat: heartbeat,
|
||||
}
|
||||
defer w.cancel()
|
||||
|
||||
go func() {
|
||||
for range heartbeat {
|
||||
}
|
||||
}()
|
||||
|
||||
err := w.Send([]byte("hello"))
|
||||
assert.ErrorIs(t, err, ErrConnectionUnavailable)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user