From 9b29592a394c5769fd07915036a47f1bb79927a8 Mon Sep 17 00:00:00 2001 From: Jay Date: Mon, 20 Apr 2026 08:45:04 -0400 Subject: [PATCH] Decoupled worker from goroutines. --- initiatorpool/pool.go | 4 +- initiatorpool/worker.go | 152 ++++--- initiatorpool/worker_dialer_test.go | 30 +- initiatorpool/worker_forwarder_test.go | 12 +- initiatorpool/worker_keepalive_test.go | 19 +- initiatorpool/worker_session_inner_test.go | 159 +++----- initiatorpool/worker_session_test.go | 443 ++++++++++++--------- initiatorpool/worker_start_test.go | 48 +-- responderpool/pool.go | 1 + responderpool/worker.go | 19 +- 10 files changed, 458 insertions(+), 429 deletions(-) diff --git a/initiatorpool/pool.go b/initiatorpool/pool.go index cec20aa..a86d63a 100644 --- a/initiatorpool/pool.go +++ b/initiatorpool/pool.go @@ -16,7 +16,7 @@ type Peer struct { worker Worker } -type WorkerContext struct { +type PoolPlugin struct { Inbox chan<- InboxMessage Events chan<- PoolEvent Errors chan<- error @@ -181,7 +181,7 @@ func (p *Pool) Connect(id string) error { if p.logger != nil { logger = p.logger.With("id", id) } - ctx := WorkerContext{ + ctx := PoolPlugin{ Inbox: p.inbox, Events: p.events, Errors: p.errors, diff --git a/initiatorpool/worker.go b/initiatorpool/worker.go index f5db150..9d6b69f 100644 --- a/initiatorpool/worker.go +++ b/initiatorpool/worker.go @@ -12,7 +12,7 @@ import ( // Worker type Worker interface { - Start(wctx WorkerContext, wg *sync.WaitGroup) + Start(pool PoolPlugin, wg *sync.WaitGroup) Stop() Send(data []byte) error } @@ -48,9 +48,9 @@ func NewWorker( return nil, err } - wctx, cancel := context.WithCancel(ctx) + pool, cancel := context.WithCancel(ctx) w := &DefaultWorker{ - Ctx: wctx, + Ctx: pool, Cancel: cancel, Id: id, Config: config, @@ -61,7 +61,7 @@ func NewWorker( } func (w *DefaultWorker) Start( - wctx WorkerContext, + pool PoolPlugin, wg *sync.WaitGroup, ) { dial := make(chan struct{}, 1) @@ -72,10 +72,34 @@ func (w *DefaultWorker) Start( var owg sync.WaitGroup owg.Add(4) - go func() { defer owg.Done(); w.RunDialer(w.Ctx, wctx, dial, newConn) }() - go func() { defer owg.Done(); w.RunKeepalive(w.Ctx, keepalive) }() - go func() { defer owg.Done(); w.RunForwarder(w.Ctx, messages, wctx.Inbox, w.Config.MaxQueueSize) }() - go func() { defer owg.Done(); w.RunSession(w.Ctx, wctx, messages, dial, keepalive, newConn) }() + go func() { + defer owg.Done() + RunDialer(w.Id, w.Ctx, pool, dial, newConn) + }() + + go func() { + defer owg.Done() + RunKeepalive(w.Ctx, w.Heartbeat, keepalive, w.Config.KeepaliveTimeout) + }() + + go func() { + defer owg.Done() + RunForwarder(w.Id, w.Ctx, messages, pool.Inbox, w.Config.MaxQueueSize) + }() + + go func() { + defer owg.Done() + session := &Session{ + id: w.Id, + connPtr: &w.Conn, + messages: messages, + heartbeat: w.Heartbeat, + dial: dial, + keepalive: keepalive, + newConn: newConn, + } + session.Start(w.Ctx, pool) + }() owg.Wait() wg.Done() @@ -106,20 +130,26 @@ func (w *DefaultWorker) Send(data []byte) error { return nil } -func (w *DefaultWorker) RunSession( +type Session struct { + id string + connPtr *atomic.Pointer[transport.Connection] + + messages chan<- ReceivedMessage + heartbeat chan<- struct{} + dial chan<- struct{} + + keepalive <-chan struct{} + newConn <-chan *transport.Connection +} + +func (s *Session) Start( ctx context.Context, - wctx WorkerContext, - - messages chan<- ReceivedMessage, - dial chan<- struct{}, - - keepalive <-chan struct{}, - newConn <-chan *transport.Connection, + pool PoolPlugin, ) { for { // request new connection select { - case dial <- struct{}{}: + case s.dial <- struct{}{}: default: } @@ -130,45 +160,42 @@ func (w *DefaultWorker) RunSession( select { case <-ctx.Done(): return - case <-keepalive: + case <-s.keepalive: select { - case dial <- struct{}{}: + case s.dial <- struct{}{}: default: } - case conn = <-newConn: + case conn = <-s.newConn: break preConn } } // set up new connection - w.Conn.Store(conn) - wctx.Events <- PoolEvent{ID: w.Id, Kind: EventConnected} + s.connPtr.Store(conn) + pool.Events <- PoolEvent{ID: s.id, Kind: EventConnected} - // set up session - sessionDone := make(chan struct{}) - var once sync.Once - onStop := func() { - once.Do(func() { close(sessionDone) }) - } + // set up session context + sctx, scancel := context.WithCancel(ctx) + onStop := func() { scancel() } // start session var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() - w.RunReader(conn, messages, sessionDone, onStop) + RunReader(sctx, onStop, conn, s.messages, s.heartbeat) }() go func() { defer wg.Done() - w.RunStopMonitor(ctx, conn, keepalive, sessionDone, onStop) + RunStopMonitor(sctx, onStop, conn, s.keepalive) }() // complete session wg.Wait() // tear down connection - w.Conn.Store(nil) - wctx.Events <- PoolEvent{ID: w.Id, Kind: EventDisconnected} + s.connPtr.Store(nil) + pool.Events <- PoolEvent{ID: s.id, Kind: EventDisconnected} // exit if worker is shutting down select { @@ -182,11 +209,12 @@ func (w *DefaultWorker) RunSession( } -func (w *DefaultWorker) RunReader( +func RunReader( + ctx context.Context, + onStop func(), conn *transport.Connection, messages chan<- ReceivedMessage, - sessionDone <-chan struct{}, - onStop func(), + heartbeat chan<- struct{}, ) { defer func() { conn.Close() @@ -195,7 +223,7 @@ func (w *DefaultWorker) RunReader( for { select { - case <-sessionDone: + case <-ctx.Done(): return case data, ok := <-conn.Incoming(): if !ok { @@ -204,27 +232,23 @@ func (w *DefaultWorker) RunReader( } // send message forward - messages <- ReceivedMessage{ - data: data, - receivedAt: time.Now(), - } + messages <- ReceivedMessage{data: data, receivedAt: time.Now()} // send heartbeat select { - case w.Heartbeat <- struct{}{}: - case <-sessionDone: + case heartbeat <- struct{}{}: + case <-ctx.Done(): return } } } } -func (w *DefaultWorker) RunStopMonitor( +func RunStopMonitor( ctx context.Context, + onStop func(), conn *transport.Connection, keepalive <-chan struct{}, - sessionDone <-chan struct{}, - onStop func(), ) { defer func() { conn.Close() @@ -234,11 +258,11 @@ func (w *DefaultWorker) RunStopMonitor( select { case <-ctx.Done(): case <-keepalive: - case <-sessionDone: } } -func (w *DefaultWorker) RunForwarder( +func RunForwarder( + id string, ctx context.Context, messages <-chan ReceivedMessage, inbox chan<- InboxMessage, @@ -271,7 +295,7 @@ func (w *DefaultWorker) RunForwarder( queue.PushBack(msg) // send next message to inbox case out <- InboxMessage{ - ID: w.Id, + ID: id, Data: next.data, ReceivedAt: next.receivedAt, }: @@ -281,12 +305,14 @@ func (w *DefaultWorker) RunForwarder( } } -func (w *DefaultWorker) RunKeepalive( +func RunKeepalive( ctx context.Context, + heartbeat <-chan struct{}, keepalive chan<- struct{}, + timeout time.Duration, ) { // disable keepalive timeout if not configured - if w.Config.KeepaliveTimeout <= 0 { + if timeout <= 0 { // wait for cancel and exit select { case <-ctx.Done(): @@ -294,14 +320,14 @@ func (w *DefaultWorker) RunKeepalive( return } - timer := time.NewTimer(w.Config.KeepaliveTimeout) + timer := time.NewTimer(timeout) defer timer.Stop() for { select { case <-ctx.Done(): return - case <-w.Heartbeat: + case <-heartbeat: // drain the timer channel and reset if !timer.Stop() { select { @@ -309,7 +335,7 @@ func (w *DefaultWorker) RunKeepalive( default: } } - timer.Reset(w.Config.KeepaliveTimeout) + timer.Reset(timeout) // timer completed case <-timer.C: // send keepalive signal, then reset the timer @@ -317,27 +343,29 @@ func (w *DefaultWorker) RunKeepalive( case keepalive <- struct{}{}: default: } - timer.Reset(w.Config.KeepaliveTimeout) + timer.Reset(timeout) } } } -func (w *DefaultWorker) Dial( +func connect( + id string, ctx context.Context, - wctx WorkerContext, + pool PoolPlugin, ) (*transport.Connection, error) { - conn, err := transport.NewConnection(w.Id, wctx.ConnectionConfig, wctx.Logger) + conn, err := transport.NewConnection(id, pool.ConnectionConfig, pool.Logger) if err != nil { return nil, err } - conn.SetDialer(wctx.Dialer) + conn.SetDialer(pool.Dialer) return conn, conn.Connect(ctx) } -func (w *DefaultWorker) RunDialer( +func RunDialer( + id string, ctx context.Context, - wctx WorkerContext, + pool PoolPlugin, dial <-chan struct{}, newConn chan<- *transport.Connection, @@ -360,13 +388,13 @@ func (w *DefaultWorker) RunDialer( }() // dial a new connection - conn, err := w.Dial(ctx, wctx) + conn, err := connect(id, ctx, pool) close(done) // send error if dial failed and continue if err != nil { select { - case wctx.Errors <- err: + case pool.Errors <- err: case <-ctx.Done(): } continue diff --git a/initiatorpool/worker_dialer_test.go b/initiatorpool/worker_dialer_test.go index 6f0480e..c8e22dc 100644 --- a/initiatorpool/worker_dialer_test.go +++ b/initiatorpool/worker_dialer_test.go @@ -16,14 +16,14 @@ import ( func TestRunDialer(t *testing.T) { t.Run("successful dial delivers connection to newConn", func(t *testing.T) { - w := &DefaultWorker{Id: "wss://test"} + url := "wss://test" dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mockSocket := honeybeetest.NewMockSocket() - wctx := WorkerContext{ + pool := PoolPlugin{ Errors: make(chan error, 1), Dialer: &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { @@ -32,7 +32,7 @@ func TestRunDialer(t *testing.T) { }, } - go w.RunDialer(ctx, wctx, dial, newConn) + go RunDialer(url, ctx, pool, dial, newConn) dial <- struct{}{} honeybeetest.Eventually(t, func() bool { @@ -47,7 +47,7 @@ func TestRunDialer(t *testing.T) { t.Run("concurrent dial signals are drained; only one connection produced.", func(t *testing.T) { - w := &DefaultWorker{Id: "wss://test"} + url := "wss://test" dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) ctx, cancel := context.WithCancel(context.Background()) @@ -60,7 +60,7 @@ func TestRunDialer(t *testing.T) { connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry started := make(chan struct{}) startOnce := sync.Once{} - wctx := WorkerContext{ + pool := PoolPlugin{ Errors: make(chan error, 1), Dialer: &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { @@ -73,7 +73,7 @@ func TestRunDialer(t *testing.T) { ConnectionConfig: connConfig, } - go w.RunDialer(ctx, wctx, dial, newConn) + go RunDialer(url, ctx, pool, dial, newConn) dial <- struct{}{} // wait for dial to start blocking on gate @@ -111,7 +111,7 @@ func TestRunDialer(t *testing.T) { }) t.Run("dial failure emits error, succeeds on next signal", func(t *testing.T) { - w := &DefaultWorker{Id: "wss://test"} + url := "wss://test" errors := make(chan error, 1) dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) @@ -122,7 +122,7 @@ func TestRunDialer(t *testing.T) { dialCount := atomic.Int32{} mockSocket := honeybeetest.NewMockSocket() connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry - wctx := WorkerContext{ + pool := PoolPlugin{ Errors: errors, Dialer: &honeybeetest.MockDialer{ DialContextFunc: func( @@ -139,7 +139,7 @@ func TestRunDialer(t *testing.T) { ConnectionConfig: connConfig, } - go w.RunDialer(ctx, wctx, dial, newConn) + go RunDialer(url, ctx, pool, dial, newConn) dial <- struct{}{} honeybeetest.Eventually(t, func() bool { @@ -164,16 +164,16 @@ func TestRunDialer(t *testing.T) { }) t.Run("exits on context cancellation", func(t *testing.T) { - w := &DefaultWorker{Id: "wss://test"} + url := "wss://test" dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) ctx, cancel := context.WithCancel(context.Background()) - wctx := WorkerContext{Errors: make(chan error, 1)} + pool := PoolPlugin{Errors: make(chan error, 1)} done := make(chan struct{}) go func() { - w.RunDialer(ctx, wctx, dial, newConn) + RunDialer(url, ctx, pool, dial, newConn) close(done) }() @@ -190,12 +190,12 @@ func TestRunDialer(t *testing.T) { }) t.Run("context cancelled during in-progress dial exits without delivering connection", func(t *testing.T) { - w := &DefaultWorker{Id: "wss://test"} + url := "wss://test" dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) ctx, cancel := context.WithCancel(context.Background()) - wctx := WorkerContext{ + pool := PoolPlugin{ Errors: make(chan error, 1), ConnectionConfig: &transport.ConnectionConfig{Retry: nil}, Dialer: &honeybeetest.MockDialer{ @@ -211,7 +211,7 @@ func TestRunDialer(t *testing.T) { done := make(chan struct{}) go func() { - w.RunDialer(ctx, wctx, dial, newConn) + RunDialer(url, ctx, pool, dial, newConn) close(done) }() diff --git a/initiatorpool/worker_forwarder_test.go b/initiatorpool/worker_forwarder_test.go index fd752db..2a8ab48 100644 --- a/initiatorpool/worker_forwarder_test.go +++ b/initiatorpool/worker_forwarder_test.go @@ -10,13 +10,13 @@ import ( func TestRunForwarder(t *testing.T) { t.Run("message passes through to inbox", func(t *testing.T) { + id := "wss://test" messages := make(chan ReceivedMessage, 1) inbox := make(chan InboxMessage, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - w := &DefaultWorker{Id: "wss://test"} - go w.RunForwarder(ctx, messages, inbox, 0) + go RunForwarder(id, ctx, messages, inbox, 0) messages <- ReceivedMessage{data: []byte("hello"), receivedAt: time.Now()} @@ -31,6 +31,7 @@ func TestRunForwarder(t *testing.T) { }) t.Run("oldest message dropped when queue is full", func(t *testing.T) { + id := "wss://test" messages := make(chan ReceivedMessage, 1) inbox := make(chan InboxMessage, 1) ctx, cancel := context.WithCancel(context.Background()) @@ -47,8 +48,7 @@ func TestRunForwarder(t *testing.T) { } }() - w := &DefaultWorker{Id: "wss://test"} - go w.RunForwarder(ctx, messages, gatedInbox, 2) + go RunForwarder(id, ctx, messages, gatedInbox, 2) // send three messages while the gated inbox is blocked messages <- ReceivedMessage{data: []byte("first"), receivedAt: time.Now()} @@ -78,15 +78,15 @@ func TestRunForwarder(t *testing.T) { }) t.Run("exits on context cancellation", func(t *testing.T) { + id := "wss://test" messages := make(chan ReceivedMessage, 1) inbox := make(chan InboxMessage, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - w := &DefaultWorker{Id: "wss://test"} done := make(chan struct{}) go func() { - w.RunForwarder(ctx, messages, inbox, 0) + RunForwarder(id, ctx, messages, inbox, 0) close(done) }() diff --git a/initiatorpool/worker_keepalive_test.go b/initiatorpool/worker_keepalive_test.go index 14b5b36..f633e04 100644 --- a/initiatorpool/worker_keepalive_test.go +++ b/initiatorpool/worker_keepalive_test.go @@ -11,19 +11,16 @@ func TestRunKeepalive(t *testing.T) { t.Run("heartbeat resets timer, no keepalive signal fired", func(t *testing.T) { heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) + timeout := 200 * time.Millisecond ctx, cancel := context.WithCancel(context.Background()) defer cancel() - w := &DefaultWorker{ - Config: &WorkerConfig{KeepaliveTimeout: 200 * time.Millisecond}, - Heartbeat: heartbeat, - } - go w.RunKeepalive(ctx, keepalive) + go RunKeepalive(ctx, heartbeat, keepalive, timeout) // send heartbeats faster than the timeout for i := 0; i < 5; i++ { time.Sleep(20 * time.Millisecond) - w.Heartbeat <- struct{}{} + heartbeat <- struct{}{} } // because the timer is being reset, keepalive signal should not be sent @@ -38,12 +35,13 @@ func TestRunKeepalive(t *testing.T) { }) t.Run("keepalive timeout fires signal", func(t *testing.T) { + heartbeat := make(chan struct{}, 1) keepalive := make(chan struct{}, 1) + timeout := 20 * time.Millisecond ctx, cancel := context.WithCancel(context.Background()) defer cancel() - w := &DefaultWorker{Config: &WorkerConfig{KeepaliveTimeout: 20 * time.Millisecond}} - go w.RunKeepalive(ctx, keepalive) + go RunKeepalive(ctx, heartbeat, keepalive, timeout) // send no heartbeats, wait for timeout and keepalive signal honeybeetest.Eventually(t, func() bool { @@ -57,13 +55,14 @@ func TestRunKeepalive(t *testing.T) { }) t.Run("exits on context cancellation", func(t *testing.T) { + heartbeat := make(chan struct{}, 1) keepalive := make(chan struct{}, 1) + timeout := 20 * time.Second ctx, cancel := context.WithCancel(context.Background()) - w := &DefaultWorker{Config: &WorkerConfig{KeepaliveTimeout: 20 * time.Second}} done := make(chan struct{}) go func() { - w.RunKeepalive(ctx, keepalive) + RunKeepalive(ctx, heartbeat, keepalive, timeout) close(done) }() diff --git a/initiatorpool/worker_session_inner_test.go b/initiatorpool/worker_session_inner_test.go index bcb6f01..b4696d9 100644 --- a/initiatorpool/worker_session_inner_test.go +++ b/initiatorpool/worker_session_inner_test.go @@ -20,22 +20,14 @@ func TestRunReader(t *testing.T) { messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}) - sessionDone := make(chan struct{}) - onStop := func() {} ctx, cancel := context.WithCancel(context.Background()) defer cancel() - w := &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Heartbeat: heartbeat, - } go func() { for range heartbeat { } }() - go w.RunReader(conn, messages, sessionDone, onStop) + go RunReader(ctx, cancel, conn, messages, heartbeat) before := time.Now() incomingData <- honeybeetest.MockIncomingData{ @@ -59,18 +51,9 @@ func TestRunReader(t *testing.T) { messages := make(chan ReceivedMessage, 10) heartbeat := make(chan struct{}) - sessionDone := make(chan struct{}) - onStop := func() {} ctx, cancel := context.WithCancel(context.Background()) defer cancel() - w := &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Heartbeat: heartbeat, - } - received := atomic.Int32{} go func() { for range heartbeat { @@ -81,7 +64,7 @@ func TestRunReader(t *testing.T) { for range messages { } }() - go w.RunReader(conn, messages, sessionDone, onStop) + go RunReader(ctx, cancel, conn, messages, heartbeat) const count = 3 for i := 0; i < count; i++ { @@ -101,16 +84,9 @@ func TestRunReader(t *testing.T) { messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}) - sessionDone := make(chan struct{}) - onStopCalled := atomic.Bool{} - onStop := func() { onStopCalled.Store(true) } - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - w := &DefaultWorker{ - Ctx: ctx, - Id: "wss://test", - Heartbeat: heartbeat, - } go func() { for range heartbeat { } @@ -119,7 +95,7 @@ func TestRunReader(t *testing.T) { for range messages { } }() - go w.RunReader(conn, messages, sessionDone, onStop) + go RunReader(ctx, cancel, conn, messages, heartbeat) // induce connection closure via reader incomingData <- honeybeetest.MockIncomingData{Err: io.EOF} @@ -132,8 +108,13 @@ func TestRunReader(t *testing.T) { }, "expected closed state") honeybeetest.Eventually(t, func() bool { - return onStopCalled.Load() - }, "expected onStop to be called") + select { + case <-ctx.Done(): + return true + default: + return false + } + }, "expected context to cancel") }) t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { @@ -141,66 +122,9 @@ func TestRunReader(t *testing.T) { messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}) - sessionDone := make(chan struct{}) - onStopCalled := atomic.Bool{} - onStop := func() { onStopCalled.Store(true) } - ctx := context.Background() - - w := &DefaultWorker{ - Ctx: ctx, - Id: "wss://test", - Heartbeat: heartbeat, - } - go w.RunReader(conn, messages, sessionDone, onStop) - - close(sessionDone) - - honeybeetest.Eventually(t, func() bool { - return conn.State() == transport.StateClosed - }, "expected closed state") - - honeybeetest.Eventually(t, func() bool { - return onStopCalled.Load() - }, "expected onStop to be called") - }) -} - -func TestRunStopMonitor(t *testing.T) { - t.Run("keepalive signal calls conn.Close and onStop", func(t *testing.T) { - conn, _, _, _ := setupWorkerTestConnection(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - keepalive := make(chan struct{}, 1) - sessionDone := make(chan struct{}) - onStopCalled := atomic.Bool{} - onStop := func() { onStopCalled.Store(true) } - - w := &DefaultWorker{Id: "wss://test"} - go w.RunStopMonitor(ctx, conn, keepalive, sessionDone, onStop) - - keepalive <- struct{}{} - - honeybeetest.Eventually(t, func() bool { - return conn.State() == transport.StateClosed - }, "expected closed state") - - honeybeetest.Eventually(t, func() bool { - return onStopCalled.Load() - }, "expected onStop to be called") - }) - - t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) { - conn, _, _, _ := setupWorkerTestConnection(t) ctx, cancel := context.WithCancel(context.Background()) - keepalive := make(chan struct{}) - sessionDone := make(chan struct{}) - onStopCalled := atomic.Bool{} - onStop := func() { onStopCalled.Store(true) } - - w := &DefaultWorker{Id: "wss://test"} - go w.RunStopMonitor(ctx, conn, keepalive, sessionDone, onStop) + go RunReader(ctx, cancel, conn, messages, heartbeat) cancel() @@ -209,31 +133,64 @@ func TestRunStopMonitor(t *testing.T) { }, "expected closed state") honeybeetest.Eventually(t, func() bool { - return onStopCalled.Load() - }, "expected onStop to be called") + select { + case <-ctx.Done(): + return true + default: + return false + } + }, "expected context to cancel") }) +} - t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { +func TestRunStopMonitor(t *testing.T) { + t.Run("keepalive signal calls conn.Close and cancel", func(t *testing.T) { conn, _, _, _ := setupWorkerTestConnection(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - keepalive := make(chan struct{}) - sessionDone := make(chan struct{}) - onStopCalled := atomic.Bool{} - onStop := func() { onStopCalled.Store(true) } + keepalive := make(chan struct{}, 1) - w := &DefaultWorker{Id: "wss://test"} - go w.RunStopMonitor(ctx, conn, keepalive, sessionDone, onStop) + go RunStopMonitor(ctx, cancel, conn, keepalive) - close(sessionDone) + keepalive <- struct{}{} honeybeetest.Eventually(t, func() bool { return conn.State() == transport.StateClosed }, "expected closed state") honeybeetest.Eventually(t, func() bool { - return onStopCalled.Load() - }, "expected onStop to be called") + select { + case <-ctx.Done(): + return true + default: + return false + } + }, "expected context to cancel") + }) + + t.Run("ctx.Done calls conn.Close and cancel", func(t *testing.T) { + conn, _, _, _ := setupWorkerTestConnection(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + keepalive := make(chan struct{}) + + go RunStopMonitor(ctx, cancel, conn, keepalive) + + cancel() + + honeybeetest.Eventually(t, func() bool { + return conn.State() == transport.StateClosed + }, "expected closed state") + + honeybeetest.Eventually(t, func() bool { + select { + case <-ctx.Done(): + return true + default: + return false + } + }, "expected context to cancel") }) } diff --git a/initiatorpool/worker_session_test.go b/initiatorpool/worker_session_test.go index 776c005..209aabe 100644 --- a/initiatorpool/worker_session_test.go +++ b/initiatorpool/worker_session_test.go @@ -5,6 +5,7 @@ import ( "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/transport" + "sync/atomic" "testing" ) @@ -20,145 +21,180 @@ func drainEvent(t *testing.T, events <-chan PoolEvent, kind PoolEventKind) { }, fmt.Sprintf("expected %s event", kind)) } -func TestRunSessionDial(t *testing.T) { - setup := func(t *testing.T) ( - w *DefaultWorker, - ctx context.Context, - cancel context.CancelFunc, - dial chan struct{}, - keepalive chan struct{}, - newConn chan *transport.Connection, - ) { - t.Helper() - ctx, cancel = context.WithCancel(context.Background()) - w = &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Config: GetDefaultWorkerConfig(), - Heartbeat: make(chan struct{}), +type testVars struct { + id string + + dial chan struct{} + keepalive chan struct{} + heartbeat chan struct{} + newConn chan *transport.Connection + messages chan ReceivedMessage + + conn *transport.Connection + mockSocket *honeybeetest.MockSocket + incomingData chan honeybeetest.MockIncomingData + outgoingData chan honeybeetest.MockOutgoingData + + connPtr *atomic.Pointer[transport.Connection] +} + +func setup(t *testing.T) ( + ctx context.Context, + cancel context.CancelFunc, + vars testVars, +) { + t.Helper() + ctx, cancel = context.WithCancel(context.Background()) + conn, mockSocket, incomingData, outgoingData := setupWorkerTestConnection(t) + vars = testVars{ + id: "wss://test", + dial: make(chan struct{}, 1), + keepalive: make(chan struct{}, 1), + heartbeat: make(chan struct{}, 1), + newConn: make(chan *transport.Connection, 1), + messages: make(chan ReceivedMessage, 256), + conn: conn, + mockSocket: mockSocket, + incomingData: incomingData, + outgoingData: outgoingData, + connPtr: &atomic.Pointer[transport.Connection]{}, + } + return +} + +func expectDial(t *testing.T, dial <-chan struct{}) { + t.Helper() + honeybeetest.Eventually(t, func() bool { + select { + case <-dial: + return true + default: + return false } - dial = make(chan struct{}, 1) - keepalive = make(chan struct{}, 1) - newConn = make(chan *transport.Connection, 1) - return - } - - expectDial := func(t *testing.T, dial <-chan struct{}) { - t.Helper() - honeybeetest.Eventually(t, func() bool { - select { - case <-dial: - return true - default: - return false - } - }, "expected dial signal") - } + }, "expected dial signal") +} +func TestRunSessionDial(t *testing.T) { t.Run("fires dial immediately on entry", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn := setup(t) + ctx, cancel, v := setup(t) defer cancel() - messages := make(chan ReceivedMessage, 1) - wctx := WorkerContext{Events: make(chan PoolEvent, 10)} + pool := PoolPlugin{Events: make(chan PoolEvent, 10)} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + go session.Start(ctx, pool) - expectDial(t, dial) + expectDial(t, v.dial) }) t.Run("keepalive fires dial", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn := setup(t) + ctx, cancel, v := setup(t) defer cancel() - messages := make(chan ReceivedMessage, 1) - wctx := WorkerContext{Events: make(chan PoolEvent, 10)} + pool := PoolPlugin{Events: make(chan PoolEvent, 10)} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + go session.Start(ctx, pool) // drain initial dial - expectDial(t, dial) + expectDial(t, v.dial) - keepalive <- struct{}{} - expectDial(t, dial) + v.keepalive <- struct{}{} + expectDial(t, v.dial) }) t.Run("multiple keepalive signals each fire dial", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn := setup(t) + ctx, cancel, v := setup(t) defer cancel() - messages := make(chan ReceivedMessage, 1) - wctx := WorkerContext{Events: make(chan PoolEvent, 10)} + pool := PoolPlugin{Events: make(chan PoolEvent, 10)} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + go session.Start(ctx, pool) // drain initial dial - expectDial(t, dial) + expectDial(t, v.dial) for i := 0; i < 3; i++ { - keepalive <- struct{}{} - expectDial(t, dial) + v.keepalive <- struct{}{} + expectDial(t, v.dial) } }) } func TestRunSessionConnect(t *testing.T) { - setup := func(t *testing.T) ( - w *DefaultWorker, - ctx context.Context, - cancel context.CancelFunc, - dial chan struct{}, - keepalive chan struct{}, - newConn chan *transport.Connection, - messages chan ReceivedMessage, - ) { - t.Helper() - ctx, cancel = context.WithCancel(context.Background()) - w = &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Config: GetDefaultWorkerConfig(), - Heartbeat: make(chan struct{}), - } - dial = make(chan struct{}, 1) - keepalive = make(chan struct{}, 1) - newConn = make(chan *transport.Connection, 1) - messages = make(chan ReceivedMessage, 256) - return - } - - t.Run("w.conn set after newConn received", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages := setup(t) - wctx := WorkerContext{Events: make(chan PoolEvent, 10)} + t.Run("connection pointer set after newConn received", func(t *testing.T) { + ctx, cancel, v := setup(t) defer cancel() - conn, _, _, _ := setupWorkerTestConnection(t) - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + pool := PoolPlugin{Events: make(chan PoolEvent, 10)} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } - newConn <- conn + go session.Start(ctx, pool) + + v.newConn <- v.conn honeybeetest.Eventually(t, func() bool { - return w.Conn.Load() != nil - }, "expected w.conn to be set") + return v.connPtr.Load() != nil + }, "expected connection pointer to be set") }) t.Run("EventConnected emitted", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages := setup(t) - events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} + ctx, cancel, v := setup(t) defer cancel() - conn, _, _, _ := setupWorkerTestConnection(t) - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + events := make(chan PoolEvent, 10) + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } - newConn <- conn + go session.Start(ctx, pool) + + v.newConn <- v.conn honeybeetest.Eventually(t, func() bool { select { case event := <-events: - return event.ID == w.Id && event.Kind == EventConnected + return event.ID == v.id && event.Kind == EventConnected default: return false } @@ -167,86 +203,91 @@ func TestRunSessionConnect(t *testing.T) { } func TestRunSessionDisconnect(t *testing.T) { - setup := func(t *testing.T) ( - w *DefaultWorker, - ctx context.Context, - cancel context.CancelFunc, - dial chan struct{}, - keepalive chan struct{}, - newConn chan *transport.Connection, - messages chan ReceivedMessage, - conn *transport.Connection, - incomingData chan honeybeetest.MockIncomingData, - ) { - t.Helper() - ctx, cancel = context.WithCancel(context.Background()) - w = &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Config: GetDefaultWorkerConfig(), - Heartbeat: make(chan struct{}), - } - dial = make(chan struct{}, 1) - keepalive = make(chan struct{}, 1) - newConn = make(chan *transport.Connection, 1) - messages = make(chan ReceivedMessage, 256) - conn, _, incomingData, _ = setupWorkerTestConnection(t) - return - } - t.Run("EventDisconnected emitted on connection close", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages, conn, incomingData := setup(t) - events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} + ctx, cancel, v := setup(t) defer cancel() - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) - newConn <- conn + events := make(chan PoolEvent, 10) + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } + + go session.Start(ctx, pool) + + v.newConn <- v.conn drainEvent(t, events, EventConnected) - close(incomingData) + close(v.incomingData) drainEvent(t, events, EventDisconnected) }) - t.Run("w.conn cleared after disconnect", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages, conn, incomingData := setup(t) - events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} + t.Run("connection pointer cleared after disconnect", func(t *testing.T) { + ctx, cancel, v := setup(t) defer cancel() - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) - newConn <- conn + events := make(chan PoolEvent, 10) + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } + + go session.Start(ctx, pool) + + v.newConn <- v.conn drainEvent(t, events, EventConnected) - close(incomingData) + close(v.incomingData) drainEvent(t, events, EventDisconnected) honeybeetest.Eventually(t, func() bool { - return w.Conn.Load() == nil - }, "expected w.conn to be cleared") + return v.connPtr.Load() == nil + }, "expected connection pointer to be nil") }) t.Run("dial fires again after disconnect", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages, conn, incomingData := setup(t) - events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} + ctx, cancel, v := setup(t) defer cancel() - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) - newConn <- conn + events := make(chan PoolEvent, 10) + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } + + go session.Start(ctx, pool) + + v.newConn <- v.conn drainEvent(t, events, EventConnected) // drain the initial dial signal before disconnecting - <-dial + <-v.dial - close(incomingData) + close(v.incomingData) drainEvent(t, events, EventDisconnected) honeybeetest.Eventually(t, func() bool { select { - case <-dial: + case <-v.dial: return true default: return false @@ -255,60 +296,54 @@ func TestRunSessionDisconnect(t *testing.T) { }) t.Run("second connection cycle emits EventConnected", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages, conn, incomingData := setup(t) - events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} + ctx, cancel, v := setup(t) defer cancel() - go w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) - newConn <- conn + events := make(chan PoolEvent, 10) + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } + + go session.Start(ctx, pool) + + v.newConn <- v.conn drainEvent(t, events, EventConnected) - close(incomingData) + close(v.incomingData) drainEvent(t, events, EventDisconnected) conn2, _, _, _ := setupWorkerTestConnection(t) - newConn <- conn2 - + v.newConn <- conn2 drainEvent(t, events, EventConnected) }) } func TestRunSessionCancellation(t *testing.T) { - setup := func(t *testing.T) ( - w *DefaultWorker, - ctx context.Context, - cancel context.CancelFunc, - dial chan struct{}, - keepalive chan struct{}, - newConn chan *transport.Connection, - messages chan ReceivedMessage, - ) { - t.Helper() - ctx, cancel = context.WithCancel(context.Background()) - w = &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Config: GetDefaultWorkerConfig(), - Heartbeat: make(chan struct{}), - } - dial = make(chan struct{}, 1) - keepalive = make(chan struct{}, 1) - newConn = make(chan *transport.Connection, 1) - messages = make(chan ReceivedMessage, 256) - return - } - t.Run("ctx cancelled pre-connection exits without emitting events", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages := setup(t) + ctx, cancel, v := setup(t) events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } done := make(chan struct{}) go func() { defer close(done) - w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + session.Start(ctx, pool) }() cancel() @@ -333,24 +368,29 @@ func TestRunSessionCancellation(t *testing.T) { }) t.Run("ctx cancelled post-connection emits EventDisconnected", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages := setup(t) + ctx, cancel, v := setup(t) events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} - - conn, _, _, _ := setupWorkerTestConnection(t) + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } done := make(chan struct{}) go func() { defer close(done) - w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + session.Start(ctx, pool) }() - newConn <- conn - + v.newConn <- v.conn drainEvent(t, events, EventConnected) cancel() - drainEvent(t, events, EventDisconnected) honeybeetest.Eventually(t, func() bool { @@ -363,29 +403,34 @@ func TestRunSessionCancellation(t *testing.T) { }, "expected runSession to exit") }) - t.Run("ctx cancelled post-connection clears w.conn", func(t *testing.T) { - w, ctx, cancel, dial, keepalive, newConn, messages := setup(t) + t.Run("ctx cancelled post-connection clears connection pointer", func(t *testing.T) { + ctx, cancel, v := setup(t) events := make(chan PoolEvent, 10) - wctx := WorkerContext{Events: events} - - conn, _, _, _ := setupWorkerTestConnection(t) + pool := PoolPlugin{Events: events} + session := &Session{ + id: v.id, + connPtr: v.connPtr, + messages: v.messages, + heartbeat: v.heartbeat, + dial: v.dial, + keepalive: v.keepalive, + newConn: v.newConn, + } done := make(chan struct{}) go func() { defer close(done) - w.RunSession(ctx, wctx, messages, dial, keepalive, newConn) + session.Start(ctx, pool) }() - newConn <- conn - + v.newConn <- v.conn drainEvent(t, events, EventConnected) cancel() - drainEvent(t, events, EventDisconnected) honeybeetest.Eventually(t, func() bool { - return w.Conn.Load() == nil - }, "expected w.conn to clear") + return v.connPtr.Load() == nil + }, "expected connection pointer to be nil") }) } diff --git a/initiatorpool/worker_start_test.go b/initiatorpool/worker_start_test.go index 7a82f39..6e748f1 100644 --- a/initiatorpool/worker_start_test.go +++ b/initiatorpool/worker_start_test.go @@ -17,13 +17,13 @@ func makeWorkerContext(t *testing.T) ( inbox chan InboxMessage, events chan PoolEvent, errors chan error, - wctx WorkerContext, + pool PoolPlugin, ) { t.Helper() inbox = make(chan InboxMessage, 256) events = make(chan PoolEvent, 10) errors = make(chan error, 10) - wctx = WorkerContext{ + pool = PoolPlugin{ Inbox: inbox, Events: events, Errors: errors, @@ -56,13 +56,13 @@ func TestWorkerStart(t *testing.T) { defer cancel() w := makeWorker(t, ctx, cancel) - _, events, _, wctx := makeWorkerContext(t) + _, events, _, pool := makeWorkerContext(t) mockSocket := honeybeetest.NewMockSocket() - wctx.Dialer = mockDialer(mockSocket) + pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup wg.Add(1) - go w.Start(wctx, &wg) + go w.Start(pool, &wg) honeybeetest.Eventually(t, func() bool { select { @@ -79,13 +79,13 @@ func TestWorkerStart(t *testing.T) { defer cancel() w := makeWorker(t, ctx, cancel) - _, events, _, wctx := makeWorkerContext(t) + _, events, _, pool := makeWorkerContext(t) _, mockSocket, _, outgoingData := setupWorkerTestConnection(t) - wctx.Dialer = mockDialer(mockSocket) + pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup wg.Add(1) - go w.Start(wctx, &wg) + go w.Start(pool, &wg) honeybeetest.Eventually(t, func() bool { select { @@ -114,7 +114,7 @@ func TestWorkerStart(t *testing.T) { defer cancel() w := makeWorker(t, ctx, cancel) - inbox, events, _, wctx := makeWorkerContext(t) + inbox, events, _, pool := makeWorkerContext(t) incomingData := make(chan honeybeetest.MockIncomingData, 10) mockSocket := honeybeetest.NewMockSocket() @@ -131,11 +131,11 @@ func TestWorkerStart(t *testing.T) { } } - wctx.Dialer = mockDialer(mockSocket) + pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup wg.Add(1) - go w.Start(wctx, &wg) + go w.Start(pool, &wg) honeybeetest.Eventually(t, func() bool { select { @@ -166,13 +166,13 @@ func TestWorkerStart(t *testing.T) { defer cancel() w := makeWorker(t, ctx, cancel) - _, events, _, wctx := makeWorkerContext(t) + _, events, _, pool := makeWorkerContext(t) _, mockSocket, incomingData, _ := setupWorkerTestConnection(t) - wctx.Dialer = mockDialer(mockSocket) + pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup wg.Add(1) - go w.Start(wctx, &wg) + go w.Start(pool, &wg) honeybeetest.Eventually(t, func() bool { select { @@ -209,13 +209,13 @@ func TestWorkerStart(t *testing.T) { defer cancel() w := makeWorker(t, ctx, cancel) - _, events, _, wctx := makeWorkerContext(t) + _, events, _, pool := makeWorkerContext(t) mockSocket := honeybeetest.NewMockSocket() - wctx.Dialer = mockDialer(mockSocket) + pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup wg.Add(1) - go w.Start(wctx, &wg) + go w.Start(pool, &wg) honeybeetest.Eventually(t, func() bool { select { @@ -254,13 +254,13 @@ func TestWorkerStart(t *testing.T) { workerCtx, workerCancel := context.WithCancel(parentCtx) w := makeWorker(t, workerCtx, workerCancel) - _, events, _, wctx := makeWorkerContext(t) + _, events, _, pool := makeWorkerContext(t) mockSocket := honeybeetest.NewMockSocket() - wctx.Dialer = mockDialer(mockSocket) + pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup wg.Add(1) - go w.Start(wctx, &wg) + go w.Start(pool, &wg) honeybeetest.Eventually(t, func() bool { select { @@ -292,9 +292,9 @@ func TestWorkerStart(t *testing.T) { defer cancel() w := makeWorker(t, ctx, cancel) - _, _, errors, wctx := makeWorkerContext(t) - wctx.ConnectionConfig = &transport.ConnectionConfig{Retry: nil} - wctx.Dialer = &honeybeetest.MockDialer{ + _, _, errors, pool := makeWorkerContext(t) + pool.ConnectionConfig = &transport.ConnectionConfig{Retry: nil} + pool.Dialer = &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return nil, nil, fmt.Errorf("dial failed") }, @@ -302,7 +302,7 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(wctx, &wg) + go w.Start(pool, &wg) honeybeetest.Eventually(t, func() bool { select { diff --git a/responderpool/pool.go b/responderpool/pool.go index 72faedc..2201002 100644 --- a/responderpool/pool.go +++ b/responderpool/pool.go @@ -9,6 +9,7 @@ type PoolEventKind string const ( EventPeerDisconnected PoolEventKind = "disconnected" EventPeerDropped PoolEventKind = "dropped" + EventPeerInactive PoolEventKind = "inactive" EventPeerEvicted PoolEventKind = "evicted" ) diff --git a/responderpool/worker.go b/responderpool/worker.go index bb8016a..a6d8990 100644 --- a/responderpool/worker.go +++ b/responderpool/worker.go @@ -8,7 +8,7 @@ import ( "time" ) -type onExitFunc func(id string, kind PoolEventKind) +type onEventFunc func(kind PoolEventKind) type ReceivedMessage struct { data []byte @@ -17,11 +17,11 @@ type ReceivedMessage struct { func RunReader( ctx context.Context, - id string, + onPeerClose onEventFunc, + conn *transport.Connection, messages chan<- ReceivedMessage, heartbeat chan<- struct{}, - onPeerClose onExitFunc, ) { for { select { @@ -40,7 +40,7 @@ func RunReader( default: } - onPeerClose(id, kind) + onPeerClose(kind) return } @@ -56,8 +56,8 @@ func RunReader( } func RunForwarder( - ctx context.Context, id string, + ctx context.Context, messages <-chan ReceivedMessage, inbox chan<- InboxMessage, maxQueueSize int, @@ -101,10 +101,9 @@ func RunForwarder( func RunWatchdog( ctx context.Context, - id string, - timeout time.Duration, + onTimeout onEventFunc, heartbeat <-chan struct{}, - onTimeout onExitFunc, + timeout time.Duration, ) { // disable watchdog timeout if not configured if timeout <= 0 { @@ -133,8 +132,8 @@ func RunWatchdog( timer.Reset(timeout) // timer completed case <-timer.C: - // evict inactive peer - onTimeout(id, EventPeerEvicted) + // signal peer is inactive + onTimeout(EventPeerInactive) return } }