diff --git a/initiatorpool/worker_test.go b/initiatorpool/worker_test.go index ad31197..8e31119 100644 --- a/initiatorpool/worker_test.go +++ b/initiatorpool/worker_test.go @@ -13,8 +13,6 @@ import ( "time" ) -// Forwarder - func TestRunForwarder(t *testing.T) { t.Run("message passes through to inbox", func(t *testing.T) { messages := make(chan receivedMessage, 1) @@ -351,4 +349,49 @@ func TestRunDialer(t *testing.T) { } }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) + + t.Run("context cancelled during in-progress dial exits without delivering connection", func(t *testing.T) { + w := &Worker{id: "wss://test"} + dial := make(chan struct{}, 1) + newConn := make(chan *transport.Connection, 1) + ctx, cancel := context.WithCancel(context.Background()) + + wctx := WorkerContext{ + Errors: make(chan error, 1), + ConnectionConfig: &transport.ConnectionConfig{Retry: nil}, + Dialer: &honeybeetest.MockDialer{ + DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { + // block until context is cancelled + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + }, + }, + } + + done := make(chan struct{}) + go func() { + w.runDialer(ctx, wctx, dial, newConn) + close(done) + }() + + dial <- struct{}{} + + // wait for dialer to block + time.Sleep(20 * time.Millisecond) + cancel() + + assert.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + + // no connection was sent + assert.Empty(t, newConn) + }) } diff --git a/transport/connection_test.go b/transport/connection_test.go index e3c06e5..d02f4b0 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "io" "net/http" + "sync/atomic" "testing" "time" ) @@ -405,6 +406,52 @@ func TestConnect(t *testing.T) { }) } +func TestConnectContextCancellation(t *testing.T) { + t.Run("context cancelled during connect returns before retries exhaust", func(t *testing.T) { + config := &ConnectionConfig{ + Retry: &RetryConfig{ + MaxRetries: 100, + InitialDelay: 500 * time.Millisecond, + MaxDelay: 1 * time.Second, + JitterFactor: 0.0, + }, + } + conn, err := NewConnection("ws://test", config, nil) + assert.NoError(t, err) + + dialCount := atomic.Int32{} + ctx, cancel := context.WithCancel(context.Background()) + + conn.dialer = &honeybeetest.MockDialer{ + DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { + dialCount.Add(1) + return nil, nil, fmt.Errorf("dial failed") + }, + } + + done := make(chan error, 1) + go func() { + done <- conn.Connect(ctx) + }() + + // wait for first dial + assert.Eventually(t, func() bool { + return dialCount.Load() >= 1 + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + cancel() + + select { + case err := <-done: + assert.ErrorIs(t, err, context.Canceled) + + // number of dials is fewer than max retry count + assert.Less(t, dialCount.Load(), int32(100)) + case <-time.After(honeybeetest.TestTimeout): + t.Fatal("Connect did not return after context cancellation") + } + }) +} + // Connection method tests func TestConnectionIncoming(t *testing.T) { diff --git a/transport/socket.go b/transport/socket.go index ccea07a..e13debf 100644 --- a/transport/socket.go +++ b/transport/socket.go @@ -47,6 +47,12 @@ func AcquireSocket( url string, logger *slog.Logger, ) (types.Socket, *http.Response, error) { + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + default: + } + if retryMgr == nil { return nil, nil, NewConnectionError("retry manager cannot be nil") } diff --git a/transport/socket_test.go b/transport/socket_test.go index 4e1918e..ffa2553 100644 --- a/transport/socket_test.go +++ b/transport/socket_test.go @@ -3,10 +3,12 @@ package transport import ( "context" "errors" + "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/types" "github.com/stretchr/testify/assert" "net/http" + "sync/atomic" "testing" "time" ) @@ -62,12 +64,12 @@ func TestAcquireSocket(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - attemptIndex := 0 + attemptIndex := atomic.Int32{} mockDialer := &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header, ) (types.Socket, *http.Response, error) { - err := tc.mockRuns[attemptIndex] - attemptIndex++ + err := tc.mockRuns[attemptIndex.Load()] + attemptIndex.Add(1) if err != nil { return nil, nil, err } @@ -148,3 +150,103 @@ func TestAcquireSocketGuards(t *testing.T) { }) } } + +func TestAcquireSocketContextCancellation(t *testing.T) { + t.Run("already-canceled context returns immediately without dialing", + func(t *testing.T) { + dialCalled := atomic.Bool{} + mockDialer := &honeybeetest.MockDialer{ + DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { + dialCalled.Store(true) + return honeybeetest.NewMockSocket(), nil, nil + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + + // cancel before acquiring socket + cancel() + + retryMgr := NewRetryManager(GetDefaultRetryConfig()) + _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) + + assert.ErrorIs(t, err, context.Canceled) + assert.False(t, dialCalled.Load()) + }) + + t.Run("context cancelled during sleep returns before next attempt", + func(t *testing.T) { + dialCount := atomic.Int32{} + mockDialer := &honeybeetest.MockDialer{ + DialContextFunc: func(_ context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { + dialCount.Add(1) + return nil, nil, fmt.Errorf("dial failed") + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + + retryMgr := NewRetryManager(&RetryConfig{ + MaxRetries: 10, + InitialDelay: 1 * time.Second, + MaxDelay: 1 * time.Second, + JitterFactor: 0.0, + }) + + done := make(chan error, 1) + go func() { + _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) + done <- err + }() + + // wait for first dial to complete, then cancel during sleep + assert.Eventually(t, func() bool { + return dialCount.Load() >= 1 + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + cancel() + + select { + case err := <-done: + assert.ErrorIs(t, err, context.Canceled) + + // dial count is 2 because the first retry is always immediate + assert.Equal(t, int32(2), dialCount.Load()) + case <-time.After(honeybeetest.TestTimeout): + t.Fatal("AcquireSocket did not return after context cancellation") + } + }) + + t.Run("context cancelled during in-progress dial unblocks and returns", + func(t *testing.T) { + mockDialer := &honeybeetest.MockDialer{ + DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { + // block until context is cancelled + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + + retryMgr := NewRetryManager(GetDefaultRetryConfig()) + done := make(chan error, 1) + go func() { + _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) + done <- err + }() + + // wait for dialer to block + time.Sleep(20 * time.Millisecond) + cancel() + + select { + case err := <-done: + assert.ErrorIs(t, err, context.Canceled) + case <-time.After(honeybeetest.TestTimeout): + t.Fatal("AcquireSocket did not return after context cancellation") + } + + }) +}