diff --git a/c2p b/c2p index 27ad39c..1c38321 100755 --- a/c2p +++ b/c2p @@ -1,2 +1,2 @@ #!/bin/bash -code2prompt -c -i c2p -i go.sum -i LICENSE +code2prompt -c -e c2p -e go.sum -e LICENSE diff --git a/errors/errors.go b/errors/errors.go new file mode 100644 index 0000000..564c8fc --- /dev/null +++ b/errors/errors.go @@ -0,0 +1,25 @@ +package errors + +import "errors" +import "fmt" + +var ( + // URL Errors + InvalidProtocol = errors.New("URL must use ws:// or wss:// scheme") + + // Configuration Errors + InvalidReadTimeout = errors.New("read timeout must be positive") + InvalidWriteTimeout = errors.New("write timeout must be positive") + InvalidRetryMaxRetries = errors.New("max retry count cannot be negative") + InvalidRetryInitialDelay = errors.New("initial delay must be positive") + InvalidRetryMaxDelay = errors.New("max delay must be positive") + InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0") +) + +func NewConfigError(text string) error { + return fmt.Errorf("configuration error: %s", text) +} + +func NewConnectionError(text string) error { + return fmt.Errorf("connection error: %s", text) +} diff --git a/go.mod b/go.mod index 7966348..48a05a7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,12 @@ module git.wisehodl.dev/jay/go-honeybee go 1.23.5 require ( - github.com/gorilla/websocket v1.5.3 // indirect - github.com/stretchr/testify v1.11.1 // indirect + github.com/gorilla/websocket v1.5.3 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 33c3566..4b33f39 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/honeybee.go b/honeybee.go new file mode 100644 index 0000000..72aab9a --- /dev/null +++ b/honeybee.go @@ -0,0 +1 @@ +package honeybee diff --git a/ws/config.go b/ws/config.go new file mode 100644 index 0000000..8e0f27e --- /dev/null +++ b/ws/config.go @@ -0,0 +1,163 @@ +package ws + +import ( + "git.wisehodl.dev/jay/go-honeybee/errors" + "time" +) + +type CloseHandler func(code int, text string) error + +type Config struct { + CloseHandler CloseHandler + ReadTimeout time.Duration + WriteTimeout time.Duration + Retry *RetryConfig +} + +type RetryConfig struct { + MaxRetries int + InitialDelay time.Duration + MaxDelay time.Duration + JitterFactor float64 +} + +type ConfigOption func(*Config) error + +func NewConfig(options ...ConfigOption) (*Config, error) { + conf := GetDefaultConfig() + if err := SetConfig(conf, options...); err != nil { + return nil, err + } + if err := ValidateConfig(conf); err != nil { + return nil, err + } + return conf, nil +} + +func GetDefaultConfig() *Config { + return &Config{} +} + +func GetDefaultRetryConfig() *RetryConfig { + return &RetryConfig{ + MaxRetries: 0, // Infinite retries + InitialDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + JitterFactor: 0.5, + } +} + +func SetConfig(config *Config, options ...ConfigOption) error { + for _, option := range options { + if err := option(config); err != nil { + return err + } + } + return nil +} + +func ValidateConfig(config *Config) error { + if config.Retry != nil { + if config.Retry.InitialDelay > config.Retry.MaxDelay { + return errors.NewConfigError("initial delay may not exceed maximum delay") + } + } + + return nil +} + +// Configuration Options + +func WithCloseHandler(handler CloseHandler) ConfigOption { + return func(c *Config) error { + c.CloseHandler = handler + return nil + } +} + +// When ReadTimeout is set to zero, read timeouts are disabled. +func WithReadTimeout(value time.Duration) ConfigOption { + return func(c *Config) error { + if value < 0 { + return errors.InvalidReadTimeout + } + c.ReadTimeout = value + return nil + } +} + +// When WriteTimeout is set to zero, read timeouts are disabled. +func WithWriteTimeout(value time.Duration) ConfigOption { + return func(c *Config) error { + if value < 0 { + return errors.InvalidWriteTimeout + } + c.WriteTimeout = value + return nil + } +} + +// WithRetry enables retry with default parameters (infinite retries, +// 1s initial delay, 5s max delay, 0.5 jitter factor). +// +// If passed after granular retry options (WithRetryMaxRetries, etc.), +// it will overwrite them. Use either WithRetry alone or the granular +// options; not both. +func WithRetry() ConfigOption { + return func(c *Config) error { + c.Retry = GetDefaultRetryConfig() + return nil + } +} + +func WithRetryMaxRetries(value int) ConfigOption { + return func(c *Config) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + if value < 0 { + return errors.InvalidRetryMaxRetries + } + c.Retry.MaxRetries = value + return nil + } +} + +func WithRetryInitialDelay(value time.Duration) ConfigOption { + return func(c *Config) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + if value <= 0 { + return errors.InvalidRetryInitialDelay + } + c.Retry.InitialDelay = value + return nil + } +} + +func WithRetryMaxDelay(value time.Duration) ConfigOption { + return func(c *Config) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + if value <= 0 { + return errors.InvalidRetryMaxDelay + } + c.Retry.MaxDelay = value + return nil + } +} + +func WithRetryJitterFactor(value float64) ConfigOption { + return func(c *Config) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + if value < 0.0 || value > 1.0 { + return errors.InvalidRetryJitterFactor + } + c.Retry.JitterFactor = value + return nil + } +} diff --git a/ws/config_test.go b/ws/config_test.go new file mode 100644 index 0000000..6e7bf0e --- /dev/null +++ b/ws/config_test.go @@ -0,0 +1,274 @@ +package ws + +import ( + "git.wisehodl.dev/jay/go-honeybee/errors" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +// Config Tests + +func TestNewConfig(t *testing.T) { + conf, err := NewConfig(WithRetry()) + + assert.NoError(t, err) + assert.Equal(t, conf, &Config{ + Retry: GetDefaultRetryConfig(), + }) + + // errors propagate + _, err = NewConfig(WithRetryMaxRetries(-1)) + assert.Error(t, err) + + _, err = NewConfig(WithRetryInitialDelay(10), WithRetryMaxDelay(1)) + assert.Error(t, err) +} + +// Default Config Tests + +func TestDefaultConfig(t *testing.T) { + conf := GetDefaultConfig() + + assert.Nil(t, conf.CloseHandler) + assert.Nil(t, conf.Retry) +} + +func TestDefaultRetryConfig(t *testing.T) { + conf := GetDefaultRetryConfig() + + assert.Equal(t, conf, &RetryConfig{ + MaxRetries: 0, + InitialDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + JitterFactor: 0.5, + }) +} + +// Config Builder Tests + +func TestSetConfig(t *testing.T) { + conf := GetDefaultConfig() + err := SetConfig( + conf, + WithRetryMaxRetries(0), + WithRetryInitialDelay(3*time.Second), + WithRetryJitterFactor(0.5), + ) + + assert.NoError(t, err) + assert.Equal(t, 0, conf.Retry.MaxRetries) + assert.Equal(t, 3*time.Second, conf.Retry.InitialDelay) + assert.Equal(t, 0.5, conf.Retry.JitterFactor) + + // errors propagate + err = SetConfig( + conf, + WithRetryMaxRetries(-10), + ) + + assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries) +} + +// Config Option Tests + +func TestWithCloseHandler(t *testing.T) { + conf := GetDefaultConfig() + opt := WithCloseHandler(func(code int, text string) error { return nil }) + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.Nil(t, conf.CloseHandler(0, "")) +} + +func TestWithReadTimeout(t *testing.T) { + conf := GetDefaultConfig() + opt := WithReadTimeout(30) + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, conf.ReadTimeout, time.Duration(30)) + + // zero allowed + conf = GetDefaultConfig() + opt = WithReadTimeout(0) + err = SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, conf.ReadTimeout, time.Duration(0)) + + // negative disallowed + conf = GetDefaultConfig() + opt = WithReadTimeout(-30) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidReadTimeout) + assert.ErrorContains(t, err, "read timeout must be positive") +} + +func TestWithWriteTimeout(t *testing.T) { + conf := GetDefaultConfig() + opt := WithWriteTimeout(30) + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, conf.WriteTimeout, time.Duration(30)) + + // zero allowed + conf = GetDefaultConfig() + opt = WithWriteTimeout(0) + err = SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, conf.WriteTimeout, time.Duration(0)) + + // negative disallowed + conf = GetDefaultConfig() + opt = WithWriteTimeout(-30) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidWriteTimeout) + assert.ErrorContains(t, err, "write timeout must be positive") +} + +func TestWithRetry(t *testing.T) { + conf := GetDefaultConfig() + opt := WithRetry() + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.NotNil(t, conf.Retry) + assert.Equal(t, conf.Retry, GetDefaultRetryConfig()) +} + +func TestWithRetryAttempts(t *testing.T) { + conf := GetDefaultConfig() + opt := WithRetryMaxRetries(3) + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, 3, conf.Retry.MaxRetries) + + // zero allowed + opt = WithRetryMaxRetries(0) + err = SetConfig(conf, opt) + assert.NoError(t, err) + + // negative disallowed + opt = WithRetryMaxRetries(-10) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries) + assert.ErrorContains(t, err, "max retry count cannot be negative") +} + +func TestWithRetryInitialDelay(t *testing.T) { + conf := GetDefaultConfig() + opt := WithRetryInitialDelay(10 * time.Second) + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, 10*time.Second, conf.Retry.InitialDelay) + + // zero disallowed + opt = WithRetryInitialDelay(0 * time.Second) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay) + assert.ErrorContains(t, err, "initial delay must be positive") + + // negative disallowed + opt = WithRetryInitialDelay(-10 * time.Second) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay) +} + +func TestWithRetryMaxDelay(t *testing.T) { + conf := GetDefaultConfig() + opt := WithRetryMaxDelay(10 * time.Second) + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, 10*time.Second, conf.Retry.MaxDelay) + + // zero disallowed + opt = WithRetryMaxDelay(0 * time.Second) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay) + assert.ErrorContains(t, err, "max delay must be positive") + + // negative disallowed + opt = WithRetryMaxDelay(-10 * time.Second) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay) +} + +func TestWithRetryJitterFactor(t *testing.T) { + conf := GetDefaultConfig() + + opt := WithRetryJitterFactor(0.2) + err := SetConfig(conf, opt) + assert.NoError(t, err) + assert.Equal(t, 0.2, conf.Retry.JitterFactor) + + // negative disallowed + opt = WithRetryJitterFactor(-1) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor) + assert.ErrorContains(t, err, "jitter factor must be between 0.0 and 1.0") + + // >1 disallowed + opt = WithRetryJitterFactor(1.1) + err = SetConfig(conf, opt) + assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor) +} + +// Config Validation Tests + +func TestValidateConfig(t *testing.T) { + cases := []struct { + name string + conf Config + wantErr error + wantErrText string + }{ + { + name: "valid empty", + conf: *GetDefaultConfig(), + }, + { + name: "valid defaults", + conf: *GetDefaultConfig(), + }, + { + name: "valid complete", + conf: Config{ + CloseHandler: (func(code int, text string) error { return nil }), + ReadTimeout: time.Duration(30), + WriteTimeout: time.Duration(30), + Retry: &RetryConfig{ + MaxRetries: 0, + InitialDelay: 2 * time.Second, + MaxDelay: 10 * time.Second, + JitterFactor: 0.2, + }, + }, + }, + { + name: "invalid - initial delay > max delay", + conf: Config{ + Retry: &RetryConfig{ + InitialDelay: 10 * time.Second, + MaxDelay: 1 * time.Second, + }, + }, + wantErrText: "initial delay may not exceed maximum delay", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateConfig(&tc.conf) + + if tc.wantErr != nil || tc.wantErrText != "" { + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + } + + if tc.wantErrText != "" { + assert.ErrorContains(t, err, tc.wantErrText) + } + return + } + + assert.NoError(t, err) + }) + } +} diff --git a/ws/connection.go b/ws/connection.go new file mode 100644 index 0000000..e437d60 --- /dev/null +++ b/ws/connection.go @@ -0,0 +1,371 @@ +package ws + +import ( + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "git.wisehodl.dev/jay/go-honeybee/errors" + "github.com/gorilla/websocket" +) + +type Dialer interface { + Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error) +} + +func NewDialer() Dialer { + return NewGorillaDialer() +} + +type GorillaDialer struct { + *websocket.Dialer +} + +func NewGorillaDialer() *GorillaDialer { + return &GorillaDialer{ + Dialer: &websocket.Dialer{ + HandshakeTimeout: 45 * time.Second, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + } +} + +// Returns the Socket interface +func (d *GorillaDialer) Dial( + urlStr string, requestHeader http.Header, +) ( + Socket, *http.Response, error, +) { + conn, resp, err := d.Dialer.Dial(urlStr, requestHeader) + return conn, resp, err +} + +type Socket interface { + WriteMessage(messageType int, data []byte) error + ReadMessage() (messageType int, p []byte, err error) + Close() error + + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error + SetCloseHandler(h func(code int, text string) error) +} + +func AcquireSocket( + retryMgr *RetryManager, + dialer Dialer, + urlStr string, +) (Socket, *http.Response, error) { + if retryMgr == nil { + return nil, nil, errors.NewConnectionError("retry manager cannot be nil") + } + if dialer == nil { + return nil, nil, errors.NewConnectionError("dialer cannot be nil") + } + if urlStr == "" { + return nil, nil, errors.NewConnectionError("URL cannot be empty") + } + + for { + socket, resp, err := dialer.Dial(urlStr, nil) + if err == nil { + return socket, resp, nil + } + + if !retryMgr.ShouldRetry() { + return nil, nil, err + } + + delay := retryMgr.CalculateDelay() + time.Sleep(delay) + retryMgr.RecordRetry() + } +} + +type ConnectionState int + +const ( + StateDisconnected ConnectionState = iota + StateConnecting + StateConnected + StateClosed +) + +func (s ConnectionState) String() string { + switch s { + case StateDisconnected: + return "disconnected" + case StateConnecting: + return "connecting" + case StateConnected: + return "connected" + case StateClosed: + return "closed" + default: + return "unknown" + } +} + +type Connection struct { + url *url.URL + dialer Dialer + socket Socket + config *Config + + incoming chan []byte + outgoing chan []byte + errors chan error + done chan struct{} + + state ConnectionState + + wg sync.WaitGroup + once sync.Once + closed bool + mu sync.RWMutex +} + +func NewConnection(urlStr string, config *Config) (*Connection, error) { + if config == nil { + config = GetDefaultConfig() + } + + if err := ValidateConfig(config); err != nil { + return nil, err + } + + parsedURL, err := ParseURL(urlStr) + if err != nil { + return nil, err + } + + return &Connection{ + url: parsedURL, + dialer: NewDialer(), + socket: nil, + config: config, + incoming: make(chan []byte, 100), + outgoing: make(chan []byte, 100), + errors: make(chan error, 10), + state: StateDisconnected, + done: make(chan struct{}), + }, nil +} + +func NewConnectionFromSocket(socket Socket, config *Config) (*Connection, error) { + if socket == nil { + return nil, errors.NewConnectionError("socket cannot be nil") + } + + if config == nil { + config = GetDefaultConfig() + } + + if err := ValidateConfig(config); err != nil { + return nil, err + } + + conn := &Connection{ + url: nil, + dialer: nil, + socket: socket, + config: config, + incoming: make(chan []byte, 100), + outgoing: make(chan []byte, 100), + errors: make(chan error, 10), + state: StateConnected, + done: make(chan struct{}), + } + + if config.CloseHandler != nil { + socket.SetCloseHandler(config.CloseHandler) + } + + conn.startReader() + conn.startWriter() + + return conn, nil +} + +func (c *Connection) Connect() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.socket != nil { + return errors.NewConnectionError("connection already has socket") + } + + if c.closed { + return errors.NewConnectionError("connection is closed") + } + + c.state = StateConnecting + + retryMgr := NewRetryManager(c.config.Retry) + socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String()) + + if err != nil { + c.state = StateDisconnected + return err + } + + c.socket = socket + c.state = StateConnected + + if c.config.CloseHandler != nil { + c.socket.SetCloseHandler(c.config.CloseHandler) + } + + c.startReader() + c.startWriter() + + return nil +} + +func (c *Connection) startReader() { + c.wg.Add(1) + go func() { + defer c.wg.Done() + + for { + select { + case <-c.done: + return + default: + if c.config.ReadTimeout > 0 { + if err := c.socket.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil { + select { + case c.errors <- fmt.Errorf("failed to set read deadline: %w", err): + case <-c.done: + } + c.Close() + return + } + } + messageType, data, err := c.socket.ReadMessage() + if err != nil { + select { + case c.errors <- err: + case <-c.done: + } + c.Close() + return + } + + if messageType == websocket.TextMessage || + messageType == websocket.BinaryMessage { + select { + case c.incoming <- data: + case <-c.done: + c.Close() + return + } + } + + } + } + }() + +} + +func (c *Connection) startWriter() { + c.wg.Add(1) + go func() { + defer c.wg.Done() + + for { + select { + case <-c.done: + return + case data := <-c.outgoing: + if c.config.WriteTimeout > 0 { + if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { + select { + case c.errors <- fmt.Errorf("failed to set write deadline: %w", err): + case <-c.done: + } + c.Close() + return + } + } + + if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil { + select { + case c.errors <- err: + case <-c.done: + } + c.Close() + return + } + } + } + }() + +} + +func (c *Connection) Send(data []byte) error { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.closed { + return errors.NewConnectionError("connection closed") + } + + select { + case c.outgoing <- data: + return nil + default: + return errors.NewConnectionError("outgoing queue full") + } +} + +func (c *Connection) Incoming() <-chan []byte { + return c.incoming +} + +func (c *Connection) Errors() <-chan error { + return c.errors +} + +// Close shuts down the connection and waits for goroutines to exit. +// If the underlying socket blocks indefinitely on read or write operations, +// Close will also block. This is expected behavior - hung sockets require +// external intervention (timeouts, process termination, etc). +func (c *Connection) Close() error { + c.mu.Lock() + + alreadyClosed := c.closed + if !alreadyClosed { + c.closed = true + c.state = StateClosed + close(c.done) + } + + socket := c.socket + c.mu.Unlock() + + if alreadyClosed { + return nil + } + + var err error + if socket != nil { + err = socket.Close() + } + + c.wg.Wait() + + close(c.incoming) + close(c.outgoing) + close(c.errors) + + return err +} + +func (c *Connection) State() ConnectionState { + c.mu.RLock() + defer c.mu.RUnlock() + return c.state +} diff --git a/ws/connection_close_test.go b/ws/connection_close_test.go new file mode 100644 index 0000000..37550e1 --- /dev/null +++ b/ws/connection_close_test.go @@ -0,0 +1,158 @@ +package ws + +import ( + "fmt" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestDisconnectedConnectionClose(t *testing.T) { + t.Run("close succeeds on disconnected connection", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + assert.Equal(t, StateDisconnected, conn.State()) + + err = conn.Close() + assert.NoError(t, err) + assert.Equal(t, StateClosed, conn.State()) + }) + + t.Run("close is idempotent", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + err = conn.Close() + assert.NoError(t, err) + + // Second close should succeed without error + err = conn.Close() + assert.NoError(t, err) + assert.Equal(t, StateClosed, conn.State()) + }) + + t.Run("close with nil socket", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + assert.Nil(t, conn.socket) + + err = conn.Close() + assert.NoError(t, err) + assert.Equal(t, StateClosed, conn.State()) + }) + + t.Run("socket close error propagates", func(t *testing.T) { + expectedErr := fmt.Errorf("socket close failed") + mockSocket := NewMockSocket() + mockSocket.CloseFunc = func() error { + return expectedErr + } + + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + conn.socket = mockSocket + + err = conn.Close() + assert.Equal(t, expectedErr, err) + assert.Equal(t, StateClosed, conn.State()) + }) + + t.Run("channels close after close", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + err = conn.Close() + assert.NoError(t, err) + + // Verify incoming channel closed + select { + case _, ok := <-conn.incoming: + assert.False(t, ok, "incoming channel should be closed") + case <-time.After(50 * time.Millisecond): + t.Fatal("timeout waiting for incoming channel closure") + } + + // Verify outgoing channel closed + select { + case _, ok := <-conn.outgoing: + assert.False(t, ok, "outgoing channel should be closed") + case <-time.After(50 * time.Millisecond): + t.Fatal("timeout waiting for outgoing channel closure") + } + + // Verify errors channel closed + select { + case _, ok := <-conn.errors: + assert.False(t, ok, "errors channel should be closed") + case <-time.After(50 * time.Millisecond): + t.Fatal("timeout waiting for errors channel closure") + } + }) + + t.Run("send fails after close", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + err = conn.Close() + assert.NoError(t, err) + + err = conn.Send([]byte("test")) + assert.Error(t, err) + assert.ErrorContains(t, err, "connection closed") + }) + +} + +func TestConnectedConnectionClose(t *testing.T) { + t.Run("blocked on ReadMessage, unblocks on closed", func(t *testing.T) { + conn, _, incomingData, _ := setupTestConnection(t, nil) + + // Wait for reader to block + time.Sleep(10 * time.Millisecond) + + err := conn.Close() + assert.NoError(t, err) + assert.Equal(t, StateClosed, conn.State()) + + close(incomingData) + }) + + t.Run("writer active during close exits cleanly", func(t *testing.T) { + conn, _, _, outgoingData := setupTestConnection(t, nil) + + for i := 0; i < 50; i++ { + conn.Send([]byte("message")) + } + + err := conn.Close() + assert.NoError(t, err) + + err = conn.Send([]byte("late")) + assert.Error(t, err, "Send should fail after close") + assert.ErrorContains(t, err, "connection closed") + + close(outgoingData) + }) + + t.Run("both goroutines active during close", func(t *testing.T) { + conn, _, incomingData, outgoingData := setupTestConnection(t, nil) + + for i := 0; i < 10; i++ { + incomingData <- mockIncomingData{ + msgType: websocket.TextMessage, + data: []byte(fmt.Sprintf("in-%d", i)), + } + conn.Send([]byte(fmt.Sprintf("out-%d", i))) + } + + time.Sleep(10 * time.Millisecond) + + err := conn.Close() + assert.NoError(t, err) + + close(incomingData) + close(outgoingData) + }) +} diff --git a/ws/connection_goroutine_test.go b/ws/connection_goroutine_test.go new file mode 100644 index 0000000..7a5e077 --- /dev/null +++ b/ws/connection_goroutine_test.go @@ -0,0 +1,404 @@ +package ws + +import ( + "fmt" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "io" + "testing" + "time" +) + +func TestStartReader(t *testing.T) { + t.Run("text messages route to incoming channel", func(t *testing.T) { + conn, _, incomingData, _ := setupTestConnection(t, nil) + defer conn.Close() + + testData := []byte("hello") + incomingData <- mockIncomingData{ + msgType: websocket.TextMessage, + data: testData, + err: nil, + } + + expectIncoming(t, conn, testData) + }) + + t.Run("binary messages route to incoming channel", func(t *testing.T) { + conn, _, incomingData, _ := setupTestConnection(t, nil) + defer conn.Close() + + testData := []byte{0x00, 0x01, 0x02} + incomingData <- mockIncomingData{ + msgType: websocket.BinaryMessage, + data: testData, + err: nil, + } + + expectIncoming(t, conn, testData) + }) + + t.Run("multiple messages processed sequentially", func(t *testing.T) { + conn, _, incomingData, _ := setupTestConnection(t, nil) + defer conn.Close() + + messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")} + for _, msg := range messages { + incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: msg, err: nil} + } + + for _, expected := range messages { + expectIncoming(t, conn, expected) + } + }) + + t.Run("read timeout disabled when zero", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + + config := &Config{ReadTimeout: 0} + + mockSocket := NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + deadlineCalled := make(chan struct{}, 1) + mockSocket.SetReadDeadlineFunc = func(t time.Time) error { + deadlineCalled <- struct{}{} + return nil + } + + conn, err := NewConnectionFromSocket(mockSocket, config) + assert.NoError(t, err) + defer conn.Close() + + select { + case <-deadlineCalled: + t.Fatal("SetReadDeadline should not be called when timeout is zero") + case <-time.After(100 * time.Millisecond): + } + + }) + + t.Run("read timeout sets deadline when positive", func(t *testing.T) { + config := &Config{ReadTimeout: 30} + + incomingData := make(chan mockIncomingData, 10) + mockSocket := NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + deadlineCalled := make(chan struct{}, 1) + mockSocket.SetReadDeadlineFunc = func(t time.Time) error { + deadlineCalled <- struct{}{} + return nil + } + + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + select { + case data := <-incomingData: + return data.msgType, data.data, data.err + case <-mockSocket.closed: + return 0, nil, io.EOF + } + } + + conn, err := NewConnectionFromSocket(mockSocket, config) + assert.NoError(t, err) + defer conn.Close() + + incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil} + + select { + case <-conn.Incoming(): + case <-time.After(100 * time.Millisecond): + } + + select { + case _, ok := <-deadlineCalled: + assert.True(t, ok, "SetReadDeadline should be called when timeout is positive") + case <-time.After(100 * time.Millisecond): + t.Fatal("SetReadDeadline was never called") + } + }) + + t.Run("reader exits on deadline error", func(t *testing.T) { + config := &Config{ReadTimeout: 1 * time.Millisecond} + + mockSocket := NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + mockSocket.SetReadDeadlineFunc = func(t time.Time) error { + return fmt.Errorf("test error") + } + + conn, err := NewConnectionFromSocket(mockSocket, config) + assert.NoError(t, err) + defer conn.Close() + + select { + case err := <-conn.Errors(): + assert.ErrorContains(t, err, "failed to set read deadline") + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for deadline error") + } + + time.Sleep(10 * time.Millisecond) + assert.Equal(t, StateClosed, conn.State()) + + }) + + t.Run("reader exits on socket read error", func(t *testing.T) { + mockSocket := NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + readErr := fmt.Errorf("read failed") + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + return 0, nil, readErr + } + + conn, err := NewConnectionFromSocket(mockSocket, nil) + assert.NoError(t, err) + defer conn.Close() + + select { + case err := <-conn.Errors(): + assert.Equal(t, readErr, err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for read error") + } + + time.Sleep(10 * time.Millisecond) + assert.Equal(t, StateClosed, conn.State()) + + }) +} + +func TestStartWriter(t *testing.T) { + t.Run("data from outgoing triggers write", func(t *testing.T) { + conn, _, _, outgoingData := setupTestConnection(t, nil) + defer conn.Close() + + testData := []byte("test message") + err := conn.Send(testData) + assert.NoError(t, err) + + expectWrite(t, outgoingData, websocket.TextMessage, testData) + }) + + t.Run("multiple messages processed sequentially", func(t *testing.T) { + conn, _, _, outgoingData := setupTestConnection(t, nil) + defer conn.Close() + + messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")} + for _, msg := range messages { + err := conn.Send(msg) + assert.NoError(t, err) + } + + for _, expected := range messages { + expectWrite(t, outgoingData, websocket.TextMessage, expected) + } + }) + + t.Run("write timeout disabled when zero", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + + config := &Config{WriteTimeout: 0} + + outgoingData := make(chan mockOutgoingData, 10) + mockSocket := NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + deadlineCalled := make(chan struct{}, 1) + mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { + deadlineCalled <- struct{}{} + return nil + } + + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + select { + case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: + case <-mockSocket.closed: + return io.EOF + } + return nil + } + + conn, err := NewConnectionFromSocket(mockSocket, config) + assert.NoError(t, err) + defer conn.Close() + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + + time.Sleep(20 * time.Millisecond) + + select { + case <-deadlineCalled: + t.Fatal("SetWriteDeadline should not be called when timeout is zero") + case <-time.After(100 * time.Millisecond): + } + }) + + t.Run("write timeout sets deadline when positive", func(t *testing.T) { + config := &Config{WriteTimeout: 30 * time.Millisecond} + + outgoingData := make(chan mockOutgoingData, 10) + mockSocket := NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + deadlineCalled := make(chan struct{}, 1) + mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { + deadlineCalled <- struct{}{} + return nil + } + + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + select { + case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: + case <-mockSocket.closed: + return io.EOF + } + return nil + } + + conn, err := NewConnectionFromSocket(mockSocket, config) + assert.NoError(t, err) + defer conn.Close() + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + + time.Sleep(20 * time.Millisecond) + + select { + case _, ok := <-deadlineCalled: + assert.True(t, ok, "SetWriteDeadline should be called when timeout is positive") + case <-time.After(100 * time.Millisecond): + t.Fatal("SetWriteDeadline was never called") + } + }) + + t.Run("writer exits on deadline error", func(t *testing.T) { + config := &Config{WriteTimeout: 1 * time.Millisecond} + + mockSocket := NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { + return fmt.Errorf("test error") + } + + conn, err := NewConnectionFromSocket(mockSocket, config) + assert.NoError(t, err) + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + defer conn.Close() + + select { + case err := <-conn.Errors(): + assert.ErrorContains(t, err, "failed to set write deadline") + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for deadline error") + } + + time.Sleep(10 * time.Millisecond) + assert.Equal(t, StateClosed, conn.State()) + }) + + t.Run("writer exits on socket write error", func(t *testing.T) { + mockSocket := NewMockSocket() + + writeErr := fmt.Errorf("write failed") + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + return writeErr + } + + conn, err := NewConnectionFromSocket(mockSocket, nil) + assert.NoError(t, err) + defer conn.Close() + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + + select { + case err := <-conn.Errors(): + assert.Equal(t, writeErr, err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for write error") + } + + time.Sleep(10 * time.Millisecond) + assert.Equal(t, StateClosed, conn.State()) + }) +} + +// Helpers + +func expectIncoming(t *testing.T, conn *Connection, expected []byte) { + t.Helper() + + select { + case received := <-conn.Incoming(): + assert.Equal(t, expected, received) + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message") + } +} + +func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) { + t.Helper() + + select { + case call := <-outgoingData: + assert.Equal(t, msgType, call.msgType) + assert.Equal(t, expected, call.data) + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for write") + } +} diff --git a/ws/connection_send_test.go b/ws/connection_send_test.go new file mode 100644 index 0000000..11b40af --- /dev/null +++ b/ws/connection_send_test.go @@ -0,0 +1,113 @@ +package ws + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" +) + +func TestConnectionSend(t *testing.T) { + cases := []struct { + name string + setup func(*Connection) + data []byte + wantErr bool + wantErrText string + }{ + { + name: "send succeeds when open", + setup: func(c *Connection) {}, + data: []byte("test message"), + }, + { + name: "send fails when closed", + setup: func(c *Connection) { + c.Close() + }, + data: []byte("test"), + wantErr: true, + wantErrText: "connection closed", + }, + { + name: "send fails when queue full", + setup: func(c *Connection) { + // Fill outgoing channel + for i := 0; i < 100; i++ { + c.outgoing <- []byte("filler") + } + }, + data: []byte("overflow"), + wantErr: true, + wantErrText: "outgoing queue full", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + tc.setup(conn) + + err = conn.Send(tc.data) + + if tc.wantErr { + assert.Error(t, err) + if tc.wantErrText != "" { + assert.ErrorContains(t, err, tc.wantErrText) + } + return + } + + assert.NoError(t, err) + + // Verify data appeared on outgoing channel + select { + case sent := <-conn.outgoing: + assert.Equal(t, tc.data, sent) + case <-time.After(50 * time.Millisecond): + t.Fatal("timeout: data not sent to outgoing channel") + } + }) + } +} + +// Run with `go test -race` to ensure no race conditions occur +func TestConnectionSendConcurrent(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + // continuously consume outgoing channel in background + done := make(chan struct{}) + go func() { + for { + select { + case <-conn.outgoing: + case <-done: + return + } + } + }() + defer close(done) + + // Send from multiple goroutines concurrently + const goroutines = 5 + const messagesPerGoroutine = 10 + var wg sync.WaitGroup + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < messagesPerGoroutine; j++ { + data := []byte(fmt.Sprintf("msg-%d-%d", id, j)) + err := conn.Send(data) + assert.NoError(t, err) + } + }(i) + } + + wg.Wait() +} diff --git a/ws/connection_socket_test.go b/ws/connection_socket_test.go new file mode 100644 index 0000000..dc094b4 --- /dev/null +++ b/ws/connection_socket_test.go @@ -0,0 +1,143 @@ +package ws + +import ( + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "testing" + "time" +) + +func TestNewDialer(t *testing.T) { + dialer := NewDialer() + + assert.NotNil(t, dialer) + _, ok := dialer.(*GorillaDialer) + assert.True(t, ok, "NewDialer should return *GorillaDialer") +} + +func TestNewGorillaDialer(t *testing.T) { + dialer := NewGorillaDialer() + + assert.NotNil(t, dialer) + assert.NotNil(t, dialer.Dialer) + assert.Equal(t, 45*time.Second, dialer.Dialer.HandshakeTimeout) + assert.Equal(t, 1024, dialer.Dialer.ReadBufferSize) + assert.Equal(t, 1024, dialer.Dialer.WriteBufferSize) +} + +func TestAcquireSocket(t *testing.T) { + cases := []struct { + name string + mockRuns []error + maxRetries int + wantRetryCount int + wantErr bool + }{ + { + name: "immediate success", + mockRuns: []error{nil}, + maxRetries: 3, + wantRetryCount: 0, + wantErr: false, + }, + { + name: "two failures, success", + mockRuns: []error{errors.New("1"), errors.New("2"), nil}, + maxRetries: 0, + wantRetryCount: 2, + wantErr: false, + }, + { + name: "three failures, failure", + mockRuns: []error{errors.New("1"), errors.New("2"), errors.New("3"), errors.New("4")}, + maxRetries: 3, + wantRetryCount: 3, + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + attemptIndex := 0 + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + err := tc.mockRuns[attemptIndex] + attemptIndex++ + if err != nil { + return nil, nil, err + } + return NewMockSocket(), nil, nil + }, + } + + retryMgr := NewRetryManager(&RetryConfig{ + MaxRetries: tc.maxRetries, + InitialDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + JitterFactor: 0.0, + }) + + socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test") + + assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount()) + if tc.wantErr { + assert.Error(t, err) + assert.Nil(t, socket) + } else { + assert.NoError(t, err) + assert.NotNil(t, socket) + } + }) + } +} + +func TestAcquireSocketGuards(t *testing.T) { + validDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return NewMockSocket(), nil, nil + }, + } + validRetryMgr := NewRetryManager(GetDefaultRetryConfig()) + + cases := []struct { + name string + retryMgr *RetryManager + dialer Dialer + url string + wantErr string + }{ + { + name: "nil retry manager", + retryMgr: nil, + dialer: validDialer, + url: "ws://test", + wantErr: "retry manager cannot be nil", + }, + { + name: "nil dialer", + retryMgr: validRetryMgr, + dialer: nil, + url: "ws://test", + wantErr: "dialer cannot be nil", + }, + { + name: "empty URL", + retryMgr: validRetryMgr, + dialer: validDialer, + url: "", + wantErr: "URL cannot be empty", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url) + + assert.Error(t, err) + assert.ErrorContains(t, err, tc.wantErr) + assert.Nil(t, socket) + assert.Nil(t, resp) + }) + } +} diff --git a/ws/connection_test.go b/ws/connection_test.go new file mode 100644 index 0000000..4012fd8 --- /dev/null +++ b/ws/connection_test.go @@ -0,0 +1,438 @@ +package ws + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "net/http" + "testing" + "time" +) + +// Connection state tests + +func TestConnectionStateString(t *testing.T) { + cases := []struct { + state ConnectionState + want string + }{ + {StateDisconnected, "disconnected"}, + {StateConnecting, "connecting"}, + {StateConnected, "connected"}, + {StateClosed, "closed"}, + {ConnectionState(99), "unknown"}, + } + + for _, tc := range cases { + t.Run(tc.want, func(t *testing.T) { + assert.Equal(t, tc.want, tc.state.String()) + }) + } +} + +func TestConnectionState(t *testing.T) { + // Test initial state + conn, _ := NewConnection("ws://test", nil) + assert.Equal(t, StateDisconnected, conn.State()) + + // Test state after FromSocket (should be Connected) + conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil) + assert.Equal(t, StateConnected, conn2.State()) + + // Test state after close + conn.Close() + assert.Equal(t, StateClosed, conn.State()) +} + +// Connection constructor tests + +func TestNewConnection(t *testing.T) { + cases := []struct { + name string + url string + config *Config + wantErr bool + wantErrText string + }{ + { + name: "valid url, nil config", + url: "ws://example.com", + config: nil, + }, + { + name: "valid url, valid config", + url: "wss://relay.example.com:8080/path", + config: &Config{ReadTimeout: 30 * time.Second}, + }, + { + name: "invalid url", + url: "http://example.com", + config: nil, + wantErr: true, + wantErrText: "URL must use ws:// or wss:// scheme", + }, + { + name: "invalid config", + url: "ws://example.com", + config: &Config{ + Retry: &RetryConfig{ + InitialDelay: 10 * time.Second, + MaxDelay: 1 * time.Second, + }, + }, + wantErr: true, + wantErrText: "initial delay may not exceed maximum delay", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + conn, err := NewConnection(tc.url, tc.config) + + if tc.wantErr { + assert.Error(t, err) + if tc.wantErrText != "" { + assert.ErrorContains(t, err, tc.wantErrText) + } + assert.Nil(t, conn) + return + } + + assert.NoError(t, err) + assert.NotNil(t, conn) + + // Verify struct fields + assert.NotNil(t, conn.url) + assert.NotNil(t, conn.dialer) + assert.Nil(t, conn.socket) + assert.NotNil(t, conn.config) + assert.NotNil(t, conn.incoming) + assert.NotNil(t, conn.outgoing) + assert.NotNil(t, conn.errors) + assert.NotNil(t, conn.done) + assert.Equal(t, StateDisconnected, conn.state) + assert.False(t, conn.closed) + + // Verify default config is used if nil is passed + if tc.config == nil { + assert.Equal(t, GetDefaultConfig(), conn.config) + } else { + assert.Equal(t, tc.config, conn.config) + } + }) + } +} + +func TestNewConnectionFromSocket(t *testing.T) { + cases := []struct { + name string + socket Socket + config *Config + wantErr bool + wantErrText string + }{ + { + name: "nil socket", + socket: nil, + config: nil, + wantErr: true, + wantErrText: "socket cannot be nil", + }, + { + name: "valid socket with nil config", + socket: NewMockSocket(), + config: nil, + }, + { + name: "valid socket with valid config", + socket: NewMockSocket(), + config: &Config{ReadTimeout: 30 * time.Second}, + }, + { + name: "invalid config", + socket: NewMockSocket(), + config: &Config{ + Retry: &RetryConfig{ + InitialDelay: 10 * time.Second, + MaxDelay: 1 * time.Second, + }, + }, + wantErr: true, + wantErrText: "initial delay may not exceed maximum delay", + }, + { + name: "close handler set when provided", + socket: NewMockSocket(), + config: &Config{ + CloseHandler: func(code int, text string) error { + return nil + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // track if SetCloseHandler was called + closeHandlerSet := false + if tc.socket != nil { + mockSocket := tc.socket.(*MockSocket) + originalSetCloseHandler := mockSocket.SetCloseHandlerFunc + + // wrapper around the original handler function + mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) { + closeHandlerSet = true + if originalSetCloseHandler != nil { + originalSetCloseHandler(h) + } + } + } + + conn, err := NewConnectionFromSocket(tc.socket, tc.config) + + if tc.wantErr { + assert.Error(t, err) + if tc.wantErrText != "" { + assert.ErrorContains(t, err, tc.wantErrText) + } + assert.Nil(t, conn) + return + } + + assert.NoError(t, err) + assert.NotNil(t, conn) + + // Verify fields initialized correctly + assert.Nil(t, conn.url) + assert.Nil(t, conn.dialer) + assert.Equal(t, tc.socket, conn.socket) + assert.NotNil(t, conn.config) + assert.NotNil(t, conn.incoming) + assert.NotNil(t, conn.outgoing) + assert.NotNil(t, conn.errors) + assert.NotNil(t, conn.done) + assert.Equal(t, StateConnected, conn.state) + assert.False(t, conn.closed) + + // Verify config defaulting + if tc.config == nil { + assert.Equal(t, GetDefaultConfig(), conn.config) + } else { + assert.Equal(t, tc.config, conn.config) + } + + // Verify close handler was set if provided + if tc.config != nil && tc.config.CloseHandler != nil { + assert.True(t, closeHandlerSet, "CloseHandler should be set on socket") + } + }) + } +} + +// ws/connection_test.go + +// Add to existing file after TestNewConnectionFromSocket + +func TestConnect(t *testing.T) { + t.Run("connect fails when socket already present", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + conn.socket = NewMockSocket() + + err = conn.Connect() + assert.Error(t, err) + assert.ErrorContains(t, err, "already has socket") + assert.Equal(t, StateDisconnected, conn.State()) + }) + + t.Run("connect fails when connection closed", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + conn.Close() + + err = conn.Connect() + assert.Error(t, err) + assert.ErrorContains(t, err, "connection is closed") + assert.Equal(t, StateClosed, conn.State()) + }) + + t.Run("connect succeeds and starts goroutines", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + outgoingData := make(chan mockOutgoingData, 10) + + mockSocket := NewMockSocket() + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + outgoingData <- mockOutgoingData{msgType: msgType, data: data} + return nil + } + + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return mockSocket, nil, nil + }, + } + conn.dialer = mockDialer + + err = conn.Connect() + assert.NoError(t, err) + assert.Equal(t, StateConnected, conn.State()) + + testData := []byte("test") + conn.Send(testData) + + time.Sleep(10 * time.Millisecond) + + select { + case msg := <-outgoingData: + assert.Equal(t, testData, msg.data) + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message write") + } + + conn.Close() + close(outgoingData) + }) + + t.Run("connect retries on dial failure", func(t *testing.T) { + config := &Config{ + Retry: &RetryConfig{ + MaxRetries: 2, + InitialDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + JitterFactor: 0.0, + }, + } + conn, err := NewConnection("ws://test", config) + assert.NoError(t, err) + + attemptCount := 0 + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + attemptCount++ + if attemptCount < 3 { + return nil, nil, fmt.Errorf("dial failed") + } + return NewMockSocket(), nil, nil + }, + } + conn.dialer = mockDialer + + err = conn.Connect() + assert.NoError(t, err) + assert.Equal(t, 3, attemptCount) + assert.Equal(t, StateConnected, conn.State()) + + conn.Close() + }) + + t.Run("connect fails after max retries", func(t *testing.T) { + config := &Config{ + Retry: &RetryConfig{ + MaxRetries: 2, + InitialDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + JitterFactor: 0.0, + }, + } + conn, err := NewConnection("ws://test", config) + assert.NoError(t, err) + + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return nil, nil, fmt.Errorf("dial failed") + }, + } + conn.dialer = mockDialer + + err = conn.Connect() + assert.Error(t, err) + assert.ErrorContains(t, err, "dial failed") + assert.Equal(t, StateDisconnected, conn.State()) + }) + + t.Run("state transitions during connect", func(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + assert.Equal(t, StateDisconnected, conn.State()) + + stateDuringDial := StateDisconnected + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + stateDuringDial = conn.state + return NewMockSocket(), nil, nil + }, + } + conn.dialer = mockDialer + + conn.Connect() + + assert.Equal(t, StateConnecting, stateDuringDial) + assert.Equal(t, StateConnected, conn.State()) + + conn.Close() + }) + + t.Run("close handler configured when provided", func(t *testing.T) { + handlerSet := false + config := &Config{ + CloseHandler: func(code int, text string) error { + return nil + }, + } + conn, err := NewConnection("ws://test", config) + assert.NoError(t, err) + + mockSocket := NewMockSocket() + mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) { + handlerSet = true + } + + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return mockSocket, nil, nil + }, + } + conn.dialer = mockDialer + + conn.Connect() + + assert.True(t, handlerSet, "close handler should be set on socket") + + conn.Close() + }) +} + +// Connection method tests + +func TestConnectionIncoming(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + incoming := conn.Incoming() + assert.NotNil(t, incoming) + + // send data through the channel to verify they are the same + testData := []byte("test") + conn.incoming <- testData + received := <-incoming + assert.Equal(t, testData, received) +} + +func TestConnectionErrors(t *testing.T) { + conn, err := NewConnection("ws://test", nil) + assert.NoError(t, err) + + errors := conn.Errors() + assert.NotNil(t, errors) + + // send data through the channel to verify they are the same + testErr := fmt.Errorf("test error") + conn.errors <- testErr + received := <-errors + assert.Equal(t, testErr, received) +} + +// Connect() tests diff --git a/ws/mocks_test.go b/ws/mocks_test.go new file mode 100644 index 0000000..9794dde --- /dev/null +++ b/ws/mocks_test.go @@ -0,0 +1,132 @@ +package ws + +import ( + "github.com/stretchr/testify/assert" + "io" + "net/http" + "sync" + "testing" + "time" +) + +// Dialer Mocks + +type MockDialer struct { + DialFunc func(string, http.Header) (Socket, *http.Response, error) +} + +func (m *MockDialer) Dial(url string, h http.Header) (Socket, *http.Response, error) { + return m.DialFunc(url, h) +} + +// Socket Mocks + +type MockSocket struct { + WriteMessageFunc func(int, []byte) error + SetReadDeadlineFunc func(t time.Time) error + SetWriteDeadlineFunc func(t time.Time) error + ReadMessageFunc func() (int, []byte, error) + CloseFunc func() error + SetCloseHandlerFunc func(func(int, string) error) + closed chan struct{} + once sync.Once +} + +func NewMockSocket() *MockSocket { + return &MockSocket{ + WriteMessageFunc: func(int, []byte) error { return nil }, + ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil }, + CloseFunc: func() error { return nil }, + + SetReadDeadlineFunc: func(time.Time) error { return nil }, + SetWriteDeadlineFunc: func(time.Time) error { return nil }, + SetCloseHandlerFunc: func(func(int, string) error) {}, + + closed: make(chan struct{}), + } + +} + +func (m *MockSocket) WriteMessage(t int, d []byte) error { + return m.WriteMessageFunc(t, d) +} + +func (m *MockSocket) ReadMessage() (int, []byte, error) { + return m.ReadMessageFunc() +} + +func (m *MockSocket) Close() error { + return m.CloseFunc() +} + +func (m *MockSocket) SetReadDeadline(t time.Time) error { + return m.SetReadDeadlineFunc(t) +} + +func (m *MockSocket) SetWriteDeadline(t time.Time) error { + return m.SetWriteDeadlineFunc(t) +} + +func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) { + m.SetCloseHandlerFunc(h) +} + +// Connection Mocks + +type mockIncomingData struct { + msgType int + data []byte + err error +} + +type mockOutgoingData struct { + msgType int + data []byte +} + +func setupTestConnection(t *testing.T, config *Config) ( + conn *Connection, + mockSocket *MockSocket, + incomingData chan mockIncomingData, + outgoingData chan mockOutgoingData, +) { + t.Helper() + + incomingData = make(chan mockIncomingData, 10) + outgoingData = make(chan mockOutgoingData, 10) + + mockSocket = NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.once.Do(func() { + close(mockSocket.closed) + }) + return nil + } + + // Wire ReadMessage to pull from incomingData channel + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + select { + case data := <-incomingData: + return data.msgType, data.data, data.err + case <-mockSocket.closed: + return 0, nil, io.EOF + } + } + + // Wire WriteMessage to push to outgoingData channel + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + select { + case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: + case <-mockSocket.closed: + return io.EOF + } + return nil + } + + var err error + conn, err = NewConnectionFromSocket(mockSocket, config) + assert.NoError(t, err) + + return conn, mockSocket, incomingData, outgoingData +} diff --git a/ws/retry.go b/ws/retry.go new file mode 100644 index 0000000..4fbb7da --- /dev/null +++ b/ws/retry.go @@ -0,0 +1,66 @@ +package ws + +import ( + "math" + "math/rand" + "time" +) + +type RetryManager struct { + config *RetryConfig + retryCount int +} + +func NewRetryManager(config *RetryConfig) *RetryManager { + return &RetryManager{ + config: config, + retryCount: 0, + } +} + +func (r *RetryManager) ShouldRetry() bool { + if r.config == nil { + return false + } + + if r.config.MaxRetries > 0 && r.retryCount >= r.config.MaxRetries { + return false + } + + return true +} + +func (r *RetryManager) CalculateDelay() time.Duration { + if r.config == nil { + return time.Second + } + + // First attempt: immediate retry + if r.retryCount == 0 { + return 0 + } + + // Exponential backoff: InitialDelay * 2^(attempts-1) + backoffMultiplier := math.Pow(2, float64(r.retryCount-1)) + baseDelay := float64(r.config.InitialDelay) * backoffMultiplier + + // Apply jitter: delay * (1 + jitterFactor * (random - 0.5)) + random := rand.Float64() + jitterMultiplier := 1 + r.config.JitterFactor*(random-0.5) + delay := time.Duration(baseDelay * jitterMultiplier) + + // Cap at MaxDelay + if delay > r.config.MaxDelay { + delay = r.config.MaxDelay + } + + return delay +} + +func (m *RetryManager) RecordRetry() { + m.retryCount++ +} + +func (m *RetryManager) RetryCount() int { + return m.retryCount +} diff --git a/ws/retry_test.go b/ws/retry_test.go new file mode 100644 index 0000000..5f5498e --- /dev/null +++ b/ws/retry_test.go @@ -0,0 +1,147 @@ +package ws + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestNewRetryManager(t *testing.T) { + config := &RetryConfig{ + MaxRetries: 0, + } + + mgr := NewRetryManager(config) + + assert.Equal(t, config, mgr.config) + assert.Equal(t, 0, mgr.retryCount) + + // Should accept nil config + mgr = NewRetryManager(nil) + assert.Nil(t, mgr.config) + assert.Equal(t, 0, mgr.retryCount) +} + +func TestRecordRetry(t *testing.T) { + mgr := NewRetryManager(nil) + assert.Equal(t, mgr.retryCount, 0) + + mgr.RecordRetry() + assert.Equal(t, mgr.retryCount, 1) + + mgr.RecordRetry() + assert.Equal(t, mgr.retryCount, 2) +} + +func TestShouldRetry(t *testing.T) { + // never retry if config is nil + mgr := NewRetryManager(nil) + assert.False(t, mgr.ShouldRetry()) + + // always retry if max attempt count is zero + mgr = &RetryManager{ + config: &RetryConfig{ + MaxRetries: 0, + }, + retryCount: 1000, + } + assert.True(t, mgr.ShouldRetry()) + + // retry if below max attempt count + mgr = &RetryManager{ + config: &RetryConfig{ + MaxRetries: 10, + }, + retryCount: 5, + } + assert.True(t, mgr.ShouldRetry()) + + // do not retry if above max attempt count + mgr = &RetryManager{ + config: &RetryConfig{ + MaxRetries: 10, + }, + retryCount: 11, + } + assert.False(t, mgr.ShouldRetry()) +} + +func TestCalculateDelayDisabled(t *testing.T) { + // default delay if retry is disabled + mgr := NewRetryManager(nil) + assert.Equal(t, time.Second, mgr.CalculateDelay()) +} + +func TestCalculateDelayWithoutJitter(t *testing.T) { + mgr := NewRetryManager(&RetryConfig{ + MaxRetries: 0, + InitialDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + JitterFactor: 0.0, + }) + + // Retry 0: immediate + assert.Equal(t, 0*time.Second, mgr.CalculateDelay()) + mgr.RecordRetry() + + // Retry 1: 1s * 2^0 = 1s + assert.Equal(t, 1*time.Second, mgr.CalculateDelay()) + mgr.RecordRetry() + + // Retry 2: 1s * 2^1 = 2s + assert.Equal(t, 2*time.Second, mgr.CalculateDelay()) + mgr.RecordRetry() + + // Retry 3: 1s * 2^2 = 4s + assert.Equal(t, 4*time.Second, mgr.CalculateDelay()) + mgr.RecordRetry() + + // Retry 4: 1s * 2^3 = 8s, capped at 5s + assert.Equal(t, 5*time.Second, mgr.CalculateDelay()) + mgr.RecordRetry() + + // Retry 5: Still capped at 5s + assert.Equal(t, 5*time.Second, mgr.CalculateDelay()) +} + +func TestCalculateDelayWithJitter(t *testing.T) { + mgr := NewRetryManager(&RetryConfig{ + MaxRetries: 0, + InitialDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + JitterFactor: 0.5, + }) + + // Retry 0: immediate + assert.Equal(t, 0*time.Second, mgr.CalculateDelay()) + mgr.RecordRetry() + + // Retry 1: 1s * 2^0 = 1s (with jitter) + delay := mgr.CalculateDelay() + assert.GreaterOrEqual(t, delay, 750*time.Millisecond) + assert.LessOrEqual(t, delay, 1250*time.Millisecond) + mgr.RecordRetry() + + // Retry 2: 1s * 2^1 = 2s (with jitter) + delay = mgr.CalculateDelay() + assert.GreaterOrEqual(t, delay, 1500*time.Millisecond) + assert.LessOrEqual(t, delay, 2500*time.Millisecond) + mgr.RecordRetry() + + // Retry 3: 1s * 2^2 = 4s (with jitter) + delay = mgr.CalculateDelay() + assert.GreaterOrEqual(t, delay, 3*time.Second) + assert.LessOrEqual(t, delay, 5*time.Second) + mgr.RecordRetry() + + // Retry 4: 1s * 2^3 = 8s, capped at 5s (with jitter) + delay = mgr.CalculateDelay() + assert.GreaterOrEqual(t, delay, 3750*time.Millisecond) + assert.LessOrEqual(t, delay, 5*time.Second) + mgr.RecordRetry() + + // Retry 5: Still capped at 5s (with jitter) + delay = mgr.CalculateDelay() + assert.GreaterOrEqual(t, delay, 3750*time.Millisecond) + assert.LessOrEqual(t, delay, 5*time.Second) +} diff --git a/ws/url.go b/ws/url.go new file mode 100644 index 0000000..99c2983 --- /dev/null +++ b/ws/url.go @@ -0,0 +1,20 @@ +package ws + +import ( + "net/url" + + "git.wisehodl.dev/jay/go-honeybee/errors" +) + +func ParseURL(urlStr string) (*url.URL, error) { + parsedURL, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + + if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" { + return nil, errors.InvalidProtocol + } + + return parsedURL, nil +} diff --git a/ws/url_test.go b/ws/url_test.go new file mode 100644 index 0000000..0a25b8b --- /dev/null +++ b/ws/url_test.go @@ -0,0 +1,93 @@ +package ws + +import ( + "git.wisehodl.dev/jay/go-honeybee/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseURL(t *testing.T) { + type wantURL struct { + scheme string + host string + path string + } + + cases := []struct { + name string + url string + want wantURL + wantErr error + wantErrText string + }{ + { + name: "valid ws url", + url: "ws://localhost:8080/relay", + want: wantURL{ + scheme: "ws", + host: "localhost:8080", + path: "/relay", + }, + }, + { + name: "valid wss url", + url: "wss://relay.example.com", + want: wantURL{ + scheme: "wss", + host: "relay.example.com", + path: "", + }, + }, + { + name: "http scheme rejected", + url: "http://example.com", + wantErr: errors.InvalidProtocol, + }, + { + name: "missing scheme", + url: "example.com:8080", + wantErr: errors.InvalidProtocol, + }, + { + name: "empty string", + url: "", + wantErr: errors.InvalidProtocol, + }, + { + name: "malformed url", + url: "ws://[::1:8080", + wantErrText: "missing ']' in host", + }, + { + name: "ipv6 address", + url: "ws://[::1]:8080/relay", + want: wantURL{ + scheme: "ws", + host: "[::1]:8080", + path: "/relay", + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseURL(tc.url) + + if tc.wantErr != nil || tc.wantErrText != "" { + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + } + + if tc.wantErrText != "" { + assert.ErrorContains(t, err, tc.wantErrText) + } + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.want.scheme, got.Scheme) + assert.Equal(t, tc.want.host, got.Host) + assert.Equal(t, tc.want.path, got.Path) + }) + } +} diff --git a/ws/ws_test.go b/ws/ws_test.go new file mode 100644 index 0000000..060be4a --- /dev/null +++ b/ws/ws_test.go @@ -0,0 +1,2 @@ +// ws package end-to-end tests +package ws