cleanup and refactors
This commit is contained in:
@@ -1,21 +1,15 @@
|
|||||||
package honeybee
|
package honeybee
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||||
"log/slog"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Types
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type WorkerFactory func(
|
|
||||||
ctx context.Context,
|
|
||||||
id string,
|
|
||||||
handler slog.Handler,
|
|
||||||
) (Worker, error)
|
|
||||||
|
|
||||||
// Pool Config
|
// Pool Config
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Types
|
||||||
|
|
||||||
type PoolConfig struct {
|
type PoolConfig struct {
|
||||||
InboxBufferSize int
|
InboxBufferSize int
|
||||||
@@ -27,6 +21,8 @@ type PoolConfig struct {
|
|||||||
|
|
||||||
type PoolOption func(*PoolConfig) error
|
type PoolOption func(*PoolConfig) error
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
|
||||||
func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) {
|
func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) {
|
||||||
conf := GetDefaultPoolConfig()
|
conf := GetDefaultPoolConfig()
|
||||||
if err := applyPoolOptions(conf, options...); err != nil {
|
if err := applyPoolOptions(conf, options...); err != nil {
|
||||||
@@ -57,6 +53,8 @@ func applyPoolOptions(config *PoolConfig, options ...PoolOption) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validation
|
||||||
|
|
||||||
func ValidatePoolConfig(config *PoolConfig) error {
|
func ValidatePoolConfig(config *PoolConfig) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -84,6 +82,8 @@ func validateBufferSize(value int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Options
|
||||||
|
|
||||||
func WithInboxBufferSize(value int) PoolOption {
|
func WithInboxBufferSize(value int) PoolOption {
|
||||||
return func(c *PoolConfig) error {
|
return func(c *PoolConfig) error {
|
||||||
if err := validateBufferSize(value); err != nil {
|
if err := validateBufferSize(value); err != nil {
|
||||||
@@ -133,7 +133,11 @@ func WithWorkerFactory(wf WorkerFactory) PoolOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Worker Config
|
// Worker Config
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Types
|
||||||
|
|
||||||
type WorkerConfig struct {
|
type WorkerConfig struct {
|
||||||
KeepaliveTimeout time.Duration
|
KeepaliveTimeout time.Duration
|
||||||
@@ -142,6 +146,8 @@ type WorkerConfig struct {
|
|||||||
|
|
||||||
type WorkerOption func(*WorkerConfig) error
|
type WorkerOption func(*WorkerConfig) error
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
|
||||||
func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) {
|
func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) {
|
||||||
conf := GetDefaultWorkerConfig()
|
conf := GetDefaultWorkerConfig()
|
||||||
if err := applyWorkerOptions(conf, options...); err != nil {
|
if err := applyWorkerOptions(conf, options...); err != nil {
|
||||||
@@ -169,6 +175,8 @@ func applyWorkerOptions(config *WorkerConfig, options ...WorkerOption) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validation
|
||||||
|
|
||||||
func ValidateWorkerConfig(config *WorkerConfig) error {
|
func ValidateWorkerConfig(config *WorkerConfig) error {
|
||||||
err := validateKeepaliveTimeout(config.KeepaliveTimeout)
|
err := validateKeepaliveTimeout(config.KeepaliveTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -192,6 +200,8 @@ func validateReconnectDelay(value time.Duration) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Options
|
||||||
|
|
||||||
// When KeepaliveTimeout is set to zero, keepalive timeouts are disabled.
|
// When KeepaliveTimeout is set to zero, keepalive timeouts are disabled.
|
||||||
func WithKeepaliveTimeout(value time.Duration) WorkerOption {
|
func WithKeepaliveTimeout(value time.Duration) WorkerOption {
|
||||||
return func(c *WorkerConfig) error {
|
return func(c *WorkerConfig) error {
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Constants
|
// Constants
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TestTimeout = 2 * time.Second
|
TestTimeout = 2 * time.Second
|
||||||
@@ -18,7 +20,9 @@ const (
|
|||||||
NegativeTestTimeout = 100 * time.Millisecond
|
NegativeTestTimeout = 100 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Types
|
// Types
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type MockIncomingData struct {
|
type MockIncomingData struct {
|
||||||
MsgType int
|
MsgType int
|
||||||
@@ -37,7 +41,9 @@ type ExpectedLog struct {
|
|||||||
Attrs map[string]any
|
Attrs map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Setup
|
// Setup
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
func SetupTestSocket(t *testing.T) (
|
func SetupTestSocket(t *testing.T) (
|
||||||
socket *MockSocket,
|
socket *MockSocket,
|
||||||
@@ -81,7 +87,9 @@ func SetupTestSocket(t *testing.T) (
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Helpers
|
// Helpers
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) {
|
func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
@@ -126,7 +134,9 @@ func Never(t *testing.T, condition func() bool, msg string) {
|
|||||||
assert.Never(t, condition, NegativeTestTimeout, TestTick, msg)
|
assert.Never(t, condition, NegativeTestTimeout, TestTick, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Logging Helpers
|
// Logging Helpers
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
func AssertLogSequence(t *testing.T, records []slog.Record, expected []ExpectedLog) {
|
func AssertLogSequence(t *testing.T, records []slog.Record, expected []ExpectedLog) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Re-exported types for consumer convenience
|
// ----------------------------------------------------------------------------
|
||||||
|
// Re-exports
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type Socket = types.Socket
|
type Socket = types.Socket
|
||||||
type Dialer = types.Dialer
|
type Dialer = types.Dialer
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Dialer Mocks
|
// Dialer Mocks
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type MockDialer struct {
|
type MockDialer struct {
|
||||||
DialContextFunc func(
|
DialContextFunc func(
|
||||||
@@ -28,7 +32,9 @@ func (m *MockDialer) DialContext(
|
|||||||
return m.DialContextFunc(ctx, url, h)
|
return m.DialContextFunc(ctx, url, h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Socket Mocks
|
// Socket Mocks
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type MockSocket struct {
|
type MockSocket struct {
|
||||||
WriteMessageFunc func(int, []byte) error
|
WriteMessageFunc func(int, []byte) error
|
||||||
@@ -93,7 +99,9 @@ func (m *MockSocket) SetPongHandler(h func(s string) error) {
|
|||||||
m.SetPongHandlerFunc(h)
|
m.SetPongHandlerFunc(h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Logging mocks
|
// Logging mocks
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type MockSlogHandler struct {
|
type MockSlogHandler struct {
|
||||||
records *[]slog.Record
|
records *[]slog.Record
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||||
component "git.wisehodl.dev/jay/go-mana-component"
|
"git.wisehodl.dev/jay/go-mana-component"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -19,7 +19,9 @@ type Dialer = types.Dialer
|
|||||||
|
|
||||||
var NormalizeURL = transport.NormalizeURL
|
var NormalizeURL = transport.NormalizeURL
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Types
|
// Types
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type PoolEventKind string
|
type PoolEventKind string
|
||||||
|
|
||||||
@@ -58,7 +60,9 @@ type PoolPlugin struct {
|
|||||||
ConnectionConfig *transport.ConnectionConfig
|
ConnectionConfig *transport.ConnectionConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Pool
|
// Pool
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
id string
|
id string
|
||||||
@@ -66,24 +70,23 @@ type Peer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Pool struct {
|
type Pool struct {
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
peers map[string]*Peer
|
peers map[string]*Peer
|
||||||
inbox chan types.InboxMessage
|
inbox chan types.InboxMessage
|
||||||
events chan PoolEvent
|
events chan PoolEvent
|
||||||
|
closed bool
|
||||||
inboxCounter *atomic.Uint64
|
|
||||||
outgoingCount *atomic.Uint64
|
|
||||||
|
|
||||||
dialer types.Dialer
|
dialer types.Dialer
|
||||||
config *PoolConfig
|
config *PoolConfig
|
||||||
handler slog.Handler
|
handler slog.Handler
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
closed bool
|
|
||||||
|
inboxCounter *atomic.Uint64
|
||||||
|
outgoingCount *atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
|
func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
|
||||||
@@ -106,26 +109,29 @@ func NewPool(ctx context.Context, config *PoolConfig, handler slog.Handler,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool"))
|
ctx, cancel := context.WithCancel(component.MustNew(ctx, "honeybee", "pool"))
|
||||||
|
|
||||||
var logger *slog.Logger
|
var logger *slog.Logger
|
||||||
if handler != nil {
|
if handler != nil {
|
||||||
c := component.FromContext(pctx)
|
c := component.FromContext(ctx)
|
||||||
logger = slog.New(handler).With(slog.Any("component", c))
|
logger = slog.New(handler).With(slog.Any("component", c))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Pool{
|
return &Pool{
|
||||||
ctx: pctx,
|
|
||||||
cancel: cancel,
|
|
||||||
peers: make(map[string]*Peer),
|
peers: make(map[string]*Peer),
|
||||||
inbox: make(chan types.InboxMessage, config.InboxBufferSize),
|
inbox: make(chan types.InboxMessage, config.InboxBufferSize),
|
||||||
events: make(chan PoolEvent, config.EventsBufferSize),
|
events: make(chan PoolEvent, config.EventsBufferSize),
|
||||||
inboxCounter: &atomic.Uint64{},
|
|
||||||
outgoingCount: &atomic.Uint64{},
|
|
||||||
dialer: transport.NewDialer(),
|
dialer: transport.NewDialer(),
|
||||||
config: config,
|
config: config,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
|
||||||
|
inboxCounter: &atomic.Uint64{},
|
||||||
|
outgoingCount: &atomic.Uint64{},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Connection Config
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Types
|
||||||
|
|
||||||
type CloseHandler func(code int, text string) error
|
type CloseHandler func(code int, text string) error
|
||||||
|
|
||||||
type ConnectionConfig struct {
|
type ConnectionConfig struct {
|
||||||
@@ -26,6 +32,8 @@ type RetryConfig struct {
|
|||||||
|
|
||||||
type ConnectionOption func(*ConnectionConfig) error
|
type ConnectionOption func(*ConnectionConfig) error
|
||||||
|
|
||||||
|
// Constructors
|
||||||
|
|
||||||
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
|
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
|
||||||
conf := GetDefaultConnectionConfig()
|
conf := GetDefaultConnectionConfig()
|
||||||
if err := applyConnectionOptions(conf, options...); err != nil {
|
if err := applyConnectionOptions(conf, options...); err != nil {
|
||||||
@@ -69,6 +77,8 @@ func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOptio
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validation
|
||||||
|
|
||||||
func ValidateConnectionConfig(config *ConnectionConfig) error {
|
func ValidateConnectionConfig(config *ConnectionConfig) error {
|
||||||
err := validateWriteTimeout(config.WriteTimeout)
|
err := validateWriteTimeout(config.WriteTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -153,6 +163,8 @@ func validateJitterFactor(value float64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Options
|
||||||
|
|
||||||
func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
||||||
return func(c *ConnectionConfig) error {
|
return func(c *ConnectionConfig) error {
|
||||||
c.CloseHandler = handler
|
c.CloseHandler = handler
|
||||||
|
|||||||
+263
-230
@@ -12,10 +12,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||||
component "git.wisehodl.dev/jay/go-mana-component"
|
"git.wisehodl.dev/jay/go-mana-component"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Types
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type ConnectionState int
|
type ConnectionState int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -49,6 +53,14 @@ type ConnectionStats struct {
|
|||||||
TotalHeartbeats uint64
|
TotalHeartbeats uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Connection
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Constructors
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
type Connection struct {
|
type Connection struct {
|
||||||
url *url.URL
|
url *url.URL
|
||||||
dialer types.Dialer
|
dialer types.Dialer
|
||||||
@@ -95,18 +107,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
|
|||||||
ctx = component.MustExtend(ctx, "connection")
|
ctx = component.MustExtend(ctx, "connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
var logger *slog.Logger
|
|
||||||
if handler != nil {
|
|
||||||
c := component.FromContext(ctx)
|
|
||||||
logger = slog.New(handler).With(slog.Any("component", c))
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := &Connection{
|
conn := &Connection{
|
||||||
url: url,
|
url: url,
|
||||||
dialer: NewDialer(),
|
dialer: NewDialer(),
|
||||||
socket: nil,
|
socket: nil,
|
||||||
config: config,
|
config: config,
|
||||||
logger: logger,
|
|
||||||
incoming: make(chan []byte, config.IncomingBufferSize),
|
incoming: make(chan []byte, config.IncomingBufferSize),
|
||||||
heartbeat: make(chan struct{}, 1),
|
heartbeat: make(chan struct{}, 1),
|
||||||
errors: make(chan error, config.ErrorsBufferSize),
|
errors: make(chan error, config.ErrorsBufferSize),
|
||||||
@@ -117,6 +122,11 @@ func NewConnection(ctx context.Context, urlStr string, config *ConnectionConfig,
|
|||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if handler != nil {
|
||||||
|
comp := component.FromContext(ctx)
|
||||||
|
conn.logger = slog.New(handler).With(slog.Any("component", comp))
|
||||||
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,18 +151,11 @@ func NewConnectionFromSocket(
|
|||||||
ctx = component.MustExtend(ctx, "connection")
|
ctx = component.MustExtend(ctx, "connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
var logger *slog.Logger
|
|
||||||
if handler != nil {
|
|
||||||
c := component.FromContext(ctx)
|
|
||||||
logger = slog.New(handler).With(slog.Any("component", c))
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := &Connection{
|
conn := &Connection{
|
||||||
url: nil,
|
url: nil,
|
||||||
dialer: nil,
|
dialer: nil,
|
||||||
socket: socket,
|
socket: socket,
|
||||||
config: config,
|
config: config,
|
||||||
logger: logger,
|
|
||||||
incoming: make(chan []byte, config.IncomingBufferSize),
|
incoming: make(chan []byte, config.IncomingBufferSize),
|
||||||
heartbeat: make(chan struct{}, 1),
|
heartbeat: make(chan struct{}, 1),
|
||||||
errors: make(chan error, config.ErrorsBufferSize),
|
errors: make(chan error, config.ErrorsBufferSize),
|
||||||
@@ -163,17 +166,31 @@ func NewConnectionFromSocket(
|
|||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if handler != nil {
|
||||||
|
comp := component.FromContext(ctx)
|
||||||
|
conn.logger = slog.New(handler).With(slog.Any("component", comp))
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize
|
||||||
if config.CloseHandler != nil {
|
if config.CloseHandler != nil {
|
||||||
socket.SetCloseHandler(config.CloseHandler)
|
socket.SetCloseHandler(config.CloseHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.setupPongHandler()
|
conn.setupPongHandler()
|
||||||
conn.startPinger()
|
|
||||||
conn.startReader()
|
if conn.config.PingInterval > 0 {
|
||||||
|
conn.wg.Go(conn.startPinger)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.wg.Go(conn.startReader)
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Methods
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
func (c *Connection) Connect(ctx context.Context) error {
|
func (c *Connection) Connect(ctx context.Context) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
@@ -186,17 +203,20 @@ func (c *Connection) Connect(ctx context.Context) error {
|
|||||||
return NewConnectionError(ErrConnectionClosed)
|
return NewConnectionError(ErrConnectionClosed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// begin connecting
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
c.logger.Debug("connecting")
|
c.logger.Debug("connecting")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.state = StateConnecting
|
c.state = StateConnecting
|
||||||
|
|
||||||
|
// obtain socket
|
||||||
retryMgr := NewRetryManager(c.config.Retry)
|
retryMgr := NewRetryManager(c.config.Retry)
|
||||||
socket, _, err := AcquireSocket(
|
socket, _, err := AcquireSocket(
|
||||||
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
|
ctx, retryMgr, c.dialer, c.url.String(), c.config.RequestHeader, c.logger)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// socket acquisition failed
|
||||||
c.state = StateDisconnected
|
c.state = StateDisconnected
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
c.logger.Error("connection failed", "error", err)
|
c.logger.Error("connection failed", "error", err)
|
||||||
@@ -204,231 +224,32 @@ func (c *Connection) Connect(ctx context.Context) error {
|
|||||||
return NewConnectionError(err)
|
return NewConnectionError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// got socket
|
||||||
c.socket = socket
|
c.socket = socket
|
||||||
c.state = StateConnected
|
|
||||||
|
|
||||||
|
// initialize
|
||||||
if c.config.CloseHandler != nil {
|
if c.config.CloseHandler != nil {
|
||||||
c.socket.SetCloseHandler(c.config.CloseHandler)
|
c.socket.SetCloseHandler(c.config.CloseHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.setupPongHandler()
|
||||||
|
|
||||||
|
if c.config.PingInterval > 0 {
|
||||||
|
c.wg.Go(c.startPinger)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.wg.Go(c.startReader)
|
||||||
|
|
||||||
|
// connected
|
||||||
|
c.state = StateConnected
|
||||||
|
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
c.logger.Info("connected")
|
c.logger.Info("connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.setupPongHandler()
|
|
||||||
c.startPinger()
|
|
||||||
c.startReader()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) Close() {
|
|
||||||
c.shutdownExternal()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownExternal() {
|
|
||||||
err := c.shutdownSetClosed(true)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.shutdownInner()
|
|
||||||
c.shutdownCleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownInternal() {
|
|
||||||
err := c.shutdownSetClosed(false)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.shutdownInner()
|
|
||||||
|
|
||||||
// defer final cleanup to allow this function to return
|
|
||||||
// otherwise, a deadlock occurs where startReader triggers a shutdown and
|
|
||||||
// must wait for itself to exit.
|
|
||||||
go func() {
|
|
||||||
c.shutdownCleanup()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownInner() {
|
|
||||||
c.shutdownSignalDone()
|
|
||||||
c.shutdownLogStart()
|
|
||||||
c.shutdownCloseSocket()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownCleanup() {
|
|
||||||
c.cleanupOnce.Do(func() {
|
|
||||||
c.wg.Wait()
|
|
||||||
c.shutdownCloseChannels()
|
|
||||||
c.shutdownLogComplete()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownSetClosed(wait bool) error {
|
|
||||||
c.mu.Lock()
|
|
||||||
if c.closed {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return NewConnectionError(ErrConnectionClosed)
|
|
||||||
}
|
|
||||||
c.closed = true
|
|
||||||
c.state = StateClosed
|
|
||||||
c.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownSignalDone() {
|
|
||||||
c.doneOnce.Do(func() {
|
|
||||||
close(c.done)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownLogStart() {
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Info("closing")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownCloseSocket() {
|
|
||||||
if c.socket != nil {
|
|
||||||
// force unblock of any network operations immediately
|
|
||||||
expired := time.Now().Add(-1 * time.Minute)
|
|
||||||
c.socket.SetReadDeadline(expired)
|
|
||||||
c.socket.SetWriteDeadline(expired)
|
|
||||||
|
|
||||||
// close socket
|
|
||||||
err := c.socket.Close()
|
|
||||||
|
|
||||||
if err != nil && c.logger != nil {
|
|
||||||
c.logger.Error("socket close failed", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownCloseChannels() {
|
|
||||||
close(c.incoming)
|
|
||||||
close(c.errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) shutdownLogComplete() {
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Info("closed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) startReader() {
|
|
||||||
c.wg.Go(func() {
|
|
||||||
defer c.shutdownInternal()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
messageType, data, err := c.socket.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
var wrappedErr error
|
|
||||||
var closeErr *websocket.CloseError
|
|
||||||
if errors.As(err, &closeErr) {
|
|
||||||
switch closeErr.Code {
|
|
||||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Info("connection closed by peer",
|
|
||||||
"code", closeErr.Code,
|
|
||||||
"text", closeErr.Text,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
|
|
||||||
default:
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Error("unexpected close",
|
|
||||||
"code", closeErr.Code,
|
|
||||||
"text", closeErr.Text,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
isLocalClose := false
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
isLocalClose = true
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if c.logger != nil {
|
|
||||||
if isLocalClose {
|
|
||||||
c.logger.Debug("read loop terminated", "error", err)
|
|
||||||
} else {
|
|
||||||
c.logger.Error("read error", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
wrappedErr = fmt.Errorf("%w: %w", ErrReadError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
case c.errors <- wrappedErr:
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if messageType == websocket.TextMessage ||
|
|
||||||
messageType == websocket.BinaryMessage {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return
|
|
||||||
case c.incoming <- data:
|
|
||||||
c.incomingCount.Add(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) setupPongHandler() {
|
|
||||||
c.socket.SetPongHandler(func(appData string) error {
|
|
||||||
select {
|
|
||||||
case c.heartbeat <- struct{}{}:
|
|
||||||
c.heartbeatCount.Add(1)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) startPinger() {
|
|
||||||
if c.config.PingInterval <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.wg.Go(func() {
|
|
||||||
defer c.shutdownInternal()
|
|
||||||
|
|
||||||
// Calculate 10% jitter window
|
|
||||||
jitter := c.config.PingInterval / 10
|
|
||||||
|
|
||||||
for {
|
|
||||||
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
|
|
||||||
next := c.config.PingInterval + offset
|
|
||||||
timer := time.NewTimer(next)
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
timer.Stop()
|
|
||||||
return
|
|
||||||
case <-timer.C:
|
|
||||||
deadline := time.Now().Add(c.config.WriteTimeout)
|
|
||||||
if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) Send(data []byte) error {
|
func (c *Connection) Send(data []byte) error {
|
||||||
c.writeMu.Lock()
|
c.writeMu.Lock()
|
||||||
defer c.writeMu.Unlock()
|
defer c.writeMu.Unlock()
|
||||||
@@ -437,6 +258,7 @@ func (c *Connection) Send(data []byte) error {
|
|||||||
return NewConnectionError(ErrConnectionClosed)
|
return NewConnectionError(ErrConnectionClosed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setup
|
||||||
if c.config.WriteTimeout > 0 {
|
if c.config.WriteTimeout > 0 {
|
||||||
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
@@ -446,7 +268,10 @@ func (c *Connection) Send(data []byte) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
|
// send
|
||||||
|
err := c.socket.WriteMessage(websocket.TextMessage, data)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
c.logger.Error("write error", "error", err)
|
c.logger.Error("write error", "error", err)
|
||||||
}
|
}
|
||||||
@@ -489,3 +314,211 @@ func (c *Connection) Stats() ConnectionStats {
|
|||||||
func (c *Connection) SetDialer(d types.Dialer) {
|
func (c *Connection) SetDialer(d types.Dialer) {
|
||||||
c.dialer = d
|
c.dialer = d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Reader loop
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
|
func (c *Connection) startReader() {
|
||||||
|
defer c.shutdownInternal()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
messageType, data, err := c.socket.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
case c.errors <- c.classifyCloseError(err):
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if messageType == websocket.TextMessage ||
|
||||||
|
messageType == websocket.BinaryMessage {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case c.incoming <- data:
|
||||||
|
c.incomingCount.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) classifyCloseError(err error) error {
|
||||||
|
var classifiedError error
|
||||||
|
var closeErr *websocket.CloseError
|
||||||
|
|
||||||
|
if errors.As(err, &closeErr) {
|
||||||
|
switch closeErr.Code {
|
||||||
|
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("connection closed by peer",
|
||||||
|
"code", closeErr.Code,
|
||||||
|
"text", closeErr.Text,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedClean, err)
|
||||||
|
|
||||||
|
default:
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Error("unexpected close",
|
||||||
|
"code", closeErr.Code,
|
||||||
|
"text", closeErr.Text,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
classifiedError = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
isLocalClose := false
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
isLocalClose = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.logger != nil {
|
||||||
|
if isLocalClose {
|
||||||
|
c.logger.Debug("read loop terminated", "error", err)
|
||||||
|
} else {
|
||||||
|
c.logger.Error("read error", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
classifiedError = fmt.Errorf("%w: %w", ErrReadError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return classifiedError
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Heartbeat Handling
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
|
func (c *Connection) setupPongHandler() {
|
||||||
|
c.socket.SetPongHandler(func(appData string) error {
|
||||||
|
select {
|
||||||
|
case c.heartbeat <- struct{}{}:
|
||||||
|
c.heartbeatCount.Add(1)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) startPinger() {
|
||||||
|
defer c.shutdownInternal()
|
||||||
|
|
||||||
|
// Calculate 10% jitter window
|
||||||
|
jitter := c.config.PingInterval / 10
|
||||||
|
|
||||||
|
for {
|
||||||
|
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
|
||||||
|
next := c.config.PingInterval + offset
|
||||||
|
timer := time.NewTimer(next)
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
timer.Stop()
|
||||||
|
return
|
||||||
|
case <-timer.C:
|
||||||
|
deadline := time.Now().Add(c.config.WriteTimeout)
|
||||||
|
err := c.socket.WriteControl(websocket.PingMessage, nil, deadline)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Shutdown
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
|
func (c *Connection) Close() {
|
||||||
|
c.shutdownExternal()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownExternal() {
|
||||||
|
// set closed
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
// idempotent shutdown
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
c.state = StateClosed
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// perform shutdown
|
||||||
|
c.shutdownInner()
|
||||||
|
c.shutdownCleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdownInternal defers final cleanup to allow it to return.
|
||||||
|
// Otherwise, a deadlock occurs where startReader triggers a shutdown and
|
||||||
|
// must wait for itself to exit.
|
||||||
|
func (c *Connection) shutdownInternal() {
|
||||||
|
// set closed
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
// idempotent shutdown
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
c.state = StateClosed
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// perform shutdown
|
||||||
|
c.shutdownInner()
|
||||||
|
|
||||||
|
// defer cleanup to avoid deadlock
|
||||||
|
go func() {
|
||||||
|
c.shutdownCleanup()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownInner() {
|
||||||
|
c.doneOnce.Do(func() {
|
||||||
|
close(c.done)
|
||||||
|
})
|
||||||
|
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("closing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.socket != nil {
|
||||||
|
// force unblock of any network operations immediately
|
||||||
|
expired := time.Now().Add(-1 * time.Minute)
|
||||||
|
c.socket.SetReadDeadline(expired)
|
||||||
|
c.socket.SetWriteDeadline(expired)
|
||||||
|
|
||||||
|
// close socket
|
||||||
|
err := c.socket.Close()
|
||||||
|
|
||||||
|
if err != nil && c.logger != nil {
|
||||||
|
c.logger.Error("socket close failed", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) shutdownCleanup() {
|
||||||
|
c.cleanupOnce.Do(func() {
|
||||||
|
c.wg.Wait()
|
||||||
|
|
||||||
|
close(c.incoming)
|
||||||
|
close(c.errors)
|
||||||
|
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("closed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ func AcquireSocket(
|
|||||||
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
|
logger.Debug("dialing", "attempt", retryMgr.RetryCount()+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dial
|
||||||
socket, resp, err := dialer.DialContext(ctx, url, header)
|
socket, resp, err := dialer.DialContext(ctx, url, header)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
@@ -77,7 +78,9 @@ func AcquireSocket(
|
|||||||
return socket, resp, nil
|
return socket, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dial failed, retry
|
||||||
if !retryMgr.ShouldRetry() {
|
if !retryMgr.ShouldRetry() {
|
||||||
|
// retry policy expired
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
logger.Error("dial failed, max retries reached",
|
logger.Error("dial failed, max retries reached",
|
||||||
"error", err,
|
"error", err,
|
||||||
@@ -95,6 +98,7 @@ func AcquireSocket(
|
|||||||
"next_delay", delay)
|
"next_delay", delay)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// context cancellable backoff
|
||||||
select {
|
select {
|
||||||
case <-time.After(delay):
|
case <-time.After(delay):
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
@@ -9,10 +9,22 @@ import (
|
|||||||
|
|
||||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||||
component "git.wisehodl.dev/jay/go-mana-component"
|
"git.wisehodl.dev/jay/go-mana-component"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
// Worker
|
// Worker
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Types
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
|
type WorkerFactory func(
|
||||||
|
ctx context.Context,
|
||||||
|
id string,
|
||||||
|
handler slog.Handler,
|
||||||
|
) (Worker, error)
|
||||||
|
|
||||||
type Worker interface {
|
type Worker interface {
|
||||||
Start(pool PoolPlugin)
|
Start(pool PoolPlugin)
|
||||||
@@ -37,19 +49,23 @@ type DefaultWorker struct {
|
|||||||
id string
|
id string
|
||||||
conn atomic.Pointer[transport.Connection]
|
conn atomic.Pointer[transport.Connection]
|
||||||
|
|
||||||
heartbeat chan struct{}
|
sendHeartbeat chan struct{}
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
config *WorkerConfig
|
||||||
|
handler slog.Handler
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
processedCount *atomic.Uint64
|
processedCount *atomic.Uint64
|
||||||
outgoingCount *atomic.Uint64
|
outgoingCount *atomic.Uint64
|
||||||
restartCount *atomic.Uint64
|
restartCount *atomic.Uint64
|
||||||
|
|
||||||
config *WorkerConfig
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
handler slog.Handler
|
|
||||||
logger *slog.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Constructor
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
func NewWorker(
|
func NewWorker(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
id string,
|
id string,
|
||||||
@@ -78,20 +94,27 @@ func NewWorker(
|
|||||||
wctx, wcancel := context.WithCancel(ctx)
|
wctx, wcancel := context.WithCancel(ctx)
|
||||||
w := &DefaultWorker{
|
w := &DefaultWorker{
|
||||||
id: id,
|
id: id,
|
||||||
|
|
||||||
|
sendHeartbeat: make(chan struct{}),
|
||||||
|
|
||||||
|
ctx: wctx,
|
||||||
|
cancel: wcancel,
|
||||||
config: config,
|
config: config,
|
||||||
heartbeat: make(chan struct{}),
|
handler: handler,
|
||||||
|
logger: logger,
|
||||||
|
|
||||||
processedCount: &atomic.Uint64{},
|
processedCount: &atomic.Uint64{},
|
||||||
outgoingCount: &atomic.Uint64{},
|
outgoingCount: &atomic.Uint64{},
|
||||||
restartCount: &atomic.Uint64{},
|
restartCount: &atomic.Uint64{},
|
||||||
ctx: wctx,
|
|
||||||
cancel: wcancel,
|
|
||||||
handler: handler,
|
|
||||||
logger: logger,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return w, nil
|
return w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Session
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
func (w *DefaultWorker) Start(pool PoolPlugin) {
|
func (w *DefaultWorker) Start(pool PoolPlugin) {
|
||||||
if w.logger != nil {
|
if w.logger != nil {
|
||||||
w.logger.Debug("starting")
|
w.logger.Debug("starting")
|
||||||
@@ -114,71 +137,19 @@ func (w *DefaultWorker) Start(pool PoolPlugin) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
||||||
newConn := make(chan *transport.Connection, 1)
|
// setup dialer
|
||||||
|
|
||||||
var timer *time.Timer
|
|
||||||
if w.config.KeepaliveTimeout > 0 {
|
|
||||||
if w.logger != nil {
|
|
||||||
w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout)
|
|
||||||
}
|
|
||||||
timer = time.NewTimer(w.config.KeepaliveTimeout)
|
|
||||||
defer timer.Stop()
|
|
||||||
} else {
|
|
||||||
if w.logger != nil {
|
|
||||||
w.logger.Debug("keepalive: disabled")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resetTimer := func() {
|
|
||||||
if timer == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !timer.Stop() {
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
timer.Reset(w.config.KeepaliveTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
timerC := func() <-chan time.Time {
|
|
||||||
if timer == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return timer.C
|
|
||||||
}
|
|
||||||
|
|
||||||
var dialCancel context.CancelFunc
|
var dialCancel context.CancelFunc
|
||||||
|
newConn := make(chan *transport.Connection, 1)
|
||||||
|
spawnDialer := func() { dialCancel = w.spawnDialer(ctx, dialCancel, newConn, pool) }
|
||||||
|
|
||||||
spawnDial := func() {
|
// setup heartbeat
|
||||||
if dialCancel != nil {
|
timer, timerC, heartbeat := w.setupHeartbeat()
|
||||||
dialCancel()
|
defer timer.Stop()
|
||||||
}
|
|
||||||
var dialCtx context.Context
|
|
||||||
dialCtx, dialCancel = context.WithCancel(ctx)
|
|
||||||
if w.logger != nil {
|
|
||||||
w.logger.Debug("session: requesting connection")
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
conn, err := connect(w.id, dialCtx, pool, w.handler)
|
|
||||||
if err != nil {
|
|
||||||
if w.logger != nil {
|
|
||||||
w.logger.Warn("dialer: dial failed")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case newConn <- conn:
|
|
||||||
case <-dialCtx.Done():
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// main loop
|
||||||
for {
|
for {
|
||||||
// spawn initial dial for this reconnect cycle
|
// spawn initial dial for this reconnect cycle
|
||||||
spawnDial()
|
spawnDialer()
|
||||||
|
|
||||||
// obtain new connection
|
// obtain new connection
|
||||||
var conn *transport.Connection
|
var conn *transport.Connection
|
||||||
@@ -190,19 +161,22 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
|||||||
dialCancel()
|
dialCancel()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-w.heartbeat:
|
|
||||||
resetTimer()
|
|
||||||
case <-timerC():
|
|
||||||
if w.logger != nil {
|
|
||||||
w.logger.Info("keepalive: no activity observed")
|
|
||||||
}
|
|
||||||
timer.Reset(w.config.KeepaliveTimeout)
|
|
||||||
spawnDial()
|
|
||||||
case conn = <-newConn:
|
case conn = <-newConn:
|
||||||
if w.logger != nil {
|
if w.logger != nil {
|
||||||
w.logger.Debug("session: connected")
|
w.logger.Debug("session: connected")
|
||||||
}
|
}
|
||||||
break preConn
|
break preConn
|
||||||
|
|
||||||
|
case <-w.sendHeartbeat:
|
||||||
|
heartbeat()
|
||||||
|
|
||||||
|
case <-timerC():
|
||||||
|
if w.logger != nil {
|
||||||
|
w.logger.Info("keepalive: no activity observed")
|
||||||
|
}
|
||||||
|
timer.Reset(w.config.KeepaliveTimeout)
|
||||||
|
spawnDialer()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,14 +194,7 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
break conn_loop
|
break conn_loop
|
||||||
case <-w.heartbeat:
|
|
||||||
resetTimer()
|
|
||||||
case <-timerC():
|
|
||||||
if w.logger != nil {
|
|
||||||
w.logger.Info("keepalive: no activity observed")
|
|
||||||
}
|
|
||||||
timer.Reset(w.config.KeepaliveTimeout)
|
|
||||||
break conn_loop
|
|
||||||
case data, ok := <-conn.Incoming():
|
case data, ok := <-conn.Incoming():
|
||||||
if !ok {
|
if !ok {
|
||||||
if w.logger != nil {
|
if w.logger != nil {
|
||||||
@@ -235,20 +202,34 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
|||||||
}
|
}
|
||||||
break conn_loop
|
break conn_loop
|
||||||
}
|
}
|
||||||
|
|
||||||
pool.Inbox <- types.InboxMessage{
|
pool.Inbox <- types.InboxMessage{
|
||||||
ID: w.id,
|
ID: w.id, Data: data, ReceivedAt: time.Now()}
|
||||||
Data: data,
|
|
||||||
ReceivedAt: time.Now(),
|
pool.InboxCounter.Add(1)
|
||||||
}
|
w.processedCount.Add(1)
|
||||||
resetTimer()
|
|
||||||
|
heartbeat()
|
||||||
|
|
||||||
case <-conn.Heartbeat():
|
case <-conn.Heartbeat():
|
||||||
if w.logger != nil {
|
if w.logger != nil {
|
||||||
w.logger.Debug("ping-pong heartbeat")
|
w.logger.Debug("ping-pong heartbeat")
|
||||||
}
|
}
|
||||||
resetTimer()
|
heartbeat()
|
||||||
|
|
||||||
|
case <-w.sendHeartbeat:
|
||||||
|
heartbeat()
|
||||||
|
|
||||||
|
case <-timerC():
|
||||||
|
if w.logger != nil {
|
||||||
|
w.logger.Info("keepalive: no activity observed")
|
||||||
|
}
|
||||||
|
timer.Reset(w.config.KeepaliveTimeout)
|
||||||
|
break conn_loop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// session ended
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
if w.logger != nil {
|
if w.logger != nil {
|
||||||
@@ -272,6 +253,98 @@ func (w *DefaultWorker) runSession(ctx context.Context, pool PoolPlugin) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *DefaultWorker) setupHeartbeat() (
|
||||||
|
timer *time.Timer, timerC func() <-chan time.Time, heartbeat func(),
|
||||||
|
) {
|
||||||
|
if w.config.KeepaliveTimeout > 0 {
|
||||||
|
if w.logger != nil {
|
||||||
|
w.logger.Debug("keepalive: enabled", "timeout", w.config.KeepaliveTimeout)
|
||||||
|
}
|
||||||
|
timer = time.NewTimer(w.config.KeepaliveTimeout)
|
||||||
|
} else {
|
||||||
|
if w.logger != nil {
|
||||||
|
w.logger.Debug("keepalive: disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
heartbeat = func() {
|
||||||
|
if timer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timer.Reset(w.config.KeepaliveTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
timerC = func() <-chan time.Time {
|
||||||
|
if timer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return timer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DefaultWorker) spawnDialer(
|
||||||
|
ctx context.Context,
|
||||||
|
dialCancel context.CancelFunc,
|
||||||
|
newConn chan<- *transport.Connection,
|
||||||
|
pool PoolPlugin,
|
||||||
|
) context.CancelFunc {
|
||||||
|
if dialCancel != nil {
|
||||||
|
dialCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
dialCtx, dialCancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
if w.logger != nil {
|
||||||
|
w.logger.Debug("session: requesting connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := connect(w.id, dialCtx, pool, w.handler)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if w.logger != nil {
|
||||||
|
w.logger.Warn("dialer: dial failed", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case newConn <- conn:
|
||||||
|
case <-dialCtx.Done():
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return dialCancel
|
||||||
|
}
|
||||||
|
|
||||||
|
func connect(
|
||||||
|
id string,
|
||||||
|
ctx context.Context,
|
||||||
|
pool PoolPlugin,
|
||||||
|
handler slog.Handler,
|
||||||
|
) (*transport.Connection, error) {
|
||||||
|
conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.SetDialer(pool.Dialer)
|
||||||
|
return conn, conn.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------/
|
||||||
|
// Methods
|
||||||
|
// -------------------------/
|
||||||
|
|
||||||
func (w *DefaultWorker) Stop() {
|
func (w *DefaultWorker) Stop() {
|
||||||
if w.logger != nil {
|
if w.logger != nil {
|
||||||
w.logger.Debug("shutting down")
|
w.logger.Debug("shutting down")
|
||||||
@@ -291,7 +364,7 @@ func (w *DefaultWorker) Send(data []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case w.heartbeat <- struct{}{}:
|
case w.sendHeartbeat <- struct{}{}:
|
||||||
case <-w.ctx.Done():
|
case <-w.ctx.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,18 +397,3 @@ func (w *DefaultWorker) Stats() WorkerStats {
|
|||||||
TotalSent: w.outgoingCount.Load(),
|
TotalSent: w.outgoingCount.Load(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(
|
|
||||||
id string,
|
|
||||||
ctx context.Context,
|
|
||||||
pool PoolPlugin,
|
|
||||||
handler slog.Handler,
|
|
||||||
) (*transport.Connection, error) {
|
|
||||||
conn, err := transport.NewConnection(ctx, id, pool.ConnectionConfig, handler)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.SetDialer(pool.Dialer)
|
|
||||||
return conn, conn.Connect(ctx)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,117 +0,0 @@
|
|||||||
package honeybee
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestWorkerSend(t *testing.T) {
|
|
||||||
t.Run("data sent to mock socket", func(t *testing.T) {
|
|
||||||
conn, _, _, outgoingData := setupTestConnection(t)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
heartbeat := make(chan struct{})
|
|
||||||
heartbeatCount := atomic.Int32{}
|
|
||||||
|
|
||||||
w := &DefaultWorker{
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
id: "wss://test",
|
|
||||||
heartbeat: heartbeat,
|
|
||||||
outgoingCount: &atomic.Uint64{},
|
|
||||||
}
|
|
||||||
w.conn.Store(conn)
|
|
||||||
defer w.cancel()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for range heartbeat {
|
|
||||||
heartbeatCount.Add(1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
testData := []byte("hello")
|
|
||||||
err := w.Send(testData)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// at least one heartbeat was sent
|
|
||||||
honeybeetest.Eventually(t, func() bool {
|
|
||||||
return heartbeatCount.Load() >= 1
|
|
||||||
}, "expected heartbeats")
|
|
||||||
|
|
||||||
// message was sent by the socket
|
|
||||||
honeybeetest.Eventually(t, func() bool {
|
|
||||||
select {
|
|
||||||
case msg := <-outgoingData:
|
|
||||||
return string(msg.Data) == "hello"
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, "expected message")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("sends one heartbeat per successful send", func(t *testing.T) {
|
|
||||||
conn, _, _, _ := setupTestConnection(t)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
heartbeat := make(chan struct{})
|
|
||||||
heartbeatCount := atomic.Int32{}
|
|
||||||
|
|
||||||
w := &DefaultWorker{
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
id: "wss://test",
|
|
||||||
heartbeat: heartbeat,
|
|
||||||
outgoingCount: &atomic.Uint64{},
|
|
||||||
}
|
|
||||||
w.conn.Store(conn)
|
|
||||||
defer w.cancel()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for range heartbeat {
|
|
||||||
heartbeatCount.Add(1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
const count = 3
|
|
||||||
for i := range count {
|
|
||||||
err := w.Send(fmt.Appendf(nil, "msg-%d", i))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
honeybeetest.Eventually(t, func() bool {
|
|
||||||
return heartbeatCount.Load() == count
|
|
||||||
}, "expected heartbeats")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns error if connection is unavailable", func(t *testing.T) {
|
|
||||||
// no connection available to worker
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
heartbeat := make(chan struct{})
|
|
||||||
|
|
||||||
w := &DefaultWorker{
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
id: "wss://test",
|
|
||||||
heartbeat: heartbeat,
|
|
||||||
}
|
|
||||||
defer w.cancel()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for range heartbeat {
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := w.Send([]byte("hello"))
|
|
||||||
assert.ErrorIs(t, err, ErrConnectionUnavailable)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
+113
-5
@@ -3,6 +3,7 @@ package honeybee
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||||
@@ -41,7 +42,7 @@ func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *D
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
id: "wss://test",
|
id: "wss://test",
|
||||||
config: config,
|
config: config,
|
||||||
heartbeat: make(chan struct{}),
|
sendHeartbeat: make(chan struct{}),
|
||||||
processedCount: &atomic.Uint64{},
|
processedCount: &atomic.Uint64{},
|
||||||
outgoingCount: &atomic.Uint64{},
|
outgoingCount: &atomic.Uint64{},
|
||||||
restartCount: &atomic.Uint64{},
|
restartCount: &atomic.Uint64{},
|
||||||
@@ -134,7 +135,7 @@ func TestWorkerSession(t *testing.T) {
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
id: "wss://test",
|
id: "wss://test",
|
||||||
config: config,
|
config: config,
|
||||||
heartbeat: make(chan struct{}),
|
sendHeartbeat: make(chan struct{}),
|
||||||
processedCount: &atomic.Uint64{},
|
processedCount: &atomic.Uint64{},
|
||||||
outgoingCount: &atomic.Uint64{},
|
outgoingCount: &atomic.Uint64{},
|
||||||
restartCount: &atomic.Uint64{},
|
restartCount: &atomic.Uint64{},
|
||||||
@@ -303,7 +304,7 @@ func TestWorkerSession(t *testing.T) {
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
id: "wss://test",
|
id: "wss://test",
|
||||||
config: config,
|
config: config,
|
||||||
heartbeat: make(chan struct{}),
|
sendHeartbeat: make(chan struct{}),
|
||||||
processedCount: &atomic.Uint64{},
|
processedCount: &atomic.Uint64{},
|
||||||
outgoingCount: &atomic.Uint64{},
|
outgoingCount: &atomic.Uint64{},
|
||||||
restartCount: &atomic.Uint64{},
|
restartCount: &atomic.Uint64{},
|
||||||
@@ -365,7 +366,7 @@ func TestWorkerSession(t *testing.T) {
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
id: "wss://test",
|
id: "wss://test",
|
||||||
config: config,
|
config: config,
|
||||||
heartbeat: make(chan struct{}),
|
sendHeartbeat: make(chan struct{}),
|
||||||
processedCount: &atomic.Uint64{},
|
processedCount: &atomic.Uint64{},
|
||||||
outgoingCount: &atomic.Uint64{},
|
outgoingCount: &atomic.Uint64{},
|
||||||
restartCount: &atomic.Uint64{},
|
restartCount: &atomic.Uint64{},
|
||||||
@@ -431,7 +432,7 @@ func TestWorkerSession(t *testing.T) {
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
id: "wss://test",
|
id: "wss://test",
|
||||||
config: config,
|
config: config,
|
||||||
heartbeat: make(chan struct{}),
|
sendHeartbeat: make(chan struct{}),
|
||||||
processedCount: &atomic.Uint64{},
|
processedCount: &atomic.Uint64{},
|
||||||
outgoingCount: &atomic.Uint64{},
|
outgoingCount: &atomic.Uint64{},
|
||||||
restartCount: &atomic.Uint64{},
|
restartCount: &atomic.Uint64{},
|
||||||
@@ -638,3 +639,110 @@ func TestWorkerSession(t *testing.T) {
|
|||||||
}, "expected wg to drain after parent cancel")
|
}, "expected wg to drain after parent cancel")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWorkerSend(t *testing.T) {
|
||||||
|
t.Run("data sent to mock socket", func(t *testing.T) {
|
||||||
|
conn, _, _, outgoingData := setupTestConnection(t)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
heartbeat := make(chan struct{})
|
||||||
|
heartbeatCount := atomic.Int32{}
|
||||||
|
|
||||||
|
w := &DefaultWorker{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
id: "wss://test",
|
||||||
|
sendHeartbeat: heartbeat,
|
||||||
|
outgoingCount: &atomic.Uint64{},
|
||||||
|
}
|
||||||
|
w.conn.Store(conn)
|
||||||
|
defer w.cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for range heartbeat {
|
||||||
|
heartbeatCount.Add(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
testData := []byte("hello")
|
||||||
|
err := w.Send(testData)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// at least one heartbeat was sent
|
||||||
|
honeybeetest.Eventually(t, func() bool {
|
||||||
|
return heartbeatCount.Load() >= 1
|
||||||
|
}, "expected heartbeats")
|
||||||
|
|
||||||
|
// message was sent by the socket
|
||||||
|
honeybeetest.Eventually(t, func() bool {
|
||||||
|
select {
|
||||||
|
case msg := <-outgoingData:
|
||||||
|
return string(msg.Data) == "hello"
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}, "expected message")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sends one heartbeat per successful send", func(t *testing.T) {
|
||||||
|
conn, _, _, _ := setupTestConnection(t)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
heartbeat := make(chan struct{})
|
||||||
|
heartbeatCount := atomic.Int32{}
|
||||||
|
|
||||||
|
w := &DefaultWorker{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
id: "wss://test",
|
||||||
|
sendHeartbeat: heartbeat,
|
||||||
|
outgoingCount: &atomic.Uint64{},
|
||||||
|
}
|
||||||
|
w.conn.Store(conn)
|
||||||
|
defer w.cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for range heartbeat {
|
||||||
|
heartbeatCount.Add(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
const count = 3
|
||||||
|
for i := range count {
|
||||||
|
err := w.Send(fmt.Appendf(nil, "msg-%d", i))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
honeybeetest.Eventually(t, func() bool {
|
||||||
|
return heartbeatCount.Load() == count
|
||||||
|
}, "expected heartbeats")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error if connection is unavailable", func(t *testing.T) {
|
||||||
|
// no connection available to worker
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
heartbeat := make(chan struct{})
|
||||||
|
|
||||||
|
w := &DefaultWorker{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
id: "wss://test",
|
||||||
|
sendHeartbeat: heartbeat,
|
||||||
|
}
|
||||||
|
defer w.cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for range heartbeat {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := w.Send([]byte("hello"))
|
||||||
|
assert.ErrorIs(t, err, ErrConnectionUnavailable)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user