Refactored package structure.
This commit is contained in:
225
transport/config.go
Normal file
225
transport/config.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type CloseHandler func(code int, text string) error
|
||||
|
||||
type ConnectionConfig struct {
|
||||
CloseHandler CloseHandler
|
||||
WriteTimeout time.Duration
|
||||
Retry *RetryConfig
|
||||
}
|
||||
|
||||
type RetryConfig struct {
|
||||
MaxRetries int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
JitterFactor float64
|
||||
}
|
||||
|
||||
type ConnectionOption func(*ConnectionConfig) error
|
||||
|
||||
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
|
||||
conf := GetDefaultConnectionConfig()
|
||||
if err := applyConnectionOptions(conf, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateConnectionConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func GetDefaultConnectionConfig() *ConnectionConfig {
|
||||
return &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
Retry: GetDefaultRetryConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetDefaultRetryConfig() *RetryConfig {
|
||||
return &RetryConfig{
|
||||
MaxRetries: 0, // Infinite retries
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 5 * time.Second,
|
||||
JitterFactor: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error {
|
||||
for _, option := range options {
|
||||
if err := option(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateConnectionConfig(config *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(config.WriteTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.Retry != nil {
|
||||
err = validateMaxRetries(config.Retry.MaxRetries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateInitialDelay(config.Retry.InitialDelay)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateMaxDelay(config.Retry.MaxDelay)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateJitterFactor(config.Retry.JitterFactor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.Retry.InitialDelay > config.Retry.MaxDelay {
|
||||
return NewConfigError("initial delay may not exceed maximum delay")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateWriteTimeout(value time.Duration) error {
|
||||
if value < 0 {
|
||||
return InvalidWriteTimeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxRetries(value int) error {
|
||||
if value < 0 {
|
||||
return InvalidRetryMaxRetries
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateInitialDelay(value time.Duration) error {
|
||||
if value <= 0 {
|
||||
return InvalidRetryInitialDelay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxDelay(value time.Duration) error {
|
||||
if value <= 0 {
|
||||
return InvalidRetryMaxDelay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateJitterFactor(value float64) error {
|
||||
if value < 0.0 || value > 1.0 {
|
||||
return InvalidRetryJitterFactor
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.CloseHandler = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// When WriteTimeout is set to zero, read timeouts are disabled.
|
||||
func WithWriteTimeout(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.WriteTimeout = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetry enables retry with default parameters (infinite retries,
|
||||
// 1s initial delay, 5s max delay, 0.5 jitter factor).
|
||||
//
|
||||
// If passed after granular retry options (WithRetryMaxRetries, etc.),
|
||||
// it will overwrite them. Use either WithRetry alone or the granular
|
||||
// options; not both.
|
||||
func WithRetry() ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryMaxRetries(value int) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateMaxRetries(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.MaxRetries = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryInitialDelay(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateInitialDelay(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.InitialDelay = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryMaxDelay(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateMaxDelay(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.MaxDelay = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryJitterFactor(value float64) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateJitterFactor(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.JitterFactor = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
257
transport/config_test.go
Normal file
257
transport/config_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Connection Config Tests
|
||||
|
||||
func TestNewConnectionConfig(t *testing.T) {
|
||||
conf, err := NewConnectionConfig()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, conf, &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
Retry: GetDefaultRetryConfig(),
|
||||
})
|
||||
|
||||
// errors propagate
|
||||
_, err = NewConnectionConfig(WithRetryMaxRetries(-1))
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = NewConnectionConfig(WithRetryInitialDelay(10), WithRetryMaxDelay(1))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Default Tests
|
||||
|
||||
func TestDefaultConnectionConfig(t *testing.T) {
|
||||
conf := GetDefaultConnectionConfig()
|
||||
|
||||
assert.Equal(t, conf, &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
Retry: GetDefaultRetryConfig(),
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultRetryConnectionConfig(t *testing.T) {
|
||||
conf := GetDefaultRetryConfig()
|
||||
|
||||
assert.Equal(t, conf, &RetryConfig{
|
||||
MaxRetries: 0,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 5 * time.Second,
|
||||
JitterFactor: 0.5,
|
||||
})
|
||||
}
|
||||
|
||||
// Builder Tests
|
||||
|
||||
func TestApplyConnectionOptions(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
err := applyConnectionOptions(
|
||||
conf,
|
||||
WithRetryMaxRetries(0),
|
||||
WithRetryInitialDelay(3*time.Second),
|
||||
WithRetryJitterFactor(0.5),
|
||||
)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, conf.Retry.MaxRetries)
|
||||
assert.Equal(t, 3*time.Second, conf.Retry.InitialDelay)
|
||||
assert.Equal(t, 0.5, conf.Retry.JitterFactor)
|
||||
|
||||
// errors propagate
|
||||
err = applyConnectionOptions(
|
||||
conf,
|
||||
WithRetryMaxRetries(-10),
|
||||
)
|
||||
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxRetries)
|
||||
}
|
||||
|
||||
// Option Tests
|
||||
|
||||
func TestWithCloseHandler(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
opt := WithCloseHandler(func(code int, text string) error { return nil })
|
||||
err := applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, conf.CloseHandler(0, ""))
|
||||
}
|
||||
|
||||
func TestWithWriteTimeout(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
opt := WithWriteTimeout(30)
|
||||
err := applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, conf.WriteTimeout, time.Duration(30))
|
||||
|
||||
// zero allowed
|
||||
conf = &ConnectionConfig{}
|
||||
opt = WithWriteTimeout(0)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, conf.WriteTimeout, time.Duration(0))
|
||||
|
||||
// negative disallowed
|
||||
conf = &ConnectionConfig{}
|
||||
opt = WithWriteTimeout(-30)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidWriteTimeout)
|
||||
assert.ErrorContains(t, err, "write timeout cannot be negative")
|
||||
}
|
||||
|
||||
func TestWithRetry(t *testing.T) {
|
||||
t.Run("default", func(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
opt := WithRetry()
|
||||
err := applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, conf.Retry)
|
||||
assert.Equal(t, conf.Retry, GetDefaultRetryConfig())
|
||||
})
|
||||
|
||||
t.Run("with attempts", func(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
opt := WithRetryMaxRetries(3)
|
||||
err := applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, conf.Retry.MaxRetries)
|
||||
|
||||
// zero allowed
|
||||
opt = WithRetryMaxRetries(0)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// negative disallowed
|
||||
opt = WithRetryMaxRetries(-10)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxRetries)
|
||||
assert.ErrorContains(t, err, "max retry count cannot be negative")
|
||||
})
|
||||
|
||||
t.Run("with initial delay", func(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
opt := WithRetryInitialDelay(10 * time.Second)
|
||||
err := applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10*time.Second, conf.Retry.InitialDelay)
|
||||
|
||||
// zero disallowed
|
||||
opt = WithRetryInitialDelay(0 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidRetryInitialDelay)
|
||||
assert.ErrorContains(t, err, "initial delay must be positive")
|
||||
|
||||
// negative disallowed
|
||||
opt = WithRetryInitialDelay(-10 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidRetryInitialDelay)
|
||||
})
|
||||
|
||||
t.Run("with max delay", func(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
opt := WithRetryMaxDelay(10 * time.Second)
|
||||
err := applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10*time.Second, conf.Retry.MaxDelay)
|
||||
|
||||
// zero disallowed
|
||||
opt = WithRetryMaxDelay(0 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxDelay)
|
||||
assert.ErrorContains(t, err, "max delay must be positive")
|
||||
|
||||
// negative disallowed
|
||||
opt = WithRetryMaxDelay(-10 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxDelay)
|
||||
})
|
||||
|
||||
t.Run("with jitter factor", func(t *testing.T) {
|
||||
conf := &ConnectionConfig{}
|
||||
|
||||
opt := WithRetryJitterFactor(0.2)
|
||||
err := applyConnectionOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0.2, conf.Retry.JitterFactor)
|
||||
|
||||
// negative disallowed
|
||||
opt = WithRetryJitterFactor(-1)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidRetryJitterFactor)
|
||||
assert.ErrorContains(t, err, "jitter factor must be between 0.0 and 1.0")
|
||||
|
||||
// >1 disallowed
|
||||
opt = WithRetryJitterFactor(1.1)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, InvalidRetryJitterFactor)
|
||||
})
|
||||
}
|
||||
|
||||
// Validation Tests
|
||||
|
||||
func TestValidateConnectionConfig(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
conf ConnectionConfig
|
||||
wantErr error
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "valid empty",
|
||||
conf: *&ConnectionConfig{},
|
||||
},
|
||||
{
|
||||
name: "valid defaults",
|
||||
conf: *GetDefaultConnectionConfig(),
|
||||
},
|
||||
{
|
||||
name: "valid complete",
|
||||
conf: ConnectionConfig{
|
||||
CloseHandler: (func(code int, text string) error { return nil }),
|
||||
WriteTimeout: time.Duration(30),
|
||||
Retry: &RetryConfig{
|
||||
MaxRetries: 0,
|
||||
InitialDelay: 2 * time.Second,
|
||||
MaxDelay: 10 * time.Second,
|
||||
JitterFactor: 0.2,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid - initial delay > max delay",
|
||||
conf: ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
InitialDelay: 10 * time.Second,
|
||||
MaxDelay: 1 * time.Second,
|
||||
},
|
||||
},
|
||||
wantErrText: "initial delay may not exceed maximum delay",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateConnectionConfig(&tc.conf)
|
||||
|
||||
if tc.wantErr != nil || tc.wantErrText != "" {
|
||||
if tc.wantErr != nil {
|
||||
assert.ErrorIs(t, err, tc.wantErr)
|
||||
}
|
||||
|
||||
if tc.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tc.wantErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
343
transport/connection.go
Normal file
343
transport/connection.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type ConnectionState int
|
||||
|
||||
const (
|
||||
StateDisconnected ConnectionState = iota
|
||||
StateConnecting
|
||||
StateConnected
|
||||
StateClosed
|
||||
)
|
||||
|
||||
func (s ConnectionState) String() string {
|
||||
switch s {
|
||||
case StateDisconnected:
|
||||
return "disconnected"
|
||||
case StateConnecting:
|
||||
return "connecting"
|
||||
case StateConnected:
|
||||
return "connected"
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
url *url.URL
|
||||
dialer types.Dialer
|
||||
socket types.Socket
|
||||
config *ConnectionConfig
|
||||
logger *slog.Logger
|
||||
|
||||
incoming chan []byte
|
||||
outgoing chan []byte
|
||||
errors chan error
|
||||
done chan struct{}
|
||||
|
||||
state ConnectionState
|
||||
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
||||
if config == nil {
|
||||
config = GetDefaultConnectionConfig()
|
||||
}
|
||||
|
||||
if err := ValidateConnectionConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedURL, err := ParseURL(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := &Connection{
|
||||
url: parsedURL,
|
||||
dialer: NewDialer(),
|
||||
socket: nil,
|
||||
config: config,
|
||||
logger: logger,
|
||||
incoming: make(chan []byte, 100),
|
||||
outgoing: make(chan []byte, 100),
|
||||
errors: make(chan error, 10),
|
||||
state: StateDisconnected,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func NewConnectionFromSocket(socket types.Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
||||
if socket == nil {
|
||||
return nil, NewConnectionError("socket cannot be nil")
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = GetDefaultConnectionConfig()
|
||||
}
|
||||
|
||||
if err := ValidateConnectionConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := &Connection{
|
||||
url: nil,
|
||||
dialer: nil,
|
||||
socket: socket,
|
||||
config: config,
|
||||
logger: logger,
|
||||
incoming: make(chan []byte, 100),
|
||||
outgoing: make(chan []byte, 100),
|
||||
errors: make(chan error, 10),
|
||||
state: StateConnected,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if config.CloseHandler != nil {
|
||||
socket.SetCloseHandler(config.CloseHandler)
|
||||
}
|
||||
|
||||
conn.startReader()
|
||||
conn.startWriter()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Connect() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.socket != nil {
|
||||
return NewConnectionError("connection already has socket")
|
||||
}
|
||||
|
||||
if c.closed {
|
||||
return NewConnectionError("connection is closed")
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connecting")
|
||||
}
|
||||
|
||||
c.state = StateConnecting
|
||||
|
||||
retryMgr := NewRetryManager(c.config.Retry)
|
||||
socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String(), c.logger)
|
||||
|
||||
if err != nil {
|
||||
c.state = StateDisconnected
|
||||
if c.logger != nil {
|
||||
c.logger.Error("connection failed", "error", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.socket = socket
|
||||
c.state = StateConnected
|
||||
|
||||
if c.config.CloseHandler != nil {
|
||||
c.socket.SetCloseHandler(c.config.CloseHandler)
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("connected")
|
||||
}
|
||||
|
||||
c.startReader()
|
||||
c.startWriter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) startReader() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
|
||||
for {
|
||||
messageType, data, err := c.socket.ReadMessage()
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
c.logger.Info("connection closed by peer",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
default:
|
||||
c.logger.Error("unexpected close",
|
||||
"code", closeErr.Code,
|
||||
"text", closeErr.Text,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
c.logger.Error("read error", "error", err)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case c.errors <- err:
|
||||
case <-c.done:
|
||||
}
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage ||
|
||||
messageType == websocket.BinaryMessage {
|
||||
select {
|
||||
case c.incoming <- data:
|
||||
case <-c.done:
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) startWriter() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case data := <-c.outgoing:
|
||||
if c.config.WriteTimeout > 0 {
|
||||
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("write deadline error", "error", err)
|
||||
}
|
||||
select {
|
||||
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
|
||||
case <-c.done:
|
||||
}
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("write error", "error", err)
|
||||
}
|
||||
select {
|
||||
case c.errors <- err:
|
||||
case <-c.done:
|
||||
}
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Send(data []byte) error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.closed {
|
||||
return NewConnectionError("connection closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case c.outgoing <- data:
|
||||
return nil
|
||||
case <-c.done:
|
||||
return NewConnectionError("connection closing")
|
||||
default:
|
||||
return NewConnectionError("outgoing queue full")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) Incoming() <-chan []byte {
|
||||
return c.incoming
|
||||
}
|
||||
|
||||
func (c *Connection) Errors() <-chan error {
|
||||
return c.errors
|
||||
}
|
||||
|
||||
func (c *Connection) shutdown() {
|
||||
c.mu.Lock()
|
||||
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closing", "state", c.state.String())
|
||||
}
|
||||
c.closed = true
|
||||
c.state = StateClosed
|
||||
socket := c.socket
|
||||
close(c.done)
|
||||
c.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
if socket != nil {
|
||||
// force immediate timeout of any blocked network I/O
|
||||
expired := time.Now().Add(-1 * time.Minute)
|
||||
socket.SetReadDeadline(expired)
|
||||
socket.SetWriteDeadline(expired)
|
||||
err := socket.Close()
|
||||
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("socket close failed", "error", err)
|
||||
}
|
||||
} else {
|
||||
if c.logger != nil {
|
||||
c.logger.Info("closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.wg.Wait()
|
||||
close(c.incoming)
|
||||
close(c.outgoing)
|
||||
close(c.errors)
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Close() {
|
||||
c.shutdown()
|
||||
}
|
||||
|
||||
func (c *Connection) State() ConnectionState {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.state
|
||||
}
|
||||
|
||||
func (c *Connection) SetDialer(d types.Dialer) {
|
||||
c.dialer = d
|
||||
}
|
||||
134
transport/connection_close_test.go
Normal file
134
transport/connection_close_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
t.Run("close succeeds on disconnected connection", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("close is idempotent", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("close with nil socket", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, conn.socket)
|
||||
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("socket close error does not propagate", func(t *testing.T) {
|
||||
expectedErr := fmt.Errorf("socket close failed")
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.CloseFunc = func() error {
|
||||
return expectedErr
|
||||
}
|
||||
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
conn.socket = mockSocket
|
||||
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("channels close after close", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case _, ok := <-conn.Errors():
|
||||
return !ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
|
||||
"errors channel should close")
|
||||
})
|
||||
|
||||
t.Run("send fails after close", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "connection closed")
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestConnectedConnectionClose(t *testing.T) {
|
||||
t.Run("blocked on ReadMessage, unblocks on closed", func(t *testing.T) {
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
|
||||
// Send a message to ensure reader loop is blocking
|
||||
canary := []byte("canary")
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage, Data: canary}
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case msg := <-conn.Incoming():
|
||||
return bytes.Equal(msg, canary)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("writer active during close exits cleanly", func(t *testing.T) {
|
||||
conn, _, _, _ := setupTestConnection(t, nil)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
conn.Send([]byte("message"))
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
|
||||
err := conn.Send([]byte("late"))
|
||||
assert.Error(t, err, "Send should fail after close")
|
||||
assert.ErrorContains(t, err, "connection closed")
|
||||
})
|
||||
|
||||
t.Run("both goroutines active during close", func(t *testing.T) {
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage,
|
||||
Data: []byte(fmt.Sprintf("in-%d", i)),
|
||||
}
|
||||
conn.Send([]byte(fmt.Sprintf("out-%d", i)))
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
}
|
||||
285
transport/connection_goroutine_test.go
Normal file
285
transport/connection_goroutine_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStartReader(t *testing.T) {
|
||||
t.Run("text messages route to incoming channel", func(t *testing.T) {
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
defer conn.Close()
|
||||
|
||||
testData := []byte("hello")
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage,
|
||||
Data: testData,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
honeybeetest.ExpectIncoming(t, conn.Incoming(), testData)
|
||||
})
|
||||
|
||||
t.Run("binary messages route to incoming channel", func(t *testing.T) {
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
defer conn.Close()
|
||||
|
||||
testData := []byte{0x00, 0x01, 0x02}
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.BinaryMessage,
|
||||
Data: testData,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
honeybeetest.ExpectIncoming(t, conn.Incoming(), testData)
|
||||
})
|
||||
|
||||
t.Run("multiple messages processed sequentially", func(t *testing.T) {
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
defer conn.Close()
|
||||
|
||||
messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")}
|
||||
for _, msg := range messages {
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage, Data: msg, Err: nil}
|
||||
}
|
||||
|
||||
for _, expected := range messages {
|
||||
honeybeetest.ExpectIncoming(t, conn.Incoming(), expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reader exits on socket read error", func(t *testing.T) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
readErr := fmt.Errorf("read failed")
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, readErr
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case err := <-conn.Errors():
|
||||
return err == readErr
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return conn.State() == StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStartWriter(t *testing.T) {
|
||||
t.Run("data from outgoing triggers write", func(t *testing.T) {
|
||||
conn, _, _, outgoingData := setupTestConnection(t, nil)
|
||||
defer conn.Close()
|
||||
|
||||
testData := []byte("test message")
|
||||
err := conn.Send(testData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, testData)
|
||||
})
|
||||
|
||||
t.Run("multiple messages processed sequentially", func(t *testing.T) {
|
||||
conn, _, _, outgoingData := setupTestConnection(t, nil)
|
||||
defer conn.Close()
|
||||
|
||||
messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")}
|
||||
for _, msg := range messages {
|
||||
err := conn.Send(msg)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for _, expected := range messages {
|
||||
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write timeout disabled when zero", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode")
|
||||
}
|
||||
|
||||
config := &ConnectionConfig{WriteTimeout: 0}
|
||||
|
||||
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
deadlineCalled := make(chan struct{}, 1)
|
||||
mockSocket.SetWriteDeadlineFunc = func(t time.Time) error {
|
||||
deadlineCalled <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- honeybeetest.MockOutgoingData{
|
||||
MsgType: msgType, Data: data}:
|
||||
case <-mockSocket.Closed:
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
select {
|
||||
case <-deadlineCalled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick,
|
||||
"SetWriteDeadline should not be called when timeout is zero")
|
||||
})
|
||||
|
||||
t.Run("write timeout sets deadline when positive", func(t *testing.T) {
|
||||
config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond}
|
||||
|
||||
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
deadlineCalled := make(chan struct{}, 1)
|
||||
mockSocket.SetWriteDeadlineFunc = func(t time.Time) error {
|
||||
deadlineCalled <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- honeybeetest.MockOutgoingData{
|
||||
MsgType: msgType, Data: data}:
|
||||
case <-mockSocket.Closed:
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-deadlineCalled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
|
||||
"SetWriteDeadline should be called when timeout is positive")
|
||||
})
|
||||
|
||||
t.Run("writer exits on deadline error", func(t *testing.T) {
|
||||
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond}
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
mockSocket.SetWriteDeadlineFunc = func(t time.Time) error {
|
||||
return fmt.Errorf("test error")
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case err := <-conn.Errors():
|
||||
return err != nil &&
|
||||
strings.Contains(err.Error(), "failed to set write deadline")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return conn.State() == StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
|
||||
t.Run("writer exits on socket write error", func(t *testing.T) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
writeErr := fmt.Errorf("write failed")
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case err := <-conn.Errors():
|
||||
return err == writeErr
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return conn.State() == StateClosed
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
// Helpers
|
||||
111
transport/connection_send_test.go
Normal file
111
transport/connection_send_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConnectionSend(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
setup func(*Connection)
|
||||
data []byte
|
||||
wantErr bool
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "send succeeds when open",
|
||||
setup: func(c *Connection) {},
|
||||
data: []byte("test message"),
|
||||
},
|
||||
{
|
||||
name: "send fails when closed",
|
||||
setup: func(c *Connection) {
|
||||
c.Close()
|
||||
},
|
||||
data: []byte("test"),
|
||||
wantErr: true,
|
||||
wantErrText: "connection closed",
|
||||
},
|
||||
{
|
||||
name: "send fails when queue full",
|
||||
setup: func(c *Connection) {
|
||||
// Fill outgoing channel
|
||||
for i := 0; i < 100; i++ {
|
||||
c.outgoing <- []byte("filler")
|
||||
}
|
||||
},
|
||||
data: []byte("overflow"),
|
||||
wantErr: true,
|
||||
wantErrText: "outgoing queue full",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tc.setup(conn)
|
||||
|
||||
err = conn.Send(tc.data)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tc.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tc.wantErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case sent := <-conn.outgoing:
|
||||
assert.Equal(t, tc.data, sent)
|
||||
default:
|
||||
t.Fatal("data not sent to outgoing channel")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Run with `go test -race` to ensure no race conditions occur
|
||||
func TestConnectionSendConcurrent(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// continuously consume outgoing channel in background
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-conn.outgoing:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
// Send from multiple goroutines concurrently
|
||||
const goroutines = 5
|
||||
const messagesPerGoroutine = 10
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < messagesPerGoroutine; j++ {
|
||||
data := []byte(fmt.Sprintf("msg-%d-%d", id, j))
|
||||
err := conn.Send(data)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
486
transport/connection_test.go
Normal file
486
transport/connection_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Connection state tests
|
||||
|
||||
func TestConnectionStateString(t *testing.T) {
|
||||
cases := []struct {
|
||||
state ConnectionState
|
||||
want string
|
||||
}{
|
||||
{StateDisconnected, "disconnected"},
|
||||
{StateConnecting, "connecting"},
|
||||
{StateConnected, "connected"},
|
||||
{StateClosed, "closed"},
|
||||
{ConnectionState(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.want, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, tc.state.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionState(t *testing.T) {
|
||||
// Test initial state
|
||||
conn, _ := NewConnection("ws://test", nil, nil)
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
|
||||
// Test state after FromSocket (should be Connected)
|
||||
conn2, _ := NewConnectionFromSocket(honeybeetest.NewMockSocket(), nil, nil)
|
||||
assert.Equal(t, StateConnected, conn2.State())
|
||||
|
||||
// Test state after close
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
}
|
||||
|
||||
// Connection constructor tests
|
||||
|
||||
func TestNewConnection(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
config *ConnectionConfig
|
||||
wantErr bool
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "valid url, nil config",
|
||||
url: "ws://example.com",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "valid url, valid config",
|
||||
url: "wss://relay.example.com:8080/path",
|
||||
config: &ConnectionConfig{WriteTimeout: 30 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "invalid url",
|
||||
url: "http://example.com",
|
||||
config: nil,
|
||||
wantErr: true,
|
||||
wantErrText: "URL must use ws:// or wss:// scheme",
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
url: "ws://example.com",
|
||||
config: &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
InitialDelay: 10 * time.Second,
|
||||
MaxDelay: 1 * time.Second,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrText: "initial delay may not exceed maximum delay",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conn, err := NewConnection(tc.url, tc.config, nil)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tc.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tc.wantErrText)
|
||||
}
|
||||
assert.Nil(t, conn)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, conn)
|
||||
|
||||
// Verify struct fields
|
||||
assert.NotNil(t, conn.url)
|
||||
assert.NotNil(t, conn.dialer)
|
||||
assert.Nil(t, conn.socket)
|
||||
assert.NotNil(t, conn.config)
|
||||
assert.NotNil(t, conn.incoming)
|
||||
assert.NotNil(t, conn.outgoing)
|
||||
assert.NotNil(t, conn.errors)
|
||||
assert.NotNil(t, conn.done)
|
||||
assert.Equal(t, StateDisconnected, conn.state)
|
||||
assert.False(t, conn.closed)
|
||||
|
||||
// Verify default config is used if nil is passed
|
||||
if tc.config == nil {
|
||||
assert.Equal(t, GetDefaultConnectionConfig(), conn.config)
|
||||
} else {
|
||||
assert.Equal(t, tc.config, conn.config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromSocket(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
socket types.Socket
|
||||
config *ConnectionConfig
|
||||
wantErr bool
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "nil socket",
|
||||
socket: nil,
|
||||
config: nil,
|
||||
wantErr: true,
|
||||
wantErrText: "socket cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "valid socket with nil config",
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "valid socket with valid config",
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: &ConnectionConfig{WriteTimeout: 30 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
InitialDelay: 10 * time.Second,
|
||||
MaxDelay: 1 * time.Second,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrText: "initial delay may not exceed maximum delay",
|
||||
},
|
||||
{
|
||||
name: "close handler set when provided",
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: &ConnectionConfig{
|
||||
CloseHandler: func(code int, text string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// track if SetCloseHandler was called
|
||||
closeHandlerSet := false
|
||||
if tc.socket != nil {
|
||||
mockSocket := tc.socket.(*honeybeetest.MockSocket)
|
||||
originalSetCloseHandler := mockSocket.SetCloseHandlerFunc
|
||||
|
||||
// wrapper around the original handler function
|
||||
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
|
||||
closeHandlerSet = true
|
||||
if originalSetCloseHandler != nil {
|
||||
originalSetCloseHandler(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(tc.socket, tc.config, nil)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tc.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tc.wantErrText)
|
||||
}
|
||||
assert.Nil(t, conn)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, conn)
|
||||
|
||||
// Verify fields initialized correctly
|
||||
assert.Nil(t, conn.url)
|
||||
assert.Nil(t, conn.dialer)
|
||||
assert.Equal(t, tc.socket, conn.socket)
|
||||
assert.NotNil(t, conn.config)
|
||||
assert.NotNil(t, conn.incoming)
|
||||
assert.NotNil(t, conn.outgoing)
|
||||
assert.NotNil(t, conn.errors)
|
||||
assert.NotNil(t, conn.done)
|
||||
assert.Equal(t, StateConnected, conn.state)
|
||||
assert.False(t, conn.closed)
|
||||
|
||||
// Verify default config is used if nil is passed
|
||||
if tc.config == nil {
|
||||
assert.Equal(t, GetDefaultConnectionConfig(), conn.config)
|
||||
} else {
|
||||
assert.Equal(t, tc.config, conn.config)
|
||||
}
|
||||
|
||||
// Verify close handler was set if provided
|
||||
if tc.config != nil && tc.config.CloseHandler != nil {
|
||||
assert.True(t, closeHandlerSet, "CloseHandler should be set on socket")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
t.Run("connect fails when socket already present", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.socket = honeybeetest.NewMockSocket()
|
||||
|
||||
err = conn.Connect()
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "already has socket")
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
})
|
||||
|
||||
t.Run("connect fails when connection closed", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
err = conn.Connect()
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "connection is closed")
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
})
|
||||
|
||||
t.Run("connect succeeds and starts goroutines", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}
|
||||
return nil
|
||||
}
|
||||
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
err = conn.Connect()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateConnected, conn.State())
|
||||
|
||||
testData := []byte("test")
|
||||
conn.Send(testData)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case msg := <-outgoingData:
|
||||
return bytes.Equal(msg.Data, testData)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
|
||||
t.Run("connect retries on dial failure", func(t *testing.T) {
|
||||
config := &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
MaxRetries: 2,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Millisecond,
|
||||
JitterFactor: 0.0,
|
||||
},
|
||||
}
|
||||
conn, err := NewConnection("ws://test", config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
attemptCount := 0
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
attemptCount++
|
||||
if attemptCount < 3 {
|
||||
return nil, nil, fmt.Errorf("dial failed")
|
||||
}
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
err = conn.Connect()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, attemptCount)
|
||||
assert.Equal(t, StateConnected, conn.State())
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
|
||||
t.Run("connect fails after max retries", func(t *testing.T) {
|
||||
config := &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
MaxRetries: 2,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Millisecond,
|
||||
JitterFactor: 0.0,
|
||||
},
|
||||
}
|
||||
conn, err := NewConnection("ws://test", config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return nil, nil, fmt.Errorf("dial failed")
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
err = conn.Connect()
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "dial failed")
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
})
|
||||
|
||||
t.Run("state transitions during connect", func(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
|
||||
stateDuringDial := StateDisconnected
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
stateDuringDial = conn.state
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
conn.Connect()
|
||||
|
||||
assert.Equal(t, StateConnecting, stateDuringDial)
|
||||
assert.Equal(t, StateConnected, conn.State())
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
|
||||
t.Run("close handler configured when provided", func(t *testing.T) {
|
||||
handlerSet := false
|
||||
config := &ConnectionConfig{
|
||||
CloseHandler: func(code int, text string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
conn, err := NewConnection("ws://test", config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
|
||||
handlerSet = true
|
||||
}
|
||||
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
conn.Connect()
|
||||
|
||||
assert.True(t, handlerSet, "close handler should be set on socket")
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// Connection method tests
|
||||
|
||||
func TestConnectionIncoming(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
incoming := conn.Incoming()
|
||||
assert.NotNil(t, incoming)
|
||||
|
||||
// send data through the channel to verify they are the same
|
||||
testData := []byte("test")
|
||||
conn.incoming <- testData
|
||||
received := <-incoming
|
||||
assert.Equal(t, testData, received)
|
||||
}
|
||||
|
||||
func TestConnectionErrors(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
errors := conn.Errors()
|
||||
assert.NotNil(t, errors)
|
||||
|
||||
// send data through the channel to verify they are the same
|
||||
testErr := fmt.Errorf("test error")
|
||||
conn.errors <- testErr
|
||||
received := <-errors
|
||||
assert.Equal(t, testErr, received)
|
||||
}
|
||||
|
||||
// Test helpers
|
||||
|
||||
func setupTestConnection(t *testing.T, config *ConnectionConfig) (
|
||||
conn *Connection,
|
||||
mockSocket *honeybeetest.MockSocket,
|
||||
incomingData chan honeybeetest.MockIncomingData,
|
||||
outgoingData chan honeybeetest.MockOutgoingData,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
incomingData = make(chan honeybeetest.MockIncomingData, 10)
|
||||
outgoingData = make(chan honeybeetest.MockOutgoingData, 10)
|
||||
|
||||
mockSocket = honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wire ReadMessage to pull from incomingData channel
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
select {
|
||||
case data := <-incomingData:
|
||||
return data.MsgType, data.Data, data.Err
|
||||
case <-mockSocket.Closed:
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Wire WriteMessage to push to outgoingData channel
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}:
|
||||
return nil
|
||||
case <-mockSocket.Closed:
|
||||
return io.EOF
|
||||
default:
|
||||
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
conn, err = NewConnectionFromSocket(mockSocket, config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return conn, mockSocket, incomingData, outgoingData
|
||||
}
|
||||
24
transport/errors.go
Normal file
24
transport/errors.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package transport
|
||||
|
||||
import "errors"
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
// URL Errors
|
||||
InvalidProtocol = errors.New("URL must use ws:// or wss:// scheme")
|
||||
|
||||
// Configuration Errors
|
||||
InvalidWriteTimeout = errors.New("write timeout cannot be negative")
|
||||
InvalidRetryMaxRetries = errors.New("max retry count cannot be negative")
|
||||
InvalidRetryInitialDelay = errors.New("initial delay must be positive")
|
||||
InvalidRetryMaxDelay = errors.New("max delay must be positive")
|
||||
InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0")
|
||||
)
|
||||
|
||||
func NewConfigError(text string) error {
|
||||
return fmt.Errorf("configuration error: %s", text)
|
||||
}
|
||||
|
||||
func NewConnectionError(text string) error {
|
||||
return fmt.Errorf("connection error: %s", text)
|
||||
}
|
||||
485
transport/logging_test.go
Normal file
485
transport/logging_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Helpers
|
||||
|
||||
type expectedLog struct {
|
||||
level slog.Level
|
||||
msg string
|
||||
attrs map[string]any
|
||||
}
|
||||
|
||||
func assertLogSequence(t *testing.T, records []slog.Record, expected []expectedLog) {
|
||||
t.Helper()
|
||||
|
||||
recIndex := 0
|
||||
for expIndex, exp := range expected {
|
||||
found := false
|
||||
|
||||
// Search forward through records
|
||||
for recIndex < len(records) {
|
||||
rec := records[recIndex]
|
||||
|
||||
if rec.Level == exp.level && strings.Contains(rec.Message, exp.msg) {
|
||||
allAttrsMatch := true
|
||||
for key, expectedValue := range exp.attrs {
|
||||
if !assertAttributePresent(t, rec, key, expectedValue) {
|
||||
allAttrsMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allAttrsMatch {
|
||||
found = true
|
||||
recIndex++ // Consume this record
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
recIndex++ // Move to next record
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Fatalf(
|
||||
"expected log not found: index=%d level=%v msg=%q attrs=%v",
|
||||
expIndex, exp.level, exp.msg, exp.attrs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findLogRecord(
|
||||
records []slog.Record, level slog.Level, msgSnippet string,
|
||||
) *slog.Record {
|
||||
for i := range records {
|
||||
if records[i].Level == level && strings.Contains(records[i].Message, msgSnippet) {
|
||||
return &records[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertAttributePresent(
|
||||
t *testing.T, record slog.Record, key string, expectedValue any,
|
||||
) bool {
|
||||
t.Helper()
|
||||
|
||||
var found bool
|
||||
var actualValue any
|
||||
|
||||
record.Attrs(func(attr slog.Attr) bool {
|
||||
if attr.Key == key {
|
||||
found = true
|
||||
actualValue = attr.Value.Any()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !found {
|
||||
t.Fatalf("attribute %q not found in log record", key)
|
||||
}
|
||||
|
||||
if !valuesEqual(actualValue, expectedValue) {
|
||||
t.Errorf("attribute %q mismatch: expected=%v actual=%v", key, expectedValue, actualValue)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func valuesEqual(a, b any) bool {
|
||||
// Direct equality
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle int/int64 conversions
|
||||
aInt, aIsInt := toInt64(a)
|
||||
bInt, bIsInt := toInt64(b)
|
||||
if aIsInt && bIsInt {
|
||||
return aInt == bInt
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func toInt64(v any) (int64, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return int64(val), true
|
||||
case int64:
|
||||
return val, true
|
||||
case int32:
|
||||
return int64(val), true
|
||||
case int16:
|
||||
return int64(val), true
|
||||
case int8:
|
||||
return int64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func TestConnectLogging(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
conn, err := NewConnection("ws://test", nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
err = conn.Connect()
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
{slog.LevelInfo, "connecting", map[string]any{}},
|
||||
{slog.LevelInfo, "dialing", map[string]any{"attempt": 1}},
|
||||
{slog.LevelInfo, "dial successful", map[string]any{"attempt": 1}},
|
||||
{slog.LevelInfo, "connected", map[string]any{}},
|
||||
}
|
||||
|
||||
assertLogSequence(t, records, expected)
|
||||
})
|
||||
|
||||
t.Run("max retries failure", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
config := &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
MaxRetries: 2,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Millisecond,
|
||||
JitterFactor: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
conn, err := NewConnection("ws://test", config, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dialErr := fmt.Errorf("dial error")
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return nil, nil, dialErr
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
err = conn.Connect()
|
||||
assert.Error(t, err)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
{slog.LevelInfo, "connecting", map[string]any{}},
|
||||
{slog.LevelInfo, "dialing", map[string]any{"attempt": 1}},
|
||||
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}},
|
||||
{slog.LevelInfo, "dialing", map[string]any{"attempt": 2}},
|
||||
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}},
|
||||
{slog.LevelInfo, "dialing", map[string]any{"attempt": 3}},
|
||||
{slog.LevelError, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}},
|
||||
{slog.LevelError, "connection failed", map[string]any{"error": dialErr}},
|
||||
}
|
||||
|
||||
assertLogSequence(t, records, expected)
|
||||
})
|
||||
|
||||
t.Run("success after retry", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
config := &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Millisecond,
|
||||
JitterFactor: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
conn, err := NewConnection("ws://test", config, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
attemptCount := 0
|
||||
dialErr := fmt.Errorf("dial error")
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
attemptCount++
|
||||
if attemptCount < 3 {
|
||||
return nil, nil, dialErr
|
||||
}
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
err = conn.Connect()
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
{slog.LevelInfo, "connecting", map[string]any{}},
|
||||
{slog.LevelInfo, "dialing", map[string]any{"attempt": 1}},
|
||||
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}},
|
||||
{slog.LevelInfo, "dialing", map[string]any{"attempt": 2}},
|
||||
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}},
|
||||
{slog.LevelInfo, "dialing", map[string]any{"attempt": 3}},
|
||||
{slog.LevelInfo, "dial successful", map[string]any{"attempt": 3}},
|
||||
{slog.LevelInfo, "connected", map[string]any{}},
|
||||
}
|
||||
|
||||
assertLogSequence(t, records, expected)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseLogging(t *testing.T) {
|
||||
t.Run("normal close", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
||||
{slog.LevelInfo, "closed", map[string]any{}},
|
||||
}
|
||||
|
||||
assertLogSequence(t, records, expected)
|
||||
})
|
||||
|
||||
t.Run("close with socket error", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
closeErr := fmt.Errorf("close error")
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.CloseFunc = func() error {
|
||||
return closeErr
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "socket close failed") != nil
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
expected := []expectedLog{
|
||||
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
||||
{slog.LevelError, "socket close failed", map[string]any{"error": closeErr}},
|
||||
}
|
||||
|
||||
assertLogSequence(t, records, expected)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReaderLogging(t *testing.T) {
|
||||
t.Run("clean close by peer", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, &websocket.CloseError{
|
||||
Code: websocket.CloseNormalClosure,
|
||||
Text: "goodbye",
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") != nil
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
record := findLogRecord(mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer")
|
||||
assert.NotNil(t, record)
|
||||
assertAttributePresent(t, *record, "code", websocket.CloseNormalClosure)
|
||||
assertAttributePresent(t, *record, "text", "goodbye")
|
||||
|
||||
})
|
||||
|
||||
t.Run("unexpected close", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, &websocket.CloseError{
|
||||
Code: websocket.CloseProtocolError,
|
||||
Text: "bad protocol",
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "unexpected close") != nil
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
record := findLogRecord(mockHandler.GetRecords(), slog.LevelError, "unexpected close")
|
||||
assert.NotNil(t, record)
|
||||
assertAttributePresent(t, *record, "code", websocket.CloseProtocolError)
|
||||
assertAttributePresent(t, *record, "text", "bad protocol")
|
||||
|
||||
})
|
||||
|
||||
t.Run("read error", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "read error") != nil
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriterLogging(t *testing.T) {
|
||||
t.Run("write deadline error", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond}
|
||||
|
||||
deadlineErr := fmt.Errorf("deadline error")
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.SetWriteDeadlineFunc = func(time.Time) error {
|
||||
return deadlineErr
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, config, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "write deadline error") != nil
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
record := findLogRecord(records, slog.LevelError, "write deadline error")
|
||||
assert.NotNil(t, record)
|
||||
assertAttributePresent(t, *record, "error", deadlineErr)
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
|
||||
t.Run("write message error", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
writeErr := fmt.Errorf("write error")
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.WriteMessageFunc = func(int, []byte) error {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = conn.Send([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "write error") != nil
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
record := findLogRecord(records, slog.LevelError, "write error")
|
||||
assert.NotNil(t, record)
|
||||
assertAttributePresent(t, *record, "error", writeErr)
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoggingDisabled(t *testing.T) {
|
||||
t.Run("nil logger produces no logs", func(t *testing.T) {
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
|
||||
err = conn.Connect()
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
assert.Empty(t, records)
|
||||
})
|
||||
}
|
||||
66
transport/retry.go
Normal file
66
transport/retry.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RetryManager struct {
|
||||
config *RetryConfig
|
||||
retryCount int
|
||||
}
|
||||
|
||||
func NewRetryManager(config *RetryConfig) *RetryManager {
|
||||
return &RetryManager{
|
||||
config: config,
|
||||
retryCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RetryManager) ShouldRetry() bool {
|
||||
if r.config == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if r.config.MaxRetries > 0 && r.retryCount >= r.config.MaxRetries {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *RetryManager) CalculateDelay() time.Duration {
|
||||
if r.config == nil {
|
||||
return time.Second
|
||||
}
|
||||
|
||||
// First attempt: immediate retry
|
||||
if r.retryCount == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Exponential backoff: InitialDelay * 2^(attempts-1)
|
||||
backoffMultiplier := math.Pow(2, float64(r.retryCount-1))
|
||||
baseDelay := float64(r.config.InitialDelay) * backoffMultiplier
|
||||
|
||||
// Apply jitter: delay * (1 + jitterFactor * (random - 0.5))
|
||||
random := rand.Float64()
|
||||
jitterMultiplier := 1 + r.config.JitterFactor*(random-0.5)
|
||||
delay := time.Duration(baseDelay * jitterMultiplier)
|
||||
|
||||
// Cap at MaxDelay
|
||||
if delay > r.config.MaxDelay {
|
||||
delay = r.config.MaxDelay
|
||||
}
|
||||
|
||||
return delay
|
||||
}
|
||||
|
||||
func (m *RetryManager) RecordRetry() {
|
||||
m.retryCount++
|
||||
}
|
||||
|
||||
func (m *RetryManager) RetryCount() int {
|
||||
return m.retryCount
|
||||
}
|
||||
147
transport/retry_test.go
Normal file
147
transport/retry_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewRetryManager(t *testing.T) {
|
||||
config := &RetryConfig{
|
||||
MaxRetries: 0,
|
||||
}
|
||||
|
||||
mgr := NewRetryManager(config)
|
||||
|
||||
assert.Equal(t, config, mgr.config)
|
||||
assert.Equal(t, 0, mgr.retryCount)
|
||||
|
||||
// Should accept nil config
|
||||
mgr = NewRetryManager(nil)
|
||||
assert.Nil(t, mgr.config)
|
||||
assert.Equal(t, 0, mgr.retryCount)
|
||||
}
|
||||
|
||||
func TestRecordRetry(t *testing.T) {
|
||||
mgr := NewRetryManager(nil)
|
||||
assert.Equal(t, mgr.retryCount, 0)
|
||||
|
||||
mgr.RecordRetry()
|
||||
assert.Equal(t, mgr.retryCount, 1)
|
||||
|
||||
mgr.RecordRetry()
|
||||
assert.Equal(t, mgr.retryCount, 2)
|
||||
}
|
||||
|
||||
func TestShouldRetry(t *testing.T) {
|
||||
// never retry if config is nil
|
||||
mgr := NewRetryManager(nil)
|
||||
assert.False(t, mgr.ShouldRetry())
|
||||
|
||||
// always retry if max attempt count is zero
|
||||
mgr = &RetryManager{
|
||||
config: &RetryConfig{
|
||||
MaxRetries: 0,
|
||||
},
|
||||
retryCount: 1000,
|
||||
}
|
||||
assert.True(t, mgr.ShouldRetry())
|
||||
|
||||
// retry if below max attempt count
|
||||
mgr = &RetryManager{
|
||||
config: &RetryConfig{
|
||||
MaxRetries: 10,
|
||||
},
|
||||
retryCount: 5,
|
||||
}
|
||||
assert.True(t, mgr.ShouldRetry())
|
||||
|
||||
// do not retry if above max attempt count
|
||||
mgr = &RetryManager{
|
||||
config: &RetryConfig{
|
||||
MaxRetries: 10,
|
||||
},
|
||||
retryCount: 11,
|
||||
}
|
||||
assert.False(t, mgr.ShouldRetry())
|
||||
}
|
||||
|
||||
func TestCalculateDelayDisabled(t *testing.T) {
|
||||
// default delay if retry is disabled
|
||||
mgr := NewRetryManager(nil)
|
||||
assert.Equal(t, time.Second, mgr.CalculateDelay())
|
||||
}
|
||||
|
||||
func TestCalculateDelayWithoutJitter(t *testing.T) {
|
||||
mgr := NewRetryManager(&RetryConfig{
|
||||
MaxRetries: 0,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 5 * time.Second,
|
||||
JitterFactor: 0.0,
|
||||
})
|
||||
|
||||
// Retry 0: immediate
|
||||
assert.Equal(t, 0*time.Second, mgr.CalculateDelay())
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 1: 1s * 2^0 = 1s
|
||||
assert.Equal(t, 1*time.Second, mgr.CalculateDelay())
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 2: 1s * 2^1 = 2s
|
||||
assert.Equal(t, 2*time.Second, mgr.CalculateDelay())
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 3: 1s * 2^2 = 4s
|
||||
assert.Equal(t, 4*time.Second, mgr.CalculateDelay())
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 4: 1s * 2^3 = 8s, capped at 5s
|
||||
assert.Equal(t, 5*time.Second, mgr.CalculateDelay())
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 5: Still capped at 5s
|
||||
assert.Equal(t, 5*time.Second, mgr.CalculateDelay())
|
||||
}
|
||||
|
||||
func TestCalculateDelayWithJitter(t *testing.T) {
|
||||
mgr := NewRetryManager(&RetryConfig{
|
||||
MaxRetries: 0,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 5 * time.Second,
|
||||
JitterFactor: 0.5,
|
||||
})
|
||||
|
||||
// Retry 0: immediate
|
||||
assert.Equal(t, 0*time.Second, mgr.CalculateDelay())
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 1: 1s * 2^0 = 1s (with jitter)
|
||||
delay := mgr.CalculateDelay()
|
||||
assert.GreaterOrEqual(t, delay, 750*time.Millisecond)
|
||||
assert.LessOrEqual(t, delay, 1250*time.Millisecond)
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 2: 1s * 2^1 = 2s (with jitter)
|
||||
delay = mgr.CalculateDelay()
|
||||
assert.GreaterOrEqual(t, delay, 1500*time.Millisecond)
|
||||
assert.LessOrEqual(t, delay, 2500*time.Millisecond)
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 3: 1s * 2^2 = 4s (with jitter)
|
||||
delay = mgr.CalculateDelay()
|
||||
assert.GreaterOrEqual(t, delay, 3*time.Second)
|
||||
assert.LessOrEqual(t, delay, 5*time.Second)
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 4: 1s * 2^3 = 8s, capped at 5s (with jitter)
|
||||
delay = mgr.CalculateDelay()
|
||||
assert.GreaterOrEqual(t, delay, 3750*time.Millisecond)
|
||||
assert.LessOrEqual(t, delay, 5*time.Second)
|
||||
mgr.RecordRetry()
|
||||
|
||||
// Retry 5: Still capped at 5s (with jitter)
|
||||
delay = mgr.CalculateDelay()
|
||||
assert.GreaterOrEqual(t, delay, 3750*time.Millisecond)
|
||||
assert.LessOrEqual(t, delay, 5*time.Second)
|
||||
}
|
||||
90
transport/socket.go
Normal file
90
transport/socket.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func NewDialer() types.Dialer {
|
||||
return NewGorillaDialer()
|
||||
}
|
||||
|
||||
type GorillaDialer struct {
|
||||
*websocket.Dialer
|
||||
}
|
||||
|
||||
func NewGorillaDialer() *GorillaDialer {
|
||||
return &GorillaDialer{
|
||||
Dialer: &websocket.Dialer{
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the Socket interface
|
||||
func (d *GorillaDialer) Dial(
|
||||
urlStr string, requestHeader http.Header,
|
||||
) (
|
||||
types.Socket, *http.Response, error,
|
||||
) {
|
||||
conn, resp, err := d.Dialer.Dial(urlStr, requestHeader)
|
||||
return conn, resp, err
|
||||
}
|
||||
|
||||
func AcquireSocket(
|
||||
retryMgr *RetryManager,
|
||||
dialer types.Dialer,
|
||||
urlStr string,
|
||||
logger *slog.Logger,
|
||||
) (types.Socket, *http.Response, error) {
|
||||
if retryMgr == nil {
|
||||
return nil, nil, NewConnectionError("retry manager cannot be nil")
|
||||
}
|
||||
if dialer == nil {
|
||||
return nil, nil, NewConnectionError("dialer cannot be nil")
|
||||
}
|
||||
if urlStr == "" {
|
||||
return nil, nil, NewConnectionError("URL cannot be empty")
|
||||
}
|
||||
|
||||
for {
|
||||
if logger != nil {
|
||||
logger.Info("dialing", "attempt", retryMgr.RetryCount()+1)
|
||||
}
|
||||
|
||||
socket, resp, err := dialer.Dial(urlStr, nil)
|
||||
if err == nil {
|
||||
if logger != nil {
|
||||
logger.Info("dial successful", "attempt", retryMgr.RetryCount()+1)
|
||||
}
|
||||
return socket, resp, nil
|
||||
}
|
||||
|
||||
if !retryMgr.ShouldRetry() {
|
||||
if logger != nil {
|
||||
logger.Error("dial failed, max retries reached",
|
||||
"error", err,
|
||||
"attempt", retryMgr.RetryCount()+1)
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
delay := retryMgr.CalculateDelay()
|
||||
|
||||
if logger != nil {
|
||||
logger.Warn("dial failed, retrying",
|
||||
"error", err,
|
||||
"attempt", retryMgr.RetryCount()+1,
|
||||
"next_delay", delay)
|
||||
}
|
||||
|
||||
time.Sleep(delay)
|
||||
retryMgr.RecordRetry()
|
||||
}
|
||||
}
|
||||
145
transport/socket_test.go
Normal file
145
transport/socket_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewDialer(t *testing.T) {
|
||||
dialer := NewDialer()
|
||||
|
||||
assert.NotNil(t, dialer)
|
||||
_, ok := dialer.(*GorillaDialer)
|
||||
assert.True(t, ok, "NewDialer should return *GorillaDialer")
|
||||
}
|
||||
|
||||
func TestNewGorillaDialer(t *testing.T) {
|
||||
dialer := NewGorillaDialer()
|
||||
|
||||
assert.NotNil(t, dialer)
|
||||
assert.NotNil(t, dialer.Dialer)
|
||||
assert.Equal(t, 45*time.Second, dialer.Dialer.HandshakeTimeout)
|
||||
assert.Equal(t, 1024, dialer.Dialer.ReadBufferSize)
|
||||
assert.Equal(t, 1024, dialer.Dialer.WriteBufferSize)
|
||||
}
|
||||
|
||||
func TestAcquireSocket(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
mockRuns []error
|
||||
maxRetries int
|
||||
wantRetryCount int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "immediate success",
|
||||
mockRuns: []error{nil},
|
||||
maxRetries: 3,
|
||||
wantRetryCount: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "two failures, success",
|
||||
mockRuns: []error{errors.New("1"), errors.New("2"), nil},
|
||||
maxRetries: 0,
|
||||
wantRetryCount: 2,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "three failures, failure",
|
||||
mockRuns: []error{errors.New("1"), errors.New("2"), errors.New("3"), errors.New("4")},
|
||||
maxRetries: 3,
|
||||
wantRetryCount: 3,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
attemptIndex := 0
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
err := tc.mockRuns[attemptIndex]
|
||||
attemptIndex++
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
retryMgr := NewRetryManager(&RetryConfig{
|
||||
MaxRetries: tc.maxRetries,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Millisecond,
|
||||
JitterFactor: 0.0,
|
||||
})
|
||||
|
||||
socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test", nil)
|
||||
|
||||
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, socket)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, socket)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireSocketGuards(t *testing.T) {
|
||||
validDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
validRetryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
retryMgr *RetryManager
|
||||
dialer types.Dialer
|
||||
url string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "nil retry manager",
|
||||
retryMgr: nil,
|
||||
dialer: validDialer,
|
||||
url: "ws://test",
|
||||
wantErr: "retry manager cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "nil dialer",
|
||||
retryMgr: validRetryMgr,
|
||||
dialer: nil,
|
||||
url: "ws://test",
|
||||
wantErr: "dialer cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "empty URL",
|
||||
retryMgr: validRetryMgr,
|
||||
dialer: validDialer,
|
||||
url: "",
|
||||
wantErr: "URL cannot be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url, nil)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, tc.wantErr)
|
||||
assert.Nil(t, socket)
|
||||
assert.Nil(t, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
38
transport/url.go
Normal file
38
transport/url.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ParseURL(urlStr string) (*url.URL, error) {
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
|
||||
return nil, InvalidProtocol
|
||||
}
|
||||
|
||||
return parsedURL, nil
|
||||
}
|
||||
|
||||
func NormalizeURL(input string) (string, error) {
|
||||
parsed, err := ParseURL(strings.ToLower(input))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
port := parsed.Port()
|
||||
if (parsed.Scheme == "wss" && port == "443") ||
|
||||
(parsed.Scheme == "ws" && port == "80") {
|
||||
parsed.Host = host
|
||||
}
|
||||
|
||||
parsed.Path = strings.TrimRight(parsed.Path, "/")
|
||||
|
||||
return parsed.String(), nil
|
||||
|
||||
}
|
||||
164
transport/url_test.go
Normal file
164
transport/url_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseURL(t *testing.T) {
|
||||
type wantURL struct {
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
want wantURL
|
||||
wantErr error
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "valid ws url",
|
||||
url: "ws://localhost:8080/relay",
|
||||
want: wantURL{
|
||||
scheme: "ws",
|
||||
host: "localhost:8080",
|
||||
path: "/relay",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid wss url",
|
||||
url: "wss://relay.example.com",
|
||||
want: wantURL{
|
||||
scheme: "wss",
|
||||
host: "relay.example.com",
|
||||
path: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "http scheme rejected",
|
||||
url: "http://example.com",
|
||||
wantErr: InvalidProtocol,
|
||||
},
|
||||
{
|
||||
name: "missing scheme",
|
||||
url: "example.com:8080",
|
||||
wantErr: InvalidProtocol,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
url: "",
|
||||
wantErr: InvalidProtocol,
|
||||
},
|
||||
{
|
||||
name: "malformed url",
|
||||
url: "ws://[::1:8080",
|
||||
wantErrText: "missing ']' in host",
|
||||
},
|
||||
{
|
||||
name: "ipv6 address",
|
||||
url: "ws://[::1]:8080/relay",
|
||||
want: wantURL{
|
||||
scheme: "ws",
|
||||
host: "[::1]:8080",
|
||||
path: "/relay",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := ParseURL(tc.url)
|
||||
|
||||
if tc.wantErr != nil || tc.wantErrText != "" {
|
||||
if tc.wantErr != nil {
|
||||
assert.ErrorIs(t, err, tc.wantErr)
|
||||
}
|
||||
|
||||
if tc.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tc.wantErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.want.scheme, got.Scheme)
|
||||
assert.Equal(t, tc.want.host, got.Host)
|
||||
assert.Equal(t, tc.want.path, got.Path)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "strip trailing slash",
|
||||
input: "wss://relay.example.com/",
|
||||
expected: "wss://relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "strip multiple trailing slashes",
|
||||
input: "wss://relay.example.com//",
|
||||
expected: "wss://relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "strip trailing slash with path",
|
||||
input: "wss://relay.example.com/path/",
|
||||
expected: "wss://relay.example.com/path",
|
||||
},
|
||||
{
|
||||
name: "lowercase scheme",
|
||||
input: "WSS://relay.example.com",
|
||||
expected: "wss://relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "lowercase host",
|
||||
input: "wss://Relay.Example.Com",
|
||||
expected: "wss://relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "strip default wss port",
|
||||
input: "wss://relay.example.com:443",
|
||||
expected: "wss://relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "strip default ws port",
|
||||
input: "ws://relay.example.com:80",
|
||||
expected: "ws://relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "preserve non-default port",
|
||||
input: "wss://relay.example.com:8080",
|
||||
expected: "wss://relay.example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "preserve path",
|
||||
input: "wss://relay.example.com/nostr",
|
||||
expected: "wss://relay.example.com/nostr",
|
||||
},
|
||||
{
|
||||
name: "no change needed",
|
||||
input: "wss://relay.example.com",
|
||||
expected: "wss://relay.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := NormalizeURL(tc.input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeURLError(t *testing.T) {
|
||||
_, err := NormalizeURL("http://relay.example.com")
|
||||
assert.ErrorIs(t, err, InvalidProtocol)
|
||||
}
|
||||
Reference in New Issue
Block a user