implemented ping-pong heartbeats. adjusted logs and defaults.

This commit is contained in:
Jay
2026-04-24 09:59:01 -04:00
parent 3091c5dfd8
commit e32bbc99d8
13 changed files with 293 additions and 30 deletions
+12
View File
@@ -27,11 +27,13 @@ func (m *MockDialer) DialContext(
type MockSocket struct { type MockSocket struct {
WriteMessageFunc func(int, []byte) error WriteMessageFunc func(int, []byte) error
WriteControlFunc func(int, []byte, time.Time) error
SetReadDeadlineFunc func(t time.Time) error SetReadDeadlineFunc func(t time.Time) error
SetWriteDeadlineFunc func(t time.Time) error SetWriteDeadlineFunc func(t time.Time) error
ReadMessageFunc func() (int, []byte, error) ReadMessageFunc func() (int, []byte, error)
CloseFunc func() error CloseFunc func() error
SetCloseHandlerFunc func(func(int, string) error) SetCloseHandlerFunc func(func(int, string) error)
SetPongHandlerFunc func(func(string) error)
Closed chan struct{} Closed chan struct{}
Once sync.Once Once sync.Once
Mu sync.Mutex Mu sync.Mutex
@@ -40,12 +42,14 @@ type MockSocket struct {
func NewMockSocket() *MockSocket { func NewMockSocket() *MockSocket {
return &MockSocket{ return &MockSocket{
WriteMessageFunc: func(int, []byte) error { return nil }, WriteMessageFunc: func(int, []byte) error { return nil },
WriteControlFunc: func(int, []byte, time.Time) error { return nil },
ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil }, ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil },
CloseFunc: func() error { return nil }, CloseFunc: func() error { return nil },
SetReadDeadlineFunc: func(time.Time) error { return nil }, SetReadDeadlineFunc: func(time.Time) error { return nil },
SetWriteDeadlineFunc: func(time.Time) error { return nil }, SetWriteDeadlineFunc: func(time.Time) error { return nil },
SetCloseHandlerFunc: func(func(int, string) error) {}, SetCloseHandlerFunc: func(func(int, string) error) {},
SetPongHandlerFunc: func(func(string) error) {},
Closed: make(chan struct{}), Closed: make(chan struct{}),
} }
@@ -56,6 +60,10 @@ func (m *MockSocket) WriteMessage(t int, d []byte) error {
return m.WriteMessageFunc(t, d) return m.WriteMessageFunc(t, d)
} }
func (m *MockSocket) WriteControl(t int, d []byte, dl time.Time) error {
return m.WriteControlFunc(t, d, dl)
}
func (m *MockSocket) ReadMessage() (int, []byte, error) { func (m *MockSocket) ReadMessage() (int, []byte, error) {
return m.ReadMessageFunc() return m.ReadMessageFunc()
} }
@@ -76,6 +84,10 @@ func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) {
m.SetCloseHandlerFunc(h) m.SetCloseHandlerFunc(h)
} }
func (m *MockSocket) SetPongHandler(h func(s string) error) {
m.SetPongHandlerFunc(h)
}
// Logging mocks // Logging mocks
type MockSlogHandler struct { type MockSlogHandler struct {
+29 -1
View File
@@ -70,13 +70,18 @@ func (w *DefaultWorker) Start(pool PoolPlugin) {
toForwarder := make(chan types.ReceivedMessage, 256) toForwarder := make(chan types.ReceivedMessage, 256)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(4) wg.Add(5)
go func() { go func() {
defer wg.Done() defer wg.Done()
RunReader(w.ctx, pool.OnExit, w.conn, toQueue, w.heartbeat, w.logger) RunReader(w.ctx, pool.OnExit, w.conn, toQueue, w.heartbeat, w.logger)
}() }()
go func() {
defer wg.Done()
RunHeartbeatForwarder(w.ctx, w.conn, w.heartbeat, w.logger)
}()
go func() { go func() {
defer wg.Done() defer wg.Done()
queue.RunQueue(w.id, w.ctx, toQueue, toForwarder, w.config.MaxQueueSize) queue.RunQueue(w.id, w.ctx, toQueue, toForwarder, w.config.MaxQueueSize)
@@ -177,6 +182,29 @@ func RunReader(
} }
} }
func RunHeartbeatForwarder(
ctx context.Context,
conn *transport.Connection,
heartbeat chan<- struct{},
logger *slog.Logger,
) {
for {
select {
case <-ctx.Done():
return
case <-conn.Heartbeat():
select {
case heartbeat <- struct{}{}:
if logger != nil {
logger.Debug("ping-pong heartbeat")
}
case <-ctx.Done():
return
}
}
}
}
func RunForwarder( func RunForwarder(
id string, id string,
ctx context.Context, ctx context.Context,
+33
View File
@@ -229,3 +229,36 @@ func TestWorkerSend(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
}) })
} }
func TestHeartbeatForwarder(t *testing.T) {
t.Run("connection level heartbeat propagates", func(t *testing.T) {
socket, _, _ := honeybeetest.SetupTestSocket(t)
var pongHandler func(string) error
socket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h }
conn, err := transport.NewConnectionFromSocket(socket, nil, nil)
assert.NoError(t, err)
heartbeat := make(chan struct{}, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go RunHeartbeatForwarder(ctx, conn, heartbeat, nil)
honeybeetest.Eventually(t, func() bool {
return pongHandler != nil
}, "expected Connection to register PongHandler")
if pongHandler == nil {
t.Fatal("pong handler was never set")
}
pongHandler("") // Trigger pong
select {
case <-heartbeat:
case <-time.After(time.Second):
t.Fatal("pong did not propagate to worker heartbeat")
}
})
}
+28 -1
View File
@@ -208,11 +208,15 @@ func (s *Session) Start(
// start session // start session
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(3)
go func() { go func() {
defer wg.Done() defer wg.Done()
RunReader(sctx, onStop, conn, s.messages, s.heartbeat, s.logger) RunReader(sctx, onStop, conn, s.messages, s.heartbeat, s.logger)
}() }()
go func() {
defer wg.Done()
RunHeartbeatForwarder(sctx, conn, s.heartbeat, s.logger)
}()
go func() { go func() {
defer wg.Done() defer wg.Done()
RunStopMonitor(sctx, onStop, conn, s.keepalive, s.logger) RunStopMonitor(sctx, onStop, conn, s.keepalive, s.logger)
@@ -289,6 +293,29 @@ func RunReader(
} }
} }
func RunHeartbeatForwarder(
ctx context.Context,
conn *transport.Connection,
heartbeat chan<- struct{},
logger *slog.Logger,
) {
for {
select {
case <-ctx.Done():
return
case <-conn.Heartbeat():
select {
case heartbeat <- struct{}{}:
if logger != nil {
logger.Debug("ping-pong heartbeat")
}
case <-ctx.Done():
return
}
}
}
}
func RunStopMonitor( func RunStopMonitor(
ctx context.Context, ctx context.Context,
onStop func(), onStop func(),
+33
View File
@@ -144,6 +144,39 @@ func TestRunReader(t *testing.T) {
}) })
} }
func TestHeartbeatForwarder(t *testing.T) {
t.Run("connection level heartbeat propagates", func(t *testing.T) {
socket, _, _ := honeybeetest.SetupTestSocket(t)
var pongHandler func(string) error
socket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h }
conn, err := transport.NewConnectionFromSocket(socket, nil, nil)
assert.NoError(t, err)
heartbeat := make(chan struct{}, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go RunHeartbeatForwarder(ctx, conn, heartbeat, nil)
honeybeetest.Eventually(t, func() bool {
return pongHandler != nil
}, "expected Connection to register PongHandler")
if pongHandler == nil {
t.Fatal("pong handler was never set")
}
pongHandler("") // Trigger pong
select {
case <-heartbeat:
case <-time.After(time.Second):
t.Fatal("pong did not propagate to worker heartbeat")
}
})
}
func TestRunStopMonitor(t *testing.T) { func TestRunStopMonitor(t *testing.T) {
t.Run("keepalive signal calls conn.Close and cancel", func(t *testing.T) { t.Run("keepalive signal calls conn.Close and cancel", func(t *testing.T) {
conn, _, _, _ := setupTestConnection(t) conn, _, _, _ := setupTestConnection(t)
+22 -1
View File
@@ -10,6 +10,7 @@ type CloseHandler func(code int, text string) error
type ConnectionConfig struct { type ConnectionConfig struct {
CloseHandler CloseHandler CloseHandler CloseHandler
WriteTimeout time.Duration WriteTimeout time.Duration
PingInterval time.Duration
IncomingBufferSize int IncomingBufferSize int
ErrorsBufferSize int ErrorsBufferSize int
LoggingEnabled bool LoggingEnabled bool
@@ -41,6 +42,7 @@ func GetDefaultConnectionConfig() *ConnectionConfig {
return &ConnectionConfig{ return &ConnectionConfig{
CloseHandler: nil, CloseHandler: nil,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second,
IncomingBufferSize: 100, IncomingBufferSize: 100,
ErrorsBufferSize: 10, ErrorsBufferSize: 10,
LoggingEnabled: true, LoggingEnabled: true,
@@ -53,7 +55,7 @@ func GetDefaultRetryConfig() *RetryConfig {
return &RetryConfig{ return &RetryConfig{
MaxRetries: 0, // Infinite retries MaxRetries: 0, // Infinite retries
InitialDelay: 1 * time.Second, InitialDelay: 1 * time.Second,
MaxDelay: 5 * time.Second, MaxDelay: 60 * time.Second,
JitterFactor: 0.5, JitterFactor: 0.5,
} }
} }
@@ -109,6 +111,13 @@ func validateWriteTimeout(value time.Duration) error {
return nil return nil
} }
func validatePingInterval(value time.Duration) error {
if value < 0 {
return InvalidPingInterval
}
return nil
}
func validateBufferSize(value int) error { func validateBufferSize(value int) error {
if value < 1 { if value < 1 {
return InvalidBufferSize return InvalidBufferSize
@@ -163,6 +172,18 @@ func WithWriteTimeout(value time.Duration) ConnectionOption {
} }
} }
// When PingInterval is set to zero, ping frames are disabled.
func WithPingInterval(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
err := validatePingInterval(value)
if err != nil {
return err
}
c.PingInterval = value
return nil
}
}
func WithIncomingBufferSize(value int) ConnectionOption { func WithIncomingBufferSize(value int) ConnectionOption {
return func(c *ConnectionConfig) error { return func(c *ConnectionConfig) error {
if err := validateBufferSize(value); err != nil { if err := validateBufferSize(value); err != nil {
+3 -1
View File
@@ -16,6 +16,7 @@ func TestNewConnectionConfig(t *testing.T) {
assert.Equal(t, conf, &ConnectionConfig{ assert.Equal(t, conf, &ConnectionConfig{
CloseHandler: nil, CloseHandler: nil,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second,
IncomingBufferSize: 100, IncomingBufferSize: 100,
ErrorsBufferSize: 10, ErrorsBufferSize: 10,
LoggingEnabled: true, LoggingEnabled: true,
@@ -39,6 +40,7 @@ func TestDefaultConnectionConfig(t *testing.T) {
assert.Equal(t, conf, &ConnectionConfig{ assert.Equal(t, conf, &ConnectionConfig{
CloseHandler: nil, CloseHandler: nil,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PingInterval: 20 * time.Second,
IncomingBufferSize: 100, IncomingBufferSize: 100,
ErrorsBufferSize: 10, ErrorsBufferSize: 10,
LoggingEnabled: true, LoggingEnabled: true,
@@ -53,7 +55,7 @@ func TestDefaultRetryConnectionConfig(t *testing.T) {
assert.Equal(t, conf, &RetryConfig{ assert.Equal(t, conf, &RetryConfig{
MaxRetries: 0, MaxRetries: 0,
InitialDelay: 1 * time.Second, InitialDelay: 1 * time.Second,
MaxDelay: 5 * time.Second, MaxDelay: 60 * time.Second,
JitterFactor: 0.5, JitterFactor: 0.5,
}) })
} }
+75 -21
View File
@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math/rand"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@@ -44,9 +45,10 @@ type Connection struct {
config *ConnectionConfig config *ConnectionConfig
logger *slog.Logger logger *slog.Logger
incoming chan []byte incoming chan []byte
errors chan error heartbeat chan struct{}
done chan struct{} errors chan error
done chan struct{}
state ConnectionState state ConnectionState
@@ -73,15 +75,16 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger)
} }
conn := &Connection{ conn := &Connection{
url: url, url: url,
dialer: NewDialer(), dialer: NewDialer(),
socket: nil, socket: nil,
config: config, config: config,
logger: logger, logger: logger,
incoming: make(chan []byte, config.IncomingBufferSize), incoming: make(chan []byte, config.IncomingBufferSize),
errors: make(chan error, config.ErrorsBufferSize), heartbeat: make(chan struct{}, 1),
state: StateDisconnected, errors: make(chan error, config.ErrorsBufferSize),
done: make(chan struct{}), state: StateDisconnected,
done: make(chan struct{}),
} }
return conn, nil return conn, nil
@@ -103,21 +106,24 @@ func NewConnectionFromSocket(
} }
conn := &Connection{ conn := &Connection{
url: nil, url: nil,
dialer: nil, dialer: nil,
socket: socket, socket: socket,
config: config, config: config,
logger: logger, logger: logger,
incoming: make(chan []byte, config.IncomingBufferSize), incoming: make(chan []byte, config.IncomingBufferSize),
errors: make(chan error, config.ErrorsBufferSize), heartbeat: make(chan struct{}, 1),
state: StateConnected, errors: make(chan error, config.ErrorsBufferSize),
done: make(chan struct{}), state: StateConnected,
done: make(chan struct{}),
} }
if config.CloseHandler != nil { if config.CloseHandler != nil {
socket.SetCloseHandler(config.CloseHandler) socket.SetCloseHandler(config.CloseHandler)
} }
conn.setupPongHandler()
conn.startPinger()
conn.startReader() conn.startReader()
return conn, nil return conn, nil
@@ -164,6 +170,8 @@ func (c *Connection) Connect(ctx context.Context) error {
c.logger.Info("connected") c.logger.Info("connected")
} }
c.setupPongHandler()
c.startPinger()
c.startReader() c.startReader()
return nil return nil
@@ -336,6 +344,48 @@ func (c *Connection) startReader() {
}() }()
} }
func (c *Connection) setupPongHandler() {
c.socket.SetPongHandler(func(appData string) error {
select {
case c.heartbeat <- struct{}{}:
default:
}
return nil
})
}
func (c *Connection) startPinger() {
if c.config.PingInterval <= 0 {
return
}
c.wg.Add(1)
go func() {
defer c.wg.Done()
defer c.shutdownInternal()
// Calculate 10% jitter window
jitter := c.config.PingInterval / 10
for {
offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter
next := c.config.PingInterval + offset
timer := time.NewTimer(next)
select {
case <-c.done:
timer.Stop()
return
case <-timer.C:
deadline := time.Now().Add(c.config.WriteTimeout)
if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
return
}
}
}
}()
}
func (c *Connection) Send(data []byte) error { func (c *Connection) Send(data []byte) error {
c.writeMu.Lock() c.writeMu.Lock()
defer c.writeMu.Unlock() defer c.writeMu.Unlock()
@@ -368,6 +418,10 @@ func (c *Connection) Incoming() <-chan []byte {
return c.incoming return c.incoming
} }
func (c *Connection) Heartbeat() <-chan struct{} {
return c.heartbeat
}
func (c *Connection) Errors() <-chan error { func (c *Connection) Errors() <-chan error {
return c.errors return c.errors
} }
+50
View File
@@ -537,6 +537,56 @@ func TestConnectionErrors(t *testing.T) {
}) })
} }
func TestConnectionHeartbeat(t *testing.T) {
t.Run("pinger sends ping frames", func(t *testing.T) {
pingCount := atomic.Int32{}
socket, _, _ := honeybeetest.SetupTestSocket(t)
socket.WriteControlFunc = func(mt int, d []byte, dl time.Time) error {
if mt == websocket.PingMessage {
pingCount.Add(1)
}
return nil
}
conf, err := NewConnectionConfig(
WithPingInterval(10 * time.Millisecond),
)
assert.NoError(t, err)
conn, _ := NewConnectionFromSocket(socket, conf, nil)
defer conn.Close()
honeybeetest.Eventually(t,
func() bool { return pingCount.Load() >= 2 },
"expected pinger to fire")
})
t.Run("pong handler triggers heartbeat channel", func(t *testing.T) {
var handler func(string) error
socket, _, _ := honeybeetest.SetupTestSocket(t)
socket.SetPongHandlerFunc = func(h func(string) error) { handler = h }
conn, _ := NewConnectionFromSocket(socket, nil, nil)
defer conn.Close()
honeybeetest.Eventually(t, func() bool {
return handler != nil
}, "expected Connection to register PongHandler")
if handler == nil {
t.Fatal("pong handler was never set")
}
handler("") // Simulate inbound pong
select {
case <-conn.Heartbeat():
case <-time.After(time.Second):
t.Fatal("heartbeat not signaled on pong")
}
})
}
// Test helpers // Test helpers
func setupTestConnection(t *testing.T) ( func setupTestConnection(t *testing.T) (
+1
View File
@@ -9,6 +9,7 @@ var (
// Configuration Errors // Configuration Errors
InvalidWriteTimeout = errors.New("write timeout cannot be negative") InvalidWriteTimeout = errors.New("write timeout cannot be negative")
InvalidPingInterval = errors.New("ping interval cannot be negative")
InvalidBufferSize = errors.New("buffer size must be greater than zero") InvalidBufferSize = errors.New("buffer size must be greater than zero")
InvalidRetryMaxRetries = errors.New("max retry count cannot be negative") InvalidRetryMaxRetries = errors.New("max retry count cannot be negative")
InvalidRetryInitialDelay = errors.New("initial delay must be positive") InvalidRetryInitialDelay = errors.New("initial delay must be positive")
+4 -4
View File
@@ -87,9 +87,9 @@ func TestConnectLogging(t *testing.T) {
expected := []honeybeetest.ExpectedLog{ expected := []honeybeetest.ExpectedLog{
log(slog.LevelDebug, "connecting", map[string]any{}), log(slog.LevelDebug, "connecting", map[string]any{}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}),
log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}), log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}),
log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}),
log(slog.LevelError, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}), log(slog.LevelError, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}),
log(slog.LevelError, "connection failed", map[string]any{"error": dialErr}), log(slog.LevelError, "connection failed", map[string]any{"error": dialErr}),
@@ -136,9 +136,9 @@ func TestConnectLogging(t *testing.T) {
expected := []honeybeetest.ExpectedLog{ expected := []honeybeetest.ExpectedLog{
log(slog.LevelDebug, "connecting", map[string]any{}), log(slog.LevelDebug, "connecting", map[string]any{}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}),
log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}), log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}),
log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}),
log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}),
log(slog.LevelDebug, "dial successful", map[string]any{"attempt": 3}), log(slog.LevelDebug, "dial successful", map[string]any{"attempt": 3}),
log(slog.LevelInfo, "connected", map[string]any{}), log(slog.LevelInfo, "connected", map[string]any{}),
+1 -1
View File
@@ -88,7 +88,7 @@ func AcquireSocket(
delay := retryMgr.CalculateDelay() delay := retryMgr.CalculateDelay()
if logger != nil { if logger != nil {
logger.Debug("dial failed, retrying", logger.Warn("dial failed, retrying",
"error", err, "error", err,
"attempt", retryMgr.RetryCount()+1, "attempt", retryMgr.RetryCount()+1,
"next_delay", delay) "next_delay", delay)
+2
View File
@@ -15,12 +15,14 @@ type Dialer interface {
type Socket interface { type Socket interface {
WriteMessage(messageType int, data []byte) error WriteMessage(messageType int, data []byte) error
WriteControl(messageType int, data []byte, deadline time.Time) error
ReadMessage() (messageType int, p []byte, err error) ReadMessage() (messageType int, p []byte, err error)
Close() error Close() error
SetReadDeadline(t time.Time) error SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error SetWriteDeadline(t time.Time) error
SetCloseHandler(h func(code int, text string) error) SetCloseHandler(h func(code int, text string) error)
SetPongHandler(h func(appData string) error)
} }
type ReceivedMessage struct { type ReceivedMessage struct {