Refactored package structure.
This commit is contained in:
422
config.go
422
config.go
@@ -1,422 +0,0 @@
|
||||
package honeybee
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Types
|
||||
|
||||
type CloseHandler func(code int, text string) error
|
||||
type WorkerFactory func(
|
||||
id string,
|
||||
conn *Connection,
|
||||
onReconnect func() (*Connection, error),
|
||||
) Worker
|
||||
|
||||
// Initiator Pool Config
|
||||
|
||||
type InitiatorPoolConfig struct {
|
||||
ConnectionConfig *ConnectionConfig
|
||||
WorkerFactory WorkerFactory
|
||||
WorkerConfig *InitiatorWorkerConfig
|
||||
}
|
||||
|
||||
type InitiatorPoolOption func(*InitiatorPoolConfig) error
|
||||
|
||||
func NewInitiatorPoolConfig(options ...InitiatorPoolOption) (*InitiatorPoolConfig, error) {
|
||||
conf := GetDefaultInitiatorPoolConfig()
|
||||
if err := applyInitiatorPoolOptions(conf, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateInitiatorPoolConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func GetDefaultInitiatorPoolConfig() *InitiatorPoolConfig {
|
||||
return &InitiatorPoolConfig{
|
||||
ConnectionConfig: nil,
|
||||
WorkerFactory: nil,
|
||||
WorkerConfig: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func applyInitiatorPoolOptions(config *InitiatorPoolConfig, options ...InitiatorPoolOption) error {
|
||||
for _, option := range options {
|
||||
if err := option(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateInitiatorPoolConfig(config *InitiatorPoolConfig) error {
|
||||
var err error
|
||||
|
||||
if config.ConnectionConfig != nil {
|
||||
err = validateConnectionConfig(config.ConnectionConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if config.WorkerConfig != nil {
|
||||
err = validateInitiatorWorkerConfig(config.WorkerConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithInitiatorConnectionConfig(cc *ConnectionConfig) InitiatorPoolOption {
|
||||
return func(c *InitiatorPoolConfig) error {
|
||||
err := validateConnectionConfig(cc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ConnectionConfig = cc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithInitiatorWorkerConfig(wc *InitiatorWorkerConfig) InitiatorPoolOption {
|
||||
return func(c *InitiatorPoolConfig) error {
|
||||
err := validateInitiatorWorkerConfig(wc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.WorkerConfig = wc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithInitiatorWorkerFactory(wf WorkerFactory) InitiatorPoolOption {
|
||||
return func(c *InitiatorPoolConfig) error {
|
||||
c.WorkerFactory = wf
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Responder Pool Config
|
||||
|
||||
type ResponderPoolConfig struct {
|
||||
ConnectionConfig *ConnectionConfig
|
||||
WorkerFactory WorkerFactory
|
||||
WorkerConfig *ResponderWorkerConfig
|
||||
}
|
||||
|
||||
// Connection Config
|
||||
|
||||
type ConnectionConfig struct {
|
||||
CloseHandler CloseHandler
|
||||
WriteTimeout time.Duration
|
||||
Retry *RetryConfig
|
||||
}
|
||||
|
||||
type RetryConfig struct {
|
||||
MaxRetries int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
JitterFactor float64
|
||||
}
|
||||
|
||||
type ConnectionOption func(*ConnectionConfig) error
|
||||
|
||||
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
|
||||
conf := GetDefaultConnectionConfig()
|
||||
if err := applyConnectionOptions(conf, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateConnectionConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func GetDefaultConnectionConfig() *ConnectionConfig {
|
||||
return &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
Retry: GetDefaultRetryConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetDefaultRetryConfig() *RetryConfig {
|
||||
return &RetryConfig{
|
||||
MaxRetries: 0, // Infinite retries
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 5 * time.Second,
|
||||
JitterFactor: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error {
|
||||
for _, option := range options {
|
||||
if err := option(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConnectionConfig(config *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(config.WriteTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.Retry != nil {
|
||||
err = validateMaxRetries(config.Retry.MaxRetries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateInitialDelay(config.Retry.InitialDelay)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateMaxDelay(config.Retry.MaxDelay)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateJitterFactor(config.Retry.JitterFactor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.Retry.InitialDelay > config.Retry.MaxDelay {
|
||||
return errors.NewConfigError("initial delay may not exceed maximum delay")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateWriteTimeout(value time.Duration) error {
|
||||
if value < 0 {
|
||||
return errors.InvalidWriteTimeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxRetries(value int) error {
|
||||
if value < 0 {
|
||||
return errors.InvalidRetryMaxRetries
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateInitialDelay(value time.Duration) error {
|
||||
if value <= 0 {
|
||||
return errors.InvalidRetryInitialDelay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxDelay(value time.Duration) error {
|
||||
if value <= 0 {
|
||||
return errors.InvalidRetryMaxDelay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateJitterFactor(value float64) error {
|
||||
if value < 0.0 || value > 1.0 {
|
||||
return errors.InvalidRetryJitterFactor
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.CloseHandler = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// When WriteTimeout is set to zero, read timeouts are disabled.
|
||||
func WithWriteTimeout(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.WriteTimeout = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetry enables retry with default parameters (infinite retries,
|
||||
// 1s initial delay, 5s max delay, 0.5 jitter factor).
|
||||
//
|
||||
// If passed after granular retry options (WithRetryMaxRetries, etc.),
|
||||
// it will overwrite them. Use either WithRetry alone or the granular
|
||||
// options; not both.
|
||||
func WithRetry() ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryMaxRetries(value int) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateMaxRetries(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.MaxRetries = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryInitialDelay(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateInitialDelay(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.InitialDelay = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryMaxDelay(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateMaxDelay(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.MaxDelay = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryJitterFactor(value float64) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateJitterFactor(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.JitterFactor = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Initiator Worker Config
|
||||
|
||||
type InitiatorWorkerConfig struct {
|
||||
IdleTimeout time.Duration
|
||||
MaxQueueSize int
|
||||
}
|
||||
|
||||
type InitiatorWorkerOption func(*InitiatorWorkerConfig) error
|
||||
|
||||
func NewInitiatorWorkerConfig(options ...InitiatorWorkerOption) (*InitiatorWorkerConfig, error) {
|
||||
conf := GetDefaultInitiatorWorkerConfig()
|
||||
if err := applyInitiatorWorkerOptions(conf, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateInitiatorWorkerConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func GetDefaultInitiatorWorkerConfig() *InitiatorWorkerConfig {
|
||||
return &InitiatorWorkerConfig{
|
||||
IdleTimeout: 20 * time.Second,
|
||||
MaxQueueSize: 0, // disabled by default
|
||||
}
|
||||
}
|
||||
|
||||
func applyInitiatorWorkerOptions(config *InitiatorWorkerConfig, options ...InitiatorWorkerOption) error {
|
||||
for _, option := range options {
|
||||
if err := option(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateInitiatorWorkerConfig(config *InitiatorWorkerConfig) error {
|
||||
err := validateIdleTimeout(config.IdleTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateMaxQueueSize(config.MaxQueueSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxQueueSize(value int) error {
|
||||
if value < 0 {
|
||||
return errors.InvalidMaxQueueSize
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateIdleTimeout(value time.Duration) error {
|
||||
if value < 0 {
|
||||
return errors.InvalidIdleTimeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// When IdleTimeout is set to zero, idle timeouts are disabled.
|
||||
func WithIdleTimeout(value time.Duration) InitiatorWorkerOption {
|
||||
return func(c *InitiatorWorkerConfig) error {
|
||||
err := validateIdleTimeout(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.IdleTimeout = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// When MaxQueueSize is set to zero, queue limits are disabled.
|
||||
func WithMaxQueueSize(value int) InitiatorWorkerOption {
|
||||
return func(c *InitiatorWorkerConfig) error {
|
||||
err := validateMaxQueueSize(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.MaxQueueSize = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Responder Worker Config
|
||||
|
||||
type ResponderWorkerConfig struct{}
|
||||
64
honeybeetest/helpers.go
Normal file
64
honeybeetest/helpers.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package honeybeetest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Constants
|
||||
|
||||
const (
|
||||
TestTimeout = 2 * time.Second
|
||||
TestTick = 10 * time.Millisecond
|
||||
NegativeTestTimeout = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// Types
|
||||
|
||||
type MockIncomingData struct {
|
||||
MsgType int
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
type MockOutgoingData struct {
|
||||
MsgType int
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case received := <-incoming:
|
||||
return bytes.Equal(received, expected)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, TestTimeout, TestTick)
|
||||
}
|
||||
|
||||
func ExpectWrite(t *testing.T, outgoingData chan MockOutgoingData, msgType int, expected []byte) {
|
||||
t.Helper()
|
||||
|
||||
var call MockOutgoingData
|
||||
found := assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case received := <-outgoingData:
|
||||
call = received
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, TestTimeout, TestTick)
|
||||
|
||||
if found {
|
||||
|
||||
assert.Equal(t, msgType, call.MsgType)
|
||||
assert.Equal(t, expected, call.Data)
|
||||
}
|
||||
}
|
||||
119
honeybeetest/mocks.go
Normal file
119
honeybeetest/mocks.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package honeybeetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialer Mocks
|
||||
|
||||
type MockDialer struct {
|
||||
DialFunc func(string, http.Header) (types.Socket, *http.Response, error)
|
||||
}
|
||||
|
||||
func (m *MockDialer) Dial(url string, h http.Header) (types.Socket, *http.Response, error) {
|
||||
return m.DialFunc(url, h)
|
||||
}
|
||||
|
||||
// Socket Mocks
|
||||
|
||||
type MockSocket struct {
|
||||
WriteMessageFunc func(int, []byte) error
|
||||
SetReadDeadlineFunc func(t time.Time) error
|
||||
SetWriteDeadlineFunc func(t time.Time) error
|
||||
ReadMessageFunc func() (int, []byte, error)
|
||||
CloseFunc func() error
|
||||
SetCloseHandlerFunc func(func(int, string) error)
|
||||
Closed chan struct{}
|
||||
Once sync.Once
|
||||
Mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewMockSocket() *MockSocket {
|
||||
return &MockSocket{
|
||||
WriteMessageFunc: func(int, []byte) error { return nil },
|
||||
ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil },
|
||||
CloseFunc: func() error { return nil },
|
||||
|
||||
SetReadDeadlineFunc: func(time.Time) error { return nil },
|
||||
SetWriteDeadlineFunc: func(time.Time) error { return nil },
|
||||
SetCloseHandlerFunc: func(func(int, string) error) {},
|
||||
|
||||
Closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (m *MockSocket) WriteMessage(t int, d []byte) error {
|
||||
return m.WriteMessageFunc(t, d)
|
||||
}
|
||||
|
||||
func (m *MockSocket) ReadMessage() (int, []byte, error) {
|
||||
return m.ReadMessageFunc()
|
||||
}
|
||||
|
||||
func (m *MockSocket) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockSocket) SetReadDeadline(t time.Time) error {
|
||||
return m.SetReadDeadlineFunc(t)
|
||||
}
|
||||
|
||||
func (m *MockSocket) SetWriteDeadline(t time.Time) error {
|
||||
return m.SetWriteDeadlineFunc(t)
|
||||
}
|
||||
|
||||
func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) {
|
||||
m.SetCloseHandlerFunc(h)
|
||||
}
|
||||
|
||||
// Logging mocks
|
||||
|
||||
type MockSlogHandler struct {
|
||||
records []slog.Record
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMockSlogHandler() *MockSlogHandler {
|
||||
return &MockSlogHandler{
|
||||
records: make([]slog.Record, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSlogHandler) Handle(ctx context.Context, record slog.Record) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.records = append(m.records, record)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockSlogHandler) WithGroup(name string) slog.Handler {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockSlogHandler) GetRecords() []slog.Record {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]slog.Record, len(m.records))
|
||||
copy(result, m.records)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *MockSlogHandler) Clear() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.records = make([]slog.Record, 0)
|
||||
}
|
||||
189
initiator/config.go
Normal file
189
initiator/config.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package initiator
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Types
|
||||
|
||||
type WorkerFactory func(
|
||||
id string,
|
||||
conn *transport.Connection,
|
||||
onReconnect func() (*transport.Connection, error),
|
||||
) Worker
|
||||
|
||||
// Pool Config
|
||||
|
||||
type PoolConfig struct {
|
||||
ConnectionConfig *transport.ConnectionConfig
|
||||
WorkerFactory WorkerFactory
|
||||
WorkerConfig *WorkerConfig
|
||||
}
|
||||
|
||||
type PoolOption func(*PoolConfig) error
|
||||
|
||||
func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) {
|
||||
conf := GetDefaultPoolConfig()
|
||||
if err := applyPoolOptions(conf, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidatePoolConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func GetDefaultPoolConfig() *PoolConfig {
|
||||
return &PoolConfig{
|
||||
ConnectionConfig: nil,
|
||||
WorkerFactory: nil,
|
||||
WorkerConfig: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func applyPoolOptions(config *PoolConfig, options ...PoolOption) error {
|
||||
for _, option := range options {
|
||||
if err := option(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidatePoolConfig(config *PoolConfig) error {
|
||||
var err error
|
||||
|
||||
if config.ConnectionConfig != nil {
|
||||
err = transport.ValidateConnectionConfig(config.ConnectionConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if config.WorkerConfig != nil {
|
||||
err = ValidateWorkerConfig(config.WorkerConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithConnectionConfig(cc *transport.ConnectionConfig) PoolOption {
|
||||
return func(c *PoolConfig) error {
|
||||
err := transport.ValidateConnectionConfig(cc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ConnectionConfig = cc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithWorkerConfig(wc *WorkerConfig) PoolOption {
|
||||
return func(c *PoolConfig) error {
|
||||
err := ValidateWorkerConfig(wc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.WorkerConfig = wc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithWorkerFactory(wf WorkerFactory) PoolOption {
|
||||
return func(c *PoolConfig) error {
|
||||
c.WorkerFactory = wf
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Worker Config
|
||||
|
||||
type WorkerConfig struct {
|
||||
IdleTimeout time.Duration
|
||||
MaxQueueSize int
|
||||
}
|
||||
|
||||
type WorkerOption func(*WorkerConfig) error
|
||||
|
||||
func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) {
|
||||
conf := GetDefaultWorkerConfig()
|
||||
if err := applyWorkerOptions(conf, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateWorkerConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func GetDefaultWorkerConfig() *WorkerConfig {
|
||||
return &WorkerConfig{
|
||||
IdleTimeout: 20 * time.Second,
|
||||
MaxQueueSize: 0, // disabled by default
|
||||
}
|
||||
}
|
||||
|
||||
func applyWorkerOptions(config *WorkerConfig, options ...WorkerOption) error {
|
||||
for _, option := range options {
|
||||
if err := option(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateWorkerConfig(config *WorkerConfig) error {
|
||||
err := validateIdleTimeout(config.IdleTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateMaxQueueSize(config.MaxQueueSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxQueueSize(value int) error {
|
||||
if value < 0 {
|
||||
return InvalidMaxQueueSize
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateIdleTimeout(value time.Duration) error {
|
||||
if value < 0 {
|
||||
return InvalidIdleTimeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// When IdleTimeout is set to zero, idle timeouts are disabled.
|
||||
func WithIdleTimeout(value time.Duration) WorkerOption {
|
||||
return func(c *WorkerConfig) error {
|
||||
err := validateIdleTimeout(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.IdleTimeout = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// When MaxQueueSize is set to zero, queue limits are disabled.
|
||||
func WithMaxQueueSize(value int) WorkerOption {
|
||||
return func(c *WorkerConfig) error {
|
||||
err := validateMaxQueueSize(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.MaxQueueSize = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,17 @@
|
||||
package honeybee
|
||||
package initiator
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewPoolConfig(t *testing.T) {
|
||||
conf, err := NewInitiatorPoolConfig()
|
||||
conf, err := NewPoolConfig()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, conf, &InitiatorPoolConfig{
|
||||
assert.Equal(t, conf, &PoolConfig{
|
||||
ConnectionConfig: nil,
|
||||
WorkerConfig: nil,
|
||||
WorkerFactory: nil,
|
||||
@@ -18,9 +19,9 @@ func TestNewPoolConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultPoolConfig(t *testing.T) {
|
||||
conf := GetDefaultInitiatorPoolConfig()
|
||||
conf := GetDefaultPoolConfig()
|
||||
|
||||
assert.Equal(t, conf, &InitiatorPoolConfig{
|
||||
assert.Equal(t, conf, &PoolConfig{
|
||||
ConnectionConfig: nil,
|
||||
WorkerConfig: nil,
|
||||
WorkerFactory: nil,
|
||||
@@ -28,10 +29,10 @@ func TestDefaultPoolConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApplyPoolOptions(t *testing.T) {
|
||||
conf := &InitiatorPoolConfig{}
|
||||
err := applyInitiatorPoolOptions(
|
||||
conf := &PoolConfig{}
|
||||
err := applyPoolOptions(
|
||||
conf,
|
||||
WithInitiatorConnectionConfig(&ConnectionConfig{}),
|
||||
WithConnectionConfig(&transport.ConnectionConfig{}),
|
||||
)
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -39,46 +40,46 @@ func TestApplyPoolOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithConnectionConfig(t *testing.T) {
|
||||
conf := &InitiatorPoolConfig{}
|
||||
opt := WithInitiatorConnectionConfig(&ConnectionConfig{WriteTimeout: 1 * time.Second})
|
||||
err := applyInitiatorPoolOptions(conf, opt)
|
||||
conf := &PoolConfig{}
|
||||
opt := WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: 1 * time.Second})
|
||||
err := applyPoolOptions(conf, opt)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, conf.ConnectionConfig)
|
||||
assert.Equal(t, 1*time.Second, conf.ConnectionConfig.WriteTimeout)
|
||||
|
||||
// invalid config is rejected
|
||||
conf = &InitiatorPoolConfig{}
|
||||
opt = WithInitiatorConnectionConfig(&ConnectionConfig{WriteTimeout: -1 * time.Second})
|
||||
err = applyInitiatorPoolOptions(conf, opt)
|
||||
conf = &PoolConfig{}
|
||||
opt = WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: -1 * time.Second})
|
||||
err = applyPoolOptions(conf, opt)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidatePoolConfig(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
conf InitiatorPoolConfig
|
||||
conf PoolConfig
|
||||
wantErr error
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "valid empty",
|
||||
conf: *&InitiatorPoolConfig{},
|
||||
conf: *&PoolConfig{},
|
||||
},
|
||||
{
|
||||
name: "valid defaults",
|
||||
conf: *GetDefaultInitiatorPoolConfig(),
|
||||
conf: *GetDefaultPoolConfig(),
|
||||
},
|
||||
{
|
||||
name: "valid complete",
|
||||
conf: InitiatorPoolConfig{
|
||||
ConnectionConfig: &ConnectionConfig{},
|
||||
conf: PoolConfig{
|
||||
ConnectionConfig: &transport.ConnectionConfig{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid connection config",
|
||||
conf: InitiatorPoolConfig{
|
||||
ConnectionConfig: &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
conf: PoolConfig{
|
||||
ConnectionConfig: &transport.ConnectionConfig{
|
||||
Retry: &transport.RetryConfig{
|
||||
InitialDelay: 10 * time.Second,
|
||||
MaxDelay: 1 * time.Second,
|
||||
},
|
||||
@@ -90,7 +91,7 @@ func TestValidatePoolConfig(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateInitiatorPoolConfig(&tc.conf)
|
||||
err := ValidatePoolConfig(&tc.conf)
|
||||
|
||||
if tc.wantErr != nil || tc.wantErrText != "" {
|
||||
if tc.wantErr != nil {
|
||||
17
initiator/errors.go
Normal file
17
initiator/errors.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package initiator
|
||||
|
||||
import "errors"
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
InvalidIdleTimeout = errors.New("idle timeout cannot be negative")
|
||||
InvalidMaxQueueSize = errors.New("maximum queue size cannot be negative")
|
||||
)
|
||||
|
||||
func NewConfigError(text string) error {
|
||||
return fmt.Errorf("configuration error: %s", text)
|
||||
}
|
||||
|
||||
func NewPoolError(text string) error {
|
||||
return fmt.Errorf("pool error: %s", text)
|
||||
}
|
||||
248
initiator/pool.go
Normal file
248
initiator/pool.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package initiator
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Types
|
||||
|
||||
type peer struct {
|
||||
conn *transport.Connection
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
type InboxMessage struct {
|
||||
ID string
|
||||
Data []byte
|
||||
ReceivedAt time.Time
|
||||
}
|
||||
|
||||
type PoolEventKind int
|
||||
|
||||
const (
|
||||
EventConnected PoolEventKind = iota
|
||||
EventDisconnected
|
||||
)
|
||||
|
||||
func (s PoolEventKind) String() string {
|
||||
switch s {
|
||||
case EventConnected:
|
||||
return "connected"
|
||||
case EventDisconnected:
|
||||
return "disconnected"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type PoolEvent struct {
|
||||
ID string
|
||||
Kind PoolEventKind
|
||||
}
|
||||
|
||||
// Pool
|
||||
|
||||
type Pool struct {
|
||||
peers map[string]*peer
|
||||
inbox chan InboxMessage
|
||||
events chan PoolEvent
|
||||
errors chan error
|
||||
done chan struct{}
|
||||
|
||||
dialer types.Dialer
|
||||
config *PoolConfig
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewPool(config *PoolConfig, logger *slog.Logger) (*Pool, error) {
|
||||
if config == nil {
|
||||
config = GetDefaultPoolConfig()
|
||||
}
|
||||
|
||||
if err := ValidatePoolConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Pool{
|
||||
peers: make(map[string]*peer),
|
||||
inbox: make(chan InboxMessage, 256),
|
||||
events: make(chan PoolEvent, 10),
|
||||
errors: make(chan error, 10),
|
||||
done: make(chan struct{}),
|
||||
dialer: transport.NewDialer(),
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *Pool) Peers() map[string]*peer {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *Pool) Inbox() chan InboxMessage {
|
||||
return p.inbox
|
||||
}
|
||||
|
||||
func (p *Pool) Events() chan PoolEvent {
|
||||
return p.events
|
||||
}
|
||||
|
||||
func (p *Pool) Errors() chan error {
|
||||
return p.errors
|
||||
}
|
||||
|
||||
func (p *Pool) Close() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
p.closed = true
|
||||
close(p.done)
|
||||
|
||||
peers := p.peers
|
||||
p.peers = make(map[string]*peer)
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
for _, conn := range peers {
|
||||
conn.conn.Close()
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(p.inbox)
|
||||
close(p.events)
|
||||
close(p.errors)
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Pool) Connect(id string) error {
|
||||
id, err := transport.NormalizeURL(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for existing connection in pool
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return NewPoolError("pool is closed")
|
||||
}
|
||||
_, exists := p.peers[id]
|
||||
p.mu.Unlock()
|
||||
|
||||
if exists {
|
||||
return NewPoolError("connection already exists")
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
var logger *slog.Logger
|
||||
if p.logger != nil {
|
||||
logger = p.logger.With("id", id)
|
||||
}
|
||||
conn, err := transport.NewConnection(id, p.config.ConnectionConfig, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.SetDialer(p.dialer)
|
||||
|
||||
// Attempt to connect
|
||||
if err := conn.Connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
// The pool closed while this connection was established.
|
||||
p.mu.Unlock()
|
||||
conn.Close()
|
||||
return NewPoolError("pool is closed")
|
||||
}
|
||||
|
||||
// Add connection to pool
|
||||
stop := make(chan struct{})
|
||||
if _, exists := p.peers[id]; exists {
|
||||
// Another process connected to this url while this one was connecting
|
||||
// Discard this connection and retain the existing one
|
||||
p.mu.Unlock()
|
||||
conn.Close()
|
||||
return NewPoolError("connection already exists")
|
||||
}
|
||||
p.peers[id] = &peer{conn: conn, stop: stop}
|
||||
p.mu.Unlock()
|
||||
|
||||
// TODO: start this connection's incoming message forwarder
|
||||
|
||||
select {
|
||||
case p.events <- PoolEvent{ID: id, Kind: EventConnected}:
|
||||
case <-p.done:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pool) Remove(id string) error {
|
||||
id, err := transport.NormalizeURL(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return NewPoolError("pool is closed")
|
||||
}
|
||||
|
||||
peer, exists := p.peers[id]
|
||||
if !exists {
|
||||
p.mu.Unlock()
|
||||
return NewPoolError("connection not found")
|
||||
}
|
||||
delete(p.peers, id)
|
||||
p.mu.Unlock()
|
||||
|
||||
close(peer.stop)
|
||||
peer.conn.Close()
|
||||
|
||||
select {
|
||||
case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}:
|
||||
case <-p.done:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pool) Send(id string, data []byte) error {
|
||||
id, err := transport.NormalizeURL(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.closed {
|
||||
return NewPoolError("pool is closed")
|
||||
}
|
||||
|
||||
peer, exists := p.peers[id]
|
||||
if !exists {
|
||||
return NewPoolError("connection not found")
|
||||
}
|
||||
|
||||
return peer.conn.Send(data)
|
||||
}
|
||||
@@ -1,7 +1,10 @@
|
||||
package honeybee
|
||||
package initiator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
@@ -11,14 +14,14 @@ import (
|
||||
|
||||
func TestPoolConnect(t *testing.T) {
|
||||
t.Run("successfully adds connection", func(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
pool, err := NewInitiatorPool(nil, nil)
|
||||
pool, err := NewPool(nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pool.dialer = mockDialer
|
||||
@@ -33,7 +36,7 @@ func TestPoolConnect(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
_, exists := pool.peers["wss://test"]
|
||||
assert.True(t, exists)
|
||||
@@ -42,14 +45,14 @@ func TestPoolConnect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("does not add duplicate", func(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
pool, err := NewInitiatorPool(nil, nil)
|
||||
pool, err := NewPool(nil, nil)
|
||||
assert.NoError(t, err)
|
||||
pool.dialer = mockDialer
|
||||
|
||||
@@ -69,18 +72,18 @@ func TestPoolConnect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("fails to add connection", func(t *testing.T) {
|
||||
pool, err := NewInitiatorPool(
|
||||
&InitiatorPoolConfig{
|
||||
ConnectionConfig: &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
pool, err := NewPool(
|
||||
&PoolConfig{
|
||||
ConnectionConfig: &transport.ConnectionConfig{
|
||||
Retry: &transport.RetryConfig{
|
||||
MaxRetries: 1,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Millisecond,
|
||||
}},
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
pool.dialer = &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
pool.dialer = &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return nil, nil, fmt.Errorf("dial failed")
|
||||
},
|
||||
}
|
||||
@@ -104,14 +107,14 @@ func TestPoolConnect(t *testing.T) {
|
||||
|
||||
func TestPoolRemove(t *testing.T) {
|
||||
t.Run("removes known url", func(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
pool, err := NewInitiatorPool(nil, nil)
|
||||
pool, err := NewPool(nil, nil)
|
||||
assert.NoError(t, err)
|
||||
pool.dialer = mockDialer
|
||||
|
||||
@@ -132,14 +135,14 @@ func TestPoolRemove(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("unknown url returns error", func(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
pool, err := NewInitiatorPool(nil, nil)
|
||||
pool, err := NewPool(nil, nil)
|
||||
assert.NoError(t, err)
|
||||
pool.dialer = mockDialer
|
||||
|
||||
@@ -149,14 +152,14 @@ func TestPoolRemove(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("closed pool returns error", func(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
pool, err := NewInitiatorPool(nil, nil)
|
||||
pool, err := NewPool(nil, nil)
|
||||
assert.NoError(t, err)
|
||||
pool.dialer = mockDialer
|
||||
|
||||
@@ -171,19 +174,19 @@ func TestPoolRemove(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPoolSend(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
outgoingData := make(chan mockOutgoingData, 10)
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
outgoingData <- mockOutgoingData{msgType: msgType, data: data}
|
||||
outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}
|
||||
return nil
|
||||
}
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
pool, err := NewInitiatorPool(nil, nil)
|
||||
pool, err := NewPool(nil, nil)
|
||||
assert.NoError(t, err)
|
||||
pool.dialer = mockDialer
|
||||
|
||||
@@ -194,7 +197,7 @@ func TestPoolSend(t *testing.T) {
|
||||
err = pool.Send("wss://test", []byte("hello"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectWrite(t, outgoingData, websocket.TextMessage, []byte("hello"))
|
||||
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, []byte("hello"))
|
||||
|
||||
pool.Close()
|
||||
}
|
||||
@@ -213,7 +216,7 @@ func expectEvent(
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick,
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
|
||||
fmt.Sprintf("expected event: URL=%q, Kind=%q",
|
||||
expectedURL, expectedKind.String()))
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package honeybee
|
||||
package initiator
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/transport"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -8,15 +9,6 @@ import (
|
||||
|
||||
// Types
|
||||
|
||||
// Worker Implementation
|
||||
|
||||
type Worker interface {
|
||||
Start(
|
||||
ctx *WorkerContext,
|
||||
wg *sync.WaitGroup,
|
||||
)
|
||||
}
|
||||
|
||||
type WorkerContext struct {
|
||||
Inbox chan<- InboxMessage
|
||||
Events chan<- PoolEvent
|
||||
@@ -26,40 +18,23 @@ type WorkerContext struct {
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// Base Struct
|
||||
// Worker
|
||||
|
||||
type worker struct {
|
||||
id string
|
||||
type Worker struct {
|
||||
id string
|
||||
config *WorkerConfig
|
||||
onReconnect func() (*transport.Connection, error)
|
||||
}
|
||||
|
||||
func (w *worker) runForwarder(
|
||||
messages <-chan []byte,
|
||||
inbox chan<- []byte,
|
||||
stop <-chan struct{},
|
||||
poolDone <-chan struct{},
|
||||
maxQueueSize int,
|
||||
) {
|
||||
}
|
||||
|
||||
// Initiator Worker
|
||||
|
||||
type InitiatorWorker struct {
|
||||
*worker
|
||||
config *InitiatorWorkerConfig
|
||||
onReconnect func() (*Connection, error)
|
||||
}
|
||||
|
||||
func newInitiatorWorker(
|
||||
func NewWorker(
|
||||
id string,
|
||||
config *InitiatorWorkerConfig,
|
||||
onReconnect func() (*Connection, error),
|
||||
config *WorkerConfig,
|
||||
onReconnect func() (*transport.Connection, error),
|
||||
logger *slog.Logger,
|
||||
|
||||
) (*InitiatorWorker, error) {
|
||||
w := &InitiatorWorker{
|
||||
worker: &worker{
|
||||
id: id,
|
||||
},
|
||||
) (*Worker, error) {
|
||||
w := &Worker{
|
||||
id: id,
|
||||
config: config,
|
||||
onReconnect: onReconnect,
|
||||
}
|
||||
@@ -67,7 +42,7 @@ func newInitiatorWorker(
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *InitiatorWorker) Start(
|
||||
func (w *Worker) Start(
|
||||
inbox chan<- InboxMessage,
|
||||
events chan<- PoolEvent,
|
||||
stop <-chan struct{},
|
||||
@@ -76,32 +51,37 @@ func (w *InitiatorWorker) Start(
|
||||
) {
|
||||
}
|
||||
|
||||
func runReader(conn *Connection,
|
||||
func (w *Worker) runReader(conn *transport.Connection,
|
||||
messages chan<- []byte,
|
||||
heartbeat chan<- time.Time,
|
||||
reconnect chan<- struct{},
|
||||
newConn <-chan *Connection,
|
||||
newConn <-chan *transport.Connection,
|
||||
stop <-chan struct{},
|
||||
poolDone <-chan struct{},
|
||||
|
||||
) {
|
||||
}
|
||||
|
||||
func runHealthMonitor(
|
||||
func (w *Worker) runForwarder(
|
||||
messages <-chan []byte,
|
||||
inbox chan<- []byte,
|
||||
stop <-chan struct{},
|
||||
poolDone <-chan struct{},
|
||||
maxQueueSize int,
|
||||
) {
|
||||
}
|
||||
|
||||
func (w *Worker) runHealthMonitor(
|
||||
heartbeat <-chan time.Time,
|
||||
stop <-chan struct{},
|
||||
poolDone <-chan struct{},
|
||||
) {
|
||||
}
|
||||
|
||||
func runReconnector(
|
||||
func (w *Worker) runReconnector(
|
||||
reconnect <-chan struct{},
|
||||
newConn chan<- *Connection,
|
||||
newConn chan<- *transport.Connection,
|
||||
stop <-chan struct{},
|
||||
poolDone <-chan struct{},
|
||||
) {
|
||||
}
|
||||
|
||||
// Responder Worker
|
||||
|
||||
type ResponderWorker struct{}
|
||||
1
initiator/worker_test.go
Normal file
1
initiator/worker_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package initiator
|
||||
192
mocks_test.go
192
mocks_test.go
@@ -1,192 +0,0 @@
|
||||
package honeybee
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test Constants
|
||||
|
||||
const (
|
||||
testTimeout = 2 * time.Second
|
||||
testTick = 10 * time.Millisecond
|
||||
negativeTestTimeout = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// Dialer Mocks
|
||||
|
||||
type MockDialer struct {
|
||||
DialFunc func(string, http.Header) (Socket, *http.Response, error)
|
||||
}
|
||||
|
||||
func (m *MockDialer) Dial(url string, h http.Header) (Socket, *http.Response, error) {
|
||||
return m.DialFunc(url, h)
|
||||
}
|
||||
|
||||
// Socket Mocks
|
||||
|
||||
type MockSocket struct {
|
||||
WriteMessageFunc func(int, []byte) error
|
||||
SetReadDeadlineFunc func(t time.Time) error
|
||||
SetWriteDeadlineFunc func(t time.Time) error
|
||||
ReadMessageFunc func() (int, []byte, error)
|
||||
CloseFunc func() error
|
||||
SetCloseHandlerFunc func(func(int, string) error)
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewMockSocket() *MockSocket {
|
||||
return &MockSocket{
|
||||
WriteMessageFunc: func(int, []byte) error { return nil },
|
||||
ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil },
|
||||
CloseFunc: func() error { return nil },
|
||||
|
||||
SetReadDeadlineFunc: func(time.Time) error { return nil },
|
||||
SetWriteDeadlineFunc: func(time.Time) error { return nil },
|
||||
SetCloseHandlerFunc: func(func(int, string) error) {},
|
||||
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (m *MockSocket) WriteMessage(t int, d []byte) error {
|
||||
return m.WriteMessageFunc(t, d)
|
||||
}
|
||||
|
||||
func (m *MockSocket) ReadMessage() (int, []byte, error) {
|
||||
return m.ReadMessageFunc()
|
||||
}
|
||||
|
||||
func (m *MockSocket) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockSocket) SetReadDeadline(t time.Time) error {
|
||||
return m.SetReadDeadlineFunc(t)
|
||||
}
|
||||
|
||||
func (m *MockSocket) SetWriteDeadline(t time.Time) error {
|
||||
return m.SetWriteDeadlineFunc(t)
|
||||
}
|
||||
|
||||
func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) {
|
||||
m.SetCloseHandlerFunc(h)
|
||||
}
|
||||
|
||||
// Connection Mocks
|
||||
|
||||
type mockIncomingData struct {
|
||||
msgType int
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
type mockOutgoingData struct {
|
||||
msgType int
|
||||
data []byte
|
||||
}
|
||||
|
||||
func setupTestConnection(t *testing.T, config *ConnectionConfig) (
|
||||
conn *Connection,
|
||||
mockSocket *MockSocket,
|
||||
incomingData chan mockIncomingData,
|
||||
outgoingData chan mockOutgoingData,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
incomingData = make(chan mockIncomingData, 10)
|
||||
outgoingData = make(chan mockOutgoingData, 10)
|
||||
|
||||
mockSocket = NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.once.Do(func() {
|
||||
close(mockSocket.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wire ReadMessage to pull from incomingData channel
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
select {
|
||||
case data := <-incomingData:
|
||||
return data.msgType, data.data, data.err
|
||||
case <-mockSocket.closed:
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Wire WriteMessage to push to outgoingData channel
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||
return nil
|
||||
case <-mockSocket.closed:
|
||||
return io.EOF
|
||||
default:
|
||||
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
conn, err = NewConnectionFromSocket(mockSocket, config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return conn, mockSocket, incomingData, outgoingData
|
||||
}
|
||||
|
||||
// Logging mocks
|
||||
|
||||
type mockSlogHandler struct {
|
||||
records []slog.Record
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockSlogHandler() *mockSlogHandler {
|
||||
return &mockSlogHandler{
|
||||
records: make([]slog.Record, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockSlogHandler) Handle(ctx context.Context, record slog.Record) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.records = append(m.records, record)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSlogHandler) WithGroup(name string) slog.Handler {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSlogHandler) GetRecords() []slog.Record {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]slog.Record, len(m.records))
|
||||
copy(result, m.records)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockSlogHandler) Clear() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.records = make([]slog.Record, 0)
|
||||
}
|
||||
304
pool.go
304
pool.go
@@ -1,304 +0,0 @@
|
||||
package honeybee
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Types
|
||||
|
||||
type peer struct {
|
||||
conn *Connection
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
type InboxMessage struct {
|
||||
ID string
|
||||
Data []byte
|
||||
ReceivedAt time.Time
|
||||
}
|
||||
|
||||
type PoolEventKind int
|
||||
|
||||
const (
|
||||
EventConnected PoolEventKind = iota
|
||||
EventDisconnected
|
||||
)
|
||||
|
||||
func (s PoolEventKind) String() string {
|
||||
switch s {
|
||||
case EventConnected:
|
||||
return "connected"
|
||||
case EventDisconnected:
|
||||
return "disconnected"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type PoolEvent struct {
|
||||
ID string
|
||||
Kind PoolEventKind
|
||||
}
|
||||
|
||||
// Pool Implementation
|
||||
|
||||
type Pool interface {
|
||||
Send(id string, data []byte) error
|
||||
Inbox() <-chan InboxMessage
|
||||
Events() <-chan PoolEvent
|
||||
Errors() <-chan error
|
||||
Close()
|
||||
}
|
||||
|
||||
// Base Struct
|
||||
|
||||
type pool struct {
|
||||
peers map[string]*peer
|
||||
inbox chan InboxMessage
|
||||
events chan PoolEvent
|
||||
errors chan error
|
||||
done chan struct{}
|
||||
|
||||
config *InitiatorPoolConfig
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (p *pool) closeAll() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
p.closed = true
|
||||
close(p.done)
|
||||
|
||||
peers := p.peers
|
||||
p.peers = make(map[string]*peer)
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
for _, conn := range peers {
|
||||
conn.conn.Close()
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(p.inbox)
|
||||
close(p.events)
|
||||
close(p.errors)
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *pool) removePeer(id string) error {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return errors.NewPoolError("pool is closed")
|
||||
}
|
||||
|
||||
peer, exists := p.peers[id]
|
||||
if !exists {
|
||||
p.mu.Unlock()
|
||||
return errors.NewPoolError("connection not found")
|
||||
}
|
||||
delete(p.peers, id)
|
||||
p.mu.Unlock()
|
||||
|
||||
close(peer.stop)
|
||||
peer.conn.Close()
|
||||
|
||||
select {
|
||||
case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}:
|
||||
case <-p.done:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pool) send(id string, data []byte) error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.closed {
|
||||
return errors.NewPoolError("pool is closed")
|
||||
}
|
||||
|
||||
peer, exists := p.peers[id]
|
||||
if !exists {
|
||||
return errors.NewPoolError("connection not found")
|
||||
}
|
||||
|
||||
return peer.conn.Send(data)
|
||||
}
|
||||
|
||||
// Initiator Pool
|
||||
|
||||
type InitiatorPool struct {
|
||||
*pool
|
||||
dialer Dialer
|
||||
}
|
||||
|
||||
func NewInitiatorPool(config *InitiatorPoolConfig, logger *slog.Logger) (*InitiatorPool, error) {
|
||||
if config == nil {
|
||||
config = GetDefaultInitiatorPoolConfig()
|
||||
}
|
||||
|
||||
if err := validateInitiatorPoolConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &InitiatorPool{
|
||||
pool: &pool{
|
||||
peers: make(map[string]*peer),
|
||||
inbox: make(chan InboxMessage, 256),
|
||||
events: make(chan PoolEvent, 10),
|
||||
errors: make(chan error, 10),
|
||||
done: make(chan struct{}),
|
||||
config: config,
|
||||
logger: logger,
|
||||
},
|
||||
dialer: NewDialer(),
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Peers() map[string]*peer {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Inbox() chan InboxMessage {
|
||||
return p.inbox
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Events() chan PoolEvent {
|
||||
return p.events
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Errors() chan error {
|
||||
return p.errors
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Close() {
|
||||
p.closeAll()
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Connect(url string) error {
|
||||
url, err := NormalizeURL(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for existing connection in pool
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return errors.NewPoolError("pool is closed")
|
||||
}
|
||||
_, exists := p.peers[url]
|
||||
p.mu.Unlock()
|
||||
|
||||
if exists {
|
||||
return errors.NewPoolError("connection already exists")
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
var logger *slog.Logger
|
||||
if p.logger != nil {
|
||||
logger = p.logger.With("url", url)
|
||||
}
|
||||
conn, err := NewConnection(url, p.config.ConnectionConfig, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.dialer = p.dialer
|
||||
|
||||
// Attempt to connect
|
||||
if err := conn.Connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
// The pool closed while this connection was established.
|
||||
p.mu.Unlock()
|
||||
conn.Close()
|
||||
return errors.NewPoolError("pool is closed")
|
||||
}
|
||||
|
||||
// Add connection to pool
|
||||
stop := make(chan struct{})
|
||||
if _, exists := p.peers[url]; exists {
|
||||
// Another process connected to this url while this one was connecting
|
||||
// Discard this connection and retain the existing one
|
||||
p.mu.Unlock()
|
||||
conn.Close()
|
||||
return errors.NewPoolError("connection already exists")
|
||||
}
|
||||
p.peers[url] = &peer{conn: conn, stop: stop}
|
||||
p.mu.Unlock()
|
||||
|
||||
// TODO: start this connection's incoming message forwarder
|
||||
|
||||
select {
|
||||
case p.events <- PoolEvent{ID: url, Kind: EventConnected}:
|
||||
case <-p.done:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Remove(url string) error {
|
||||
url, err := NormalizeURL(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.removePeer(url)
|
||||
}
|
||||
|
||||
func (p *InitiatorPool) Send(url string, data []byte) error {
|
||||
url, err := NormalizeURL(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.send(url, data)
|
||||
}
|
||||
|
||||
// Responder Pool
|
||||
|
||||
type ResponderPool struct {
|
||||
*pool
|
||||
idGenerator func() string
|
||||
}
|
||||
|
||||
func (p *ResponderPool) Peers() map[string]*peer {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *ResponderPool) Inbox() chan InboxMessage {
|
||||
return p.inbox
|
||||
}
|
||||
|
||||
func (p *ResponderPool) Events() chan PoolEvent {
|
||||
return p.events
|
||||
}
|
||||
|
||||
func (p *ResponderPool) Errors() chan error {
|
||||
return p.errors
|
||||
}
|
||||
|
||||
func (p *ResponderPool) Close() {
|
||||
p.closeAll()
|
||||
}
|
||||
225
transport/config.go
Normal file
225
transport/config.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type CloseHandler func(code int, text string) error
|
||||
|
||||
type ConnectionConfig struct {
|
||||
CloseHandler CloseHandler
|
||||
WriteTimeout time.Duration
|
||||
Retry *RetryConfig
|
||||
}
|
||||
|
||||
type RetryConfig struct {
|
||||
MaxRetries int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
JitterFactor float64
|
||||
}
|
||||
|
||||
type ConnectionOption func(*ConnectionConfig) error
|
||||
|
||||
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
|
||||
conf := GetDefaultConnectionConfig()
|
||||
if err := applyConnectionOptions(conf, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateConnectionConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func GetDefaultConnectionConfig() *ConnectionConfig {
|
||||
return &ConnectionConfig{
|
||||
CloseHandler: nil,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
Retry: GetDefaultRetryConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetDefaultRetryConfig() *RetryConfig {
|
||||
return &RetryConfig{
|
||||
MaxRetries: 0, // Infinite retries
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 5 * time.Second,
|
||||
JitterFactor: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error {
|
||||
for _, option := range options {
|
||||
if err := option(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateConnectionConfig(config *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(config.WriteTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.Retry != nil {
|
||||
err = validateMaxRetries(config.Retry.MaxRetries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateInitialDelay(config.Retry.InitialDelay)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateMaxDelay(config.Retry.MaxDelay)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateJitterFactor(config.Retry.JitterFactor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.Retry.InitialDelay > config.Retry.MaxDelay {
|
||||
return NewConfigError("initial delay may not exceed maximum delay")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateWriteTimeout(value time.Duration) error {
|
||||
if value < 0 {
|
||||
return InvalidWriteTimeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxRetries(value int) error {
|
||||
if value < 0 {
|
||||
return InvalidRetryMaxRetries
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateInitialDelay(value time.Duration) error {
|
||||
if value <= 0 {
|
||||
return InvalidRetryInitialDelay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMaxDelay(value time.Duration) error {
|
||||
if value <= 0 {
|
||||
return InvalidRetryMaxDelay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateJitterFactor(value float64) error {
|
||||
if value < 0.0 || value > 1.0 {
|
||||
return InvalidRetryJitterFactor
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithCloseHandler(handler CloseHandler) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.CloseHandler = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// When WriteTimeout is set to zero, read timeouts are disabled.
|
||||
func WithWriteTimeout(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
err := validateWriteTimeout(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.WriteTimeout = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetry enables retry with default parameters (infinite retries,
|
||||
// 1s initial delay, 5s max delay, 0.5 jitter factor).
|
||||
//
|
||||
// If passed after granular retry options (WithRetryMaxRetries, etc.),
|
||||
// it will overwrite them. Use either WithRetry alone or the granular
|
||||
// options; not both.
|
||||
func WithRetry() ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryMaxRetries(value int) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateMaxRetries(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.MaxRetries = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryInitialDelay(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateInitialDelay(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.InitialDelay = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryMaxDelay(value time.Duration) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateMaxDelay(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.MaxDelay = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetryJitterFactor(value float64) ConnectionOption {
|
||||
return func(c *ConnectionConfig) error {
|
||||
if c.Retry == nil {
|
||||
c.Retry = GetDefaultRetryConfig()
|
||||
}
|
||||
|
||||
err := validateJitterFactor(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Retry.JitterFactor = value
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -72,7 +71,7 @@ func TestApplyConnectionOptions(t *testing.T) {
|
||||
WithRetryMaxRetries(-10),
|
||||
)
|
||||
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries)
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxRetries)
|
||||
}
|
||||
|
||||
// Option Tests
|
||||
@@ -103,7 +102,7 @@ func TestWithWriteTimeout(t *testing.T) {
|
||||
conf = &ConnectionConfig{}
|
||||
opt = WithWriteTimeout(-30)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidWriteTimeout)
|
||||
assert.ErrorIs(t, err, InvalidWriteTimeout)
|
||||
assert.ErrorContains(t, err, "write timeout cannot be negative")
|
||||
}
|
||||
|
||||
@@ -132,7 +131,7 @@ func TestWithRetry(t *testing.T) {
|
||||
// negative disallowed
|
||||
opt = WithRetryMaxRetries(-10)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries)
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxRetries)
|
||||
assert.ErrorContains(t, err, "max retry count cannot be negative")
|
||||
})
|
||||
|
||||
@@ -146,13 +145,13 @@ func TestWithRetry(t *testing.T) {
|
||||
// zero disallowed
|
||||
opt = WithRetryInitialDelay(0 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay)
|
||||
assert.ErrorIs(t, err, InvalidRetryInitialDelay)
|
||||
assert.ErrorContains(t, err, "initial delay must be positive")
|
||||
|
||||
// negative disallowed
|
||||
opt = WithRetryInitialDelay(-10 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay)
|
||||
assert.ErrorIs(t, err, InvalidRetryInitialDelay)
|
||||
})
|
||||
|
||||
t.Run("with max delay", func(t *testing.T) {
|
||||
@@ -165,13 +164,13 @@ func TestWithRetry(t *testing.T) {
|
||||
// zero disallowed
|
||||
opt = WithRetryMaxDelay(0 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay)
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxDelay)
|
||||
assert.ErrorContains(t, err, "max delay must be positive")
|
||||
|
||||
// negative disallowed
|
||||
opt = WithRetryMaxDelay(-10 * time.Second)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay)
|
||||
assert.ErrorIs(t, err, InvalidRetryMaxDelay)
|
||||
})
|
||||
|
||||
t.Run("with jitter factor", func(t *testing.T) {
|
||||
@@ -185,13 +184,13 @@ func TestWithRetry(t *testing.T) {
|
||||
// negative disallowed
|
||||
opt = WithRetryJitterFactor(-1)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor)
|
||||
assert.ErrorIs(t, err, InvalidRetryJitterFactor)
|
||||
assert.ErrorContains(t, err, "jitter factor must be between 0.0 and 1.0")
|
||||
|
||||
// >1 disallowed
|
||||
opt = WithRetryJitterFactor(1.1)
|
||||
err = applyConnectionOptions(conf, opt)
|
||||
assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor)
|
||||
assert.ErrorIs(t, err, InvalidRetryJitterFactor)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -239,7 +238,7 @@ func TestValidateConnectionConfig(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateConnectionConfig(&tc.conf)
|
||||
err := ValidateConnectionConfig(&tc.conf)
|
||||
|
||||
if tc.wantErr != nil || tc.wantErrText != "" {
|
||||
if tc.wantErr != nil {
|
||||
@@ -1,14 +1,14 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
@@ -38,8 +38,8 @@ func (s ConnectionState) String() string {
|
||||
|
||||
type Connection struct {
|
||||
url *url.URL
|
||||
dialer Dialer
|
||||
socket Socket
|
||||
dialer types.Dialer
|
||||
socket types.Socket
|
||||
config *ConnectionConfig
|
||||
logger *slog.Logger
|
||||
|
||||
@@ -60,7 +60,7 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger)
|
||||
config = GetDefaultConnectionConfig()
|
||||
}
|
||||
|
||||
if err := validateConnectionConfig(config); err != nil {
|
||||
if err := ValidateConnectionConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -85,16 +85,16 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func NewConnectionFromSocket(socket Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
||||
func NewConnectionFromSocket(socket types.Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
|
||||
if socket == nil {
|
||||
return nil, errors.NewConnectionError("socket cannot be nil")
|
||||
return nil, NewConnectionError("socket cannot be nil")
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = GetDefaultConnectionConfig()
|
||||
}
|
||||
|
||||
if err := validateConnectionConfig(config); err != nil {
|
||||
if err := ValidateConnectionConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -126,11 +126,11 @@ func (c *Connection) Connect() error {
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.socket != nil {
|
||||
return errors.NewConnectionError("connection already has socket")
|
||||
return NewConnectionError("connection already has socket")
|
||||
}
|
||||
|
||||
if c.closed {
|
||||
return errors.NewConnectionError("connection is closed")
|
||||
return NewConnectionError("connection is closed")
|
||||
}
|
||||
|
||||
if c.logger != nil {
|
||||
@@ -177,7 +177,7 @@ func (c *Connection) startReader() {
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
var closeErr *websocket.CloseError
|
||||
if stderrors.As(err, &closeErr) {
|
||||
if errors.As(err, &closeErr) {
|
||||
switch closeErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
|
||||
c.logger.Info("connection closed by peer",
|
||||
@@ -263,16 +263,16 @@ func (c *Connection) Send(data []byte) error {
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.closed {
|
||||
return errors.NewConnectionError("connection closed")
|
||||
return NewConnectionError("connection closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case c.outgoing <- data:
|
||||
return nil
|
||||
case <-c.done:
|
||||
return errors.NewConnectionError("connection closing")
|
||||
return NewConnectionError("connection closing")
|
||||
default:
|
||||
return errors.NewConnectionError("outgoing queue full")
|
||||
return NewConnectionError("outgoing queue full")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -337,3 +337,7 @@ func (c *Connection) State() ConnectionState {
|
||||
defer c.mu.RUnlock()
|
||||
return c.state
|
||||
}
|
||||
|
||||
func (c *Connection) SetDialer(d types.Dialer) {
|
||||
c.dialer = d
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
@@ -38,7 +39,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
|
||||
t.Run("socket close error does not propagate", func(t *testing.T) {
|
||||
expectedErr := fmt.Errorf("socket close failed")
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.CloseFunc = func() error {
|
||||
return expectedErr
|
||||
}
|
||||
@@ -64,7 +65,8 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick, "errors channel should close")
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
|
||||
"errors channel should close")
|
||||
})
|
||||
|
||||
t.Run("send fails after close", func(t *testing.T) {
|
||||
@@ -86,7 +88,8 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
|
||||
// Send a message to ensure reader loop is blocking
|
||||
canary := []byte("canary")
|
||||
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: canary}
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage, Data: canary}
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
@@ -95,7 +98,7 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
conn.Close()
|
||||
assert.Equal(t, StateClosed, conn.State())
|
||||
@@ -119,9 +122,9 @@ func TestConnectedConnectionClose(t *testing.T) {
|
||||
conn, _, incomingData, _ := setupTestConnection(t, nil)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
incomingData <- mockIncomingData{
|
||||
msgType: websocket.TextMessage,
|
||||
data: []byte(fmt.Sprintf("in-%d", i)),
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage,
|
||||
Data: []byte(fmt.Sprintf("in-%d", i)),
|
||||
}
|
||||
conn.Send([]byte(fmt.Sprintf("out-%d", i)))
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
@@ -17,13 +17,13 @@ func TestStartReader(t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
testData := []byte("hello")
|
||||
incomingData <- mockIncomingData{
|
||||
msgType: websocket.TextMessage,
|
||||
data: testData,
|
||||
err: nil,
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage,
|
||||
Data: testData,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
expectIncoming(t, conn, testData)
|
||||
honeybeetest.ExpectIncoming(t, conn.Incoming(), testData)
|
||||
})
|
||||
|
||||
t.Run("binary messages route to incoming channel", func(t *testing.T) {
|
||||
@@ -31,13 +31,13 @@ func TestStartReader(t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
testData := []byte{0x00, 0x01, 0x02}
|
||||
incomingData <- mockIncomingData{
|
||||
msgType: websocket.BinaryMessage,
|
||||
data: testData,
|
||||
err: nil,
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.BinaryMessage,
|
||||
Data: testData,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
expectIncoming(t, conn, testData)
|
||||
honeybeetest.ExpectIncoming(t, conn.Incoming(), testData)
|
||||
})
|
||||
|
||||
t.Run("multiple messages processed sequentially", func(t *testing.T) {
|
||||
@@ -46,20 +46,21 @@ func TestStartReader(t *testing.T) {
|
||||
|
||||
messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")}
|
||||
for _, msg := range messages {
|
||||
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: msg, err: nil}
|
||||
incomingData <- honeybeetest.MockIncomingData{
|
||||
MsgType: websocket.TextMessage, Data: msg, Err: nil}
|
||||
}
|
||||
|
||||
for _, expected := range messages {
|
||||
expectIncoming(t, conn, expected)
|
||||
honeybeetest.ExpectIncoming(t, conn.Incoming(), expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reader exits on socket read error", func(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.once.Do(func() {
|
||||
close(mockSocket.closed)
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@@ -80,11 +81,11 @@ func TestStartReader(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return conn.State() == StateClosed
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,7 +98,7 @@ func TestStartWriter(t *testing.T) {
|
||||
err := conn.Send(testData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectWrite(t, outgoingData, websocket.TextMessage, testData)
|
||||
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, testData)
|
||||
})
|
||||
|
||||
t.Run("multiple messages processed sequentially", func(t *testing.T) {
|
||||
@@ -111,7 +112,7 @@ func TestStartWriter(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, expected := range messages {
|
||||
expectWrite(t, outgoingData, websocket.TextMessage, expected)
|
||||
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, expected)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -122,12 +123,12 @@ func TestStartWriter(t *testing.T) {
|
||||
|
||||
config := &ConnectionConfig{WriteTimeout: 0}
|
||||
|
||||
outgoingData := make(chan mockOutgoingData, 10)
|
||||
mockSocket := NewMockSocket()
|
||||
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.once.Do(func() {
|
||||
close(mockSocket.closed)
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@@ -140,8 +141,9 @@ func TestStartWriter(t *testing.T) {
|
||||
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||
case <-mockSocket.closed:
|
||||
case outgoingData <- honeybeetest.MockOutgoingData{
|
||||
MsgType: msgType, Data: data}:
|
||||
case <-mockSocket.Closed:
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
@@ -161,19 +163,19 @@ func TestStartWriter(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, negativeTestTimeout, testTick,
|
||||
}, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick,
|
||||
"SetWriteDeadline should not be called when timeout is zero")
|
||||
})
|
||||
|
||||
t.Run("write timeout sets deadline when positive", func(t *testing.T) {
|
||||
config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond}
|
||||
|
||||
outgoingData := make(chan mockOutgoingData, 10)
|
||||
mockSocket := NewMockSocket()
|
||||
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.once.Do(func() {
|
||||
close(mockSocket.closed)
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@@ -186,8 +188,9 @@ func TestStartWriter(t *testing.T) {
|
||||
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
|
||||
case <-mockSocket.closed:
|
||||
case outgoingData <- honeybeetest.MockOutgoingData{
|
||||
MsgType: msgType, Data: data}:
|
||||
case <-mockSocket.Closed:
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
@@ -207,18 +210,18 @@ func TestStartWriter(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick,
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
|
||||
"SetWriteDeadline should be called when timeout is positive")
|
||||
})
|
||||
|
||||
t.Run("writer exits on deadline error", func(t *testing.T) {
|
||||
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond}
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.once.Do(func() {
|
||||
close(mockSocket.closed)
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@@ -242,15 +245,15 @@ func TestStartWriter(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return conn.State() == StateClosed
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
|
||||
t.Run("writer exits on socket write error", func(t *testing.T) {
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
|
||||
writeErr := fmt.Errorf("write failed")
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
@@ -271,45 +274,12 @@ func TestStartWriter(t *testing.T) {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return conn.State() == StateClosed
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func expectIncoming(t *testing.T, conn *Connection, expected []byte) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case received := <-conn.Incoming():
|
||||
return bytes.Equal(received, expected)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
}
|
||||
|
||||
func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) {
|
||||
t.Helper()
|
||||
|
||||
var call mockOutgoingData
|
||||
found := assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case received := <-outgoingData:
|
||||
call = received
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
|
||||
if found {
|
||||
|
||||
assert.Equal(t, msgType, call.msgType)
|
||||
assert.Equal(t, expected, call.data)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,9 +1,12 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -36,7 +39,7 @@ func TestConnectionState(t *testing.T) {
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
|
||||
// Test state after FromSocket (should be Connected)
|
||||
conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil, nil)
|
||||
conn2, _ := NewConnectionFromSocket(honeybeetest.NewMockSocket(), nil, nil)
|
||||
assert.Equal(t, StateConnected, conn2.State())
|
||||
|
||||
// Test state after close
|
||||
@@ -126,7 +129,7 @@ func TestNewConnection(t *testing.T) {
|
||||
func TestNewConnectionFromSocket(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
socket Socket
|
||||
socket types.Socket
|
||||
config *ConnectionConfig
|
||||
wantErr bool
|
||||
wantErrText string
|
||||
@@ -140,17 +143,17 @@ func TestNewConnectionFromSocket(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "valid socket with nil config",
|
||||
socket: NewMockSocket(),
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "valid socket with valid config",
|
||||
socket: NewMockSocket(),
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: &ConnectionConfig{WriteTimeout: 30 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
socket: NewMockSocket(),
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: &ConnectionConfig{
|
||||
Retry: &RetryConfig{
|
||||
InitialDelay: 10 * time.Second,
|
||||
@@ -162,7 +165,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "close handler set when provided",
|
||||
socket: NewMockSocket(),
|
||||
socket: honeybeetest.NewMockSocket(),
|
||||
config: &ConnectionConfig{
|
||||
CloseHandler: func(code int, text string) error {
|
||||
return nil
|
||||
@@ -176,7 +179,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
|
||||
// track if SetCloseHandler was called
|
||||
closeHandlerSet := false
|
||||
if tc.socket != nil {
|
||||
mockSocket := tc.socket.(*MockSocket)
|
||||
mockSocket := tc.socket.(*honeybeetest.MockSocket)
|
||||
originalSetCloseHandler := mockSocket.SetCloseHandlerFunc
|
||||
|
||||
// wrapper around the original handler function
|
||||
@@ -234,7 +237,7 @@ func TestConnect(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn.socket = NewMockSocket()
|
||||
conn.socket = honeybeetest.NewMockSocket()
|
||||
|
||||
err = conn.Connect()
|
||||
assert.Error(t, err)
|
||||
@@ -258,16 +261,16 @@ func TestConnect(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
outgoingData := make(chan mockOutgoingData, 10)
|
||||
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
outgoingData <- mockOutgoingData{msgType: msgType, data: data}
|
||||
outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}
|
||||
return nil
|
||||
}
|
||||
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
@@ -283,11 +286,11 @@ func TestConnect(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case msg := <-outgoingData:
|
||||
return bytes.Equal(msg.data, testData)
|
||||
return bytes.Equal(msg.Data, testData)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
conn.Close()
|
||||
})
|
||||
@@ -305,13 +308,13 @@ func TestConnect(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
attemptCount := 0
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
attemptCount++
|
||||
if attemptCount < 3 {
|
||||
return nil, nil, fmt.Errorf("dial failed")
|
||||
}
|
||||
return NewMockSocket(), nil, nil
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
@@ -336,8 +339,8 @@ func TestConnect(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return nil, nil, fmt.Errorf("dial failed")
|
||||
},
|
||||
}
|
||||
@@ -355,10 +358,10 @@ func TestConnect(t *testing.T) {
|
||||
assert.Equal(t, StateDisconnected, conn.State())
|
||||
|
||||
stateDuringDial := StateDisconnected
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
stateDuringDial = conn.state
|
||||
return NewMockSocket(), nil, nil
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
@@ -381,13 +384,13 @@ func TestConnect(t *testing.T) {
|
||||
conn, err := NewConnection("ws://test", config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
|
||||
handlerSet = true
|
||||
}
|
||||
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
@@ -431,4 +434,53 @@ func TestConnectionErrors(t *testing.T) {
|
||||
assert.Equal(t, testErr, received)
|
||||
}
|
||||
|
||||
// Connect() tests
|
||||
// Test helpers
|
||||
|
||||
func setupTestConnection(t *testing.T, config *ConnectionConfig) (
|
||||
conn *Connection,
|
||||
mockSocket *honeybeetest.MockSocket,
|
||||
incomingData chan honeybeetest.MockIncomingData,
|
||||
outgoingData chan honeybeetest.MockOutgoingData,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
incomingData = make(chan honeybeetest.MockIncomingData, 10)
|
||||
outgoingData = make(chan honeybeetest.MockOutgoingData, 10)
|
||||
|
||||
mockSocket = honeybeetest.NewMockSocket()
|
||||
|
||||
mockSocket.CloseFunc = func() error {
|
||||
mockSocket.Once.Do(func() {
|
||||
close(mockSocket.Closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wire ReadMessage to pull from incomingData channel
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
select {
|
||||
case data := <-incomingData:
|
||||
return data.MsgType, data.Data, data.Err
|
||||
case <-mockSocket.Closed:
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Wire WriteMessage to push to outgoingData channel
|
||||
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
|
||||
select {
|
||||
case outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}:
|
||||
return nil
|
||||
case <-mockSocket.Closed:
|
||||
return io.EOF
|
||||
default:
|
||||
return fmt.Errorf("mock outgoing chanel unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
conn, err = NewConnectionFromSocket(mockSocket, config, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return conn, mockSocket, incomingData, outgoingData
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package errors
|
||||
package transport
|
||||
|
||||
import "errors"
|
||||
import "fmt"
|
||||
@@ -8,13 +8,11 @@ var (
|
||||
InvalidProtocol = errors.New("URL must use ws:// or wss:// scheme")
|
||||
|
||||
// Configuration Errors
|
||||
InvalidIdleTimeout = errors.New("idle timeout cannot be negative")
|
||||
InvalidWriteTimeout = errors.New("write timeout cannot be negative")
|
||||
InvalidRetryMaxRetries = errors.New("max retry count cannot be negative")
|
||||
InvalidRetryInitialDelay = errors.New("initial delay must be positive")
|
||||
InvalidRetryMaxDelay = errors.New("max delay must be positive")
|
||||
InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0")
|
||||
InvalidMaxQueueSize = errors.New("maximum queue size cannot be negative")
|
||||
)
|
||||
|
||||
func NewConfigError(text string) error {
|
||||
@@ -24,7 +22,3 @@ func NewConfigError(text string) error {
|
||||
func NewConnectionError(text string) error {
|
||||
return fmt.Errorf("connection error: %s", text)
|
||||
}
|
||||
|
||||
func NewPoolError(text string) error {
|
||||
return fmt.Errorf("pool error: %s", text)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -136,15 +138,15 @@ func toInt64(v any) (int64, bool) {
|
||||
|
||||
func TestConnectLogging(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
conn, err := NewConnection("ws://test", nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
@@ -167,7 +169,7 @@ func TestConnectLogging(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("max retries failure", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
config := &ConnectionConfig{
|
||||
@@ -183,8 +185,8 @@ func TestConnectLogging(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
dialErr := fmt.Errorf("dial error")
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return nil, nil, dialErr
|
||||
},
|
||||
}
|
||||
@@ -210,7 +212,7 @@ func TestConnectLogging(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("success after retry", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
config := &ConnectionConfig{
|
||||
@@ -227,13 +229,13 @@ func TestConnectLogging(t *testing.T) {
|
||||
|
||||
attemptCount := 0
|
||||
dialErr := fmt.Errorf("dial error")
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
attemptCount++
|
||||
if attemptCount < 3 {
|
||||
return nil, nil, dialErr
|
||||
}
|
||||
return NewMockSocket(), nil, nil
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
conn.dialer = mockDialer
|
||||
@@ -261,10 +263,10 @@ func TestConnectLogging(t *testing.T) {
|
||||
|
||||
func TestCloseLogging(t *testing.T) {
|
||||
t.Run("normal close", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -273,7 +275,7 @@ func TestCloseLogging(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
@@ -286,11 +288,11 @@ func TestCloseLogging(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("close with socket error", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
closeErr := fmt.Errorf("close error")
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.CloseFunc = func() error {
|
||||
return closeErr
|
||||
}
|
||||
@@ -303,7 +305,7 @@ func TestCloseLogging(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "socket close failed") != nil
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
@@ -318,10 +320,10 @@ func TestCloseLogging(t *testing.T) {
|
||||
|
||||
func TestReaderLogging(t *testing.T) {
|
||||
t.Run("clean close by peer", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, &websocket.CloseError{
|
||||
Code: websocket.CloseNormalClosure,
|
||||
@@ -336,7 +338,7 @@ func TestReaderLogging(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") != nil
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
record := findLogRecord(mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer")
|
||||
assert.NotNil(t, record)
|
||||
@@ -346,10 +348,10 @@ func TestReaderLogging(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("unexpected close", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, &websocket.CloseError{
|
||||
Code: websocket.CloseProtocolError,
|
||||
@@ -364,7 +366,7 @@ func TestReaderLogging(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "unexpected close") != nil
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
record := findLogRecord(mockHandler.GetRecords(), slog.LevelError, "unexpected close")
|
||||
assert.NotNil(t, record)
|
||||
@@ -374,10 +376,10 @@ func TestReaderLogging(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("read error", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
@@ -389,19 +391,19 @@ func TestReaderLogging(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "read error") != nil
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriterLogging(t *testing.T) {
|
||||
t.Run("write deadline error", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond}
|
||||
|
||||
deadlineErr := fmt.Errorf("deadline error")
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.SetWriteDeadlineFunc = func(time.Time) error {
|
||||
return deadlineErr
|
||||
}
|
||||
@@ -415,7 +417,7 @@ func TestWriterLogging(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "write deadline error") != nil
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
@@ -427,11 +429,11 @@ func TestWriterLogging(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("write message error", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
logger := slog.New(mockHandler)
|
||||
|
||||
writeErr := fmt.Errorf("write error")
|
||||
mockSocket := NewMockSocket()
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockSocket.WriteMessageFunc = func(int, []byte) error {
|
||||
return writeErr
|
||||
}
|
||||
@@ -445,7 +447,7 @@ func TestWriterLogging(t *testing.T) {
|
||||
assert.Eventually(t, func() bool {
|
||||
return findLogRecord(
|
||||
mockHandler.GetRecords(), slog.LevelError, "write error") != nil
|
||||
}, testTimeout, testTick)
|
||||
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
|
||||
|
||||
records := mockHandler.GetRecords()
|
||||
|
||||
@@ -459,14 +461,14 @@ func TestWriterLogging(t *testing.T) {
|
||||
|
||||
func TestLoggingDisabled(t *testing.T) {
|
||||
t.Run("nil logger produces no logs", func(t *testing.T) {
|
||||
mockHandler := newMockSlogHandler()
|
||||
mockHandler := honeybeetest.NewMockSlogHandler()
|
||||
|
||||
conn, err := NewConnection("ws://test", nil, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockSocket := NewMockSocket()
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockSocket := honeybeetest.NewMockSocket()
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return mockSocket, nil, nil
|
||||
},
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"math"
|
||||
@@ -1,4 +1,4 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -1,19 +1,15 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Dialer interface {
|
||||
Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error)
|
||||
}
|
||||
|
||||
func NewDialer() Dialer {
|
||||
func NewDialer() types.Dialer {
|
||||
return NewGorillaDialer()
|
||||
}
|
||||
|
||||
@@ -35,36 +31,26 @@ func NewGorillaDialer() *GorillaDialer {
|
||||
func (d *GorillaDialer) Dial(
|
||||
urlStr string, requestHeader http.Header,
|
||||
) (
|
||||
Socket, *http.Response, error,
|
||||
types.Socket, *http.Response, error,
|
||||
) {
|
||||
conn, resp, err := d.Dialer.Dial(urlStr, requestHeader)
|
||||
return conn, resp, err
|
||||
}
|
||||
|
||||
type Socket interface {
|
||||
WriteMessage(messageType int, data []byte) error
|
||||
ReadMessage() (messageType int, p []byte, err error)
|
||||
Close() error
|
||||
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
SetCloseHandler(h func(code int, text string) error)
|
||||
}
|
||||
|
||||
func AcquireSocket(
|
||||
retryMgr *RetryManager,
|
||||
dialer Dialer,
|
||||
dialer types.Dialer,
|
||||
urlStr string,
|
||||
logger *slog.Logger,
|
||||
) (Socket, *http.Response, error) {
|
||||
) (types.Socket, *http.Response, error) {
|
||||
if retryMgr == nil {
|
||||
return nil, nil, errors.NewConnectionError("retry manager cannot be nil")
|
||||
return nil, nil, NewConnectionError("retry manager cannot be nil")
|
||||
}
|
||||
if dialer == nil {
|
||||
return nil, nil, errors.NewConnectionError("dialer cannot be nil")
|
||||
return nil, nil, NewConnectionError("dialer cannot be nil")
|
||||
}
|
||||
if urlStr == "" {
|
||||
return nil, nil, errors.NewConnectionError("URL cannot be empty")
|
||||
return nil, nil, NewConnectionError("URL cannot be empty")
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -1,7 +1,9 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
|
||||
"git.wisehodl.dev/jay/go-honeybee/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
@@ -60,14 +62,14 @@ func TestAcquireSocket(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
attemptIndex := 0
|
||||
mockDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
mockDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
err := tc.mockRuns[attemptIndex]
|
||||
attemptIndex++
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return NewMockSocket(), nil, nil
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -93,9 +95,9 @@ func TestAcquireSocket(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAcquireSocketGuards(t *testing.T) {
|
||||
validDialer := &MockDialer{
|
||||
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||
return NewMockSocket(), nil, nil
|
||||
validDialer := &honeybeetest.MockDialer{
|
||||
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
|
||||
return honeybeetest.NewMockSocket(), nil, nil
|
||||
},
|
||||
}
|
||||
validRetryMgr := NewRetryManager(GetDefaultRetryConfig())
|
||||
@@ -103,7 +105,7 @@ func TestAcquireSocketGuards(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
retryMgr *RetryManager
|
||||
dialer Dialer
|
||||
dialer types.Dialer
|
||||
url string
|
||||
wantErr string
|
||||
}{
|
||||
@@ -1,10 +1,8 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||
)
|
||||
|
||||
func ParseURL(urlStr string) (*url.URL, error) {
|
||||
@@ -14,7 +12,7 @@ func ParseURL(urlStr string) (*url.URL, error) {
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
|
||||
return nil, errors.InvalidProtocol
|
||||
return nil, InvalidProtocol
|
||||
}
|
||||
|
||||
return parsedURL, nil
|
||||
@@ -1,7 +1,6 @@
|
||||
package honeybee
|
||||
package transport
|
||||
|
||||
import (
|
||||
"git.wisehodl.dev/jay/go-honeybee/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
@@ -41,17 +40,17 @@ func TestParseURL(t *testing.T) {
|
||||
{
|
||||
name: "http scheme rejected",
|
||||
url: "http://example.com",
|
||||
wantErr: errors.InvalidProtocol,
|
||||
wantErr: InvalidProtocol,
|
||||
},
|
||||
{
|
||||
name: "missing scheme",
|
||||
url: "example.com:8080",
|
||||
wantErr: errors.InvalidProtocol,
|
||||
wantErr: InvalidProtocol,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
url: "",
|
||||
wantErr: errors.InvalidProtocol,
|
||||
wantErr: InvalidProtocol,
|
||||
},
|
||||
{
|
||||
name: "malformed url",
|
||||
@@ -161,5 +160,5 @@ func TestNormalizeURL(t *testing.T) {
|
||||
|
||||
func TestNormalizeURLError(t *testing.T) {
|
||||
_, err := NormalizeURL("http://relay.example.com")
|
||||
assert.ErrorIs(t, err, errors.InvalidProtocol)
|
||||
assert.ErrorIs(t, err, InvalidProtocol)
|
||||
}
|
||||
20
types/types.go
Normal file
20
types/types.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Dialer interface {
|
||||
Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error)
|
||||
}
|
||||
|
||||
type Socket interface {
|
||||
WriteMessage(messageType int, data []byte) error
|
||||
ReadMessage() (messageType int, p []byte, err error)
|
||||
Close() error
|
||||
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
SetCloseHandler(h func(code int, text string) error)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package honeybee
|
||||
|
||||
import (
|
||||
// "github.com/stretchr/testify/assert"
|
||||
// "testing"
|
||||
// "time"
|
||||
)
|
||||
Reference in New Issue
Block a user