Completed core connection wrapper.
This commit is contained in:
2
c2p
2
c2p
@@ -1,2 +1,2 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
code2prompt -c -i c2p -i go.sum -i LICENSE
|
code2prompt -c -e c2p -e go.sum -e LICENSE
|
||||||
|
|||||||
25
errors/errors.go
Normal file
25
errors/errors.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package errors
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// URL Errors
|
||||||
|
InvalidProtocol = errors.New("URL must use ws:// or wss:// scheme")
|
||||||
|
|
||||||
|
// Configuration Errors
|
||||||
|
InvalidReadTimeout = errors.New("read timeout must be positive")
|
||||||
|
InvalidWriteTimeout = errors.New("write timeout must be positive")
|
||||||
|
InvalidRetryMaxRetries = errors.New("max retry count cannot be negative")
|
||||||
|
InvalidRetryInitialDelay = errors.New("initial delay must be positive")
|
||||||
|
InvalidRetryMaxDelay = errors.New("max delay must be positive")
|
||||||
|
InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0")
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewConfigError(text string) error {
|
||||||
|
return fmt.Errorf("configuration error: %s", text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionError(text string) error {
|
||||||
|
return fmt.Errorf("connection error: %s", text)
|
||||||
|
}
|
||||||
10
go.mod
10
go.mod
@@ -3,6 +3,12 @@ module git.wisehodl.dev/jay/go-honeybee
|
|||||||
go 1.23.5
|
go 1.23.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/stretchr/testify v1.11.1 // indirect
|
github.com/stretchr/testify v1.11.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -1,4 +1,12 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
1
honeybee.go
Normal file
1
honeybee.go
Normal file
@@ -0,0 +1 @@
|
|||||||
|
package honeybee
|
||||||
163
ws/config.go
Normal file
163
ws/config.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CloseHandler func(code int, text string) error
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
CloseHandler CloseHandler
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
Retry *RetryConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetryConfig struct {
|
||||||
|
MaxRetries int
|
||||||
|
InitialDelay time.Duration
|
||||||
|
MaxDelay time.Duration
|
||||||
|
JitterFactor float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigOption func(*Config) error
|
||||||
|
|
||||||
|
func NewConfig(options ...ConfigOption) (*Config, error) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
if err := SetConfig(conf, options...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ValidateConfig(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDefaultConfig() *Config {
|
||||||
|
return &Config{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDefaultRetryConfig() *RetryConfig {
|
||||||
|
return &RetryConfig{
|
||||||
|
MaxRetries: 0, // Infinite retries
|
||||||
|
InitialDelay: 1 * time.Second,
|
||||||
|
MaxDelay: 5 * time.Second,
|
||||||
|
JitterFactor: 0.5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetConfig(config *Config, options ...ConfigOption) error {
|
||||||
|
for _, option := range options {
|
||||||
|
if err := option(config); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateConfig(config *Config) error {
|
||||||
|
if config.Retry != nil {
|
||||||
|
if config.Retry.InitialDelay > config.Retry.MaxDelay {
|
||||||
|
return errors.NewConfigError("initial delay may not exceed maximum delay")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configuration Options
|
||||||
|
|
||||||
|
func WithCloseHandler(handler CloseHandler) ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
c.CloseHandler = handler
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// When ReadTimeout is set to zero, read timeouts are disabled.
|
||||||
|
func WithReadTimeout(value time.Duration) ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
if value < 0 {
|
||||||
|
return errors.InvalidReadTimeout
|
||||||
|
}
|
||||||
|
c.ReadTimeout = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// When WriteTimeout is set to zero, read timeouts are disabled.
|
||||||
|
func WithWriteTimeout(value time.Duration) ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
if value < 0 {
|
||||||
|
return errors.InvalidWriteTimeout
|
||||||
|
}
|
||||||
|
c.WriteTimeout = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetry enables retry with default parameters (infinite retries,
|
||||||
|
// 1s initial delay, 5s max delay, 0.5 jitter factor).
|
||||||
|
//
|
||||||
|
// If passed after granular retry options (WithRetryMaxRetries, etc.),
|
||||||
|
// it will overwrite them. Use either WithRetry alone or the granular
|
||||||
|
// options; not both.
|
||||||
|
func WithRetry() ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
c.Retry = GetDefaultRetryConfig()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryMaxRetries(value int) ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
if c.Retry == nil {
|
||||||
|
c.Retry = GetDefaultRetryConfig()
|
||||||
|
}
|
||||||
|
if value < 0 {
|
||||||
|
return errors.InvalidRetryMaxRetries
|
||||||
|
}
|
||||||
|
c.Retry.MaxRetries = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryInitialDelay(value time.Duration) ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
if c.Retry == nil {
|
||||||
|
c.Retry = GetDefaultRetryConfig()
|
||||||
|
}
|
||||||
|
if value <= 0 {
|
||||||
|
return errors.InvalidRetryInitialDelay
|
||||||
|
}
|
||||||
|
c.Retry.InitialDelay = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryMaxDelay(value time.Duration) ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
if c.Retry == nil {
|
||||||
|
c.Retry = GetDefaultRetryConfig()
|
||||||
|
}
|
||||||
|
if value <= 0 {
|
||||||
|
return errors.InvalidRetryMaxDelay
|
||||||
|
}
|
||||||
|
c.Retry.MaxDelay = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryJitterFactor(value float64) ConfigOption {
|
||||||
|
return func(c *Config) error {
|
||||||
|
if c.Retry == nil {
|
||||||
|
c.Retry = GetDefaultRetryConfig()
|
||||||
|
}
|
||||||
|
if value < 0.0 || value > 1.0 {
|
||||||
|
return errors.InvalidRetryJitterFactor
|
||||||
|
}
|
||||||
|
c.Retry.JitterFactor = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
274
ws/config_test.go
Normal file
274
ws/config_test.go
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config Tests
|
||||||
|
|
||||||
|
func TestNewConfig(t *testing.T) {
|
||||||
|
conf, err := NewConfig(WithRetry())
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, conf, &Config{
|
||||||
|
Retry: GetDefaultRetryConfig(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// errors propagate
|
||||||
|
_, err = NewConfig(WithRetryMaxRetries(-1))
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
_, err = NewConfig(WithRetryInitialDelay(10), WithRetryMaxDelay(1))
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default Config Tests
|
||||||
|
|
||||||
|
func TestDefaultConfig(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
|
||||||
|
assert.Nil(t, conf.CloseHandler)
|
||||||
|
assert.Nil(t, conf.Retry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultRetryConfig(t *testing.T) {
|
||||||
|
conf := GetDefaultRetryConfig()
|
||||||
|
|
||||||
|
assert.Equal(t, conf, &RetryConfig{
|
||||||
|
MaxRetries: 0,
|
||||||
|
InitialDelay: 1 * time.Second,
|
||||||
|
MaxDelay: 5 * time.Second,
|
||||||
|
JitterFactor: 0.5,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config Builder Tests
|
||||||
|
|
||||||
|
func TestSetConfig(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
err := SetConfig(
|
||||||
|
conf,
|
||||||
|
WithRetryMaxRetries(0),
|
||||||
|
WithRetryInitialDelay(3*time.Second),
|
||||||
|
WithRetryJitterFactor(0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, conf.Retry.MaxRetries)
|
||||||
|
assert.Equal(t, 3*time.Second, conf.Retry.InitialDelay)
|
||||||
|
assert.Equal(t, 0.5, conf.Retry.JitterFactor)
|
||||||
|
|
||||||
|
// errors propagate
|
||||||
|
err = SetConfig(
|
||||||
|
conf,
|
||||||
|
WithRetryMaxRetries(-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config Option Tests
|
||||||
|
|
||||||
|
func TestWithCloseHandler(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
opt := WithCloseHandler(func(code int, text string) error { return nil })
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, conf.CloseHandler(0, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithReadTimeout(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
opt := WithReadTimeout(30)
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, conf.ReadTimeout, time.Duration(30))
|
||||||
|
|
||||||
|
// zero allowed
|
||||||
|
conf = GetDefaultConfig()
|
||||||
|
opt = WithReadTimeout(0)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, conf.ReadTimeout, time.Duration(0))
|
||||||
|
|
||||||
|
// negative disallowed
|
||||||
|
conf = GetDefaultConfig()
|
||||||
|
opt = WithReadTimeout(-30)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidReadTimeout)
|
||||||
|
assert.ErrorContains(t, err, "read timeout must be positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithWriteTimeout(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
opt := WithWriteTimeout(30)
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, conf.WriteTimeout, time.Duration(30))
|
||||||
|
|
||||||
|
// zero allowed
|
||||||
|
conf = GetDefaultConfig()
|
||||||
|
opt = WithWriteTimeout(0)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, conf.WriteTimeout, time.Duration(0))
|
||||||
|
|
||||||
|
// negative disallowed
|
||||||
|
conf = GetDefaultConfig()
|
||||||
|
opt = WithWriteTimeout(-30)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidWriteTimeout)
|
||||||
|
assert.ErrorContains(t, err, "write timeout must be positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetry(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
opt := WithRetry()
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, conf.Retry)
|
||||||
|
assert.Equal(t, conf.Retry, GetDefaultRetryConfig())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryAttempts(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
opt := WithRetryMaxRetries(3)
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, conf.Retry.MaxRetries)
|
||||||
|
|
||||||
|
// zero allowed
|
||||||
|
opt = WithRetryMaxRetries(0)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// negative disallowed
|
||||||
|
opt = WithRetryMaxRetries(-10)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries)
|
||||||
|
assert.ErrorContains(t, err, "max retry count cannot be negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryInitialDelay(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
opt := WithRetryInitialDelay(10 * time.Second)
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 10*time.Second, conf.Retry.InitialDelay)
|
||||||
|
|
||||||
|
// zero disallowed
|
||||||
|
opt = WithRetryInitialDelay(0 * time.Second)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay)
|
||||||
|
assert.ErrorContains(t, err, "initial delay must be positive")
|
||||||
|
|
||||||
|
// negative disallowed
|
||||||
|
opt = WithRetryInitialDelay(-10 * time.Second)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryMaxDelay(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
opt := WithRetryMaxDelay(10 * time.Second)
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 10*time.Second, conf.Retry.MaxDelay)
|
||||||
|
|
||||||
|
// zero disallowed
|
||||||
|
opt = WithRetryMaxDelay(0 * time.Second)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay)
|
||||||
|
assert.ErrorContains(t, err, "max delay must be positive")
|
||||||
|
|
||||||
|
// negative disallowed
|
||||||
|
opt = WithRetryMaxDelay(-10 * time.Second)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryJitterFactor(t *testing.T) {
|
||||||
|
conf := GetDefaultConfig()
|
||||||
|
|
||||||
|
opt := WithRetryJitterFactor(0.2)
|
||||||
|
err := SetConfig(conf, opt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0.2, conf.Retry.JitterFactor)
|
||||||
|
|
||||||
|
// negative disallowed
|
||||||
|
opt = WithRetryJitterFactor(-1)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor)
|
||||||
|
assert.ErrorContains(t, err, "jitter factor must be between 0.0 and 1.0")
|
||||||
|
|
||||||
|
// >1 disallowed
|
||||||
|
opt = WithRetryJitterFactor(1.1)
|
||||||
|
err = SetConfig(conf, opt)
|
||||||
|
assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config Validation Tests
|
||||||
|
|
||||||
|
func TestValidateConfig(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
conf Config
|
||||||
|
wantErr error
|
||||||
|
wantErrText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid empty",
|
||||||
|
conf: *GetDefaultConfig(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid defaults",
|
||||||
|
conf: *GetDefaultConfig(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid complete",
|
||||||
|
conf: Config{
|
||||||
|
CloseHandler: (func(code int, text string) error { return nil }),
|
||||||
|
ReadTimeout: time.Duration(30),
|
||||||
|
WriteTimeout: time.Duration(30),
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
MaxRetries: 0,
|
||||||
|
InitialDelay: 2 * time.Second,
|
||||||
|
MaxDelay: 10 * time.Second,
|
||||||
|
JitterFactor: 0.2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid - initial delay > max delay",
|
||||||
|
conf: Config{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
InitialDelay: 10 * time.Second,
|
||||||
|
MaxDelay: 1 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErrText: "initial delay may not exceed maximum delay",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := ValidateConfig(&tc.conf)
|
||||||
|
|
||||||
|
if tc.wantErr != nil || tc.wantErrText != "" {
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
assert.ErrorIs(t, err, tc.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.wantErrText != "" {
|
||||||
|
assert.ErrorContains(t, err, tc.wantErrText)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
371
ws/connection.go
Normal file
371
ws/connection.go
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dialer interface {
|
||||||
|
Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDialer() Dialer {
|
||||||
|
return NewGorillaDialer()
|
||||||
|
}
|
||||||
|
|
||||||
|
type GorillaDialer struct {
|
||||||
|
*websocket.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGorillaDialer() *GorillaDialer {
|
||||||
|
return &GorillaDialer{
|
||||||
|
Dialer: &websocket.Dialer{
|
||||||
|
HandshakeTimeout: 45 * time.Second,
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the Socket interface
|
||||||
|
func (d *GorillaDialer) Dial(
|
||||||
|
urlStr string, requestHeader http.Header,
|
||||||
|
) (
|
||||||
|
Socket, *http.Response, error,
|
||||||
|
) {
|
||||||
|
conn, resp, err := d.Dialer.Dial(urlStr, requestHeader)
|
||||||
|
return conn, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type Socket interface {
|
||||||
|
WriteMessage(messageType int, data []byte) error
|
||||||
|
ReadMessage() (messageType int, p []byte, err error)
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
|
SetWriteDeadline(t time.Time) error
|
||||||
|
SetCloseHandler(h func(code int, text string) error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AcquireSocket(
|
||||||
|
retryMgr *RetryManager,
|
||||||
|
dialer Dialer,
|
||||||
|
urlStr string,
|
||||||
|
) (Socket, *http.Response, error) {
|
||||||
|
if retryMgr == nil {
|
||||||
|
return nil, nil, errors.NewConnectionError("retry manager cannot be nil")
|
||||||
|
}
|
||||||
|
if dialer == nil {
|
||||||
|
return nil, nil, errors.NewConnectionError("dialer cannot be nil")
|
||||||
|
}
|
||||||
|
if urlStr == "" {
|
||||||
|
return nil, nil, errors.NewConnectionError("URL cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
socket, resp, err := dialer.Dial(urlStr, nil)
|
||||||
|
if err == nil {
|
||||||
|
return socket, resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !retryMgr.ShouldRetry() {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := retryMgr.CalculateDelay()
|
||||||
|
time.Sleep(delay)
|
||||||
|
retryMgr.RecordRetry()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectionState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateDisconnected ConnectionState = iota
|
||||||
|
StateConnecting
|
||||||
|
StateConnected
|
||||||
|
StateClosed
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s ConnectionState) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateDisconnected:
|
||||||
|
return "disconnected"
|
||||||
|
case StateConnecting:
|
||||||
|
return "connecting"
|
||||||
|
case StateConnected:
|
||||||
|
return "connected"
|
||||||
|
case StateClosed:
|
||||||
|
return "closed"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Connection struct {
|
||||||
|
url *url.URL
|
||||||
|
dialer Dialer
|
||||||
|
socket Socket
|
||||||
|
config *Config
|
||||||
|
|
||||||
|
incoming chan []byte
|
||||||
|
outgoing chan []byte
|
||||||
|
errors chan error
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
state ConnectionState
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
once sync.Once
|
||||||
|
closed bool
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnection(urlStr string, config *Config) (*Connection, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = GetDefaultConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ValidateConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedURL, err := ParseURL(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Connection{
|
||||||
|
url: parsedURL,
|
||||||
|
dialer: NewDialer(),
|
||||||
|
socket: nil,
|
||||||
|
config: config,
|
||||||
|
incoming: make(chan []byte, 100),
|
||||||
|
outgoing: make(chan []byte, 100),
|
||||||
|
errors: make(chan error, 10),
|
||||||
|
state: StateDisconnected,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionFromSocket(socket Socket, config *Config) (*Connection, error) {
|
||||||
|
if socket == nil {
|
||||||
|
return nil, errors.NewConnectionError("socket cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
config = GetDefaultConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ValidateConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &Connection{
|
||||||
|
url: nil,
|
||||||
|
dialer: nil,
|
||||||
|
socket: socket,
|
||||||
|
config: config,
|
||||||
|
incoming: make(chan []byte, 100),
|
||||||
|
outgoing: make(chan []byte, 100),
|
||||||
|
errors: make(chan error, 10),
|
||||||
|
state: StateConnected,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.CloseHandler != nil {
|
||||||
|
socket.SetCloseHandler(config.CloseHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.startReader()
|
||||||
|
conn.startWriter()
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Connect() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.socket != nil {
|
||||||
|
return errors.NewConnectionError("connection already has socket")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.closed {
|
||||||
|
return errors.NewConnectionError("connection is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = StateConnecting
|
||||||
|
|
||||||
|
retryMgr := NewRetryManager(c.config.Retry)
|
||||||
|
socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.state = StateDisconnected
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.socket = socket
|
||||||
|
c.state = StateConnected
|
||||||
|
|
||||||
|
if c.config.CloseHandler != nil {
|
||||||
|
c.socket.SetCloseHandler(c.config.CloseHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.startReader()
|
||||||
|
c.startWriter()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) startReader() {
|
||||||
|
c.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if c.config.ReadTimeout > 0 {
|
||||||
|
if err := c.socket.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil {
|
||||||
|
select {
|
||||||
|
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
|
||||||
|
case <-c.done:
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messageType, data, err := c.socket.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case c.errors <- err:
|
||||||
|
case <-c.done:
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if messageType == websocket.TextMessage ||
|
||||||
|
messageType == websocket.BinaryMessage {
|
||||||
|
select {
|
||||||
|
case c.incoming <- data:
|
||||||
|
case <-c.done:
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) startWriter() {
|
||||||
|
c.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case data := <-c.outgoing:
|
||||||
|
if c.config.WriteTimeout > 0 {
|
||||||
|
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||||
|
select {
|
||||||
|
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
|
||||||
|
case <-c.done:
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||||
|
select {
|
||||||
|
case c.errors <- err:
|
||||||
|
case <-c.done:
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Send(data []byte) error {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
if c.closed {
|
||||||
|
return errors.NewConnectionError("connection closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.outgoing <- data:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.NewConnectionError("outgoing queue full")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Incoming() <-chan []byte {
|
||||||
|
return c.incoming
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Errors() <-chan error {
|
||||||
|
return c.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the connection and waits for goroutines to exit.
|
||||||
|
// If the underlying socket blocks indefinitely on read or write operations,
|
||||||
|
// Close will also block. This is expected behavior - hung sockets require
|
||||||
|
// external intervention (timeouts, process termination, etc).
|
||||||
|
func (c *Connection) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
|
||||||
|
alreadyClosed := c.closed
|
||||||
|
if !alreadyClosed {
|
||||||
|
c.closed = true
|
||||||
|
c.state = StateClosed
|
||||||
|
close(c.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
socket := c.socket
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if alreadyClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if socket != nil {
|
||||||
|
err = socket.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.wg.Wait()
|
||||||
|
|
||||||
|
close(c.incoming)
|
||||||
|
close(c.outgoing)
|
||||||
|
close(c.errors)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) State() ConnectionState {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.state
|
||||||
|
}
|
||||||
158
ws/connection_close_test.go
Normal file
158
ws/connection_close_test.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDisconnectedConnectionClose(t *testing.T) {
|
||||||
|
t.Run("close succeeds on disconnected connection", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("close is idempotent", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Second close should succeed without error
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("close with nil socket", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, conn.socket)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("socket close error propagates", func(t *testing.T) {
|
||||||
|
expectedErr := fmt.Errorf("socket close failed")
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
return expectedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
conn.socket = mockSocket
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.Equal(t, expectedErr, err)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("channels close after close", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify incoming channel closed
|
||||||
|
select {
|
||||||
|
case _, ok := <-conn.incoming:
|
||||||
|
assert.False(t, ok, "incoming channel should be closed")
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for incoming channel closure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify outgoing channel closed
|
||||||
|
select {
|
||||||
|
case _, ok := <-conn.outgoing:
|
||||||
|
assert.False(t, ok, "outgoing channel should be closed")
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for outgoing channel closure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify errors channel closed
|
||||||
|
select {
|
||||||
|
case _, ok := <-conn.errors:
|
||||||
|
assert.False(t, ok, "errors channel should be closed")
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for errors channel closure")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("send fails after close", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Send([]byte("test"))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "connection closed")
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectedConnectionClose(t *testing.T) {
|
||||||
|
t.Run("blocked on ReadMessage, unblocks on closed", func(t *testing.T) {
|
||||||
|
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||||
|
|
||||||
|
// Wait for reader to block
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
err := conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
|
||||||
|
close(incomingData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("writer active during close exits cleanly", func(t *testing.T) {
|
||||||
|
conn, _, _, outgoingData := setupTestConnection(t, nil)
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
conn.Send([]byte("message"))
|
||||||
|
}
|
||||||
|
|
||||||
|
err := conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Send([]byte("late"))
|
||||||
|
assert.Error(t, err, "Send should fail after close")
|
||||||
|
assert.ErrorContains(t, err, "connection closed")
|
||||||
|
|
||||||
|
close(outgoingData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("both goroutines active during close", func(t *testing.T) {
|
||||||
|
conn, _, incomingData, outgoingData := setupTestConnection(t, nil)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
incomingData <- mockIncomingData{
|
||||||
|
msgType: websocket.TextMessage,
|
||||||
|
data: []byte(fmt.Sprintf("in-%d", i)),
|
||||||
|
}
|
||||||
|
conn.Send([]byte(fmt.Sprintf("out-%d", i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
err := conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
close(incomingData)
|
||||||
|
close(outgoingData)
|
||||||
|
})
|
||||||
|
}
|
||||||
404
ws/connection_goroutine_test.go
Normal file
404
ws/connection_goroutine_test.go
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStartReader(t *testing.T) {
|
||||||
|
t.Run("text messages route to incoming channel", func(t *testing.T) {
|
||||||
|
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
testData := []byte("hello")
|
||||||
|
incomingData <- mockIncomingData{
|
||||||
|
msgType: websocket.TextMessage,
|
||||||
|
data: testData,
|
||||||
|
err: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectIncoming(t, conn, testData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("binary messages route to incoming channel", func(t *testing.T) {
|
||||||
|
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
testData := []byte{0x00, 0x01, 0x02}
|
||||||
|
incomingData <- mockIncomingData{
|
||||||
|
msgType: websocket.BinaryMessage,
|
||||||
|
data: testData,
|
||||||
|
err: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectIncoming(t, conn, testData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple messages processed sequentially", func(t *testing.T) {
|
||||||
|
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")}
|
||||||
|
for _, msg := range messages {
|
||||||
|
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: msg, err: nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range messages {
|
||||||
|
expectIncoming(t, conn, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("read timeout disabled when zero", func(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{ReadTimeout: 0}
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deadlineCalled := make(chan struct{}, 1)
|
||||||
|
mockSocket.SetReadDeadlineFunc = func(t time.Time) error {
|
||||||
|
deadlineCalled <- struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-deadlineCalled:
|
||||||
|
t.Fatal("SetReadDeadline should not be called when timeout is zero")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("read timeout sets deadline when positive", func(t *testing.T) {
|
||||||
|
config := &Config{ReadTimeout: 30}
|
||||||
|
|
||||||
|
incomingData := make(chan mockIncomingData, 10)
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deadlineCalled := make(chan struct{}, 1)
|
||||||
|
mockSocket.SetReadDeadlineFunc = func(t time.Time) error {
|
||||||
|
deadlineCalled <- struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||||
|
select {
|
||||||
|
case data := <-incomingData:
|
||||||
|
return data.msgType, data.data, data.err
|
||||||
|
case <-mockSocket.closed:
|
||||||
|
return 0, nil, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-conn.Incoming():
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case _, ok := <-deadlineCalled:
|
||||||
|
assert.True(t, ok, "SetReadDeadline should be called when timeout is positive")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("SetReadDeadline was never called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reader exits on deadline error", func(t *testing.T) {
|
||||||
|
config := &Config{ReadTimeout: 1 * time.Millisecond}
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSocket.SetReadDeadlineFunc = func(t time.Time) error {
|
||||||
|
return fmt.Errorf("test error")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-conn.Errors():
|
||||||
|
assert.ErrorContains(t, err, "failed to set read deadline")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for deadline error")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reader exits on socket read error", func(t *testing.T) {
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
readErr := fmt.Errorf("read failed")
|
||||||
|
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||||
|
return 0, nil, readErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-conn.Errors():
|
||||||
|
assert.Equal(t, readErr, err)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for read error")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartWriter(t *testing.T) {
|
||||||
|
t.Run("data from outgoing triggers write", func(t *testing.T) {
|
||||||
|
conn, _, _, outgoingData := setupTestConnection(t, nil)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
testData := []byte("test message")
|
||||||
|
err := conn.Send(testData)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
expectWrite(t, outgoingData, websocket.TextMessage, testData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple messages processed sequentially", func(t *testing.T) {
|
||||||
|
conn, _, _, outgoingData := setupTestConnection(t, nil)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")}
|
||||||
|
for _, msg := range messages {
|
||||||
|
err := conn.Send(msg)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range messages {
|
||||||
|
expectWrite(t, outgoingData, websocket.TextMessage, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write timeout disabled when zero", func(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{WriteTimeout: 0}
|
||||||
|
|
||||||
|
outgoingData := make(chan mockOutgoingData, 10)
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deadlineCalled := make(chan struct{}, 1)
|
||||||
|
mockSocket.SetWriteDeadlineFunc = func(t time.Time) error {
|
||||||
|
deadlineCalled <- struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||||
|
select {
|
||||||
|
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||||
|
case <-mockSocket.closed:
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.Send([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-deadlineCalled:
|
||||||
|
t.Fatal("SetWriteDeadline should not be called when timeout is zero")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write timeout sets deadline when positive", func(t *testing.T) {
|
||||||
|
config := &Config{WriteTimeout: 30 * time.Millisecond}
|
||||||
|
|
||||||
|
outgoingData := make(chan mockOutgoingData, 10)
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deadlineCalled := make(chan struct{}, 1)
|
||||||
|
mockSocket.SetWriteDeadlineFunc = func(t time.Time) error {
|
||||||
|
deadlineCalled <- struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||||
|
select {
|
||||||
|
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||||
|
case <-mockSocket.closed:
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.Send([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case _, ok := <-deadlineCalled:
|
||||||
|
assert.True(t, ok, "SetWriteDeadline should be called when timeout is positive")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("SetWriteDeadline was never called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("writer exits on deadline error", func(t *testing.T) {
|
||||||
|
config := &Config{WriteTimeout: 1 * time.Millisecond}
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSocket.SetWriteDeadlineFunc = func(t time.Time) error {
|
||||||
|
return fmt.Errorf("test error")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Send([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-conn.Errors():
|
||||||
|
assert.ErrorContains(t, err, "failed to set write deadline")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for deadline error")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("writer exits on socket write error", func(t *testing.T) {
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
writeErr := fmt.Errorf("write failed")
|
||||||
|
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.Send([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-conn.Errors():
|
||||||
|
assert.Equal(t, writeErr, err)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for write error")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers
|
||||||
|
|
||||||
|
func expectIncoming(t *testing.T, conn *Connection, expected []byte) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case received := <-conn.Incoming():
|
||||||
|
assert.Equal(t, expected, received)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case call := <-outgoingData:
|
||||||
|
assert.Equal(t, msgType, call.msgType)
|
||||||
|
assert.Equal(t, expected, call.data)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for write")
|
||||||
|
}
|
||||||
|
}
|
||||||
113
ws/connection_send_test.go
Normal file
113
ws/connection_send_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnectionSend(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*Connection)
|
||||||
|
data []byte
|
||||||
|
wantErr bool
|
||||||
|
wantErrText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "send succeeds when open",
|
||||||
|
setup: func(c *Connection) {},
|
||||||
|
data: []byte("test message"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "send fails when closed",
|
||||||
|
setup: func(c *Connection) {
|
||||||
|
c.Close()
|
||||||
|
},
|
||||||
|
data: []byte("test"),
|
||||||
|
wantErr: true,
|
||||||
|
wantErrText: "connection closed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "send fails when queue full",
|
||||||
|
setup: func(c *Connection) {
|
||||||
|
// Fill outgoing channel
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
c.outgoing <- []byte("filler")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
data: []byte("overflow"),
|
||||||
|
wantErr: true,
|
||||||
|
wantErrText: "outgoing queue full",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tc.setup(conn)
|
||||||
|
|
||||||
|
err = conn.Send(tc.data)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tc.wantErrText != "" {
|
||||||
|
assert.ErrorContains(t, err, tc.wantErrText)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify data appeared on outgoing channel
|
||||||
|
select {
|
||||||
|
case sent := <-conn.outgoing:
|
||||||
|
assert.Equal(t, tc.data, sent)
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
t.Fatal("timeout: data not sent to outgoing channel")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run with `go test -race` to ensure no race conditions occur
|
||||||
|
func TestConnectionSendConcurrent(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// continuously consume outgoing channel in background
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-conn.outgoing:
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
// Send from multiple goroutines concurrently
|
||||||
|
const goroutines = 5
|
||||||
|
const messagesPerGoroutine = 10
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < messagesPerGoroutine; j++ {
|
||||||
|
data := []byte(fmt.Sprintf("msg-%d-%d", id, j))
|
||||||
|
err := conn.Send(data)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
143
ws/connection_socket_test.go
Normal file
143
ws/connection_socket_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDialer(t *testing.T) {
|
||||||
|
dialer := NewDialer()
|
||||||
|
|
||||||
|
assert.NotNil(t, dialer)
|
||||||
|
_, ok := dialer.(*GorillaDialer)
|
||||||
|
assert.True(t, ok, "NewDialer should return *GorillaDialer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGorillaDialer(t *testing.T) {
|
||||||
|
dialer := NewGorillaDialer()
|
||||||
|
|
||||||
|
assert.NotNil(t, dialer)
|
||||||
|
assert.NotNil(t, dialer.Dialer)
|
||||||
|
assert.Equal(t, 45*time.Second, dialer.Dialer.HandshakeTimeout)
|
||||||
|
assert.Equal(t, 1024, dialer.Dialer.ReadBufferSize)
|
||||||
|
assert.Equal(t, 1024, dialer.Dialer.WriteBufferSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAcquireSocket(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
mockRuns []error
|
||||||
|
maxRetries int
|
||||||
|
wantRetryCount int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "immediate success",
|
||||||
|
mockRuns: []error{nil},
|
||||||
|
maxRetries: 3,
|
||||||
|
wantRetryCount: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two failures, success",
|
||||||
|
mockRuns: []error{errors.New("1"), errors.New("2"), nil},
|
||||||
|
maxRetries: 0,
|
||||||
|
wantRetryCount: 2,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three failures, failure",
|
||||||
|
mockRuns: []error{errors.New("1"), errors.New("2"), errors.New("3"), errors.New("4")},
|
||||||
|
maxRetries: 3,
|
||||||
|
wantRetryCount: 3,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
attemptIndex := 0
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
err := tc.mockRuns[attemptIndex]
|
||||||
|
attemptIndex++
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
retryMgr := NewRetryManager(&RetryConfig{
|
||||||
|
MaxRetries: tc.maxRetries,
|
||||||
|
InitialDelay: 1 * time.Millisecond,
|
||||||
|
MaxDelay: 5 * time.Millisecond,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test")
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, socket)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, socket)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAcquireSocketGuards(t *testing.T) {
|
||||||
|
validDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
return NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
validRetryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
retryMgr *RetryManager
|
||||||
|
dialer Dialer
|
||||||
|
url string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil retry manager",
|
||||||
|
retryMgr: nil,
|
||||||
|
dialer: validDialer,
|
||||||
|
url: "ws://test",
|
||||||
|
wantErr: "retry manager cannot be nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil dialer",
|
||||||
|
retryMgr: validRetryMgr,
|
||||||
|
dialer: nil,
|
||||||
|
url: "ws://test",
|
||||||
|
wantErr: "dialer cannot be nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty URL",
|
||||||
|
retryMgr: validRetryMgr,
|
||||||
|
dialer: validDialer,
|
||||||
|
url: "",
|
||||||
|
wantErr: "URL cannot be empty",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url)
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, tc.wantErr)
|
||||||
|
assert.Nil(t, socket)
|
||||||
|
assert.Nil(t, resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
438
ws/connection_test.go
Normal file
438
ws/connection_test.go
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connection state tests
|
||||||
|
|
||||||
|
func TestConnectionStateString(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
state ConnectionState
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{StateDisconnected, "disconnected"},
|
||||||
|
{StateConnecting, "connecting"},
|
||||||
|
{StateConnected, "connected"},
|
||||||
|
{StateClosed, "closed"},
|
||||||
|
{ConnectionState(99), "unknown"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.want, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.want, tc.state.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionState(t *testing.T) {
|
||||||
|
// Test initial state
|
||||||
|
conn, _ := NewConnection("ws://test", nil)
|
||||||
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
|
||||||
|
// Test state after FromSocket (should be Connected)
|
||||||
|
conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil)
|
||||||
|
assert.Equal(t, StateConnected, conn2.State())
|
||||||
|
|
||||||
|
// Test state after close
|
||||||
|
conn.Close()
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection constructor tests
|
||||||
|
|
||||||
|
func TestNewConnection(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
config *Config
|
||||||
|
wantErr bool
|
||||||
|
wantErrText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid url, nil config",
|
||||||
|
url: "ws://example.com",
|
||||||
|
config: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid url, valid config",
|
||||||
|
url: "wss://relay.example.com:8080/path",
|
||||||
|
config: &Config{ReadTimeout: 30 * time.Second},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid url",
|
||||||
|
url: "http://example.com",
|
||||||
|
config: nil,
|
||||||
|
wantErr: true,
|
||||||
|
wantErrText: "URL must use ws:// or wss:// scheme",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid config",
|
||||||
|
url: "ws://example.com",
|
||||||
|
config: &Config{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
InitialDelay: 10 * time.Second,
|
||||||
|
MaxDelay: 1 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
wantErrText: "initial delay may not exceed maximum delay",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
conn, err := NewConnection(tc.url, tc.config)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tc.wantErrText != "" {
|
||||||
|
assert.ErrorContains(t, err, tc.wantErrText)
|
||||||
|
}
|
||||||
|
assert.Nil(t, conn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Verify struct fields
|
||||||
|
assert.NotNil(t, conn.url)
|
||||||
|
assert.NotNil(t, conn.dialer)
|
||||||
|
assert.Nil(t, conn.socket)
|
||||||
|
assert.NotNil(t, conn.config)
|
||||||
|
assert.NotNil(t, conn.incoming)
|
||||||
|
assert.NotNil(t, conn.outgoing)
|
||||||
|
assert.NotNil(t, conn.errors)
|
||||||
|
assert.NotNil(t, conn.done)
|
||||||
|
assert.Equal(t, StateDisconnected, conn.state)
|
||||||
|
assert.False(t, conn.closed)
|
||||||
|
|
||||||
|
// Verify default config is used if nil is passed
|
||||||
|
if tc.config == nil {
|
||||||
|
assert.Equal(t, GetDefaultConfig(), conn.config)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tc.config, conn.config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromSocket(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
socket Socket
|
||||||
|
config *Config
|
||||||
|
wantErr bool
|
||||||
|
wantErrText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil socket",
|
||||||
|
socket: nil,
|
||||||
|
config: nil,
|
||||||
|
wantErr: true,
|
||||||
|
wantErrText: "socket cannot be nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid socket with nil config",
|
||||||
|
socket: NewMockSocket(),
|
||||||
|
config: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid socket with valid config",
|
||||||
|
socket: NewMockSocket(),
|
||||||
|
config: &Config{ReadTimeout: 30 * time.Second},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid config",
|
||||||
|
socket: NewMockSocket(),
|
||||||
|
config: &Config{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
InitialDelay: 10 * time.Second,
|
||||||
|
MaxDelay: 1 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
wantErrText: "initial delay may not exceed maximum delay",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "close handler set when provided",
|
||||||
|
socket: NewMockSocket(),
|
||||||
|
config: &Config{
|
||||||
|
CloseHandler: func(code int, text string) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// track if SetCloseHandler was called
|
||||||
|
closeHandlerSet := false
|
||||||
|
if tc.socket != nil {
|
||||||
|
mockSocket := tc.socket.(*MockSocket)
|
||||||
|
originalSetCloseHandler := mockSocket.SetCloseHandlerFunc
|
||||||
|
|
||||||
|
// wrapper around the original handler function
|
||||||
|
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
|
||||||
|
closeHandlerSet = true
|
||||||
|
if originalSetCloseHandler != nil {
|
||||||
|
originalSetCloseHandler(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(tc.socket, tc.config)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tc.wantErrText != "" {
|
||||||
|
assert.ErrorContains(t, err, tc.wantErrText)
|
||||||
|
}
|
||||||
|
assert.Nil(t, conn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Verify fields initialized correctly
|
||||||
|
assert.Nil(t, conn.url)
|
||||||
|
assert.Nil(t, conn.dialer)
|
||||||
|
assert.Equal(t, tc.socket, conn.socket)
|
||||||
|
assert.NotNil(t, conn.config)
|
||||||
|
assert.NotNil(t, conn.incoming)
|
||||||
|
assert.NotNil(t, conn.outgoing)
|
||||||
|
assert.NotNil(t, conn.errors)
|
||||||
|
assert.NotNil(t, conn.done)
|
||||||
|
assert.Equal(t, StateConnected, conn.state)
|
||||||
|
assert.False(t, conn.closed)
|
||||||
|
|
||||||
|
// Verify config defaulting
|
||||||
|
if tc.config == nil {
|
||||||
|
assert.Equal(t, GetDefaultConfig(), conn.config)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tc.config, conn.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify close handler was set if provided
|
||||||
|
if tc.config != nil && tc.config.CloseHandler != nil {
|
||||||
|
assert.True(t, closeHandlerSet, "CloseHandler should be set on socket")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ws/connection_test.go
|
||||||
|
|
||||||
|
// Add to existing file after TestNewConnectionFromSocket
|
||||||
|
|
||||||
|
func TestConnect(t *testing.T) {
|
||||||
|
t.Run("connect fails when socket already present", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
conn.socket = NewMockSocket()
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "already has socket")
|
||||||
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("connect fails when connection closed", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "connection is closed")
|
||||||
|
assert.Equal(t, StateClosed, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("connect succeeds and starts goroutines", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
outgoingData := make(chan mockOutgoingData, 10)
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||||
|
outgoingData <- mockOutgoingData{msgType: msgType, data: data}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
return mockSocket, nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, StateConnected, conn.State())
|
||||||
|
|
||||||
|
testData := []byte("test")
|
||||||
|
conn.Send(testData)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-outgoingData:
|
||||||
|
assert.Equal(t, testData, msg.data)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timeout waiting for message write")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
close(outgoingData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("connect retries on dial failure", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
MaxRetries: 2,
|
||||||
|
InitialDelay: 1 * time.Millisecond,
|
||||||
|
MaxDelay: 5 * time.Millisecond,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := NewConnection("ws://test", config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
attemptCount := 0
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
attemptCount++
|
||||||
|
if attemptCount < 3 {
|
||||||
|
return nil, nil, fmt.Errorf("dial failed")
|
||||||
|
}
|
||||||
|
return NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, attemptCount)
|
||||||
|
assert.Equal(t, StateConnected, conn.State())
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("connect fails after max retries", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
MaxRetries: 2,
|
||||||
|
InitialDelay: 1 * time.Millisecond,
|
||||||
|
MaxDelay: 5 * time.Millisecond,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := NewConnection("ws://test", config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
return nil, nil, fmt.Errorf("dial failed")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "dial failed")
|
||||||
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("state transitions during connect", func(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
|
||||||
|
stateDuringDial := StateDisconnected
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
stateDuringDial = conn.state
|
||||||
|
return NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
conn.Connect()
|
||||||
|
|
||||||
|
assert.Equal(t, StateConnecting, stateDuringDial)
|
||||||
|
assert.Equal(t, StateConnected, conn.State())
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("close handler configured when provided", func(t *testing.T) {
|
||||||
|
handlerSet := false
|
||||||
|
config := &Config{
|
||||||
|
CloseHandler: func(code int, text string) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := NewConnection("ws://test", config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
|
||||||
|
handlerSet = true
|
||||||
|
}
|
||||||
|
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
return mockSocket, nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
conn.Connect()
|
||||||
|
|
||||||
|
assert.True(t, handlerSet, "close handler should be set on socket")
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection method tests
|
||||||
|
|
||||||
|
func TestConnectionIncoming(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
incoming := conn.Incoming()
|
||||||
|
assert.NotNil(t, incoming)
|
||||||
|
|
||||||
|
// send data through the channel to verify they are the same
|
||||||
|
testData := []byte("test")
|
||||||
|
conn.incoming <- testData
|
||||||
|
received := <-incoming
|
||||||
|
assert.Equal(t, testData, received)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionErrors(t *testing.T) {
|
||||||
|
conn, err := NewConnection("ws://test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
errors := conn.Errors()
|
||||||
|
assert.NotNil(t, errors)
|
||||||
|
|
||||||
|
// send data through the channel to verify they are the same
|
||||||
|
testErr := fmt.Errorf("test error")
|
||||||
|
conn.errors <- testErr
|
||||||
|
received := <-errors
|
||||||
|
assert.Equal(t, testErr, received)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect() tests
|
||||||
132
ws/mocks_test.go
Normal file
132
ws/mocks_test.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialer Mocks
|
||||||
|
|
||||||
|
type MockDialer struct {
|
||||||
|
DialFunc func(string, http.Header) (Socket, *http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDialer) Dial(url string, h http.Header) (Socket, *http.Response, error) {
|
||||||
|
return m.DialFunc(url, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Socket Mocks
|
||||||
|
|
||||||
|
type MockSocket struct {
|
||||||
|
WriteMessageFunc func(int, []byte) error
|
||||||
|
SetReadDeadlineFunc func(t time.Time) error
|
||||||
|
SetWriteDeadlineFunc func(t time.Time) error
|
||||||
|
ReadMessageFunc func() (int, []byte, error)
|
||||||
|
CloseFunc func() error
|
||||||
|
SetCloseHandlerFunc func(func(int, string) error)
|
||||||
|
closed chan struct{}
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockSocket() *MockSocket {
|
||||||
|
return &MockSocket{
|
||||||
|
WriteMessageFunc: func(int, []byte) error { return nil },
|
||||||
|
ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil },
|
||||||
|
CloseFunc: func() error { return nil },
|
||||||
|
|
||||||
|
SetReadDeadlineFunc: func(time.Time) error { return nil },
|
||||||
|
SetWriteDeadlineFunc: func(time.Time) error { return nil },
|
||||||
|
SetCloseHandlerFunc: func(func(int, string) error) {},
|
||||||
|
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocket) WriteMessage(t int, d []byte) error {
|
||||||
|
return m.WriteMessageFunc(t, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocket) ReadMessage() (int, []byte, error) {
|
||||||
|
return m.ReadMessageFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocket) Close() error {
|
||||||
|
return m.CloseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocket) SetReadDeadline(t time.Time) error {
|
||||||
|
return m.SetReadDeadlineFunc(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocket) SetWriteDeadline(t time.Time) error {
|
||||||
|
return m.SetWriteDeadlineFunc(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) {
|
||||||
|
m.SetCloseHandlerFunc(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection Mocks
|
||||||
|
|
||||||
|
type mockIncomingData struct {
|
||||||
|
msgType int
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockOutgoingData struct {
|
||||||
|
msgType int
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestConnection(t *testing.T, config *Config) (
|
||||||
|
conn *Connection,
|
||||||
|
mockSocket *MockSocket,
|
||||||
|
incomingData chan mockIncomingData,
|
||||||
|
outgoingData chan mockOutgoingData,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
incomingData = make(chan mockIncomingData, 10)
|
||||||
|
outgoingData = make(chan mockOutgoingData, 10)
|
||||||
|
|
||||||
|
mockSocket = NewMockSocket()
|
||||||
|
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
mockSocket.once.Do(func() {
|
||||||
|
close(mockSocket.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wire ReadMessage to pull from incomingData channel
|
||||||
|
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||||
|
select {
|
||||||
|
case data := <-incomingData:
|
||||||
|
return data.msgType, data.data, data.err
|
||||||
|
case <-mockSocket.closed:
|
||||||
|
return 0, nil, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wire WriteMessage to push to outgoingData channel
|
||||||
|
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||||
|
select {
|
||||||
|
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||||
|
case <-mockSocket.closed:
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
conn, err = NewConnectionFromSocket(mockSocket, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return conn, mockSocket, incomingData, outgoingData
|
||||||
|
}
|
||||||
66
ws/retry.go
Normal file
66
ws/retry.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RetryManager struct {
|
||||||
|
config *RetryConfig
|
||||||
|
retryCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRetryManager(config *RetryConfig) *RetryManager {
|
||||||
|
return &RetryManager{
|
||||||
|
config: config,
|
||||||
|
retryCount: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RetryManager) ShouldRetry() bool {
|
||||||
|
if r.config == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.MaxRetries > 0 && r.retryCount >= r.config.MaxRetries {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RetryManager) CalculateDelay() time.Duration {
|
||||||
|
if r.config == nil {
|
||||||
|
return time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// First attempt: immediate retry
|
||||||
|
if r.retryCount == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential backoff: InitialDelay * 2^(attempts-1)
|
||||||
|
backoffMultiplier := math.Pow(2, float64(r.retryCount-1))
|
||||||
|
baseDelay := float64(r.config.InitialDelay) * backoffMultiplier
|
||||||
|
|
||||||
|
// Apply jitter: delay * (1 + jitterFactor * (random - 0.5))
|
||||||
|
random := rand.Float64()
|
||||||
|
jitterMultiplier := 1 + r.config.JitterFactor*(random-0.5)
|
||||||
|
delay := time.Duration(baseDelay * jitterMultiplier)
|
||||||
|
|
||||||
|
// Cap at MaxDelay
|
||||||
|
if delay > r.config.MaxDelay {
|
||||||
|
delay = r.config.MaxDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RetryManager) RecordRetry() {
|
||||||
|
m.retryCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RetryManager) RetryCount() int {
|
||||||
|
return m.retryCount
|
||||||
|
}
|
||||||
147
ws/retry_test.go
Normal file
147
ws/retry_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewRetryManager(t *testing.T) {
|
||||||
|
config := &RetryConfig{
|
||||||
|
MaxRetries: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := NewRetryManager(config)
|
||||||
|
|
||||||
|
assert.Equal(t, config, mgr.config)
|
||||||
|
assert.Equal(t, 0, mgr.retryCount)
|
||||||
|
|
||||||
|
// Should accept nil config
|
||||||
|
mgr = NewRetryManager(nil)
|
||||||
|
assert.Nil(t, mgr.config)
|
||||||
|
assert.Equal(t, 0, mgr.retryCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordRetry(t *testing.T) {
|
||||||
|
mgr := NewRetryManager(nil)
|
||||||
|
assert.Equal(t, mgr.retryCount, 0)
|
||||||
|
|
||||||
|
mgr.RecordRetry()
|
||||||
|
assert.Equal(t, mgr.retryCount, 1)
|
||||||
|
|
||||||
|
mgr.RecordRetry()
|
||||||
|
assert.Equal(t, mgr.retryCount, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldRetry(t *testing.T) {
|
||||||
|
// never retry if config is nil
|
||||||
|
mgr := NewRetryManager(nil)
|
||||||
|
assert.False(t, mgr.ShouldRetry())
|
||||||
|
|
||||||
|
// always retry if max attempt count is zero
|
||||||
|
mgr = &RetryManager{
|
||||||
|
config: &RetryConfig{
|
||||||
|
MaxRetries: 0,
|
||||||
|
},
|
||||||
|
retryCount: 1000,
|
||||||
|
}
|
||||||
|
assert.True(t, mgr.ShouldRetry())
|
||||||
|
|
||||||
|
// retry if below max attempt count
|
||||||
|
mgr = &RetryManager{
|
||||||
|
config: &RetryConfig{
|
||||||
|
MaxRetries: 10,
|
||||||
|
},
|
||||||
|
retryCount: 5,
|
||||||
|
}
|
||||||
|
assert.True(t, mgr.ShouldRetry())
|
||||||
|
|
||||||
|
// do not retry if above max attempt count
|
||||||
|
mgr = &RetryManager{
|
||||||
|
config: &RetryConfig{
|
||||||
|
MaxRetries: 10,
|
||||||
|
},
|
||||||
|
retryCount: 11,
|
||||||
|
}
|
||||||
|
assert.False(t, mgr.ShouldRetry())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateDelayDisabled(t *testing.T) {
|
||||||
|
// default delay if retry is disabled
|
||||||
|
mgr := NewRetryManager(nil)
|
||||||
|
assert.Equal(t, time.Second, mgr.CalculateDelay())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateDelayWithoutJitter(t *testing.T) {
|
||||||
|
mgr := NewRetryManager(&RetryConfig{
|
||||||
|
MaxRetries: 0,
|
||||||
|
InitialDelay: 1 * time.Second,
|
||||||
|
MaxDelay: 5 * time.Second,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Retry 0: immediate
|
||||||
|
assert.Equal(t, 0*time.Second, mgr.CalculateDelay())
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 1: 1s * 2^0 = 1s
|
||||||
|
assert.Equal(t, 1*time.Second, mgr.CalculateDelay())
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 2: 1s * 2^1 = 2s
|
||||||
|
assert.Equal(t, 2*time.Second, mgr.CalculateDelay())
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 3: 1s * 2^2 = 4s
|
||||||
|
assert.Equal(t, 4*time.Second, mgr.CalculateDelay())
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 4: 1s * 2^3 = 8s, capped at 5s
|
||||||
|
assert.Equal(t, 5*time.Second, mgr.CalculateDelay())
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 5: Still capped at 5s
|
||||||
|
assert.Equal(t, 5*time.Second, mgr.CalculateDelay())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateDelayWithJitter(t *testing.T) {
|
||||||
|
mgr := NewRetryManager(&RetryConfig{
|
||||||
|
MaxRetries: 0,
|
||||||
|
InitialDelay: 1 * time.Second,
|
||||||
|
MaxDelay: 5 * time.Second,
|
||||||
|
JitterFactor: 0.5,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Retry 0: immediate
|
||||||
|
assert.Equal(t, 0*time.Second, mgr.CalculateDelay())
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 1: 1s * 2^0 = 1s (with jitter)
|
||||||
|
delay := mgr.CalculateDelay()
|
||||||
|
assert.GreaterOrEqual(t, delay, 750*time.Millisecond)
|
||||||
|
assert.LessOrEqual(t, delay, 1250*time.Millisecond)
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 2: 1s * 2^1 = 2s (with jitter)
|
||||||
|
delay = mgr.CalculateDelay()
|
||||||
|
assert.GreaterOrEqual(t, delay, 1500*time.Millisecond)
|
||||||
|
assert.LessOrEqual(t, delay, 2500*time.Millisecond)
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 3: 1s * 2^2 = 4s (with jitter)
|
||||||
|
delay = mgr.CalculateDelay()
|
||||||
|
assert.GreaterOrEqual(t, delay, 3*time.Second)
|
||||||
|
assert.LessOrEqual(t, delay, 5*time.Second)
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 4: 1s * 2^3 = 8s, capped at 5s (with jitter)
|
||||||
|
delay = mgr.CalculateDelay()
|
||||||
|
assert.GreaterOrEqual(t, delay, 3750*time.Millisecond)
|
||||||
|
assert.LessOrEqual(t, delay, 5*time.Second)
|
||||||
|
mgr.RecordRetry()
|
||||||
|
|
||||||
|
// Retry 5: Still capped at 5s (with jitter)
|
||||||
|
delay = mgr.CalculateDelay()
|
||||||
|
assert.GreaterOrEqual(t, delay, 3750*time.Millisecond)
|
||||||
|
assert.LessOrEqual(t, delay, 5*time.Second)
|
||||||
|
}
|
||||||
20
ws/url.go
Normal file
20
ws/url.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ParseURL(urlStr string) (*url.URL, error) {
|
||||||
|
parsedURL, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
|
||||||
|
return nil, errors.InvalidProtocol
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedURL, nil
|
||||||
|
}
|
||||||
93
ws/url_test.go
Normal file
93
ws/url_test.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseURL(t *testing.T) {
|
||||||
|
type wantURL struct {
|
||||||
|
scheme string
|
||||||
|
host string
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
want wantURL
|
||||||
|
wantErr error
|
||||||
|
wantErrText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid ws url",
|
||||||
|
url: "ws://localhost:8080/relay",
|
||||||
|
want: wantURL{
|
||||||
|
scheme: "ws",
|
||||||
|
host: "localhost:8080",
|
||||||
|
path: "/relay",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid wss url",
|
||||||
|
url: "wss://relay.example.com",
|
||||||
|
want: wantURL{
|
||||||
|
scheme: "wss",
|
||||||
|
host: "relay.example.com",
|
||||||
|
path: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http scheme rejected",
|
||||||
|
url: "http://example.com",
|
||||||
|
wantErr: errors.InvalidProtocol,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing scheme",
|
||||||
|
url: "example.com:8080",
|
||||||
|
wantErr: errors.InvalidProtocol,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
url: "",
|
||||||
|
wantErr: errors.InvalidProtocol,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed url",
|
||||||
|
url: "ws://[::1:8080",
|
||||||
|
wantErrText: "missing ']' in host",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6 address",
|
||||||
|
url: "ws://[::1]:8080/relay",
|
||||||
|
want: wantURL{
|
||||||
|
scheme: "ws",
|
||||||
|
host: "[::1]:8080",
|
||||||
|
path: "/relay",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got, err := ParseURL(tc.url)
|
||||||
|
|
||||||
|
if tc.wantErr != nil || tc.wantErrText != "" {
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
assert.ErrorIs(t, err, tc.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.wantErrText != "" {
|
||||||
|
assert.ErrorContains(t, err, tc.wantErrText)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.want.scheme, got.Scheme)
|
||||||
|
assert.Equal(t, tc.want.host, got.Host)
|
||||||
|
assert.Equal(t, tc.want.path, got.Path)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
2
ws/ws_test.go
Normal file
2
ws/ws_test.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// ws package end-to-end tests
|
||||||
|
package ws
|
||||||
Reference in New Issue
Block a user