253 lines
6.4 KiB
Go
253 lines
6.4 KiB
Go
package transport
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
|
"git.wisehodl.dev/jay/go-honeybee/types"
|
|
"github.com/stretchr/testify/assert"
|
|
"net/http"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestNewDialer(t *testing.T) {
|
|
dialer := NewDialer()
|
|
|
|
assert.NotNil(t, dialer)
|
|
_, ok := dialer.(*GorillaDialer)
|
|
assert.True(t, ok, "NewDialer should return *GorillaDialer")
|
|
}
|
|
|
|
func TestNewGorillaDialer(t *testing.T) {
|
|
dialer := NewGorillaDialer()
|
|
|
|
assert.NotNil(t, dialer)
|
|
assert.NotNil(t, dialer.Dialer)
|
|
assert.Equal(t, 45*time.Second, dialer.Dialer.HandshakeTimeout)
|
|
assert.Equal(t, 1024, dialer.Dialer.ReadBufferSize)
|
|
assert.Equal(t, 1024, dialer.Dialer.WriteBufferSize)
|
|
}
|
|
|
|
func TestAcquireSocket(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
mockRuns []error
|
|
maxRetries int
|
|
wantRetryCount int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "immediate success",
|
|
mockRuns: []error{nil},
|
|
maxRetries: 3,
|
|
wantRetryCount: 0,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "two failures, success",
|
|
mockRuns: []error{errors.New("1"), errors.New("2"), nil},
|
|
maxRetries: 0,
|
|
wantRetryCount: 2,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "three failures, failure",
|
|
mockRuns: []error{errors.New("1"), errors.New("2"), errors.New("3"), errors.New("4")},
|
|
maxRetries: 3,
|
|
wantRetryCount: 3,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
attemptIndex := atomic.Int32{}
|
|
mockDialer := &honeybeetest.MockDialer{
|
|
DialContextFunc: func(context.Context, string, http.Header,
|
|
) (types.Socket, *http.Response, error) {
|
|
err := tc.mockRuns[attemptIndex.Load()]
|
|
attemptIndex.Add(1)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return honeybeetest.NewMockSocket(), nil, nil
|
|
},
|
|
}
|
|
|
|
retryMgr := NewRetryManager(&RetryConfig{
|
|
MaxRetries: tc.maxRetries,
|
|
InitialDelay: 1 * time.Millisecond,
|
|
MaxDelay: 5 * time.Millisecond,
|
|
JitterFactor: 0.0,
|
|
})
|
|
|
|
socket, _, err := AcquireSocket(
|
|
context.Background(), retryMgr, mockDialer, "ws://test", nil)
|
|
|
|
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
|
if tc.wantErr {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, socket)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, socket)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAcquireSocketGuards(t *testing.T) {
|
|
validDialer := &honeybeetest.MockDialer{
|
|
DialContextFunc: func(context.Context, string, http.Header,
|
|
) (types.Socket, *http.Response, error) {
|
|
return honeybeetest.NewMockSocket(), nil, nil
|
|
},
|
|
}
|
|
validRetryMgr := NewRetryManager(GetDefaultRetryConfig())
|
|
|
|
cases := []struct {
|
|
name string
|
|
retryMgr *RetryManager
|
|
dialer types.Dialer
|
|
url string
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "nil retry manager",
|
|
retryMgr: nil,
|
|
dialer: validDialer,
|
|
url: "ws://test",
|
|
wantErr: "retry manager cannot be nil",
|
|
},
|
|
{
|
|
name: "nil dialer",
|
|
retryMgr: validRetryMgr,
|
|
dialer: nil,
|
|
url: "ws://test",
|
|
wantErr: "dialer cannot be nil",
|
|
},
|
|
{
|
|
name: "empty URL",
|
|
retryMgr: validRetryMgr,
|
|
dialer: validDialer,
|
|
url: "",
|
|
wantErr: "URL cannot be empty",
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
socket, resp, err := AcquireSocket(
|
|
context.Background(), tc.retryMgr, tc.dialer, tc.url, nil)
|
|
|
|
assert.Error(t, err)
|
|
assert.ErrorContains(t, err, tc.wantErr)
|
|
assert.Nil(t, socket)
|
|
assert.Nil(t, resp)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAcquireSocketContextCancellation(t *testing.T) {
|
|
t.Run("already-canceled context returns immediately without dialing",
|
|
func(t *testing.T) {
|
|
dialCalled := atomic.Bool{}
|
|
mockDialer := &honeybeetest.MockDialer{
|
|
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
|
dialCalled.Store(true)
|
|
return honeybeetest.NewMockSocket(), nil, nil
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// cancel before acquiring socket
|
|
cancel()
|
|
|
|
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
|
|
|
assert.ErrorIs(t, err, context.Canceled)
|
|
assert.False(t, dialCalled.Load())
|
|
})
|
|
|
|
t.Run("context cancelled during sleep returns before next attempt",
|
|
func(t *testing.T) {
|
|
dialCount := atomic.Int32{}
|
|
mockDialer := &honeybeetest.MockDialer{
|
|
DialContextFunc: func(_ context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
|
dialCount.Add(1)
|
|
return nil, nil, fmt.Errorf("dial failed")
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
retryMgr := NewRetryManager(&RetryConfig{
|
|
MaxRetries: 10,
|
|
InitialDelay: 1 * time.Second,
|
|
MaxDelay: 1 * time.Second,
|
|
JitterFactor: 0.0,
|
|
})
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
|
done <- err
|
|
}()
|
|
|
|
// wait for first two dials to complete, then cancel during sleep
|
|
assert.Eventually(t, func() bool {
|
|
return dialCount.Load() > 1
|
|
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-done:
|
|
assert.ErrorIs(t, err, context.Canceled)
|
|
|
|
// dial count is 2 because the first retry is always immediate
|
|
assert.Equal(t, int32(2), dialCount.Load())
|
|
case <-time.After(honeybeetest.TestTimeout):
|
|
t.Fatal("AcquireSocket did not return after context cancellation")
|
|
}
|
|
})
|
|
|
|
t.Run("context cancelled during in-progress dial unblocks and returns",
|
|
func(t *testing.T) {
|
|
mockDialer := &honeybeetest.MockDialer{
|
|
DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) {
|
|
// block until context is cancelled
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
retryMgr := NewRetryManager(GetDefaultRetryConfig())
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil)
|
|
done <- err
|
|
}()
|
|
|
|
// wait for dialer to block
|
|
time.Sleep(20 * time.Millisecond)
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-done:
|
|
assert.ErrorIs(t, err, context.Canceled)
|
|
case <-time.After(honeybeetest.TestTimeout):
|
|
t.Fatal("AcquireSocket did not return after context cancellation")
|
|
}
|
|
|
|
})
|
|
}
|