Added context cancellation tests.
This commit is contained in:
@@ -13,8 +13,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Forwarder
|
|
||||||
|
|
||||||
func TestRunForwarder(t *testing.T) {
|
func TestRunForwarder(t *testing.T) {
|
||||||
t.Run("message passes through to inbox", func(t *testing.T) {
|
t.Run("message passes through to inbox", func(t *testing.T) {
|
||||||
messages := make(chan receivedMessage, 1)
|
messages := make(chan receivedMessage, 1)
|
||||||
@@ -351,4 +349,49 @@ func TestRunDialer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("context cancelled during in-progress dial exits without delivering connection", func(t *testing.T) {
|
||||||
|
w := &Worker{id: "wss://test"}
|
||||||
|
dial := make(chan struct{}, 1)
|
||||||
|
newConn := make(chan *transport.Connection, 1)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
wctx := WorkerContext{
|
||||||
|
Errors: make(chan error, 1),
|
||||||
|
ConnectionConfig: &transport.ConnectionConfig{Retry: nil},
|
||||||
|
Dialer: &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
// block until context is cancelled
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.runDialer(ctx, wctx, dial, newConn)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
dial <- struct{}{}
|
||||||
|
|
||||||
|
// wait for dialer to block
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
|
||||||
|
// no connection was sent
|
||||||
|
assert.Empty(t, newConn)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -405,6 +406,52 @@ func TestConnect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectContextCancellation(t *testing.T) {
|
||||||
|
t.Run("context cancelled during connect returns before retries exhaust", func(t *testing.T) {
|
||||||
|
config := &ConnectionConfig{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
MaxRetries: 100,
|
||||||
|
InitialDelay: 500 * time.Millisecond,
|
||||||
|
MaxDelay: 1 * time.Second,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := NewConnection("ws://test", config, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
dialCount := atomic.Int32{}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
conn.dialer = &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
dialCount.Add(1)
|
||||||
|
return nil, nil, fmt.Errorf("dial failed")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- conn.Connect(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for first dial
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
return dialCount.Load() >= 1
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
assert.ErrorIs(t, err, context.Canceled)
|
||||||
|
|
||||||
|
// number of dials is fewer than max retry count
|
||||||
|
assert.Less(t, dialCount.Load(), int32(100))
|
||||||
|
case <-time.After(honeybeetest.TestTimeout):
|
||||||
|
t.Fatal("Connect did not return after context cancellation")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Connection method tests
|
// Connection method tests
|
||||||
|
|
||||||
func TestConnectionIncoming(t *testing.T) {
|
func TestConnectionIncoming(t *testing.T) {
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ func AcquireSocket(
|
|||||||
url string,
|
url string,
|
||||||
logger *slog.Logger,
|
logger *slog.Logger,
|
||||||
) (types.Socket, *http.Response, error) {
|
) (types.Socket, *http.Response, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
if retryMgr == nil {
|
if retryMgr == nil {
|
||||||
return nil, nil, NewConnectionError("retry manager cannot be nil")
|
return nil, nil, NewConnectionError("retry manager cannot be nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package transport
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -62,12 +64,12 @@ func TestAcquireSocket(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
attemptIndex := 0
|
attemptIndex := atomic.Int32{}
|
||||||
mockDialer := &honeybeetest.MockDialer{
|
mockDialer := &honeybeetest.MockDialer{
|
||||||
DialContextFunc: func(context.Context, string, http.Header,
|
DialContextFunc: func(context.Context, string, http.Header,
|
||||||
) (types.Socket, *http.Response, error) {
|
) (types.Socket, *http.Response, error) {
|
||||||
err := tc.mockRuns[attemptIndex]
|
err := tc.mockRuns[attemptIndex.Load()]
|
||||||
attemptIndex++
|
attemptIndex.Add(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -148,3 +150,103 @@ func TestAcquireSocketGuards(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAcquireSocketContextCancellation(t *testing.T) {
|
||||||
|
t.Run("already-canceled context returns immediately without dialing",
|
||||||
|
func(t *testing.T) {
|
||||||
|
dialCalled := atomic.Bool{}
|
||||||
|
mockDialer := &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
dialCalled.Store(true)
|
||||||
|
return honeybeetest.NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// cancel before acquiring socket
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||||
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, context.Canceled)
|
||||||
|
assert.False(t, dialCalled.Load())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context cancelled during sleep returns before next attempt",
|
||||||
|
func(t *testing.T) {
|
||||||
|
dialCount := atomic.Int32{}
|
||||||
|
mockDialer := &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(_ context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
dialCount.Add(1)
|
||||||
|
return nil, nil, fmt.Errorf("dial failed")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
retryMgr := NewRetryManager(&RetryConfig{
|
||||||
|
MaxRetries: 10,
|
||||||
|
InitialDelay: 1 * time.Second,
|
||||||
|
MaxDelay: 1 * time.Second,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for first dial to complete, then cancel during sleep
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
return dialCount.Load() >= 1
|
||||||
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
assert.ErrorIs(t, err, context.Canceled)
|
||||||
|
|
||||||
|
// dial count is 2 because the first retry is always immediate
|
||||||
|
assert.Equal(t, int32(2), dialCount.Load())
|
||||||
|
case <-time.After(honeybeetest.TestTimeout):
|
||||||
|
t.Fatal("AcquireSocket did not return after context cancellation")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context cancelled during in-progress dial unblocks and returns",
|
||||||
|
func(t *testing.T) {
|
||||||
|
mockDialer := &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
// block until context is cancelled
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for dialer to block
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
assert.ErrorIs(t, err, context.Canceled)
|
||||||
|
case <-time.After(honeybeetest.TestTimeout):
|
||||||
|
t.Fatal("AcquireSocket did not return after context cancellation")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user