Added context cancellation tests.

This commit is contained in:
Jay
2026-04-18 18:01:22 -04:00
parent b4c5c897e8
commit 8d79a002f8
4 changed files with 203 additions and 5 deletions

View File

@@ -13,8 +13,6 @@ import (
"time" "time"
) )
// Forwarder
func TestRunForwarder(t *testing.T) { func TestRunForwarder(t *testing.T) {
t.Run("message passes through to inbox", func(t *testing.T) { t.Run("message passes through to inbox", func(t *testing.T) {
messages := make(chan receivedMessage, 1) messages := make(chan receivedMessage, 1)
@@ -351,4 +349,49 @@ func TestRunDialer(t *testing.T) {
} }
}, honeybeetest.TestTimeout, honeybeetest.TestTick) }, honeybeetest.TestTimeout, honeybeetest.TestTick)
}) })
t.Run("context cancelled during in-progress dial exits without delivering connection", func(t *testing.T) {
w := &Worker{id: "wss://test"}
dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1)
ctx, cancel := context.WithCancel(context.Background())
wctx := WorkerContext{
Errors: make(chan error, 1),
ConnectionConfig: &transport.ConnectionConfig{Retry: nil},
Dialer: &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
// block until context is cancelled
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
}
},
},
}
done := make(chan struct{})
go func() {
w.runDialer(ctx, wctx, dial, newConn)
close(done)
}()
dial <- struct{}{}
// wait for dialer to block
time.Sleep(20 * time.Millisecond)
cancel()
assert.Eventually(t, func() bool {
select {
case <-done:
return true
default:
return false
}
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
// no connection was sent
assert.Empty(t, newConn)
})
} }

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io" "io"
"net/http" "net/http"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
@@ -405,6 +406,52 @@ func TestConnect(t *testing.T) {
}) })
} }
func TestConnectContextCancellation(t *testing.T) {
t.Run("context cancelled during connect returns before retries exhaust", func(t *testing.T) {
config := &ConnectionConfig{
Retry: &RetryConfig{
MaxRetries: 100,
InitialDelay: 500 * time.Millisecond,
MaxDelay: 1 * time.Second,
JitterFactor: 0.0,
},
}
conn, err := NewConnection("ws://test", config, nil)
assert.NoError(t, err)
dialCount := atomic.Int32{}
ctx, cancel := context.WithCancel(context.Background())
conn.dialer = &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
dialCount.Add(1)
return nil, nil, fmt.Errorf("dial failed")
},
}
done := make(chan error, 1)
go func() {
done <- conn.Connect(ctx)
}()
// wait for first dial
assert.Eventually(t, func() bool {
return dialCount.Load() >= 1
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
cancel()
select {
case err := <-done:
assert.ErrorIs(t, err, context.Canceled)
// number of dials is fewer than max retry count
assert.Less(t, dialCount.Load(), int32(100))
case <-time.After(honeybeetest.TestTimeout):
t.Fatal("Connect did not return after context cancellation")
}
})
}
// Connection method tests // Connection method tests
func TestConnectionIncoming(t *testing.T) { func TestConnectionIncoming(t *testing.T) {

View File

@@ -47,6 +47,12 @@ func AcquireSocket(
url string, url string,
logger *slog.Logger, logger *slog.Logger,
) (types.Socket, *http.Response, error) { ) (types.Socket, *http.Response, error) {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
default:
}
if retryMgr == nil { if retryMgr == nil {
return nil, nil, NewConnectionError("retry manager cannot be nil") return nil, nil, NewConnectionError("retry manager cannot be nil")
} }

View File

@@ -3,10 +3,12 @@ package transport
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/types" "git.wisehodl.dev/jay/go-honeybee/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"net/http" "net/http"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
@@ -62,12 +64,12 @@ func TestAcquireSocket(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
attemptIndex := 0 attemptIndex := atomic.Int32{}
mockDialer := &honeybeetest.MockDialer{ mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header, DialContextFunc: func(context.Context, string, http.Header,
) (types.Socket, *http.Response, error) { ) (types.Socket, *http.Response, error) {
err := tc.mockRuns[attemptIndex] err := tc.mockRuns[attemptIndex.Load()]
attemptIndex++ attemptIndex.Add(1)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -148,3 +150,103 @@ func TestAcquireSocketGuards(t *testing.T) {
}) })
} }
} }
func TestAcquireSocketContextCancellation(t *testing.T) {
t.Run("already-canceled context returns immediately without dialing",
func(t *testing.T) {
dialCalled := atomic.Bool{}
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
dialCalled.Store(true)
return honeybeetest.NewMockSocket(), nil, nil
},
}
ctx, cancel := context.WithCancel(context.Background())
// cancel before acquiring socket
cancel()
retryMgr := NewRetryManager(GetDefaultRetryConfig())
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
assert.ErrorIs(t, err, context.Canceled)
assert.False(t, dialCalled.Load())
})
t.Run("context cancelled during sleep returns before next attempt",
func(t *testing.T) {
dialCount := atomic.Int32{}
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(_ context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
dialCount.Add(1)
return nil, nil, fmt.Errorf("dial failed")
},
}
ctx, cancel := context.WithCancel(context.Background())
retryMgr := NewRetryManager(&RetryConfig{
MaxRetries: 10,
InitialDelay: 1 * time.Second,
MaxDelay: 1 * time.Second,
JitterFactor: 0.0,
})
done := make(chan error, 1)
go func() {
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
done <- err
}()
// wait for first dial to complete, then cancel during sleep
assert.Eventually(t, func() bool {
return dialCount.Load() >= 1
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
cancel()
select {
case err := <-done:
assert.ErrorIs(t, err, context.Canceled)
// dial count is 2 because the first retry is always immediate
assert.Equal(t, int32(2), dialCount.Load())
case <-time.After(honeybeetest.TestTimeout):
t.Fatal("AcquireSocket did not return after context cancellation")
}
})
t.Run("context cancelled during in-progress dial unblocks and returns",
func(t *testing.T) {
mockDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
// block until context is cancelled
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
}
},
}
ctx, cancel := context.WithCancel(context.Background())
retryMgr := NewRetryManager(GetDefaultRetryConfig())
done := make(chan error, 1)
go func() {
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
done <- err
}()
// wait for dialer to block
time.Sleep(20 * time.Millisecond)
cancel()
select {
case err := <-done:
assert.ErrorIs(t, err, context.Canceled)
case <-time.After(honeybeetest.TestTimeout):
t.Fatal("AcquireSocket did not return after context cancellation")
}
})
}