package transport import ( "context" "errors" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/types" "github.com/stretchr/testify/assert" "net/http" "sync/atomic" "testing" "time" ) func TestNewDialer(t *testing.T) { dialer := NewDialer() assert.NotNil(t, dialer) _, ok := dialer.(*GorillaDialer) assert.True(t, ok, "NewDialer should return *GorillaDialer") } func TestNewGorillaDialer(t *testing.T) { dialer := NewGorillaDialer() assert.NotNil(t, dialer) assert.NotNil(t, dialer.Dialer) assert.Equal(t, 45*time.Second, dialer.Dialer.HandshakeTimeout) assert.Equal(t, 1024, dialer.Dialer.ReadBufferSize) assert.Equal(t, 1024, dialer.Dialer.WriteBufferSize) } func TestAcquireSocket(t *testing.T) { cases := []struct { name string mockRuns []error maxRetries int wantRetryCount int wantErr bool }{ { name: "immediate success", mockRuns: []error{nil}, maxRetries: 3, wantRetryCount: 0, wantErr: false, }, { name: "two failures, success", mockRuns: []error{errors.New("1"), errors.New("2"), nil}, maxRetries: 0, wantRetryCount: 2, wantErr: false, }, { name: "three failures, failure", mockRuns: []error{errors.New("1"), errors.New("2"), errors.New("3"), errors.New("4")}, maxRetries: 3, wantRetryCount: 3, wantErr: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { attemptIndex := atomic.Int32{} mockDialer := &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header, ) (types.Socket, *http.Response, error) { err := tc.mockRuns[attemptIndex.Load()] attemptIndex.Add(1) if err != nil { return nil, nil, err } return honeybeetest.NewMockSocket(), nil, nil }, } retryMgr := NewRetryManager(&RetryConfig{ MaxRetries: tc.maxRetries, InitialDelay: 1 * time.Millisecond, MaxDelay: 5 * time.Millisecond, JitterFactor: 0.0, }) socket, _, err := AcquireSocket( context.Background(), retryMgr, mockDialer, "ws://test", nil) assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount()) if tc.wantErr { assert.Error(t, err) assert.Nil(t, socket) } else { assert.NoError(t, err) assert.NotNil(t, socket) } }) } } func TestAcquireSocketGuards(t *testing.T) { validDialer := &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header, ) (types.Socket, *http.Response, error) { return honeybeetest.NewMockSocket(), nil, nil }, } validRetryMgr := NewRetryManager(GetDefaultRetryConfig()) cases := []struct { name string retryMgr *RetryManager dialer types.Dialer url string wantErr string }{ { name: "nil retry manager", retryMgr: nil, dialer: validDialer, url: "ws://test", wantErr: "retry manager cannot be nil", }, { name: "nil dialer", retryMgr: validRetryMgr, dialer: nil, url: "ws://test", wantErr: "dialer cannot be nil", }, { name: "empty URL", retryMgr: validRetryMgr, dialer: validDialer, url: "", wantErr: "URL cannot be empty", }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { socket, resp, err := AcquireSocket( context.Background(), tc.retryMgr, tc.dialer, tc.url, nil) assert.Error(t, err) assert.ErrorContains(t, err, tc.wantErr) assert.Nil(t, socket) assert.Nil(t, resp) }) } } func TestAcquireSocketContextCancellation(t *testing.T) { t.Run("already-canceled context returns immediately without dialing", func(t *testing.T) { dialCalled := atomic.Bool{} mockDialer := &honeybeetest.MockDialer{ DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { dialCalled.Store(true) return honeybeetest.NewMockSocket(), nil, nil }, } ctx, cancel := context.WithCancel(context.Background()) // cancel before acquiring socket cancel() retryMgr := NewRetryManager(GetDefaultRetryConfig()) _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) assert.ErrorIs(t, err, context.Canceled) assert.False(t, dialCalled.Load()) }) t.Run("context cancelled during sleep returns before next attempt", func(t *testing.T) { dialCount := atomic.Int32{} mockDialer := &honeybeetest.MockDialer{ DialContextFunc: func(_ context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { dialCount.Add(1) return nil, nil, fmt.Errorf("dial failed") }, } ctx, cancel := context.WithCancel(context.Background()) retryMgr := NewRetryManager(&RetryConfig{ MaxRetries: 10, InitialDelay: 1 * time.Second, MaxDelay: 1 * time.Second, JitterFactor: 0.0, }) done := make(chan error, 1) go func() { _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) done <- err }() // wait for first two dials to complete, then cancel during sleep assert.Eventually(t, func() bool { return dialCount.Load() > 1 }, honeybeetest.TestTimeout, honeybeetest.TestTick) cancel() select { case err := <-done: assert.ErrorIs(t, err, context.Canceled) // dial count is 2 because the first retry is always immediate assert.Equal(t, int32(2), dialCount.Load()) case <-time.After(honeybeetest.TestTimeout): t.Fatal("AcquireSocket did not return after context cancellation") } }) t.Run("context cancelled during in-progress dial unblocks and returns", func(t *testing.T) { mockDialer := &honeybeetest.MockDialer{ DialContextFunc: func(ctx context.Context, _ string, _ http.Header) (types.Socket, *http.Response, error) { // block until context is cancelled select { case <-ctx.Done(): return nil, nil, ctx.Err() } }, } ctx, cancel := context.WithCancel(context.Background()) retryMgr := NewRetryManager(GetDefaultRetryConfig()) done := make(chan error, 1) go func() { _, _, err := AcquireSocket(ctx, retryMgr, mockDialer, "ws://test", nil) done <- err }() // wait for dialer to block time.Sleep(20 * time.Millisecond) cancel() select { case err := <-done: assert.ErrorIs(t, err, context.Canceled) case <-time.After(honeybeetest.TestTimeout): t.Fatal("AcquireSocket did not return after context cancellation") } }) }