Files
go-honeybee/ws/connection_test.go

439 lines
10 KiB
Go

package ws
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"time"
)
// Connection state tests
func TestConnectionStateString(t *testing.T) {
cases := []struct {
state ConnectionState
want string
}{
{StateDisconnected, "disconnected"},
{StateConnecting, "connecting"},
{StateConnected, "connected"},
{StateClosed, "closed"},
{ConnectionState(99), "unknown"},
}
for _, tc := range cases {
t.Run(tc.want, func(t *testing.T) {
assert.Equal(t, tc.want, tc.state.String())
})
}
}
func TestConnectionState(t *testing.T) {
// Test initial state
conn, _ := NewConnection("ws://test", nil)
assert.Equal(t, StateDisconnected, conn.State())
// Test state after FromSocket (should be Connected)
conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil)
assert.Equal(t, StateConnected, conn2.State())
// Test state after close
conn.Close()
assert.Equal(t, StateClosed, conn.State())
}
// Connection constructor tests
func TestNewConnection(t *testing.T) {
cases := []struct {
name string
url string
config *Config
wantErr bool
wantErrText string
}{
{
name: "valid url, nil config",
url: "ws://example.com",
config: nil,
},
{
name: "valid url, valid config",
url: "wss://relay.example.com:8080/path",
config: &Config{ReadTimeout: 30 * time.Second},
},
{
name: "invalid url",
url: "http://example.com",
config: nil,
wantErr: true,
wantErrText: "URL must use ws:// or wss:// scheme",
},
{
name: "invalid config",
url: "ws://example.com",
config: &Config{
Retry: &RetryConfig{
InitialDelay: 10 * time.Second,
MaxDelay: 1 * time.Second,
},
},
wantErr: true,
wantErrText: "initial delay may not exceed maximum delay",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
conn, err := NewConnection(tc.url, tc.config)
if tc.wantErr {
assert.Error(t, err)
if tc.wantErrText != "" {
assert.ErrorContains(t, err, tc.wantErrText)
}
assert.Nil(t, conn)
return
}
assert.NoError(t, err)
assert.NotNil(t, conn)
// Verify struct fields
assert.NotNil(t, conn.url)
assert.NotNil(t, conn.dialer)
assert.Nil(t, conn.socket)
assert.NotNil(t, conn.config)
assert.NotNil(t, conn.incoming)
assert.NotNil(t, conn.outgoing)
assert.NotNil(t, conn.errors)
assert.NotNil(t, conn.done)
assert.Equal(t, StateDisconnected, conn.state)
assert.False(t, conn.closed)
// Verify default config is used if nil is passed
if tc.config == nil {
assert.Equal(t, GetDefaultConfig(), conn.config)
} else {
assert.Equal(t, tc.config, conn.config)
}
})
}
}
func TestNewConnectionFromSocket(t *testing.T) {
cases := []struct {
name string
socket Socket
config *Config
wantErr bool
wantErrText string
}{
{
name: "nil socket",
socket: nil,
config: nil,
wantErr: true,
wantErrText: "socket cannot be nil",
},
{
name: "valid socket with nil config",
socket: NewMockSocket(),
config: nil,
},
{
name: "valid socket with valid config",
socket: NewMockSocket(),
config: &Config{ReadTimeout: 30 * time.Second},
},
{
name: "invalid config",
socket: NewMockSocket(),
config: &Config{
Retry: &RetryConfig{
InitialDelay: 10 * time.Second,
MaxDelay: 1 * time.Second,
},
},
wantErr: true,
wantErrText: "initial delay may not exceed maximum delay",
},
{
name: "close handler set when provided",
socket: NewMockSocket(),
config: &Config{
CloseHandler: func(code int, text string) error {
return nil
},
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// track if SetCloseHandler was called
closeHandlerSet := false
if tc.socket != nil {
mockSocket := tc.socket.(*MockSocket)
originalSetCloseHandler := mockSocket.SetCloseHandlerFunc
// wrapper around the original handler function
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
closeHandlerSet = true
if originalSetCloseHandler != nil {
originalSetCloseHandler(h)
}
}
}
conn, err := NewConnectionFromSocket(tc.socket, tc.config)
if tc.wantErr {
assert.Error(t, err)
if tc.wantErrText != "" {
assert.ErrorContains(t, err, tc.wantErrText)
}
assert.Nil(t, conn)
return
}
assert.NoError(t, err)
assert.NotNil(t, conn)
// Verify fields initialized correctly
assert.Nil(t, conn.url)
assert.Nil(t, conn.dialer)
assert.Equal(t, tc.socket, conn.socket)
assert.NotNil(t, conn.config)
assert.NotNil(t, conn.incoming)
assert.NotNil(t, conn.outgoing)
assert.NotNil(t, conn.errors)
assert.NotNil(t, conn.done)
assert.Equal(t, StateConnected, conn.state)
assert.False(t, conn.closed)
// Verify config defaulting
if tc.config == nil {
assert.Equal(t, GetDefaultConfig(), conn.config)
} else {
assert.Equal(t, tc.config, conn.config)
}
// Verify close handler was set if provided
if tc.config != nil && tc.config.CloseHandler != nil {
assert.True(t, closeHandlerSet, "CloseHandler should be set on socket")
}
})
}
}
// ws/connection_test.go
// Add to existing file after TestNewConnectionFromSocket
func TestConnect(t *testing.T) {
t.Run("connect fails when socket already present", func(t *testing.T) {
conn, err := NewConnection("ws://test", nil)
assert.NoError(t, err)
conn.socket = NewMockSocket()
err = conn.Connect()
assert.Error(t, err)
assert.ErrorContains(t, err, "already has socket")
assert.Equal(t, StateDisconnected, conn.State())
})
t.Run("connect fails when connection closed", func(t *testing.T) {
conn, err := NewConnection("ws://test", nil)
assert.NoError(t, err)
conn.Close()
err = conn.Connect()
assert.Error(t, err)
assert.ErrorContains(t, err, "connection is closed")
assert.Equal(t, StateClosed, conn.State())
})
t.Run("connect succeeds and starts goroutines", func(t *testing.T) {
conn, err := NewConnection("ws://test", nil)
assert.NoError(t, err)
outgoingData := make(chan mockOutgoingData, 10)
mockSocket := NewMockSocket()
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
outgoingData <- mockOutgoingData{msgType: msgType, data: data}
return nil
}
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
conn.dialer = mockDialer
err = conn.Connect()
assert.NoError(t, err)
assert.Equal(t, StateConnected, conn.State())
testData := []byte("test")
conn.Send(testData)
time.Sleep(10 * time.Millisecond)
select {
case msg := <-outgoingData:
assert.Equal(t, testData, msg.data)
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for message write")
}
conn.Close()
close(outgoingData)
})
t.Run("connect retries on dial failure", func(t *testing.T) {
config := &Config{
Retry: &RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection("ws://test", config)
assert.NoError(t, err)
attemptCount := 0
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
attemptCount++
if attemptCount < 3 {
return nil, nil, fmt.Errorf("dial failed")
}
return NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
err = conn.Connect()
assert.NoError(t, err)
assert.Equal(t, 3, attemptCount)
assert.Equal(t, StateConnected, conn.State())
conn.Close()
})
t.Run("connect fails after max retries", func(t *testing.T) {
config := &Config{
Retry: &RetryConfig{
MaxRetries: 2,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
JitterFactor: 0.0,
},
}
conn, err := NewConnection("ws://test", config)
assert.NoError(t, err)
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
return nil, nil, fmt.Errorf("dial failed")
},
}
conn.dialer = mockDialer
err = conn.Connect()
assert.Error(t, err)
assert.ErrorContains(t, err, "dial failed")
assert.Equal(t, StateDisconnected, conn.State())
})
t.Run("state transitions during connect", func(t *testing.T) {
conn, err := NewConnection("ws://test", nil)
assert.NoError(t, err)
assert.Equal(t, StateDisconnected, conn.State())
stateDuringDial := StateDisconnected
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
stateDuringDial = conn.state
return NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
conn.Connect()
assert.Equal(t, StateConnecting, stateDuringDial)
assert.Equal(t, StateConnected, conn.State())
conn.Close()
})
t.Run("close handler configured when provided", func(t *testing.T) {
handlerSet := false
config := &Config{
CloseHandler: func(code int, text string) error {
return nil
},
}
conn, err := NewConnection("ws://test", config)
assert.NoError(t, err)
mockSocket := NewMockSocket()
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
handlerSet = true
}
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
conn.dialer = mockDialer
conn.Connect()
assert.True(t, handlerSet, "close handler should be set on socket")
conn.Close()
})
}
// Connection method tests
func TestConnectionIncoming(t *testing.T) {
conn, err := NewConnection("ws://test", nil)
assert.NoError(t, err)
incoming := conn.Incoming()
assert.NotNil(t, incoming)
// send data through the channel to verify they are the same
testData := []byte("test")
conn.incoming <- testData
received := <-incoming
assert.Equal(t, testData, received)
}
func TestConnectionErrors(t *testing.T) {
conn, err := NewConnection("ws://test", nil)
assert.NoError(t, err)
errors := conn.Errors()
assert.NotNil(t, errors)
// send data through the channel to verify they are the same
testErr := fmt.Errorf("test error")
conn.errors <- testErr
received := <-errors
assert.Equal(t, testErr, received)
}
// Connect() tests