From b4c5c897e86c6dbff1025131378a44aa6c1aea4c Mon Sep 17 00:00:00 2001 From: Jay Date: Sat, 18 Apr 2026 17:11:22 -0400 Subject: [PATCH] Injected context cancellation for dial and retry cancellation. --- honeybeetest/mocks.go | 10 ++- initiatorpool/config.go | 3 +- initiatorpool/pool.go | 38 ++++----- initiatorpool/pool_test.go | 28 ++++--- initiatorpool/worker.go | 79 +++++++++--------- initiatorpool/worker_test.go | 157 ++++++++++------------------------- transport/connection.go | 14 ++-- transport/connection_test.go | 25 +++--- transport/logging_test.go | 17 ++-- transport/socket.go | 22 +++-- transport/socket_test.go | 13 ++- types/types.go | 6 +- 12 files changed, 182 insertions(+), 230 deletions(-) diff --git a/honeybeetest/mocks.go b/honeybeetest/mocks.go index 51a92c8..dc7af91 100644 --- a/honeybeetest/mocks.go +++ b/honeybeetest/mocks.go @@ -12,11 +12,15 @@ import ( // Dialer Mocks type MockDialer struct { - DialFunc func(string, http.Header) (types.Socket, *http.Response, error) + DialContextFunc func( + context.Context, string, http.Header, + ) (types.Socket, *http.Response, error) } -func (m *MockDialer) Dial(url string, h http.Header) (types.Socket, *http.Response, error) { - return m.DialFunc(url, h) +func (m *MockDialer) DialContext( + ctx context.Context, url string, h http.Header, +) (types.Socket, *http.Response, error) { + return m.DialContextFunc(ctx, url, h) } // Socket Mocks diff --git a/initiatorpool/config.go b/initiatorpool/config.go index b2ae5a2..2c7f038 100644 --- a/initiatorpool/config.go +++ b/initiatorpool/config.go @@ -1,13 +1,14 @@ package initiatorpool import ( + "context" "git.wisehodl.dev/jay/go-honeybee/transport" "time" ) // Types -type WorkerFactory func(id string, stop <-chan struct{}) (*Worker, error) +type WorkerFactory func(ctx context.Context, id string) (*Worker, error) // Pool Config diff --git a/initiatorpool/pool.go b/initiatorpool/pool.go index d7607d1..faecbc6 100644 --- a/initiatorpool/pool.go +++ b/initiatorpool/pool.go @@ -1,6 +1,7 @@ package initiatorpool import ( + "context" "git.wisehodl.dev/jay/go-honeybee/transport" "git.wisehodl.dev/jay/go-honeybee/types" "log/slog" @@ -13,14 +14,12 @@ import ( type Peer struct { id string worker *Worker - stop chan struct{} } type WorkerContext struct { Inbox chan<- InboxMessage Events chan<- PoolEvent Errors chan<- error - PoolDone <-chan struct{} Logger *slog.Logger Dialer types.Dialer ConnectionConfig *transport.ConnectionConfig @@ -47,11 +46,13 @@ type PoolEvent struct { // Pool type Pool struct { + ctx context.Context + cancel context.CancelFunc + peers map[string]*Peer inbox chan InboxMessage events chan PoolEvent errors chan error - done chan struct{} dialer types.Dialer config *PoolConfig @@ -62,7 +63,8 @@ type Pool struct { closed bool } -func NewPool(config *PoolConfig, logger *slog.Logger) (*Pool, error) { +func NewPool(ctx context.Context, config *PoolConfig, logger *slog.Logger, +) (*Pool, error) { if config == nil { config = GetDefaultPoolConfig() } @@ -71,8 +73,9 @@ func NewPool(config *PoolConfig, logger *slog.Logger) (*Pool, error) { // The factory function should be non-blocking or else Connect() may cause // deadlocks. if config.WorkerFactory == nil { - config.WorkerFactory = func(id string, stop <-chan struct{}) (*Worker, error) { - return NewWorker(id, stop, config.WorkerConfig) + config.WorkerFactory = func( + ctx context.Context, id string) (*Worker, error) { + return NewWorker(ctx, id, config.WorkerConfig) } } @@ -80,12 +83,15 @@ func NewPool(config *PoolConfig, logger *slog.Logger) (*Pool, error) { return nil, err } + pctx, cancel := context.WithCancel(ctx) + p := &Pool{ + ctx: pctx, + cancel: cancel, peers: make(map[string]*Peer), inbox: make(chan InboxMessage, 256), events: make(chan PoolEvent, 10), errors: make(chan error, 10), - done: make(chan struct{}), dialer: transport.NewDialer(), config: config, logger: logger, @@ -125,17 +131,12 @@ func (p *Pool) Close() { } p.closed = true - close(p.done) + p.cancel() - peers := p.peers p.peers = make(map[string]*Peer) p.mu.Unlock() - for _, p := range peers { - close(p.stop) - } - go func() { p.wg.Wait() close(p.inbox) @@ -162,13 +163,9 @@ func (p *Pool) Connect(id string) error { return NewPoolError("connection already exists") } - // Create new worker - stop := make(chan struct{}) - // The worker factory must be non-blocking to avoid deadlocks - worker, err := p.config.WorkerFactory(id, stop) + worker, err := p.config.WorkerFactory(p.ctx, id) if err != nil { - close(stop) return err } @@ -180,7 +177,6 @@ func (p *Pool) Connect(id string) error { Inbox: p.inbox, Events: p.events, Errors: p.errors, - PoolDone: p.done, Logger: logger, Dialer: p.dialer, ConnectionConfig: p.config.ConnectionConfig, @@ -189,7 +185,7 @@ func (p *Pool) Connect(id string) error { p.wg.Add(1) go worker.Start(ctx, &p.wg) - p.peers[id] = &Peer{id: id, worker: worker, stop: stop} + p.peers[id] = &Peer{id: id, worker: worker} return nil } @@ -214,7 +210,7 @@ func (p *Pool) Remove(id string) error { delete(p.peers, id) p.mu.Unlock() - close(peer.stop) + peer.worker.Stop() return nil } diff --git a/initiatorpool/pool_test.go b/initiatorpool/pool_test.go index 7f5aac1..ccc62fa 100644 --- a/initiatorpool/pool_test.go +++ b/initiatorpool/pool_test.go @@ -1,6 +1,7 @@ package initiatorpool import ( + "context" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/transport" @@ -17,12 +18,12 @@ func _TestPoolConnect(t *testing.T) { t.Run("successfully adds connection", func(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewPool(nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -48,12 +49,12 @@ func _TestPoolConnect(t *testing.T) { t.Run("does not add duplicate", func(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewPool(nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -74,6 +75,7 @@ func _TestPoolConnect(t *testing.T) { t.Run("fails to add connection", func(t *testing.T) { pool, err := NewPool( + context.Background(), &PoolConfig{ ConnectionConfig: &transport.ConnectionConfig{ Retry: &transport.RetryConfig{ @@ -84,7 +86,7 @@ func _TestPoolConnect(t *testing.T) { }, nil) assert.NoError(t, err) pool.dialer = &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return nil, nil, fmt.Errorf("dial failed") }, } @@ -111,12 +113,12 @@ func _TestPoolRemove(t *testing.T) { t.Run("removes known url", func(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewPool(nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -139,12 +141,12 @@ func _TestPoolRemove(t *testing.T) { t.Run("unknown url returns error", func(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewPool(nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -156,12 +158,12 @@ func _TestPoolRemove(t *testing.T) { t.Run("closed pool returns error", func(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewPool(nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -184,12 +186,12 @@ func _TestPoolSend(t *testing.T) { return nil } mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewPool(nil, nil) + pool, err := NewPool(context.Background(), nil, nil) assert.NoError(t, err) pool.dialer = mockDialer diff --git a/initiatorpool/worker.go b/initiatorpool/worker.go index 81bbecf..ed495df 100644 --- a/initiatorpool/worker.go +++ b/initiatorpool/worker.go @@ -2,6 +2,7 @@ package initiatorpool import ( "container/list" + "context" "git.wisehodl.dev/jay/go-honeybee/transport" "sync" "time" @@ -15,15 +16,16 @@ type receivedMessage struct { } type Worker struct { + ctx context.Context + cancel context.CancelFunc id string - stop <-chan struct{} config *WorkerConfig outbound chan []byte } func NewWorker( + ctx context.Context, id string, - stop <-chan struct{}, config *WorkerConfig, ) (*Worker, error) { @@ -36,9 +38,11 @@ func NewWorker( return nil, err } + wctx, cancel := context.WithCancel(ctx) w := &Worker{ + ctx: wctx, + cancel: cancel, id: id, - stop: stop, outbound: make(chan []byte, 64), config: config, } @@ -50,7 +54,7 @@ func (w *Worker) Send(data []byte) error { select { case w.outbound <- data: return nil - case <-w.stop: + case <-w.ctx.Done(): return NewWorkerError(w.id, "worker is stopped") default: return NewWorkerError(w.id, "outbound queue full") @@ -63,7 +67,14 @@ func (w *Worker) Start( ) { } +func (w *Worker) Stop() { + w.cancel() +} + func (w *Worker) runSession( + ctx context.Context, + wctx WorkerContext, + messages chan<- receivedMessage, heartbeat chan<- struct{}, dial chan<- struct{}, @@ -71,11 +82,6 @@ func (w *Worker) runSession( keepalive <-chan struct{}, outbound <-chan []byte, newConn <-chan *transport.Connection, - - ctx WorkerContext, - - workerStop <-chan struct{}, - poolDone <-chan struct{}, ) { } @@ -98,20 +104,18 @@ func (w *Worker) runWriter( } func (w *Worker) runStopMonitor( + ctx context.Context, conn *transport.Connection, keepalive <-chan struct{}, - workerStop <-chan struct{}, - poolDone <-chan struct{}, sessionDone <-chan struct{}, onStop func(), ) { } func (w *Worker) runForwarder( + ctx context.Context, messages <-chan receivedMessage, inbox chan<- InboxMessage, - stop <-chan struct{}, - poolDone <-chan struct{}, maxQueueSize int, ) { queue := list.New() @@ -129,9 +133,7 @@ func (w *Worker) runForwarder( } select { - case <-stop: - return - case <-poolDone: + case <-ctx.Done(): return case msg := <-messages: // limit queue size if maximum is configured @@ -154,17 +156,15 @@ func (w *Worker) runForwarder( } func (w *Worker) runKeepalive( + ctx context.Context, heartbeat <-chan struct{}, keepalive chan<- struct{}, - stop <-chan struct{}, - poolDone <-chan struct{}, ) { // disable keepalive timeout if not configured if w.config.KeepaliveTimeout <= 0 { - // wait for stop signal and exit + // wait for cancel and exit select { - case <-stop: - case <-poolDone: + case <-ctx.Done(): } return } @@ -174,9 +174,7 @@ func (w *Worker) runKeepalive( for { select { - case <-stop: - return - case <-poolDone: + case <-ctx.Done(): return case <-heartbeat: // drain the timer channel and reset @@ -199,28 +197,29 @@ func (w *Worker) runKeepalive( } } -func (w *Worker) dial(ctx WorkerContext) (*transport.Connection, error) { - conn, err := transport.NewConnection(w.id, ctx.ConnectionConfig, ctx.Logger) +func (w *Worker) dial( + ctx context.Context, + wctx WorkerContext, +) (*transport.Connection, error) { + conn, err := transport.NewConnection(w.id, wctx.ConnectionConfig, wctx.Logger) if err != nil { return nil, err } - conn.SetDialer(ctx.Dialer) - return conn, conn.Connect() + conn.SetDialer(wctx.Dialer) + return conn, conn.Connect(ctx) } func (w *Worker) runDialer( + ctx context.Context, + wctx WorkerContext, + dial <-chan struct{}, newConn chan<- *transport.Connection, - ctx WorkerContext, - stop <-chan struct{}, - poolDone <-chan struct{}, ) { for { select { - case <-stop: - return - case <-poolDone: + case <-ctx.Done(): return case <-dial: // drain dial signals while connection is being established @@ -236,15 +235,14 @@ func (w *Worker) runDialer( }() // dial a new connection - conn, err := w.dial(ctx) + conn, err := w.dial(ctx, wctx) close(done) // send error if dial failed and continue if err != nil { select { - case ctx.Errors <- err: - case <-stop: - case <-poolDone: + case wctx.Errors <- err: + case <-ctx.Done(): } continue } @@ -252,10 +250,7 @@ func (w *Worker) runDialer( // send the new connection or close and exit select { case newConn <- conn: - case <-stop: - conn.Close() - return - case <-poolDone: + case <-ctx.Done(): conn.Close() return } diff --git a/initiatorpool/worker_test.go b/initiatorpool/worker_test.go index c563394..ad31197 100644 --- a/initiatorpool/worker_test.go +++ b/initiatorpool/worker_test.go @@ -1,6 +1,7 @@ package initiatorpool import ( + "context" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/transport" @@ -18,11 +19,11 @@ func TestRunForwarder(t *testing.T) { t.Run("message passes through to inbox", func(t *testing.T) { messages := make(chan receivedMessage, 1) inbox := make(chan InboxMessage, 1) - stop := make(chan struct{}) - defer close(stop) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() w := &Worker{id: "wss://test"} - go w.runForwarder(messages, inbox, stop, nil, 0) + go w.runForwarder(ctx, messages, inbox, 0) messages <- receivedMessage{data: []byte("hello"), receivedAt: time.Now()} @@ -39,8 +40,8 @@ func TestRunForwarder(t *testing.T) { t.Run("oldest message dropped when queue is full", func(t *testing.T) { messages := make(chan receivedMessage, 1) inbox := make(chan InboxMessage, 1) - stop := make(chan struct{}) - defer close(stop) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() gate := make(chan struct{}) gatedInbox := make(chan InboxMessage) @@ -54,7 +55,7 @@ func TestRunForwarder(t *testing.T) { }() w := &Worker{id: "wss://test"} - go w.runForwarder(messages, gatedInbox, stop, nil, 2) + go w.runForwarder(ctx, messages, gatedInbox, 2) // send three messages while the gated inbox is blocked messages <- receivedMessage{data: []byte("first"), receivedAt: time.Now()} @@ -83,42 +84,20 @@ func TestRunForwarder(t *testing.T) { }) - t.Run("exits on stop", func(t *testing.T) { + t.Run("exits on context cancellation", func(t *testing.T) { messages := make(chan receivedMessage, 1) inbox := make(chan InboxMessage, 1) - stop := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() w := &Worker{id: "wss://test"} done := make(chan struct{}) go func() { - w.runForwarder(messages, inbox, stop, nil, 0) + w.runForwarder(ctx, messages, inbox, 0) close(done) }() - close(stop) - assert.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) - }) - - t.Run("exits on pool done", func(t *testing.T) { - messages := make(chan receivedMessage, 1) - inbox := make(chan InboxMessage, 1) - poolDone := make(chan struct{}) - - w := &Worker{id: "wss://test"} - done := make(chan struct{}) - go func() { - w.runForwarder(messages, inbox, nil, poolDone, 0) - close(done) - }() - - close(poolDone) + cancel() assert.Eventually(t, func() bool { select { case <-done: @@ -134,11 +113,11 @@ func TestRunKeepalive(t *testing.T) { t.Run("heartbeat resets timer, no keepalive signal fired", func(t *testing.T) { heartbeat := make(chan struct{}, 3) keepalive := make(chan struct{}, 1) - stop := make(chan struct{}) - defer close(stop) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 100 * time.Millisecond}} - go w.runKeepalive(heartbeat, keepalive, stop, nil) + go w.runKeepalive(ctx, heartbeat, keepalive) // send heartbeats faster than the timeout for i := 0; i < 5; i++ { @@ -160,11 +139,11 @@ func TestRunKeepalive(t *testing.T) { t.Run("keepalive timeout fires signal", func(t *testing.T) { heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) - stop := make(chan struct{}) - defer close(stop) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Millisecond}} - go w.runKeepalive(heartbeat, keepalive, stop, nil) + go w.runKeepalive(ctx, heartbeat, keepalive) // send no heartbeats, wait for timeout and keepalive signal assert.Eventually(t, func() bool { @@ -177,42 +156,19 @@ func TestRunKeepalive(t *testing.T) { }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) - t.Run("exits on stop", func(t *testing.T) { + t.Run("exits on context cancellation", func(t *testing.T) { heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) - stop := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Second}} done := make(chan struct{}) go func() { - w.runKeepalive(heartbeat, keepalive, stop, nil) + w.runKeepalive(ctx, heartbeat, keepalive) close(done) }() - close(stop) - assert.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) - }) - - t.Run("exits on stop", func(t *testing.T) { - heartbeat := make(chan struct{}) - keepalive := make(chan struct{}, 1) - poolDone := make(chan struct{}) - - w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Second}} - done := make(chan struct{}) - go func() { - w.runKeepalive(heartbeat, keepalive, nil, poolDone) - close(done) - }() - - close(poolDone) + cancel() assert.Eventually(t, func() bool { select { case <-done: @@ -229,20 +185,20 @@ func TestRunDialer(t *testing.T) { w := &Worker{id: "wss://test"} dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) - stop := make(chan struct{}) - defer close(stop) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() mockSocket := honeybeetest.NewMockSocket() - ctx := WorkerContext{ + wctx := WorkerContext{ Errors: make(chan error, 1), Dialer: &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, }, } - go w.runDialer(dial, newConn, ctx, stop, nil) + go w.runDialer(ctx, wctx, dial, newConn) dial <- struct{}{} assert.Eventually(t, func() bool { @@ -260,18 +216,18 @@ func TestRunDialer(t *testing.T) { w := &Worker{id: "wss://test"} dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) - stop := make(chan struct{}) - defer close(stop) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() gate := make(chan struct{}) dialCount := atomic.Int32{} mockSocket := honeybeetest.NewMockSocket() connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry - ctx := WorkerContext{ + wctx := WorkerContext{ Errors: make(chan error, 1), Dialer: &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { dialCount.Add(1) <-gate return mockSocket, nil, nil @@ -280,7 +236,7 @@ func TestRunDialer(t *testing.T) { ConnectionConfig: connConfig, } - go w.runDialer(dial, newConn, ctx, stop, nil) + go w.runDialer(ctx, wctx, dial, newConn) dial <- struct{}{} // wait for dial to start blocking on gate @@ -322,17 +278,19 @@ func TestRunDialer(t *testing.T) { errors := make(chan error, 1) dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) - stop := make(chan struct{}) - defer close(stop) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // use atomic counter to fail first dial and pass second dialCount := atomic.Int32{} mockSocket := honeybeetest.NewMockSocket() connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry - ctx := WorkerContext{ + wctx := WorkerContext{ Errors: errors, Dialer: &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func( + context.Context, string, http.Header, + ) (types.Socket, *http.Response, error) { if dialCount.Add(1) == 1 { // fail first return nil, nil, fmt.Errorf("dial failed") @@ -344,7 +302,7 @@ func TestRunDialer(t *testing.T) { ConnectionConfig: connConfig, } - go w.runDialer(dial, newConn, ctx, stop, nil) + go w.runDialer(ctx, wctx, dial, newConn) dial <- struct{}{} assert.Eventually(t, func() bool { @@ -368,21 +326,21 @@ func TestRunDialer(t *testing.T) { }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) - t.Run("exits on stop", func(t *testing.T) { + t.Run("exits on context cancellation", func(t *testing.T) { w := &Worker{id: "wss://test"} dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) - stop := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) - ctx := WorkerContext{Errors: make(chan error, 1)} + wctx := WorkerContext{Errors: make(chan error, 1)} done := make(chan struct{}) go func() { - w.runDialer(dial, newConn, ctx, stop, nil) + w.runDialer(ctx, wctx, dial, newConn) close(done) }() - close(stop) + cancel() assert.Eventually(t, func() bool { select { @@ -393,31 +351,4 @@ func TestRunDialer(t *testing.T) { } }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) - - t.Run("exits on pool done", func(t *testing.T) { - w := &Worker{id: "wss://test"} - dial := make(chan struct{}, 1) - newConn := make(chan *transport.Connection, 1) - poolDone := make(chan struct{}) - - ctx := WorkerContext{Errors: make(chan error, 1)} - - done := make(chan struct{}) - go func() { - w.runDialer(dial, newConn, ctx, nil, poolDone) - close(done) - }() - - close(poolDone) - - assert.Eventually(t, func() bool { - select { - case <-done: - return true - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) - }) - } diff --git a/transport/connection.go b/transport/connection.go index 369a7d4..4867ad2 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -1,6 +1,7 @@ package transport import ( + "context" "errors" "fmt" "log/slog" @@ -64,13 +65,13 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) return nil, err } - parsedURL, err := ParseURL(urlStr) + url, err := ParseURL(urlStr) if err != nil { return nil, err } conn := &Connection{ - url: parsedURL, + url: url, dialer: NewDialer(), socket: nil, config: config, @@ -85,7 +86,9 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) return conn, nil } -func NewConnectionFromSocket(socket types.Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { +func NewConnectionFromSocket( + socket types.Socket, config *ConnectionConfig, logger *slog.Logger, +) (*Connection, error) { if socket == nil { return nil, NewConnectionError("socket cannot be nil") } @@ -121,7 +124,7 @@ func NewConnectionFromSocket(socket types.Socket, config *ConnectionConfig, logg return conn, nil } -func (c *Connection) Connect() error { +func (c *Connection) Connect(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() @@ -140,7 +143,8 @@ func (c *Connection) Connect() error { c.state = StateConnecting retryMgr := NewRetryManager(c.config.Retry) - socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String(), c.logger) + socket, _, err := AcquireSocket( + ctx, retryMgr, c.dialer, c.url.String(), c.logger) if err != nil { c.state = StateDisconnected diff --git a/transport/connection_test.go b/transport/connection_test.go index eb30c62..e3c06e5 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -2,6 +2,7 @@ package transport import ( "bytes" + "context" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/types" @@ -239,7 +240,7 @@ func TestConnect(t *testing.T) { conn.socket = honeybeetest.NewMockSocket() - err = conn.Connect() + err = conn.Connect(context.Background()) assert.Error(t, err) assert.ErrorContains(t, err, "already has socket") assert.Equal(t, StateDisconnected, conn.State()) @@ -251,7 +252,7 @@ func TestConnect(t *testing.T) { conn.Close() - err = conn.Connect() + err = conn.Connect(context.Background()) assert.Error(t, err) assert.ErrorContains(t, err, "connection is closed") assert.Equal(t, StateClosed, conn.State()) @@ -270,13 +271,13 @@ func TestConnect(t *testing.T) { } mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } conn.dialer = mockDialer - err = conn.Connect() + err = conn.Connect(context.Background()) assert.NoError(t, err) assert.Equal(t, StateConnected, conn.State()) @@ -309,7 +310,7 @@ func TestConnect(t *testing.T) { attemptCount := 0 mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { attemptCount++ if attemptCount < 3 { return nil, nil, fmt.Errorf("dial failed") @@ -319,7 +320,7 @@ func TestConnect(t *testing.T) { } conn.dialer = mockDialer - err = conn.Connect() + err = conn.Connect(context.Background()) assert.NoError(t, err) assert.Equal(t, 3, attemptCount) assert.Equal(t, StateConnected, conn.State()) @@ -340,13 +341,13 @@ func TestConnect(t *testing.T) { assert.NoError(t, err) mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return nil, nil, fmt.Errorf("dial failed") }, } conn.dialer = mockDialer - err = conn.Connect() + err = conn.Connect(context.Background()) assert.Error(t, err) assert.ErrorContains(t, err, "dial failed") assert.Equal(t, StateDisconnected, conn.State()) @@ -359,14 +360,14 @@ func TestConnect(t *testing.T) { stateDuringDial := StateDisconnected mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { stateDuringDial = conn.state return honeybeetest.NewMockSocket(), nil, nil }, } conn.dialer = mockDialer - conn.Connect() + conn.Connect(context.Background()) assert.Equal(t, StateConnecting, stateDuringDial) assert.Equal(t, StateConnected, conn.State()) @@ -390,13 +391,13 @@ func TestConnect(t *testing.T) { } mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } conn.dialer = mockDialer - conn.Connect() + conn.Connect(context.Background()) assert.True(t, handlerSet, "close handler should be set on socket") diff --git a/transport/logging_test.go b/transport/logging_test.go index ee42381..e55f926 100644 --- a/transport/logging_test.go +++ b/transport/logging_test.go @@ -1,6 +1,7 @@ package transport import ( + "context" "fmt" "io" "log/slog" @@ -146,13 +147,13 @@ func TestConnectLogging(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } conn.dialer = mockDialer - err = conn.Connect() + err = conn.Connect(context.Background()) assert.NoError(t, err) defer conn.Close() @@ -186,13 +187,13 @@ func TestConnectLogging(t *testing.T) { dialErr := fmt.Errorf("dial error") mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return nil, nil, dialErr }, } conn.dialer = mockDialer - err = conn.Connect() + err = conn.Connect(context.Background()) assert.Error(t, err) records := mockHandler.GetRecords() @@ -230,7 +231,7 @@ func TestConnectLogging(t *testing.T) { attemptCount := 0 dialErr := fmt.Errorf("dial error") mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { attemptCount++ if attemptCount < 3 { return nil, nil, dialErr @@ -240,7 +241,7 @@ func TestConnectLogging(t *testing.T) { } conn.dialer = mockDialer - err = conn.Connect() + err = conn.Connect(context.Background()) assert.NoError(t, err) defer conn.Close() @@ -468,13 +469,13 @@ func TestLoggingDisabled(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } conn.dialer = mockDialer - err = conn.Connect() + err = conn.Connect(context.Background()) assert.NoError(t, err) conn.Close() diff --git a/transport/socket.go b/transport/socket.go index 42ffbe9..ccea07a 100644 --- a/transport/socket.go +++ b/transport/socket.go @@ -1,6 +1,7 @@ package transport import ( + "context" "log/slog" "net/http" "time" @@ -28,19 +29,22 @@ func NewGorillaDialer() *GorillaDialer { } // Returns the Socket interface -func (d *GorillaDialer) Dial( - urlStr string, requestHeader http.Header, +func (d *GorillaDialer) DialContext( + ctx context.Context, + url string, + header http.Header, ) ( types.Socket, *http.Response, error, ) { - conn, resp, err := d.Dialer.Dial(urlStr, requestHeader) + conn, resp, err := d.Dialer.DialContext(ctx, url, header) return conn, resp, err } func AcquireSocket( + ctx context.Context, retryMgr *RetryManager, dialer types.Dialer, - urlStr string, + url string, logger *slog.Logger, ) (types.Socket, *http.Response, error) { if retryMgr == nil { @@ -49,7 +53,7 @@ func AcquireSocket( if dialer == nil { return nil, nil, NewConnectionError("dialer cannot be nil") } - if urlStr == "" { + if url == "" { return nil, nil, NewConnectionError("URL cannot be empty") } @@ -58,7 +62,7 @@ func AcquireSocket( logger.Info("dialing", "attempt", retryMgr.RetryCount()+1) } - socket, resp, err := dialer.Dial(urlStr, nil) + socket, resp, err := dialer.DialContext(ctx, url, nil) if err == nil { if logger != nil { logger.Info("dial successful", "attempt", retryMgr.RetryCount()+1) @@ -84,7 +88,11 @@ func AcquireSocket( "next_delay", delay) } - time.Sleep(delay) + select { + case <-time.After(delay): + case <-ctx.Done(): + return nil, nil, ctx.Err() + } retryMgr.RecordRetry() } } diff --git a/transport/socket_test.go b/transport/socket_test.go index 21ba921..4e1918e 100644 --- a/transport/socket_test.go +++ b/transport/socket_test.go @@ -1,6 +1,7 @@ package transport import ( + "context" "errors" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/types" @@ -63,7 +64,8 @@ func TestAcquireSocket(t *testing.T) { t.Run(tc.name, func(t *testing.T) { attemptIndex := 0 mockDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header, + ) (types.Socket, *http.Response, error) { err := tc.mockRuns[attemptIndex] attemptIndex++ if err != nil { @@ -80,7 +82,8 @@ func TestAcquireSocket(t *testing.T) { JitterFactor: 0.0, }) - socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test", nil) + socket, _, err := AcquireSocket( + context.Background(), retryMgr, mockDialer, "ws://test", nil) assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount()) if tc.wantErr { @@ -96,7 +99,8 @@ func TestAcquireSocket(t *testing.T) { func TestAcquireSocketGuards(t *testing.T) { validDialer := &honeybeetest.MockDialer{ - DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + DialContextFunc: func(context.Context, string, http.Header, + ) (types.Socket, *http.Response, error) { return honeybeetest.NewMockSocket(), nil, nil }, } @@ -134,7 +138,8 @@ func TestAcquireSocketGuards(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url, nil) + socket, resp, err := AcquireSocket( + context.Background(), tc.retryMgr, tc.dialer, tc.url, nil) assert.Error(t, err) assert.ErrorContains(t, err, tc.wantErr) diff --git a/types/types.go b/types/types.go index f1e6fc4..1dfa361 100644 --- a/types/types.go +++ b/types/types.go @@ -1,12 +1,16 @@ package types import ( + "context" "net/http" "time" ) type Dialer interface { - Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error) + DialContext(ctx context.Context, + url string, + header http.Header, + ) (Socket, *http.Response, error) } type Socket interface {