From e32bbc99d886816eba96b805d8d3b7274ab07995 Mon Sep 17 00:00:00 2001 From: Jay Date: Fri, 24 Apr 2026 09:59:01 -0400 Subject: [PATCH] implemented ping-pong heartbeats. adjusted logs and defaults. --- honeybeetest/mocks.go | 12 ++++ inbound/worker.go | 30 ++++++++- inbound/worker_test.go | 33 +++++++++ outbound/worker.go | 29 +++++++- outbound/worker_session_inner_test.go | 33 +++++++++ transport/config.go | 23 ++++++- transport/config_test.go | 4 +- transport/connection.go | 96 +++++++++++++++++++++------ transport/connection_test.go | 50 ++++++++++++++ transport/errors.go | 1 + transport/logging_test.go | 8 +-- transport/socket.go | 2 +- types/types.go | 2 + 13 files changed, 293 insertions(+), 30 deletions(-) diff --git a/honeybeetest/mocks.go b/honeybeetest/mocks.go index a29420c..0fb53e5 100644 --- a/honeybeetest/mocks.go +++ b/honeybeetest/mocks.go @@ -27,11 +27,13 @@ func (m *MockDialer) DialContext( type MockSocket struct { WriteMessageFunc func(int, []byte) error + WriteControlFunc func(int, []byte, time.Time) error SetReadDeadlineFunc func(t time.Time) error SetWriteDeadlineFunc func(t time.Time) error ReadMessageFunc func() (int, []byte, error) CloseFunc func() error SetCloseHandlerFunc func(func(int, string) error) + SetPongHandlerFunc func(func(string) error) Closed chan struct{} Once sync.Once Mu sync.Mutex @@ -40,12 +42,14 @@ type MockSocket struct { func NewMockSocket() *MockSocket { return &MockSocket{ WriteMessageFunc: func(int, []byte) error { return nil }, + WriteControlFunc: func(int, []byte, time.Time) error { return nil }, ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil }, CloseFunc: func() error { return nil }, SetReadDeadlineFunc: func(time.Time) error { return nil }, SetWriteDeadlineFunc: func(time.Time) error { return nil }, SetCloseHandlerFunc: func(func(int, string) error) {}, + SetPongHandlerFunc: func(func(string) error) {}, Closed: make(chan struct{}), } @@ -56,6 +60,10 @@ func (m *MockSocket) WriteMessage(t int, d []byte) error { return m.WriteMessageFunc(t, d) } +func (m *MockSocket) WriteControl(t int, d []byte, dl time.Time) error { + return m.WriteControlFunc(t, d, dl) +} + func (m *MockSocket) ReadMessage() (int, []byte, error) { return m.ReadMessageFunc() } @@ -76,6 +84,10 @@ func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) { m.SetCloseHandlerFunc(h) } +func (m *MockSocket) SetPongHandler(h func(s string) error) { + m.SetPongHandlerFunc(h) +} + // Logging mocks type MockSlogHandler struct { diff --git a/inbound/worker.go b/inbound/worker.go index 00f6666..dca42c4 100644 --- a/inbound/worker.go +++ b/inbound/worker.go @@ -70,13 +70,18 @@ func (w *DefaultWorker) Start(pool PoolPlugin) { toForwarder := make(chan types.ReceivedMessage, 256) var wg sync.WaitGroup - wg.Add(4) + wg.Add(5) go func() { defer wg.Done() RunReader(w.ctx, pool.OnExit, w.conn, toQueue, w.heartbeat, w.logger) }() + go func() { + defer wg.Done() + RunHeartbeatForwarder(w.ctx, w.conn, w.heartbeat, w.logger) + }() + go func() { defer wg.Done() queue.RunQueue(w.id, w.ctx, toQueue, toForwarder, w.config.MaxQueueSize) @@ -177,6 +182,29 @@ func RunReader( } } +func RunHeartbeatForwarder( + ctx context.Context, + conn *transport.Connection, + heartbeat chan<- struct{}, + logger *slog.Logger, +) { + for { + select { + case <-ctx.Done(): + return + case <-conn.Heartbeat(): + select { + case heartbeat <- struct{}{}: + if logger != nil { + logger.Debug("ping-pong heartbeat") + } + case <-ctx.Done(): + return + } + } + } +} + func RunForwarder( id string, ctx context.Context, diff --git a/inbound/worker_test.go b/inbound/worker_test.go index afe9683..c9a2f71 100644 --- a/inbound/worker_test.go +++ b/inbound/worker_test.go @@ -229,3 +229,36 @@ func TestWorkerSend(t *testing.T) { assert.Error(t, err) }) } + +func TestHeartbeatForwarder(t *testing.T) { + t.Run("connection level heartbeat propagates", func(t *testing.T) { + socket, _, _ := honeybeetest.SetupTestSocket(t) + var pongHandler func(string) error + socket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h } + + conn, err := transport.NewConnectionFromSocket(socket, nil, nil) + assert.NoError(t, err) + + heartbeat := make(chan struct{}, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go RunHeartbeatForwarder(ctx, conn, heartbeat, nil) + + honeybeetest.Eventually(t, func() bool { + return pongHandler != nil + }, "expected Connection to register PongHandler") + + if pongHandler == nil { + t.Fatal("pong handler was never set") + } + + pongHandler("") // Trigger pong + + select { + case <-heartbeat: + case <-time.After(time.Second): + t.Fatal("pong did not propagate to worker heartbeat") + } + }) +} diff --git a/outbound/worker.go b/outbound/worker.go index 7a4d85b..0dd87ae 100644 --- a/outbound/worker.go +++ b/outbound/worker.go @@ -208,11 +208,15 @@ func (s *Session) Start( // start session var wg sync.WaitGroup - wg.Add(2) + wg.Add(3) go func() { defer wg.Done() RunReader(sctx, onStop, conn, s.messages, s.heartbeat, s.logger) }() + go func() { + defer wg.Done() + RunHeartbeatForwarder(sctx, conn, s.heartbeat, s.logger) + }() go func() { defer wg.Done() RunStopMonitor(sctx, onStop, conn, s.keepalive, s.logger) @@ -289,6 +293,29 @@ func RunReader( } } +func RunHeartbeatForwarder( + ctx context.Context, + conn *transport.Connection, + heartbeat chan<- struct{}, + logger *slog.Logger, +) { + for { + select { + case <-ctx.Done(): + return + case <-conn.Heartbeat(): + select { + case heartbeat <- struct{}{}: + if logger != nil { + logger.Debug("ping-pong heartbeat") + } + case <-ctx.Done(): + return + } + } + } +} + func RunStopMonitor( ctx context.Context, onStop func(), diff --git a/outbound/worker_session_inner_test.go b/outbound/worker_session_inner_test.go index cfa1111..f92689d 100644 --- a/outbound/worker_session_inner_test.go +++ b/outbound/worker_session_inner_test.go @@ -144,6 +144,39 @@ func TestRunReader(t *testing.T) { }) } +func TestHeartbeatForwarder(t *testing.T) { + t.Run("connection level heartbeat propagates", func(t *testing.T) { + socket, _, _ := honeybeetest.SetupTestSocket(t) + var pongHandler func(string) error + socket.SetPongHandlerFunc = func(h func(string) error) { pongHandler = h } + + conn, err := transport.NewConnectionFromSocket(socket, nil, nil) + assert.NoError(t, err) + + heartbeat := make(chan struct{}, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go RunHeartbeatForwarder(ctx, conn, heartbeat, nil) + + honeybeetest.Eventually(t, func() bool { + return pongHandler != nil + }, "expected Connection to register PongHandler") + + if pongHandler == nil { + t.Fatal("pong handler was never set") + } + + pongHandler("") // Trigger pong + + select { + case <-heartbeat: + case <-time.After(time.Second): + t.Fatal("pong did not propagate to worker heartbeat") + } + }) +} + func TestRunStopMonitor(t *testing.T) { t.Run("keepalive signal calls conn.Close and cancel", func(t *testing.T) { conn, _, _, _ := setupTestConnection(t) diff --git a/transport/config.go b/transport/config.go index 9aa9328..384aedf 100644 --- a/transport/config.go +++ b/transport/config.go @@ -10,6 +10,7 @@ type CloseHandler func(code int, text string) error type ConnectionConfig struct { CloseHandler CloseHandler WriteTimeout time.Duration + PingInterval time.Duration IncomingBufferSize int ErrorsBufferSize int LoggingEnabled bool @@ -41,6 +42,7 @@ func GetDefaultConnectionConfig() *ConnectionConfig { return &ConnectionConfig{ CloseHandler: nil, WriteTimeout: 30 * time.Second, + PingInterval: 20 * time.Second, IncomingBufferSize: 100, ErrorsBufferSize: 10, LoggingEnabled: true, @@ -53,7 +55,7 @@ func GetDefaultRetryConfig() *RetryConfig { return &RetryConfig{ MaxRetries: 0, // Infinite retries InitialDelay: 1 * time.Second, - MaxDelay: 5 * time.Second, + MaxDelay: 60 * time.Second, JitterFactor: 0.5, } } @@ -109,6 +111,13 @@ func validateWriteTimeout(value time.Duration) error { return nil } +func validatePingInterval(value time.Duration) error { + if value < 0 { + return InvalidPingInterval + } + return nil +} + func validateBufferSize(value int) error { if value < 1 { return InvalidBufferSize @@ -163,6 +172,18 @@ func WithWriteTimeout(value time.Duration) ConnectionOption { } } +// When PingInterval is set to zero, ping frames are disabled. +func WithPingInterval(value time.Duration) ConnectionOption { + return func(c *ConnectionConfig) error { + err := validatePingInterval(value) + if err != nil { + return err + } + c.PingInterval = value + return nil + } +} + func WithIncomingBufferSize(value int) ConnectionOption { return func(c *ConnectionConfig) error { if err := validateBufferSize(value); err != nil { diff --git a/transport/config_test.go b/transport/config_test.go index d72994f..86c24e8 100644 --- a/transport/config_test.go +++ b/transport/config_test.go @@ -16,6 +16,7 @@ func TestNewConnectionConfig(t *testing.T) { assert.Equal(t, conf, &ConnectionConfig{ CloseHandler: nil, WriteTimeout: 30 * time.Second, + PingInterval: 20 * time.Second, IncomingBufferSize: 100, ErrorsBufferSize: 10, LoggingEnabled: true, @@ -39,6 +40,7 @@ func TestDefaultConnectionConfig(t *testing.T) { assert.Equal(t, conf, &ConnectionConfig{ CloseHandler: nil, WriteTimeout: 30 * time.Second, + PingInterval: 20 * time.Second, IncomingBufferSize: 100, ErrorsBufferSize: 10, LoggingEnabled: true, @@ -53,7 +55,7 @@ func TestDefaultRetryConnectionConfig(t *testing.T) { assert.Equal(t, conf, &RetryConfig{ MaxRetries: 0, InitialDelay: 1 * time.Second, - MaxDelay: 5 * time.Second, + MaxDelay: 60 * time.Second, JitterFactor: 0.5, }) } diff --git a/transport/connection.go b/transport/connection.go index 7b45f9f..62b8983 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "math/rand" "net/url" "sync" "time" @@ -44,9 +45,10 @@ type Connection struct { config *ConnectionConfig logger *slog.Logger - incoming chan []byte - errors chan error - done chan struct{} + incoming chan []byte + heartbeat chan struct{} + errors chan error + done chan struct{} state ConnectionState @@ -73,15 +75,16 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) } conn := &Connection{ - url: url, - dialer: NewDialer(), - socket: nil, - config: config, - logger: logger, - incoming: make(chan []byte, config.IncomingBufferSize), - errors: make(chan error, config.ErrorsBufferSize), - state: StateDisconnected, - done: make(chan struct{}), + url: url, + dialer: NewDialer(), + socket: nil, + config: config, + logger: logger, + incoming: make(chan []byte, config.IncomingBufferSize), + heartbeat: make(chan struct{}, 1), + errors: make(chan error, config.ErrorsBufferSize), + state: StateDisconnected, + done: make(chan struct{}), } return conn, nil @@ -103,21 +106,24 @@ func NewConnectionFromSocket( } conn := &Connection{ - url: nil, - dialer: nil, - socket: socket, - config: config, - logger: logger, - incoming: make(chan []byte, config.IncomingBufferSize), - errors: make(chan error, config.ErrorsBufferSize), - state: StateConnected, - done: make(chan struct{}), + url: nil, + dialer: nil, + socket: socket, + config: config, + logger: logger, + incoming: make(chan []byte, config.IncomingBufferSize), + heartbeat: make(chan struct{}, 1), + errors: make(chan error, config.ErrorsBufferSize), + state: StateConnected, + done: make(chan struct{}), } if config.CloseHandler != nil { socket.SetCloseHandler(config.CloseHandler) } + conn.setupPongHandler() + conn.startPinger() conn.startReader() return conn, nil @@ -164,6 +170,8 @@ func (c *Connection) Connect(ctx context.Context) error { c.logger.Info("connected") } + c.setupPongHandler() + c.startPinger() c.startReader() return nil @@ -336,6 +344,48 @@ func (c *Connection) startReader() { }() } +func (c *Connection) setupPongHandler() { + c.socket.SetPongHandler(func(appData string) error { + select { + case c.heartbeat <- struct{}{}: + default: + } + return nil + }) +} + +func (c *Connection) startPinger() { + if c.config.PingInterval <= 0 { + return + } + + c.wg.Add(1) + go func() { + defer c.wg.Done() + defer c.shutdownInternal() + + // Calculate 10% jitter window + jitter := c.config.PingInterval / 10 + + for { + offset := time.Duration(rand.Int63n(int64(jitter*2))) - jitter + next := c.config.PingInterval + offset + timer := time.NewTimer(next) + select { + case <-c.done: + timer.Stop() + return + case <-timer.C: + deadline := time.Now().Add(c.config.WriteTimeout) + if err := c.socket.WriteControl(websocket.PingMessage, nil, deadline); err != nil { + return + } + } + } + }() + +} + func (c *Connection) Send(data []byte) error { c.writeMu.Lock() defer c.writeMu.Unlock() @@ -368,6 +418,10 @@ func (c *Connection) Incoming() <-chan []byte { return c.incoming } +func (c *Connection) Heartbeat() <-chan struct{} { + return c.heartbeat +} + func (c *Connection) Errors() <-chan error { return c.errors } diff --git a/transport/connection_test.go b/transport/connection_test.go index 3ad58f8..399bdac 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -537,6 +537,56 @@ func TestConnectionErrors(t *testing.T) { }) } +func TestConnectionHeartbeat(t *testing.T) { + t.Run("pinger sends ping frames", func(t *testing.T) { + pingCount := atomic.Int32{} + socket, _, _ := honeybeetest.SetupTestSocket(t) + socket.WriteControlFunc = func(mt int, d []byte, dl time.Time) error { + if mt == websocket.PingMessage { + pingCount.Add(1) + } + return nil + } + + conf, err := NewConnectionConfig( + WithPingInterval(10 * time.Millisecond), + ) + assert.NoError(t, err) + + conn, _ := NewConnectionFromSocket(socket, conf, nil) + defer conn.Close() + + honeybeetest.Eventually(t, + func() bool { return pingCount.Load() >= 2 }, + "expected pinger to fire") + }) + + t.Run("pong handler triggers heartbeat channel", func(t *testing.T) { + var handler func(string) error + socket, _, _ := honeybeetest.SetupTestSocket(t) + socket.SetPongHandlerFunc = func(h func(string) error) { handler = h } + + conn, _ := NewConnectionFromSocket(socket, nil, nil) + defer conn.Close() + + honeybeetest.Eventually(t, func() bool { + return handler != nil + }, "expected Connection to register PongHandler") + + if handler == nil { + t.Fatal("pong handler was never set") + } + + handler("") // Simulate inbound pong + + select { + case <-conn.Heartbeat(): + case <-time.After(time.Second): + t.Fatal("heartbeat not signaled on pong") + } + }) +} + // Test helpers func setupTestConnection(t *testing.T) ( diff --git a/transport/errors.go b/transport/errors.go index c31c0cf..58c972d 100644 --- a/transport/errors.go +++ b/transport/errors.go @@ -9,6 +9,7 @@ var ( // Configuration Errors InvalidWriteTimeout = errors.New("write timeout cannot be negative") + InvalidPingInterval = errors.New("ping interval cannot be negative") InvalidBufferSize = errors.New("buffer size must be greater than zero") InvalidRetryMaxRetries = errors.New("max retry count cannot be negative") InvalidRetryInitialDelay = errors.New("initial delay must be positive") diff --git a/transport/logging_test.go b/transport/logging_test.go index 486f925..2a6e0e5 100644 --- a/transport/logging_test.go +++ b/transport/logging_test.go @@ -87,9 +87,9 @@ func TestConnectLogging(t *testing.T) { expected := []honeybeetest.ExpectedLog{ log(slog.LevelDebug, "connecting", map[string]any{}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}), - log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}), + log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}), - log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), + log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}), log(slog.LevelError, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}), log(slog.LevelError, "connection failed", map[string]any{"error": dialErr}), @@ -136,9 +136,9 @@ func TestConnectLogging(t *testing.T) { expected := []honeybeetest.ExpectedLog{ log(slog.LevelDebug, "connecting", map[string]any{}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 1}), - log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}), + log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 2}), - log(slog.LevelDebug, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), + log(slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}), log(slog.LevelDebug, "dialing", map[string]any{"attempt": 3}), log(slog.LevelDebug, "dial successful", map[string]any{"attempt": 3}), log(slog.LevelInfo, "connected", map[string]any{}), diff --git a/transport/socket.go b/transport/socket.go index 986c11f..8df0868 100644 --- a/transport/socket.go +++ b/transport/socket.go @@ -88,7 +88,7 @@ func AcquireSocket( delay := retryMgr.CalculateDelay() if logger != nil { - logger.Debug("dial failed, retrying", + logger.Warn("dial failed, retrying", "error", err, "attempt", retryMgr.RetryCount()+1, "next_delay", delay) diff --git a/types/types.go b/types/types.go index bd139da..0ff4f46 100644 --- a/types/types.go +++ b/types/types.go @@ -15,12 +15,14 @@ type Dialer interface { type Socket interface { WriteMessage(messageType int, data []byte) error + WriteControl(messageType int, data []byte, deadline time.Time) error ReadMessage() (messageType int, p []byte, err error) Close() error SetReadDeadline(t time.Time) error SetWriteDeadline(t time.Time) error SetCloseHandler(h func(code int, text string) error) + SetPongHandler(h func(appData string) error) } type ReceivedMessage struct {