From 3af3696d86f7ac1b0a4bd41339b4611bfa35ad3a Mon Sep 17 00:00:00 2001 From: Jay Date: Fri, 17 Apr 2026 14:53:29 -0400 Subject: [PATCH] Refactored package structure. --- config.go | 422 ------------------ honeybeetest/helpers.go | 64 +++ honeybeetest/mocks.go | 119 +++++ initiator/config.go | 189 ++++++++ .../config_pool_test.go | 47 +- initiator/errors.go | 17 + initiator/pool.go | 248 ++++++++++ pool_test.go => initiator/pool_test.go | 75 ++-- worker.go => initiator/worker.go | 76 ++-- initiator/worker_test.go | 1 + mocks_test.go | 192 -------- pool.go | 304 ------------- transport/config.go | 225 ++++++++++ .../config_test.go | 23 +- connection.go => transport/connection.go | 34 +- .../connection_close_test.go | 19 +- .../connection_goroutine_test.go | 122 ++--- .../connection_send_test.go | 2 +- .../connection_test.go | 108 +++-- {errors => transport}/errors.go | 8 +- logging_test.go => transport/logging_test.go | 76 ++-- retry.go => transport/retry.go | 2 +- retry_test.go => transport/retry_test.go | 2 +- socket.go => transport/socket.go | 32 +- socket_test.go => transport/socket_test.go | 18 +- url.go => transport/url.go | 6 +- url_test.go => transport/url_test.go | 11 +- types/types.go | 20 + worker_test.go | 7 - 29 files changed, 1210 insertions(+), 1259 deletions(-) delete mode 100644 config.go create mode 100644 honeybeetest/helpers.go create mode 100644 honeybeetest/mocks.go create mode 100644 initiator/config.go rename config_pool_test.go => initiator/config_pool_test.go (60%) create mode 100644 initiator/errors.go create mode 100644 initiator/pool.go rename pool_test.go => initiator/pool_test.go (65%) rename worker.go => initiator/worker.go (54%) create mode 100644 initiator/worker_test.go delete mode 100644 mocks_test.go delete mode 100644 pool.go create mode 100644 transport/config.go rename config_connection_test.go => transport/config_test.go (90%) rename connection.go => transport/connection.go (87%) rename connection_close_test.go => transport/connection_close_test.go (85%) rename connection_goroutine_test.go => transport/connection_goroutine_test.go (70%) rename connection_send_test.go => transport/connection_send_test.go (99%) rename connection_test.go => transport/connection_test.go (76%) rename {errors => transport}/errors.go (73%) rename logging_test.go => transport/logging_test.go (84%) rename retry.go => transport/retry.go (98%) rename retry_test.go => transport/retry_test.go (99%) rename socket.go => transport/socket.go (65%) rename socket_test.go => transport/socket_test.go (86%) rename url.go => transport/url.go (86%) rename url_test.go => transport/url_test.go (93%) create mode 100644 types/types.go delete mode 100644 worker_test.go diff --git a/config.go b/config.go deleted file mode 100644 index 37a2650..0000000 --- a/config.go +++ /dev/null @@ -1,422 +0,0 @@ -package honeybee - -import ( - "git.wisehodl.dev/jay/go-honeybee/errors" - "time" -) - -// Types - -type CloseHandler func(code int, text string) error -type WorkerFactory func( - id string, - conn *Connection, - onReconnect func() (*Connection, error), -) Worker - -// Initiator Pool Config - -type InitiatorPoolConfig struct { - ConnectionConfig *ConnectionConfig - WorkerFactory WorkerFactory - WorkerConfig *InitiatorWorkerConfig -} - -type InitiatorPoolOption func(*InitiatorPoolConfig) error - -func NewInitiatorPoolConfig(options ...InitiatorPoolOption) (*InitiatorPoolConfig, error) { - conf := GetDefaultInitiatorPoolConfig() - if err := applyInitiatorPoolOptions(conf, options...); err != nil { - return nil, err - } - if err := validateInitiatorPoolConfig(conf); err != nil { - return nil, err - } - return conf, nil -} - -func GetDefaultInitiatorPoolConfig() *InitiatorPoolConfig { - return &InitiatorPoolConfig{ - ConnectionConfig: nil, - WorkerFactory: nil, - WorkerConfig: nil, - } -} - -func applyInitiatorPoolOptions(config *InitiatorPoolConfig, options ...InitiatorPoolOption) error { - for _, option := range options { - if err := option(config); err != nil { - return err - } - } - return nil -} - -func validateInitiatorPoolConfig(config *InitiatorPoolConfig) error { - var err error - - if config.ConnectionConfig != nil { - err = validateConnectionConfig(config.ConnectionConfig) - if err != nil { - return err - } - } - - if config.WorkerConfig != nil { - err = validateInitiatorWorkerConfig(config.WorkerConfig) - if err != nil { - return err - } - } - - return nil -} - -func WithInitiatorConnectionConfig(cc *ConnectionConfig) InitiatorPoolOption { - return func(c *InitiatorPoolConfig) error { - err := validateConnectionConfig(cc) - if err != nil { - return err - } - c.ConnectionConfig = cc - return nil - } -} - -func WithInitiatorWorkerConfig(wc *InitiatorWorkerConfig) InitiatorPoolOption { - return func(c *InitiatorPoolConfig) error { - err := validateInitiatorWorkerConfig(wc) - if err != nil { - return err - } - c.WorkerConfig = wc - return nil - } -} - -func WithInitiatorWorkerFactory(wf WorkerFactory) InitiatorPoolOption { - return func(c *InitiatorPoolConfig) error { - c.WorkerFactory = wf - return nil - } -} - -// Responder Pool Config - -type ResponderPoolConfig struct { - ConnectionConfig *ConnectionConfig - WorkerFactory WorkerFactory - WorkerConfig *ResponderWorkerConfig -} - -// Connection Config - -type ConnectionConfig struct { - CloseHandler CloseHandler - WriteTimeout time.Duration - Retry *RetryConfig -} - -type RetryConfig struct { - MaxRetries int - InitialDelay time.Duration - MaxDelay time.Duration - JitterFactor float64 -} - -type ConnectionOption func(*ConnectionConfig) error - -func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) { - conf := GetDefaultConnectionConfig() - if err := applyConnectionOptions(conf, options...); err != nil { - return nil, err - } - if err := validateConnectionConfig(conf); err != nil { - return nil, err - } - return conf, nil -} - -func GetDefaultConnectionConfig() *ConnectionConfig { - return &ConnectionConfig{ - CloseHandler: nil, - WriteTimeout: 30 * time.Second, - Retry: GetDefaultRetryConfig(), - } -} - -func GetDefaultRetryConfig() *RetryConfig { - return &RetryConfig{ - MaxRetries: 0, // Infinite retries - InitialDelay: 1 * time.Second, - MaxDelay: 5 * time.Second, - JitterFactor: 0.5, - } -} - -func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error { - for _, option := range options { - if err := option(config); err != nil { - return err - } - } - return nil -} - -func validateConnectionConfig(config *ConnectionConfig) error { - err := validateWriteTimeout(config.WriteTimeout) - if err != nil { - return err - } - - if config.Retry != nil { - err = validateMaxRetries(config.Retry.MaxRetries) - if err != nil { - return err - } - - err = validateInitialDelay(config.Retry.InitialDelay) - if err != nil { - return err - } - - err = validateMaxDelay(config.Retry.MaxDelay) - if err != nil { - return err - } - - err = validateJitterFactor(config.Retry.JitterFactor) - if err != nil { - return err - } - - if config.Retry.InitialDelay > config.Retry.MaxDelay { - return errors.NewConfigError("initial delay may not exceed maximum delay") - } - } - - return nil -} - -func validateWriteTimeout(value time.Duration) error { - if value < 0 { - return errors.InvalidWriteTimeout - } - return nil -} - -func validateMaxRetries(value int) error { - if value < 0 { - return errors.InvalidRetryMaxRetries - } - return nil -} - -func validateInitialDelay(value time.Duration) error { - if value <= 0 { - return errors.InvalidRetryInitialDelay - } - return nil -} - -func validateMaxDelay(value time.Duration) error { - if value <= 0 { - return errors.InvalidRetryMaxDelay - } - return nil -} - -func validateJitterFactor(value float64) error { - if value < 0.0 || value > 1.0 { - return errors.InvalidRetryJitterFactor - } - return nil -} - -func WithCloseHandler(handler CloseHandler) ConnectionOption { - return func(c *ConnectionConfig) error { - c.CloseHandler = handler - return nil - } -} - -// When WriteTimeout is set to zero, read timeouts are disabled. -func WithWriteTimeout(value time.Duration) ConnectionOption { - return func(c *ConnectionConfig) error { - err := validateWriteTimeout(value) - if err != nil { - return err - } - c.WriteTimeout = value - return nil - } -} - -// WithRetry enables retry with default parameters (infinite retries, -// 1s initial delay, 5s max delay, 0.5 jitter factor). -// -// If passed after granular retry options (WithRetryMaxRetries, etc.), -// it will overwrite them. Use either WithRetry alone or the granular -// options; not both. -func WithRetry() ConnectionOption { - return func(c *ConnectionConfig) error { - c.Retry = GetDefaultRetryConfig() - return nil - } -} - -func WithRetryMaxRetries(value int) ConnectionOption { - return func(c *ConnectionConfig) error { - if c.Retry == nil { - c.Retry = GetDefaultRetryConfig() - } - - err := validateMaxRetries(value) - if err != nil { - return err - } - - c.Retry.MaxRetries = value - return nil - } -} - -func WithRetryInitialDelay(value time.Duration) ConnectionOption { - return func(c *ConnectionConfig) error { - if c.Retry == nil { - c.Retry = GetDefaultRetryConfig() - } - - err := validateInitialDelay(value) - if err != nil { - return err - } - - c.Retry.InitialDelay = value - return nil - } -} - -func WithRetryMaxDelay(value time.Duration) ConnectionOption { - return func(c *ConnectionConfig) error { - if c.Retry == nil { - c.Retry = GetDefaultRetryConfig() - } - - err := validateMaxDelay(value) - if err != nil { - return err - } - - c.Retry.MaxDelay = value - return nil - } -} - -func WithRetryJitterFactor(value float64) ConnectionOption { - return func(c *ConnectionConfig) error { - if c.Retry == nil { - c.Retry = GetDefaultRetryConfig() - } - - err := validateJitterFactor(value) - if err != nil { - return err - } - - c.Retry.JitterFactor = value - return nil - } -} - -// Initiator Worker Config - -type InitiatorWorkerConfig struct { - IdleTimeout time.Duration - MaxQueueSize int -} - -type InitiatorWorkerOption func(*InitiatorWorkerConfig) error - -func NewInitiatorWorkerConfig(options ...InitiatorWorkerOption) (*InitiatorWorkerConfig, error) { - conf := GetDefaultInitiatorWorkerConfig() - if err := applyInitiatorWorkerOptions(conf, options...); err != nil { - return nil, err - } - if err := validateInitiatorWorkerConfig(conf); err != nil { - return nil, err - } - return conf, nil -} - -func GetDefaultInitiatorWorkerConfig() *InitiatorWorkerConfig { - return &InitiatorWorkerConfig{ - IdleTimeout: 20 * time.Second, - MaxQueueSize: 0, // disabled by default - } -} - -func applyInitiatorWorkerOptions(config *InitiatorWorkerConfig, options ...InitiatorWorkerOption) error { - for _, option := range options { - if err := option(config); err != nil { - return err - } - } - return nil -} - -func validateInitiatorWorkerConfig(config *InitiatorWorkerConfig) error { - err := validateIdleTimeout(config.IdleTimeout) - if err != nil { - return err - } - - err = validateMaxQueueSize(config.MaxQueueSize) - if err != nil { - return err - } - - return nil -} - -func validateMaxQueueSize(value int) error { - if value < 0 { - return errors.InvalidMaxQueueSize - } - return nil -} - -func validateIdleTimeout(value time.Duration) error { - if value < 0 { - return errors.InvalidIdleTimeout - } - return nil -} - -// When IdleTimeout is set to zero, idle timeouts are disabled. -func WithIdleTimeout(value time.Duration) InitiatorWorkerOption { - return func(c *InitiatorWorkerConfig) error { - err := validateIdleTimeout(value) - if err != nil { - return err - } - c.IdleTimeout = value - return nil - } -} - -// When MaxQueueSize is set to zero, queue limits are disabled. -func WithMaxQueueSize(value int) InitiatorWorkerOption { - return func(c *InitiatorWorkerConfig) error { - err := validateMaxQueueSize(value) - if err != nil { - return err - } - c.MaxQueueSize = value - return nil - } -} - -// Responder Worker Config - -type ResponderWorkerConfig struct{} diff --git a/honeybeetest/helpers.go b/honeybeetest/helpers.go new file mode 100644 index 0000000..4caa402 --- /dev/null +++ b/honeybeetest/helpers.go @@ -0,0 +1,64 @@ +package honeybeetest + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +// Constants + +const ( + TestTimeout = 2 * time.Second + TestTick = 10 * time.Millisecond + NegativeTestTimeout = 100 * time.Millisecond +) + +// Types + +type MockIncomingData struct { + MsgType int + Data []byte + Err error +} + +type MockOutgoingData struct { + MsgType int + Data []byte +} + +// Helpers + +func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) { + t.Helper() + assert.Eventually(t, func() bool { + select { + case received := <-incoming: + return bytes.Equal(received, expected) + default: + return false + } + }, TestTimeout, TestTick) +} + +func ExpectWrite(t *testing.T, outgoingData chan MockOutgoingData, msgType int, expected []byte) { + t.Helper() + + var call MockOutgoingData + found := assert.Eventually(t, func() bool { + select { + case received := <-outgoingData: + call = received + return true + default: + return false + } + }, TestTimeout, TestTick) + + if found { + + assert.Equal(t, msgType, call.MsgType) + assert.Equal(t, expected, call.Data) + } +} diff --git a/honeybeetest/mocks.go b/honeybeetest/mocks.go new file mode 100644 index 0000000..51a92c8 --- /dev/null +++ b/honeybeetest/mocks.go @@ -0,0 +1,119 @@ +package honeybeetest + +import ( + "context" + "git.wisehodl.dev/jay/go-honeybee/types" + "log/slog" + "net/http" + "sync" + "time" +) + +// Dialer Mocks + +type MockDialer struct { + DialFunc func(string, http.Header) (types.Socket, *http.Response, error) +} + +func (m *MockDialer) Dial(url string, h http.Header) (types.Socket, *http.Response, error) { + return m.DialFunc(url, h) +} + +// Socket Mocks + +type MockSocket struct { + WriteMessageFunc func(int, []byte) error + SetReadDeadlineFunc func(t time.Time) error + SetWriteDeadlineFunc func(t time.Time) error + ReadMessageFunc func() (int, []byte, error) + CloseFunc func() error + SetCloseHandlerFunc func(func(int, string) error) + Closed chan struct{} + Once sync.Once + Mu sync.Mutex +} + +func NewMockSocket() *MockSocket { + return &MockSocket{ + WriteMessageFunc: func(int, []byte) error { return nil }, + ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil }, + CloseFunc: func() error { return nil }, + + SetReadDeadlineFunc: func(time.Time) error { return nil }, + SetWriteDeadlineFunc: func(time.Time) error { return nil }, + SetCloseHandlerFunc: func(func(int, string) error) {}, + + Closed: make(chan struct{}), + } + +} + +func (m *MockSocket) WriteMessage(t int, d []byte) error { + return m.WriteMessageFunc(t, d) +} + +func (m *MockSocket) ReadMessage() (int, []byte, error) { + return m.ReadMessageFunc() +} + +func (m *MockSocket) Close() error { + return m.CloseFunc() +} + +func (m *MockSocket) SetReadDeadline(t time.Time) error { + return m.SetReadDeadlineFunc(t) +} + +func (m *MockSocket) SetWriteDeadline(t time.Time) error { + return m.SetWriteDeadlineFunc(t) +} + +func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) { + m.SetCloseHandlerFunc(h) +} + +// Logging mocks + +type MockSlogHandler struct { + records []slog.Record + mu sync.RWMutex +} + +func NewMockSlogHandler() *MockSlogHandler { + return &MockSlogHandler{ + records: make([]slog.Record, 0), + } +} + +func (m *MockSlogHandler) Handle(ctx context.Context, record slog.Record) error { + m.mu.Lock() + defer m.mu.Unlock() + m.records = append(m.records, record) + return nil +} + +func (m *MockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool { + return true +} + +func (m *MockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return m +} + +func (m *MockSlogHandler) WithGroup(name string) slog.Handler { + return m +} + +func (m *MockSlogHandler) GetRecords() []slog.Record { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]slog.Record, len(m.records)) + copy(result, m.records) + return result +} + +func (m *MockSlogHandler) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + m.records = make([]slog.Record, 0) +} diff --git a/initiator/config.go b/initiator/config.go new file mode 100644 index 0000000..59b1552 --- /dev/null +++ b/initiator/config.go @@ -0,0 +1,189 @@ +package initiator + +import ( + "git.wisehodl.dev/jay/go-honeybee/transport" + "time" +) + +// Types + +type WorkerFactory func( + id string, + conn *transport.Connection, + onReconnect func() (*transport.Connection, error), +) Worker + +// Pool Config + +type PoolConfig struct { + ConnectionConfig *transport.ConnectionConfig + WorkerFactory WorkerFactory + WorkerConfig *WorkerConfig +} + +type PoolOption func(*PoolConfig) error + +func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) { + conf := GetDefaultPoolConfig() + if err := applyPoolOptions(conf, options...); err != nil { + return nil, err + } + if err := ValidatePoolConfig(conf); err != nil { + return nil, err + } + return conf, nil +} + +func GetDefaultPoolConfig() *PoolConfig { + return &PoolConfig{ + ConnectionConfig: nil, + WorkerFactory: nil, + WorkerConfig: nil, + } +} + +func applyPoolOptions(config *PoolConfig, options ...PoolOption) error { + for _, option := range options { + if err := option(config); err != nil { + return err + } + } + return nil +} + +func ValidatePoolConfig(config *PoolConfig) error { + var err error + + if config.ConnectionConfig != nil { + err = transport.ValidateConnectionConfig(config.ConnectionConfig) + if err != nil { + return err + } + } + + if config.WorkerConfig != nil { + err = ValidateWorkerConfig(config.WorkerConfig) + if err != nil { + return err + } + } + + return nil +} + +func WithConnectionConfig(cc *transport.ConnectionConfig) PoolOption { + return func(c *PoolConfig) error { + err := transport.ValidateConnectionConfig(cc) + if err != nil { + return err + } + c.ConnectionConfig = cc + return nil + } +} + +func WithWorkerConfig(wc *WorkerConfig) PoolOption { + return func(c *PoolConfig) error { + err := ValidateWorkerConfig(wc) + if err != nil { + return err + } + c.WorkerConfig = wc + return nil + } +} + +func WithWorkerFactory(wf WorkerFactory) PoolOption { + return func(c *PoolConfig) error { + c.WorkerFactory = wf + return nil + } +} + +// Worker Config + +type WorkerConfig struct { + IdleTimeout time.Duration + MaxQueueSize int +} + +type WorkerOption func(*WorkerConfig) error + +func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) { + conf := GetDefaultWorkerConfig() + if err := applyWorkerOptions(conf, options...); err != nil { + return nil, err + } + if err := ValidateWorkerConfig(conf); err != nil { + return nil, err + } + return conf, nil +} + +func GetDefaultWorkerConfig() *WorkerConfig { + return &WorkerConfig{ + IdleTimeout: 20 * time.Second, + MaxQueueSize: 0, // disabled by default + } +} + +func applyWorkerOptions(config *WorkerConfig, options ...WorkerOption) error { + for _, option := range options { + if err := option(config); err != nil { + return err + } + } + return nil +} + +func ValidateWorkerConfig(config *WorkerConfig) error { + err := validateIdleTimeout(config.IdleTimeout) + if err != nil { + return err + } + + err = validateMaxQueueSize(config.MaxQueueSize) + if err != nil { + return err + } + + return nil +} + +func validateMaxQueueSize(value int) error { + if value < 0 { + return InvalidMaxQueueSize + } + return nil +} + +func validateIdleTimeout(value time.Duration) error { + if value < 0 { + return InvalidIdleTimeout + } + return nil +} + +// When IdleTimeout is set to zero, idle timeouts are disabled. +func WithIdleTimeout(value time.Duration) WorkerOption { + return func(c *WorkerConfig) error { + err := validateIdleTimeout(value) + if err != nil { + return err + } + c.IdleTimeout = value + return nil + } +} + +// When MaxQueueSize is set to zero, queue limits are disabled. +func WithMaxQueueSize(value int) WorkerOption { + return func(c *WorkerConfig) error { + err := validateMaxQueueSize(value) + if err != nil { + return err + } + c.MaxQueueSize = value + return nil + } +} diff --git a/config_pool_test.go b/initiator/config_pool_test.go similarity index 60% rename from config_pool_test.go rename to initiator/config_pool_test.go index 9200c7a..a1d5bdd 100644 --- a/config_pool_test.go +++ b/initiator/config_pool_test.go @@ -1,16 +1,17 @@ -package honeybee +package initiator import ( + "git.wisehodl.dev/jay/go-honeybee/transport" "github.com/stretchr/testify/assert" "testing" "time" ) func TestNewPoolConfig(t *testing.T) { - conf, err := NewInitiatorPoolConfig() + conf, err := NewPoolConfig() assert.NoError(t, err) - assert.Equal(t, conf, &InitiatorPoolConfig{ + assert.Equal(t, conf, &PoolConfig{ ConnectionConfig: nil, WorkerConfig: nil, WorkerFactory: nil, @@ -18,9 +19,9 @@ func TestNewPoolConfig(t *testing.T) { } func TestDefaultPoolConfig(t *testing.T) { - conf := GetDefaultInitiatorPoolConfig() + conf := GetDefaultPoolConfig() - assert.Equal(t, conf, &InitiatorPoolConfig{ + assert.Equal(t, conf, &PoolConfig{ ConnectionConfig: nil, WorkerConfig: nil, WorkerFactory: nil, @@ -28,10 +29,10 @@ func TestDefaultPoolConfig(t *testing.T) { } func TestApplyPoolOptions(t *testing.T) { - conf := &InitiatorPoolConfig{} - err := applyInitiatorPoolOptions( + conf := &PoolConfig{} + err := applyPoolOptions( conf, - WithInitiatorConnectionConfig(&ConnectionConfig{}), + WithConnectionConfig(&transport.ConnectionConfig{}), ) assert.NoError(t, err) @@ -39,46 +40,46 @@ func TestApplyPoolOptions(t *testing.T) { } func TestWithConnectionConfig(t *testing.T) { - conf := &InitiatorPoolConfig{} - opt := WithInitiatorConnectionConfig(&ConnectionConfig{WriteTimeout: 1 * time.Second}) - err := applyInitiatorPoolOptions(conf, opt) + conf := &PoolConfig{} + opt := WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: 1 * time.Second}) + err := applyPoolOptions(conf, opt) assert.NoError(t, err) assert.NotNil(t, conf.ConnectionConfig) assert.Equal(t, 1*time.Second, conf.ConnectionConfig.WriteTimeout) // invalid config is rejected - conf = &InitiatorPoolConfig{} - opt = WithInitiatorConnectionConfig(&ConnectionConfig{WriteTimeout: -1 * time.Second}) - err = applyInitiatorPoolOptions(conf, opt) + conf = &PoolConfig{} + opt = WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: -1 * time.Second}) + err = applyPoolOptions(conf, opt) assert.Error(t, err) } func TestValidatePoolConfig(t *testing.T) { cases := []struct { name string - conf InitiatorPoolConfig + conf PoolConfig wantErr error wantErrText string }{ { name: "valid empty", - conf: *&InitiatorPoolConfig{}, + conf: *&PoolConfig{}, }, { name: "valid defaults", - conf: *GetDefaultInitiatorPoolConfig(), + conf: *GetDefaultPoolConfig(), }, { name: "valid complete", - conf: InitiatorPoolConfig{ - ConnectionConfig: &ConnectionConfig{}, + conf: PoolConfig{ + ConnectionConfig: &transport.ConnectionConfig{}, }, }, { name: "invalid connection config", - conf: InitiatorPoolConfig{ - ConnectionConfig: &ConnectionConfig{ - Retry: &RetryConfig{ + conf: PoolConfig{ + ConnectionConfig: &transport.ConnectionConfig{ + Retry: &transport.RetryConfig{ InitialDelay: 10 * time.Second, MaxDelay: 1 * time.Second, }, @@ -90,7 +91,7 @@ func TestValidatePoolConfig(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - err := validateInitiatorPoolConfig(&tc.conf) + err := ValidatePoolConfig(&tc.conf) if tc.wantErr != nil || tc.wantErrText != "" { if tc.wantErr != nil { diff --git a/initiator/errors.go b/initiator/errors.go new file mode 100644 index 0000000..cf31318 --- /dev/null +++ b/initiator/errors.go @@ -0,0 +1,17 @@ +package initiator + +import "errors" +import "fmt" + +var ( + InvalidIdleTimeout = errors.New("idle timeout cannot be negative") + InvalidMaxQueueSize = errors.New("maximum queue size cannot be negative") +) + +func NewConfigError(text string) error { + return fmt.Errorf("configuration error: %s", text) +} + +func NewPoolError(text string) error { + return fmt.Errorf("pool error: %s", text) +} diff --git a/initiator/pool.go b/initiator/pool.go new file mode 100644 index 0000000..651e8ef --- /dev/null +++ b/initiator/pool.go @@ -0,0 +1,248 @@ +package initiator + +import ( + "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" + "log/slog" + "sync" + "time" +) + +// Types + +type peer struct { + conn *transport.Connection + stop chan struct{} +} + +type InboxMessage struct { + ID string + Data []byte + ReceivedAt time.Time +} + +type PoolEventKind int + +const ( + EventConnected PoolEventKind = iota + EventDisconnected +) + +func (s PoolEventKind) String() string { + switch s { + case EventConnected: + return "connected" + case EventDisconnected: + return "disconnected" + default: + return "unknown" + } +} + +type PoolEvent struct { + ID string + Kind PoolEventKind +} + +// Pool + +type Pool struct { + peers map[string]*peer + inbox chan InboxMessage + events chan PoolEvent + errors chan error + done chan struct{} + + dialer types.Dialer + config *PoolConfig + logger *slog.Logger + + mu sync.RWMutex + wg sync.WaitGroup + closed bool +} + +func NewPool(config *PoolConfig, logger *slog.Logger) (*Pool, error) { + if config == nil { + config = GetDefaultPoolConfig() + } + + if err := ValidatePoolConfig(config); err != nil { + return nil, err + } + + p := &Pool{ + peers: make(map[string]*peer), + inbox: make(chan InboxMessage, 256), + events: make(chan PoolEvent, 10), + errors: make(chan error, 10), + done: make(chan struct{}), + dialer: transport.NewDialer(), + config: config, + logger: logger, + } + + return p, nil +} + +func (p *Pool) Peers() map[string]*peer { + return p.peers +} + +func (p *Pool) Inbox() chan InboxMessage { + return p.inbox +} + +func (p *Pool) Events() chan PoolEvent { + return p.events +} + +func (p *Pool) Errors() chan error { + return p.errors +} + +func (p *Pool) Close() { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return + } + + p.closed = true + close(p.done) + + peers := p.peers + p.peers = make(map[string]*peer) + + p.mu.Unlock() + + for _, conn := range peers { + conn.conn.Close() + } + + go func() { + p.wg.Wait() + close(p.inbox) + close(p.events) + close(p.errors) + }() +} + +func (p *Pool) Connect(id string) error { + id, err := transport.NormalizeURL(id) + if err != nil { + return err + } + + // Check for existing connection in pool + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return NewPoolError("pool is closed") + } + _, exists := p.peers[id] + p.mu.Unlock() + + if exists { + return NewPoolError("connection already exists") + } + + // Create new connection + var logger *slog.Logger + if p.logger != nil { + logger = p.logger.With("id", id) + } + conn, err := transport.NewConnection(id, p.config.ConnectionConfig, logger) + if err != nil { + return err + } + conn.SetDialer(p.dialer) + + // Attempt to connect + if err := conn.Connect(); err != nil { + return err + } + + p.mu.Lock() + if p.closed { + // The pool closed while this connection was established. + p.mu.Unlock() + conn.Close() + return NewPoolError("pool is closed") + } + + // Add connection to pool + stop := make(chan struct{}) + if _, exists := p.peers[id]; exists { + // Another process connected to this url while this one was connecting + // Discard this connection and retain the existing one + p.mu.Unlock() + conn.Close() + return NewPoolError("connection already exists") + } + p.peers[id] = &peer{conn: conn, stop: stop} + p.mu.Unlock() + + // TODO: start this connection's incoming message forwarder + + select { + case p.events <- PoolEvent{ID: id, Kind: EventConnected}: + case <-p.done: + return nil + } + + return nil +} + +func (p *Pool) Remove(id string) error { + id, err := transport.NormalizeURL(id) + if err != nil { + return err + } + + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return NewPoolError("pool is closed") + } + + peer, exists := p.peers[id] + if !exists { + p.mu.Unlock() + return NewPoolError("connection not found") + } + delete(p.peers, id) + p.mu.Unlock() + + close(peer.stop) + peer.conn.Close() + + select { + case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}: + case <-p.done: + return nil + } + + return nil +} + +func (p *Pool) Send(id string, data []byte) error { + id, err := transport.NormalizeURL(id) + if err != nil { + return err + } + + p.mu.RLock() + defer p.mu.RUnlock() + + if p.closed { + return NewPoolError("pool is closed") + } + + peer, exists := p.peers[id] + if !exists { + return NewPoolError("connection not found") + } + + return peer.conn.Send(data) +} diff --git a/pool_test.go b/initiator/pool_test.go similarity index 65% rename from pool_test.go rename to initiator/pool_test.go index 0b06902..2346002 100644 --- a/pool_test.go +++ b/initiator/pool_test.go @@ -1,7 +1,10 @@ -package honeybee +package initiator import ( "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "net/http" @@ -11,14 +14,14 @@ import ( func TestPoolConnect(t *testing.T) { t.Run("successfully adds connection", func(t *testing.T) { - mockSocket := NewMockSocket() - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockSocket := honeybeetest.NewMockSocket() + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewInitiatorPool(nil, nil) + pool, err := NewPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -33,7 +36,7 @@ func TestPoolConnect(t *testing.T) { default: return false } - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) _, exists := pool.peers["wss://test"] assert.True(t, exists) @@ -42,14 +45,14 @@ func TestPoolConnect(t *testing.T) { }) t.Run("does not add duplicate", func(t *testing.T) { - mockSocket := NewMockSocket() - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockSocket := honeybeetest.NewMockSocket() + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewInitiatorPool(nil, nil) + pool, err := NewPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -69,18 +72,18 @@ func TestPoolConnect(t *testing.T) { }) t.Run("fails to add connection", func(t *testing.T) { - pool, err := NewInitiatorPool( - &InitiatorPoolConfig{ - ConnectionConfig: &ConnectionConfig{ - Retry: &RetryConfig{ + pool, err := NewPool( + &PoolConfig{ + ConnectionConfig: &transport.ConnectionConfig{ + Retry: &transport.RetryConfig{ MaxRetries: 1, InitialDelay: 1 * time.Millisecond, MaxDelay: 5 * time.Millisecond, }}, }, nil) assert.NoError(t, err) - pool.dialer = &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + pool.dialer = &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return nil, nil, fmt.Errorf("dial failed") }, } @@ -104,14 +107,14 @@ func TestPoolConnect(t *testing.T) { func TestPoolRemove(t *testing.T) { t.Run("removes known url", func(t *testing.T) { - mockSocket := NewMockSocket() - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockSocket := honeybeetest.NewMockSocket() + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewInitiatorPool(nil, nil) + pool, err := NewPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -132,14 +135,14 @@ func TestPoolRemove(t *testing.T) { }) t.Run("unknown url returns error", func(t *testing.T) { - mockSocket := NewMockSocket() - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockSocket := honeybeetest.NewMockSocket() + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewInitiatorPool(nil, nil) + pool, err := NewPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -149,14 +152,14 @@ func TestPoolRemove(t *testing.T) { }) t.Run("closed pool returns error", func(t *testing.T) { - mockSocket := NewMockSocket() - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockSocket := honeybeetest.NewMockSocket() + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewInitiatorPool(nil, nil) + pool, err := NewPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -171,19 +174,19 @@ func TestPoolRemove(t *testing.T) { } func TestPoolSend(t *testing.T) { - mockSocket := NewMockSocket() - outgoingData := make(chan mockOutgoingData, 10) + mockSocket := honeybeetest.NewMockSocket() + outgoingData := make(chan honeybeetest.MockOutgoingData, 10) mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { - outgoingData <- mockOutgoingData{msgType: msgType, data: data} + outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data} return nil } - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } - pool, err := NewInitiatorPool(nil, nil) + pool, err := NewPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -194,7 +197,7 @@ func TestPoolSend(t *testing.T) { err = pool.Send("wss://test", []byte("hello")) assert.NoError(t, err) - expectWrite(t, outgoingData, websocket.TextMessage, []byte("hello")) + honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, []byte("hello")) pool.Close() } @@ -213,7 +216,7 @@ func expectEvent( default: return false } - }, testTimeout, testTick, + }, honeybeetest.TestTimeout, honeybeetest.TestTick, fmt.Sprintf("expected event: URL=%q, Kind=%q", expectedURL, expectedKind.String())) } diff --git a/worker.go b/initiator/worker.go similarity index 54% rename from worker.go rename to initiator/worker.go index 2dbb2b5..1635e8b 100644 --- a/worker.go +++ b/initiator/worker.go @@ -1,6 +1,7 @@ -package honeybee +package initiator import ( + "git.wisehodl.dev/jay/go-honeybee/transport" "log/slog" "sync" "time" @@ -8,15 +9,6 @@ import ( // Types -// Worker Implementation - -type Worker interface { - Start( - ctx *WorkerContext, - wg *sync.WaitGroup, - ) -} - type WorkerContext struct { Inbox chan<- InboxMessage Events chan<- PoolEvent @@ -26,40 +18,23 @@ type WorkerContext struct { Logger *slog.Logger } -// Base Struct +// Worker -type worker struct { - id string +type Worker struct { + id string + config *WorkerConfig + onReconnect func() (*transport.Connection, error) } -func (w *worker) runForwarder( - messages <-chan []byte, - inbox chan<- []byte, - stop <-chan struct{}, - poolDone <-chan struct{}, - maxQueueSize int, -) { -} - -// Initiator Worker - -type InitiatorWorker struct { - *worker - config *InitiatorWorkerConfig - onReconnect func() (*Connection, error) -} - -func newInitiatorWorker( +func NewWorker( id string, - config *InitiatorWorkerConfig, - onReconnect func() (*Connection, error), + config *WorkerConfig, + onReconnect func() (*transport.Connection, error), logger *slog.Logger, -) (*InitiatorWorker, error) { - w := &InitiatorWorker{ - worker: &worker{ - id: id, - }, +) (*Worker, error) { + w := &Worker{ + id: id, config: config, onReconnect: onReconnect, } @@ -67,7 +42,7 @@ func newInitiatorWorker( return w, nil } -func (w *InitiatorWorker) Start( +func (w *Worker) Start( inbox chan<- InboxMessage, events chan<- PoolEvent, stop <-chan struct{}, @@ -76,32 +51,37 @@ func (w *InitiatorWorker) Start( ) { } -func runReader(conn *Connection, +func (w *Worker) runReader(conn *transport.Connection, messages chan<- []byte, heartbeat chan<- time.Time, reconnect chan<- struct{}, - newConn <-chan *Connection, + newConn <-chan *transport.Connection, stop <-chan struct{}, poolDone <-chan struct{}, ) { } -func runHealthMonitor( +func (w *Worker) runForwarder( + messages <-chan []byte, + inbox chan<- []byte, + stop <-chan struct{}, + poolDone <-chan struct{}, + maxQueueSize int, +) { +} + +func (w *Worker) runHealthMonitor( heartbeat <-chan time.Time, stop <-chan struct{}, poolDone <-chan struct{}, ) { } -func runReconnector( +func (w *Worker) runReconnector( reconnect <-chan struct{}, - newConn chan<- *Connection, + newConn chan<- *transport.Connection, stop <-chan struct{}, poolDone <-chan struct{}, ) { } - -// Responder Worker - -type ResponderWorker struct{} diff --git a/initiator/worker_test.go b/initiator/worker_test.go new file mode 100644 index 0000000..cacb148 --- /dev/null +++ b/initiator/worker_test.go @@ -0,0 +1 @@ +package initiator diff --git a/mocks_test.go b/mocks_test.go deleted file mode 100644 index 7206980..0000000 --- a/mocks_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package honeybee - -import ( - "context" - "fmt" - "github.com/stretchr/testify/assert" - "io" - "log/slog" - "net/http" - "sync" - "testing" - "time" -) - -// Test Constants - -const ( - testTimeout = 2 * time.Second - testTick = 10 * time.Millisecond - negativeTestTimeout = 100 * time.Millisecond -) - -// Dialer Mocks - -type MockDialer struct { - DialFunc func(string, http.Header) (Socket, *http.Response, error) -} - -func (m *MockDialer) Dial(url string, h http.Header) (Socket, *http.Response, error) { - return m.DialFunc(url, h) -} - -// Socket Mocks - -type MockSocket struct { - WriteMessageFunc func(int, []byte) error - SetReadDeadlineFunc func(t time.Time) error - SetWriteDeadlineFunc func(t time.Time) error - ReadMessageFunc func() (int, []byte, error) - CloseFunc func() error - SetCloseHandlerFunc func(func(int, string) error) - closed chan struct{} - once sync.Once - mu sync.Mutex -} - -func NewMockSocket() *MockSocket { - return &MockSocket{ - WriteMessageFunc: func(int, []byte) error { return nil }, - ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil }, - CloseFunc: func() error { return nil }, - - SetReadDeadlineFunc: func(time.Time) error { return nil }, - SetWriteDeadlineFunc: func(time.Time) error { return nil }, - SetCloseHandlerFunc: func(func(int, string) error) {}, - - closed: make(chan struct{}), - } - -} - -func (m *MockSocket) WriteMessage(t int, d []byte) error { - return m.WriteMessageFunc(t, d) -} - -func (m *MockSocket) ReadMessage() (int, []byte, error) { - return m.ReadMessageFunc() -} - -func (m *MockSocket) Close() error { - return m.CloseFunc() -} - -func (m *MockSocket) SetReadDeadline(t time.Time) error { - return m.SetReadDeadlineFunc(t) -} - -func (m *MockSocket) SetWriteDeadline(t time.Time) error { - return m.SetWriteDeadlineFunc(t) -} - -func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) { - m.SetCloseHandlerFunc(h) -} - -// Connection Mocks - -type mockIncomingData struct { - msgType int - data []byte - err error -} - -type mockOutgoingData struct { - msgType int - data []byte -} - -func setupTestConnection(t *testing.T, config *ConnectionConfig) ( - conn *Connection, - mockSocket *MockSocket, - incomingData chan mockIncomingData, - outgoingData chan mockOutgoingData, -) { - t.Helper() - - incomingData = make(chan mockIncomingData, 10) - outgoingData = make(chan mockOutgoingData, 10) - - mockSocket = NewMockSocket() - - mockSocket.CloseFunc = func() error { - mockSocket.once.Do(func() { - close(mockSocket.closed) - }) - return nil - } - - // Wire ReadMessage to pull from incomingData channel - mockSocket.ReadMessageFunc = func() (int, []byte, error) { - select { - case data := <-incomingData: - return data.msgType, data.data, data.err - case <-mockSocket.closed: - return 0, nil, io.EOF - } - } - - // Wire WriteMessage to push to outgoingData channel - mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { - select { - case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: - return nil - case <-mockSocket.closed: - return io.EOF - default: - return fmt.Errorf("mock outgoing chanel unavailable") - } - } - - var err error - conn, err = NewConnectionFromSocket(mockSocket, config, nil) - assert.NoError(t, err) - - return conn, mockSocket, incomingData, outgoingData -} - -// Logging mocks - -type mockSlogHandler struct { - records []slog.Record - mu sync.RWMutex -} - -func newMockSlogHandler() *mockSlogHandler { - return &mockSlogHandler{ - records: make([]slog.Record, 0), - } -} - -func (m *mockSlogHandler) Handle(ctx context.Context, record slog.Record) error { - m.mu.Lock() - defer m.mu.Unlock() - m.records = append(m.records, record) - return nil -} - -func (m *mockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool { - return true -} - -func (m *mockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - return m -} - -func (m *mockSlogHandler) WithGroup(name string) slog.Handler { - return m -} - -func (m *mockSlogHandler) GetRecords() []slog.Record { - m.mu.RLock() - defer m.mu.RUnlock() - result := make([]slog.Record, len(m.records)) - copy(result, m.records) - return result -} - -func (m *mockSlogHandler) Clear() { - m.mu.Lock() - defer m.mu.Unlock() - m.records = make([]slog.Record, 0) -} diff --git a/pool.go b/pool.go deleted file mode 100644 index 0e68edd..0000000 --- a/pool.go +++ /dev/null @@ -1,304 +0,0 @@ -package honeybee - -import ( - "git.wisehodl.dev/jay/go-honeybee/errors" - "log/slog" - "sync" - "time" -) - -// Types - -type peer struct { - conn *Connection - stop chan struct{} -} - -type InboxMessage struct { - ID string - Data []byte - ReceivedAt time.Time -} - -type PoolEventKind int - -const ( - EventConnected PoolEventKind = iota - EventDisconnected -) - -func (s PoolEventKind) String() string { - switch s { - case EventConnected: - return "connected" - case EventDisconnected: - return "disconnected" - default: - return "unknown" - } -} - -type PoolEvent struct { - ID string - Kind PoolEventKind -} - -// Pool Implementation - -type Pool interface { - Send(id string, data []byte) error - Inbox() <-chan InboxMessage - Events() <-chan PoolEvent - Errors() <-chan error - Close() -} - -// Base Struct - -type pool struct { - peers map[string]*peer - inbox chan InboxMessage - events chan PoolEvent - errors chan error - done chan struct{} - - config *InitiatorPoolConfig - logger *slog.Logger - - mu sync.RWMutex - wg sync.WaitGroup - closed bool -} - -func (p *pool) closeAll() { - p.mu.Lock() - if p.closed { - p.mu.Unlock() - return - } - - p.closed = true - close(p.done) - - peers := p.peers - p.peers = make(map[string]*peer) - - p.mu.Unlock() - - for _, conn := range peers { - conn.conn.Close() - } - - go func() { - p.wg.Wait() - close(p.inbox) - close(p.events) - close(p.errors) - }() -} - -func (p *pool) removePeer(id string) error { - p.mu.Lock() - if p.closed { - p.mu.Unlock() - return errors.NewPoolError("pool is closed") - } - - peer, exists := p.peers[id] - if !exists { - p.mu.Unlock() - return errors.NewPoolError("connection not found") - } - delete(p.peers, id) - p.mu.Unlock() - - close(peer.stop) - peer.conn.Close() - - select { - case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}: - case <-p.done: - return nil - } - - return nil -} - -func (p *pool) send(id string, data []byte) error { - p.mu.RLock() - defer p.mu.RUnlock() - - if p.closed { - return errors.NewPoolError("pool is closed") - } - - peer, exists := p.peers[id] - if !exists { - return errors.NewPoolError("connection not found") - } - - return peer.conn.Send(data) -} - -// Initiator Pool - -type InitiatorPool struct { - *pool - dialer Dialer -} - -func NewInitiatorPool(config *InitiatorPoolConfig, logger *slog.Logger) (*InitiatorPool, error) { - if config == nil { - config = GetDefaultInitiatorPoolConfig() - } - - if err := validateInitiatorPoolConfig(config); err != nil { - return nil, err - } - - p := &InitiatorPool{ - pool: &pool{ - peers: make(map[string]*peer), - inbox: make(chan InboxMessage, 256), - events: make(chan PoolEvent, 10), - errors: make(chan error, 10), - done: make(chan struct{}), - config: config, - logger: logger, - }, - dialer: NewDialer(), - } - - return p, nil -} - -func (p *InitiatorPool) Peers() map[string]*peer { - return p.peers -} - -func (p *InitiatorPool) Inbox() chan InboxMessage { - return p.inbox -} - -func (p *InitiatorPool) Events() chan PoolEvent { - return p.events -} - -func (p *InitiatorPool) Errors() chan error { - return p.errors -} - -func (p *InitiatorPool) Close() { - p.closeAll() -} - -func (p *InitiatorPool) Connect(url string) error { - url, err := NormalizeURL(url) - if err != nil { - return err - } - - // Check for existing connection in pool - p.mu.Lock() - if p.closed { - p.mu.Unlock() - return errors.NewPoolError("pool is closed") - } - _, exists := p.peers[url] - p.mu.Unlock() - - if exists { - return errors.NewPoolError("connection already exists") - } - - // Create new connection - var logger *slog.Logger - if p.logger != nil { - logger = p.logger.With("url", url) - } - conn, err := NewConnection(url, p.config.ConnectionConfig, logger) - if err != nil { - return err - } - conn.dialer = p.dialer - - // Attempt to connect - if err := conn.Connect(); err != nil { - return err - } - - p.mu.Lock() - if p.closed { - // The pool closed while this connection was established. - p.mu.Unlock() - conn.Close() - return errors.NewPoolError("pool is closed") - } - - // Add connection to pool - stop := make(chan struct{}) - if _, exists := p.peers[url]; exists { - // Another process connected to this url while this one was connecting - // Discard this connection and retain the existing one - p.mu.Unlock() - conn.Close() - return errors.NewPoolError("connection already exists") - } - p.peers[url] = &peer{conn: conn, stop: stop} - p.mu.Unlock() - - // TODO: start this connection's incoming message forwarder - - select { - case p.events <- PoolEvent{ID: url, Kind: EventConnected}: - case <-p.done: - return nil - } - - return nil -} - -func (p *InitiatorPool) Remove(url string) error { - url, err := NormalizeURL(url) - if err != nil { - return err - } - - return p.removePeer(url) -} - -func (p *InitiatorPool) Send(url string, data []byte) error { - url, err := NormalizeURL(url) - if err != nil { - return err - } - - return p.send(url, data) -} - -// Responder Pool - -type ResponderPool struct { - *pool - idGenerator func() string -} - -func (p *ResponderPool) Peers() map[string]*peer { - return p.peers -} - -func (p *ResponderPool) Inbox() chan InboxMessage { - return p.inbox -} - -func (p *ResponderPool) Events() chan PoolEvent { - return p.events -} - -func (p *ResponderPool) Errors() chan error { - return p.errors -} - -func (p *ResponderPool) Close() { - p.closeAll() -} diff --git a/transport/config.go b/transport/config.go new file mode 100644 index 0000000..d3dd4a7 --- /dev/null +++ b/transport/config.go @@ -0,0 +1,225 @@ +package transport + +import ( + "time" +) + +type CloseHandler func(code int, text string) error + +type ConnectionConfig struct { + CloseHandler CloseHandler + WriteTimeout time.Duration + Retry *RetryConfig +} + +type RetryConfig struct { + MaxRetries int + InitialDelay time.Duration + MaxDelay time.Duration + JitterFactor float64 +} + +type ConnectionOption func(*ConnectionConfig) error + +func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) { + conf := GetDefaultConnectionConfig() + if err := applyConnectionOptions(conf, options...); err != nil { + return nil, err + } + if err := ValidateConnectionConfig(conf); err != nil { + return nil, err + } + return conf, nil +} + +func GetDefaultConnectionConfig() *ConnectionConfig { + return &ConnectionConfig{ + CloseHandler: nil, + WriteTimeout: 30 * time.Second, + Retry: GetDefaultRetryConfig(), + } +} + +func GetDefaultRetryConfig() *RetryConfig { + return &RetryConfig{ + MaxRetries: 0, // Infinite retries + InitialDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + JitterFactor: 0.5, + } +} + +func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error { + for _, option := range options { + if err := option(config); err != nil { + return err + } + } + return nil +} + +func ValidateConnectionConfig(config *ConnectionConfig) error { + err := validateWriteTimeout(config.WriteTimeout) + if err != nil { + return err + } + + if config.Retry != nil { + err = validateMaxRetries(config.Retry.MaxRetries) + if err != nil { + return err + } + + err = validateInitialDelay(config.Retry.InitialDelay) + if err != nil { + return err + } + + err = validateMaxDelay(config.Retry.MaxDelay) + if err != nil { + return err + } + + err = validateJitterFactor(config.Retry.JitterFactor) + if err != nil { + return err + } + + if config.Retry.InitialDelay > config.Retry.MaxDelay { + return NewConfigError("initial delay may not exceed maximum delay") + } + } + + return nil +} + +func validateWriteTimeout(value time.Duration) error { + if value < 0 { + return InvalidWriteTimeout + } + return nil +} + +func validateMaxRetries(value int) error { + if value < 0 { + return InvalidRetryMaxRetries + } + return nil +} + +func validateInitialDelay(value time.Duration) error { + if value <= 0 { + return InvalidRetryInitialDelay + } + return nil +} + +func validateMaxDelay(value time.Duration) error { + if value <= 0 { + return InvalidRetryMaxDelay + } + return nil +} + +func validateJitterFactor(value float64) error { + if value < 0.0 || value > 1.0 { + return InvalidRetryJitterFactor + } + return nil +} + +func WithCloseHandler(handler CloseHandler) ConnectionOption { + return func(c *ConnectionConfig) error { + c.CloseHandler = handler + return nil + } +} + +// When WriteTimeout is set to zero, read timeouts are disabled. +func WithWriteTimeout(value time.Duration) ConnectionOption { + return func(c *ConnectionConfig) error { + err := validateWriteTimeout(value) + if err != nil { + return err + } + c.WriteTimeout = value + return nil + } +} + +// WithRetry enables retry with default parameters (infinite retries, +// 1s initial delay, 5s max delay, 0.5 jitter factor). +// +// If passed after granular retry options (WithRetryMaxRetries, etc.), +// it will overwrite them. Use either WithRetry alone or the granular +// options; not both. +func WithRetry() ConnectionOption { + return func(c *ConnectionConfig) error { + c.Retry = GetDefaultRetryConfig() + return nil + } +} + +func WithRetryMaxRetries(value int) ConnectionOption { + return func(c *ConnectionConfig) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + + err := validateMaxRetries(value) + if err != nil { + return err + } + + c.Retry.MaxRetries = value + return nil + } +} + +func WithRetryInitialDelay(value time.Duration) ConnectionOption { + return func(c *ConnectionConfig) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + + err := validateInitialDelay(value) + if err != nil { + return err + } + + c.Retry.InitialDelay = value + return nil + } +} + +func WithRetryMaxDelay(value time.Duration) ConnectionOption { + return func(c *ConnectionConfig) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + + err := validateMaxDelay(value) + if err != nil { + return err + } + + c.Retry.MaxDelay = value + return nil + } +} + +func WithRetryJitterFactor(value float64) ConnectionOption { + return func(c *ConnectionConfig) error { + if c.Retry == nil { + c.Retry = GetDefaultRetryConfig() + } + + err := validateJitterFactor(value) + if err != nil { + return err + } + + c.Retry.JitterFactor = value + return nil + } +} diff --git a/config_connection_test.go b/transport/config_test.go similarity index 90% rename from config_connection_test.go rename to transport/config_test.go index 9d38c83..1fc259a 100644 --- a/config_connection_test.go +++ b/transport/config_test.go @@ -1,7 +1,6 @@ -package honeybee +package transport import ( - "git.wisehodl.dev/jay/go-honeybee/errors" "github.com/stretchr/testify/assert" "testing" "time" @@ -72,7 +71,7 @@ func TestApplyConnectionOptions(t *testing.T) { WithRetryMaxRetries(-10), ) - assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries) + assert.ErrorIs(t, err, InvalidRetryMaxRetries) } // Option Tests @@ -103,7 +102,7 @@ func TestWithWriteTimeout(t *testing.T) { conf = &ConnectionConfig{} opt = WithWriteTimeout(-30) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidWriteTimeout) + assert.ErrorIs(t, err, InvalidWriteTimeout) assert.ErrorContains(t, err, "write timeout cannot be negative") } @@ -132,7 +131,7 @@ func TestWithRetry(t *testing.T) { // negative disallowed opt = WithRetryMaxRetries(-10) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries) + assert.ErrorIs(t, err, InvalidRetryMaxRetries) assert.ErrorContains(t, err, "max retry count cannot be negative") }) @@ -146,13 +145,13 @@ func TestWithRetry(t *testing.T) { // zero disallowed opt = WithRetryInitialDelay(0 * time.Second) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay) + assert.ErrorIs(t, err, InvalidRetryInitialDelay) assert.ErrorContains(t, err, "initial delay must be positive") // negative disallowed opt = WithRetryInitialDelay(-10 * time.Second) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay) + assert.ErrorIs(t, err, InvalidRetryInitialDelay) }) t.Run("with max delay", func(t *testing.T) { @@ -165,13 +164,13 @@ func TestWithRetry(t *testing.T) { // zero disallowed opt = WithRetryMaxDelay(0 * time.Second) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay) + assert.ErrorIs(t, err, InvalidRetryMaxDelay) assert.ErrorContains(t, err, "max delay must be positive") // negative disallowed opt = WithRetryMaxDelay(-10 * time.Second) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay) + assert.ErrorIs(t, err, InvalidRetryMaxDelay) }) t.Run("with jitter factor", func(t *testing.T) { @@ -185,13 +184,13 @@ func TestWithRetry(t *testing.T) { // negative disallowed opt = WithRetryJitterFactor(-1) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor) + assert.ErrorIs(t, err, InvalidRetryJitterFactor) assert.ErrorContains(t, err, "jitter factor must be between 0.0 and 1.0") // >1 disallowed opt = WithRetryJitterFactor(1.1) err = applyConnectionOptions(conf, opt) - assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor) + assert.ErrorIs(t, err, InvalidRetryJitterFactor) }) } @@ -239,7 +238,7 @@ func TestValidateConnectionConfig(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - err := validateConnectionConfig(&tc.conf) + err := ValidateConnectionConfig(&tc.conf) if tc.wantErr != nil || tc.wantErrText != "" { if tc.wantErr != nil { diff --git a/connection.go b/transport/connection.go similarity index 87% rename from connection.go rename to transport/connection.go index e11a3b8..369a7d4 100644 --- a/connection.go +++ b/transport/connection.go @@ -1,14 +1,14 @@ -package honeybee +package transport import ( - stderrors "errors" + "errors" "fmt" "log/slog" "net/url" "sync" "time" - "git.wisehodl.dev/jay/go-honeybee/errors" + "git.wisehodl.dev/jay/go-honeybee/types" "github.com/gorilla/websocket" ) @@ -38,8 +38,8 @@ func (s ConnectionState) String() string { type Connection struct { url *url.URL - dialer Dialer - socket Socket + dialer types.Dialer + socket types.Socket config *ConnectionConfig logger *slog.Logger @@ -60,7 +60,7 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) config = GetDefaultConnectionConfig() } - if err := validateConnectionConfig(config); err != nil { + if err := ValidateConnectionConfig(config); err != nil { return nil, err } @@ -85,16 +85,16 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) return conn, nil } -func NewConnectionFromSocket(socket Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { +func NewConnectionFromSocket(socket types.Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { if socket == nil { - return nil, errors.NewConnectionError("socket cannot be nil") + return nil, NewConnectionError("socket cannot be nil") } if config == nil { config = GetDefaultConnectionConfig() } - if err := validateConnectionConfig(config); err != nil { + if err := ValidateConnectionConfig(config); err != nil { return nil, err } @@ -126,11 +126,11 @@ func (c *Connection) Connect() error { defer c.mu.Unlock() if c.socket != nil { - return errors.NewConnectionError("connection already has socket") + return NewConnectionError("connection already has socket") } if c.closed { - return errors.NewConnectionError("connection is closed") + return NewConnectionError("connection is closed") } if c.logger != nil { @@ -177,7 +177,7 @@ func (c *Connection) startReader() { if err != nil { if c.logger != nil { var closeErr *websocket.CloseError - if stderrors.As(err, &closeErr) { + if errors.As(err, &closeErr) { switch closeErr.Code { case websocket.CloseNormalClosure, websocket.CloseGoingAway: c.logger.Info("connection closed by peer", @@ -263,16 +263,16 @@ func (c *Connection) Send(data []byte) error { defer c.mu.RUnlock() if c.closed { - return errors.NewConnectionError("connection closed") + return NewConnectionError("connection closed") } select { case c.outgoing <- data: return nil case <-c.done: - return errors.NewConnectionError("connection closing") + return NewConnectionError("connection closing") default: - return errors.NewConnectionError("outgoing queue full") + return NewConnectionError("outgoing queue full") } } @@ -337,3 +337,7 @@ func (c *Connection) State() ConnectionState { defer c.mu.RUnlock() return c.state } + +func (c *Connection) SetDialer(d types.Dialer) { + c.dialer = d +} diff --git a/connection_close_test.go b/transport/connection_close_test.go similarity index 85% rename from connection_close_test.go rename to transport/connection_close_test.go index eddb2c9..fa44a99 100644 --- a/connection_close_test.go +++ b/transport/connection_close_test.go @@ -1,8 +1,9 @@ -package honeybee +package transport import ( "bytes" "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "testing" @@ -38,7 +39,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { t.Run("socket close error does not propagate", func(t *testing.T) { expectedErr := fmt.Errorf("socket close failed") - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.CloseFunc = func() error { return expectedErr } @@ -64,7 +65,8 @@ func TestDisconnectedConnectionClose(t *testing.T) { default: return false } - }, testTimeout, testTick, "errors channel should close") + }, honeybeetest.TestTimeout, honeybeetest.TestTick, + "errors channel should close") }) t.Run("send fails after close", func(t *testing.T) { @@ -86,7 +88,8 @@ func TestConnectedConnectionClose(t *testing.T) { // Send a message to ensure reader loop is blocking canary := []byte("canary") - incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: canary} + incomingData <- honeybeetest.MockIncomingData{ + MsgType: websocket.TextMessage, Data: canary} assert.Eventually(t, func() bool { select { @@ -95,7 +98,7 @@ func TestConnectedConnectionClose(t *testing.T) { default: return false } - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) conn.Close() assert.Equal(t, StateClosed, conn.State()) @@ -119,9 +122,9 @@ func TestConnectedConnectionClose(t *testing.T) { conn, _, incomingData, _ := setupTestConnection(t, nil) for i := 0; i < 10; i++ { - incomingData <- mockIncomingData{ - msgType: websocket.TextMessage, - data: []byte(fmt.Sprintf("in-%d", i)), + incomingData <- honeybeetest.MockIncomingData{ + MsgType: websocket.TextMessage, + Data: []byte(fmt.Sprintf("in-%d", i)), } conn.Send([]byte(fmt.Sprintf("out-%d", i))) } diff --git a/connection_goroutine_test.go b/transport/connection_goroutine_test.go similarity index 70% rename from connection_goroutine_test.go rename to transport/connection_goroutine_test.go index 3c7699c..df54e0c 100644 --- a/connection_goroutine_test.go +++ b/transport/connection_goroutine_test.go @@ -1,8 +1,8 @@ -package honeybee +package transport import ( - "bytes" "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "io" @@ -17,13 +17,13 @@ func TestStartReader(t *testing.T) { defer conn.Close() testData := []byte("hello") - incomingData <- mockIncomingData{ - msgType: websocket.TextMessage, - data: testData, - err: nil, + incomingData <- honeybeetest.MockIncomingData{ + MsgType: websocket.TextMessage, + Data: testData, + Err: nil, } - expectIncoming(t, conn, testData) + honeybeetest.ExpectIncoming(t, conn.Incoming(), testData) }) t.Run("binary messages route to incoming channel", func(t *testing.T) { @@ -31,13 +31,13 @@ func TestStartReader(t *testing.T) { defer conn.Close() testData := []byte{0x00, 0x01, 0x02} - incomingData <- mockIncomingData{ - msgType: websocket.BinaryMessage, - data: testData, - err: nil, + incomingData <- honeybeetest.MockIncomingData{ + MsgType: websocket.BinaryMessage, + Data: testData, + Err: nil, } - expectIncoming(t, conn, testData) + honeybeetest.ExpectIncoming(t, conn.Incoming(), testData) }) t.Run("multiple messages processed sequentially", func(t *testing.T) { @@ -46,20 +46,21 @@ func TestStartReader(t *testing.T) { messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")} for _, msg := range messages { - incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: msg, err: nil} + incomingData <- honeybeetest.MockIncomingData{ + MsgType: websocket.TextMessage, Data: msg, Err: nil} } for _, expected := range messages { - expectIncoming(t, conn, expected) + honeybeetest.ExpectIncoming(t, conn.Incoming(), expected) } }) t.Run("reader exits on socket read error", func(t *testing.T) { - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.CloseFunc = func() error { - mockSocket.once.Do(func() { - close(mockSocket.closed) + mockSocket.Once.Do(func() { + close(mockSocket.Closed) }) return nil } @@ -80,11 +81,11 @@ func TestStartReader(t *testing.T) { default: return false } - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) assert.Eventually(t, func() bool { return conn.State() == StateClosed - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) } @@ -97,7 +98,7 @@ func TestStartWriter(t *testing.T) { err := conn.Send(testData) assert.NoError(t, err) - expectWrite(t, outgoingData, websocket.TextMessage, testData) + honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, testData) }) t.Run("multiple messages processed sequentially", func(t *testing.T) { @@ -111,7 +112,7 @@ func TestStartWriter(t *testing.T) { } for _, expected := range messages { - expectWrite(t, outgoingData, websocket.TextMessage, expected) + honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, expected) } }) @@ -122,12 +123,12 @@ func TestStartWriter(t *testing.T) { config := &ConnectionConfig{WriteTimeout: 0} - outgoingData := make(chan mockOutgoingData, 10) - mockSocket := NewMockSocket() + outgoingData := make(chan honeybeetest.MockOutgoingData, 10) + mockSocket := honeybeetest.NewMockSocket() mockSocket.CloseFunc = func() error { - mockSocket.once.Do(func() { - close(mockSocket.closed) + mockSocket.Once.Do(func() { + close(mockSocket.Closed) }) return nil } @@ -140,8 +141,9 @@ func TestStartWriter(t *testing.T) { mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { select { - case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: - case <-mockSocket.closed: + case outgoingData <- honeybeetest.MockOutgoingData{ + MsgType: msgType, Data: data}: + case <-mockSocket.Closed: return io.EOF } return nil @@ -161,19 +163,19 @@ func TestStartWriter(t *testing.T) { default: return false } - }, negativeTestTimeout, testTick, + }, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick, "SetWriteDeadline should not be called when timeout is zero") }) t.Run("write timeout sets deadline when positive", func(t *testing.T) { config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond} - outgoingData := make(chan mockOutgoingData, 10) - mockSocket := NewMockSocket() + outgoingData := make(chan honeybeetest.MockOutgoingData, 10) + mockSocket := honeybeetest.NewMockSocket() mockSocket.CloseFunc = func() error { - mockSocket.once.Do(func() { - close(mockSocket.closed) + mockSocket.Once.Do(func() { + close(mockSocket.Closed) }) return nil } @@ -186,8 +188,9 @@ func TestStartWriter(t *testing.T) { mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { select { - case outgoingData <- mockOutgoingData{msgType: msgType, data: data}: - case <-mockSocket.closed: + case outgoingData <- honeybeetest.MockOutgoingData{ + MsgType: msgType, Data: data}: + case <-mockSocket.Closed: return io.EOF } return nil @@ -207,18 +210,18 @@ func TestStartWriter(t *testing.T) { default: return false } - }, testTimeout, testTick, + }, honeybeetest.TestTimeout, honeybeetest.TestTick, "SetWriteDeadline should be called when timeout is positive") }) t.Run("writer exits on deadline error", func(t *testing.T) { config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond} - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.CloseFunc = func() error { - mockSocket.once.Do(func() { - close(mockSocket.closed) + mockSocket.Once.Do(func() { + close(mockSocket.Closed) }) return nil } @@ -242,15 +245,15 @@ func TestStartWriter(t *testing.T) { default: return false } - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) assert.Eventually(t, func() bool { return conn.State() == StateClosed - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) t.Run("writer exits on socket write error", func(t *testing.T) { - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() writeErr := fmt.Errorf("write failed") mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { @@ -271,45 +274,12 @@ func TestStartWriter(t *testing.T) { default: return false } - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) assert.Eventually(t, func() bool { return conn.State() == StateClosed - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) } // Helpers - -func expectIncoming(t *testing.T, conn *Connection, expected []byte) { - t.Helper() - assert.Eventually(t, func() bool { - select { - case received := <-conn.Incoming(): - return bytes.Equal(received, expected) - default: - return false - } - }, testTimeout, testTick) -} - -func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) { - t.Helper() - - var call mockOutgoingData - found := assert.Eventually(t, func() bool { - select { - case received := <-outgoingData: - call = received - return true - default: - return false - } - }, testTimeout, testTick) - - if found { - - assert.Equal(t, msgType, call.msgType) - assert.Equal(t, expected, call.data) - } -} diff --git a/connection_send_test.go b/transport/connection_send_test.go similarity index 99% rename from connection_send_test.go rename to transport/connection_send_test.go index 03b0ba0..e5869d9 100644 --- a/connection_send_test.go +++ b/transport/connection_send_test.go @@ -1,4 +1,4 @@ -package honeybee +package transport import ( "fmt" diff --git a/connection_test.go b/transport/connection_test.go similarity index 76% rename from connection_test.go rename to transport/connection_test.go index ed4b68c..eb30c62 100644 --- a/connection_test.go +++ b/transport/connection_test.go @@ -1,9 +1,12 @@ -package honeybee +package transport import ( "bytes" "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/types" "github.com/stretchr/testify/assert" + "io" "net/http" "testing" "time" @@ -36,7 +39,7 @@ func TestConnectionState(t *testing.T) { assert.Equal(t, StateDisconnected, conn.State()) // Test state after FromSocket (should be Connected) - conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil, nil) + conn2, _ := NewConnectionFromSocket(honeybeetest.NewMockSocket(), nil, nil) assert.Equal(t, StateConnected, conn2.State()) // Test state after close @@ -126,7 +129,7 @@ func TestNewConnection(t *testing.T) { func TestNewConnectionFromSocket(t *testing.T) { cases := []struct { name string - socket Socket + socket types.Socket config *ConnectionConfig wantErr bool wantErrText string @@ -140,17 +143,17 @@ func TestNewConnectionFromSocket(t *testing.T) { }, { name: "valid socket with nil config", - socket: NewMockSocket(), + socket: honeybeetest.NewMockSocket(), config: nil, }, { name: "valid socket with valid config", - socket: NewMockSocket(), + socket: honeybeetest.NewMockSocket(), config: &ConnectionConfig{WriteTimeout: 30 * time.Second}, }, { name: "invalid config", - socket: NewMockSocket(), + socket: honeybeetest.NewMockSocket(), config: &ConnectionConfig{ Retry: &RetryConfig{ InitialDelay: 10 * time.Second, @@ -162,7 +165,7 @@ func TestNewConnectionFromSocket(t *testing.T) { }, { name: "close handler set when provided", - socket: NewMockSocket(), + socket: honeybeetest.NewMockSocket(), config: &ConnectionConfig{ CloseHandler: func(code int, text string) error { return nil @@ -176,7 +179,7 @@ func TestNewConnectionFromSocket(t *testing.T) { // track if SetCloseHandler was called closeHandlerSet := false if tc.socket != nil { - mockSocket := tc.socket.(*MockSocket) + mockSocket := tc.socket.(*honeybeetest.MockSocket) originalSetCloseHandler := mockSocket.SetCloseHandlerFunc // wrapper around the original handler function @@ -234,7 +237,7 @@ func TestConnect(t *testing.T) { conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) - conn.socket = NewMockSocket() + conn.socket = honeybeetest.NewMockSocket() err = conn.Connect() assert.Error(t, err) @@ -258,16 +261,16 @@ func TestConnect(t *testing.T) { conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) - outgoingData := make(chan mockOutgoingData, 10) + outgoingData := make(chan honeybeetest.MockOutgoingData, 10) - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { - outgoingData <- mockOutgoingData{msgType: msgType, data: data} + outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data} return nil } - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } @@ -283,11 +286,11 @@ func TestConnect(t *testing.T) { assert.Eventually(t, func() bool { select { case msg := <-outgoingData: - return bytes.Equal(msg.data, testData) + return bytes.Equal(msg.Data, testData) default: return false } - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) conn.Close() }) @@ -305,13 +308,13 @@ func TestConnect(t *testing.T) { assert.NoError(t, err) attemptCount := 0 - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { attemptCount++ if attemptCount < 3 { return nil, nil, fmt.Errorf("dial failed") } - return NewMockSocket(), nil, nil + return honeybeetest.NewMockSocket(), nil, nil }, } conn.dialer = mockDialer @@ -336,8 +339,8 @@ func TestConnect(t *testing.T) { conn, err := NewConnection("ws://test", config, nil) assert.NoError(t, err) - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return nil, nil, fmt.Errorf("dial failed") }, } @@ -355,10 +358,10 @@ func TestConnect(t *testing.T) { assert.Equal(t, StateDisconnected, conn.State()) stateDuringDial := StateDisconnected - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { stateDuringDial = conn.state - return NewMockSocket(), nil, nil + return honeybeetest.NewMockSocket(), nil, nil }, } conn.dialer = mockDialer @@ -381,13 +384,13 @@ func TestConnect(t *testing.T) { conn, err := NewConnection("ws://test", config, nil) assert.NoError(t, err) - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) { handlerSet = true } - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } @@ -431,4 +434,53 @@ func TestConnectionErrors(t *testing.T) { assert.Equal(t, testErr, received) } -// Connect() tests +// Test helpers + +func setupTestConnection(t *testing.T, config *ConnectionConfig) ( + conn *Connection, + mockSocket *honeybeetest.MockSocket, + incomingData chan honeybeetest.MockIncomingData, + outgoingData chan honeybeetest.MockOutgoingData, +) { + t.Helper() + + incomingData = make(chan honeybeetest.MockIncomingData, 10) + outgoingData = make(chan honeybeetest.MockOutgoingData, 10) + + mockSocket = honeybeetest.NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.Once.Do(func() { + close(mockSocket.Closed) + }) + return nil + } + + // Wire ReadMessage to pull from incomingData channel + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + select { + case data := <-incomingData: + return data.MsgType, data.Data, data.Err + case <-mockSocket.Closed: + return 0, nil, io.EOF + } + } + + // Wire WriteMessage to push to outgoingData channel + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + select { + case outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}: + return nil + case <-mockSocket.Closed: + return io.EOF + default: + return fmt.Errorf("mock outgoing chanel unavailable") + } + } + + var err error + conn, err = NewConnectionFromSocket(mockSocket, config, nil) + assert.NoError(t, err) + + return conn, mockSocket, incomingData, outgoingData +} diff --git a/errors/errors.go b/transport/errors.go similarity index 73% rename from errors/errors.go rename to transport/errors.go index d768a9f..bcc01dc 100644 --- a/errors/errors.go +++ b/transport/errors.go @@ -1,4 +1,4 @@ -package errors +package transport import "errors" import "fmt" @@ -8,13 +8,11 @@ var ( InvalidProtocol = errors.New("URL must use ws:// or wss:// scheme") // Configuration Errors - InvalidIdleTimeout = errors.New("idle timeout cannot be negative") InvalidWriteTimeout = errors.New("write timeout cannot be negative") InvalidRetryMaxRetries = errors.New("max retry count cannot be negative") InvalidRetryInitialDelay = errors.New("initial delay must be positive") InvalidRetryMaxDelay = errors.New("max delay must be positive") InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0") - InvalidMaxQueueSize = errors.New("maximum queue size cannot be negative") ) func NewConfigError(text string) error { @@ -24,7 +22,3 @@ func NewConfigError(text string) error { func NewConnectionError(text string) error { return fmt.Errorf("connection error: %s", text) } - -func NewPoolError(text string) error { - return fmt.Errorf("pool error: %s", text) -} diff --git a/logging_test.go b/transport/logging_test.go similarity index 84% rename from logging_test.go rename to transport/logging_test.go index 905a787..ee42381 100644 --- a/logging_test.go +++ b/transport/logging_test.go @@ -1,4 +1,4 @@ -package honeybee +package transport import ( "fmt" @@ -9,6 +9,8 @@ import ( "testing" "time" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/types" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) @@ -136,15 +138,15 @@ func toInt64(v any) (int64, bool) { func TestConnectLogging(t *testing.T) { t.Run("success", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) conn, err := NewConnection("ws://test", nil, logger) assert.NoError(t, err) - mockSocket := NewMockSocket() - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockSocket := honeybeetest.NewMockSocket() + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } @@ -167,7 +169,7 @@ func TestConnectLogging(t *testing.T) { }) t.Run("max retries failure", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) config := &ConnectionConfig{ @@ -183,8 +185,8 @@ func TestConnectLogging(t *testing.T) { assert.NoError(t, err) dialErr := fmt.Errorf("dial error") - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return nil, nil, dialErr }, } @@ -210,7 +212,7 @@ func TestConnectLogging(t *testing.T) { }) t.Run("success after retry", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) config := &ConnectionConfig{ @@ -227,13 +229,13 @@ func TestConnectLogging(t *testing.T) { attemptCount := 0 dialErr := fmt.Errorf("dial error") - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { attemptCount++ if attemptCount < 3 { return nil, nil, dialErr } - return NewMockSocket(), nil, nil + return honeybeetest.NewMockSocket(), nil, nil }, } conn.dialer = mockDialer @@ -261,10 +263,10 @@ func TestConnectLogging(t *testing.T) { func TestCloseLogging(t *testing.T) { t.Run("normal close", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() conn, err := NewConnectionFromSocket(mockSocket, nil, logger) assert.NoError(t, err) @@ -273,7 +275,7 @@ func TestCloseLogging(t *testing.T) { assert.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) records := mockHandler.GetRecords() @@ -286,11 +288,11 @@ func TestCloseLogging(t *testing.T) { }) t.Run("close with socket error", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) closeErr := fmt.Errorf("close error") - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.CloseFunc = func() error { return closeErr } @@ -303,7 +305,7 @@ func TestCloseLogging(t *testing.T) { assert.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "socket close failed") != nil - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) records := mockHandler.GetRecords() @@ -318,10 +320,10 @@ func TestCloseLogging(t *testing.T) { func TestReaderLogging(t *testing.T) { t.Run("clean close by peer", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.ReadMessageFunc = func() (int, []byte, error) { return 0, nil, &websocket.CloseError{ Code: websocket.CloseNormalClosure, @@ -336,7 +338,7 @@ func TestReaderLogging(t *testing.T) { assert.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") != nil - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) record := findLogRecord(mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") assert.NotNil(t, record) @@ -346,10 +348,10 @@ func TestReaderLogging(t *testing.T) { }) t.Run("unexpected close", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.ReadMessageFunc = func() (int, []byte, error) { return 0, nil, &websocket.CloseError{ Code: websocket.CloseProtocolError, @@ -364,7 +366,7 @@ func TestReaderLogging(t *testing.T) { assert.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "unexpected close") != nil - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) record := findLogRecord(mockHandler.GetRecords(), slog.LevelError, "unexpected close") assert.NotNil(t, record) @@ -374,10 +376,10 @@ func TestReaderLogging(t *testing.T) { }) t.Run("read error", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.ReadMessageFunc = func() (int, []byte, error) { return 0, nil, io.EOF } @@ -389,19 +391,19 @@ func TestReaderLogging(t *testing.T) { assert.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "read error") != nil - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) } func TestWriterLogging(t *testing.T) { t.Run("write deadline error", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond} deadlineErr := fmt.Errorf("deadline error") - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.SetWriteDeadlineFunc = func(time.Time) error { return deadlineErr } @@ -415,7 +417,7 @@ func TestWriterLogging(t *testing.T) { assert.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "write deadline error") != nil - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) records := mockHandler.GetRecords() @@ -427,11 +429,11 @@ func TestWriterLogging(t *testing.T) { }) t.Run("write message error", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() logger := slog.New(mockHandler) writeErr := fmt.Errorf("write error") - mockSocket := NewMockSocket() + mockSocket := honeybeetest.NewMockSocket() mockSocket.WriteMessageFunc = func(int, []byte) error { return writeErr } @@ -445,7 +447,7 @@ func TestWriterLogging(t *testing.T) { assert.Eventually(t, func() bool { return findLogRecord( mockHandler.GetRecords(), slog.LevelError, "write error") != nil - }, testTimeout, testTick) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) records := mockHandler.GetRecords() @@ -459,14 +461,14 @@ func TestWriterLogging(t *testing.T) { func TestLoggingDisabled(t *testing.T) { t.Run("nil logger produces no logs", func(t *testing.T) { - mockHandler := newMockSlogHandler() + mockHandler := honeybeetest.NewMockSlogHandler() conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) - mockSocket := NewMockSocket() - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockSocket := honeybeetest.NewMockSocket() + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { return mockSocket, nil, nil }, } diff --git a/retry.go b/transport/retry.go similarity index 98% rename from retry.go rename to transport/retry.go index 5c8995f..16063ba 100644 --- a/retry.go +++ b/transport/retry.go @@ -1,4 +1,4 @@ -package honeybee +package transport import ( "math" diff --git a/retry_test.go b/transport/retry_test.go similarity index 99% rename from retry_test.go rename to transport/retry_test.go index f7d90ae..e594e94 100644 --- a/retry_test.go +++ b/transport/retry_test.go @@ -1,4 +1,4 @@ -package honeybee +package transport import ( "github.com/stretchr/testify/assert" diff --git a/socket.go b/transport/socket.go similarity index 65% rename from socket.go rename to transport/socket.go index 1f6538c..42ffbe9 100644 --- a/socket.go +++ b/transport/socket.go @@ -1,19 +1,15 @@ -package honeybee +package transport import ( "log/slog" "net/http" "time" - "git.wisehodl.dev/jay/go-honeybee/errors" + "git.wisehodl.dev/jay/go-honeybee/types" "github.com/gorilla/websocket" ) -type Dialer interface { - Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error) -} - -func NewDialer() Dialer { +func NewDialer() types.Dialer { return NewGorillaDialer() } @@ -35,36 +31,26 @@ func NewGorillaDialer() *GorillaDialer { func (d *GorillaDialer) Dial( urlStr string, requestHeader http.Header, ) ( - Socket, *http.Response, error, + types.Socket, *http.Response, error, ) { conn, resp, err := d.Dialer.Dial(urlStr, requestHeader) return conn, resp, err } -type Socket interface { - WriteMessage(messageType int, data []byte) error - ReadMessage() (messageType int, p []byte, err error) - Close() error - - SetReadDeadline(t time.Time) error - SetWriteDeadline(t time.Time) error - SetCloseHandler(h func(code int, text string) error) -} - func AcquireSocket( retryMgr *RetryManager, - dialer Dialer, + dialer types.Dialer, urlStr string, logger *slog.Logger, -) (Socket, *http.Response, error) { +) (types.Socket, *http.Response, error) { if retryMgr == nil { - return nil, nil, errors.NewConnectionError("retry manager cannot be nil") + return nil, nil, NewConnectionError("retry manager cannot be nil") } if dialer == nil { - return nil, nil, errors.NewConnectionError("dialer cannot be nil") + return nil, nil, NewConnectionError("dialer cannot be nil") } if urlStr == "" { - return nil, nil, errors.NewConnectionError("URL cannot be empty") + return nil, nil, NewConnectionError("URL cannot be empty") } for { diff --git a/socket_test.go b/transport/socket_test.go similarity index 86% rename from socket_test.go rename to transport/socket_test.go index 5fae4e8..21ba921 100644 --- a/socket_test.go +++ b/transport/socket_test.go @@ -1,7 +1,9 @@ -package honeybee +package transport import ( "errors" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/types" "github.com/stretchr/testify/assert" "net/http" "testing" @@ -60,14 +62,14 @@ func TestAcquireSocket(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { attemptIndex := 0 - mockDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + mockDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { err := tc.mockRuns[attemptIndex] attemptIndex++ if err != nil { return nil, nil, err } - return NewMockSocket(), nil, nil + return honeybeetest.NewMockSocket(), nil, nil }, } @@ -93,9 +95,9 @@ func TestAcquireSocket(t *testing.T) { } func TestAcquireSocketGuards(t *testing.T) { - validDialer := &MockDialer{ - DialFunc: func(string, http.Header) (Socket, *http.Response, error) { - return NewMockSocket(), nil, nil + validDialer := &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + return honeybeetest.NewMockSocket(), nil, nil }, } validRetryMgr := NewRetryManager(GetDefaultRetryConfig()) @@ -103,7 +105,7 @@ func TestAcquireSocketGuards(t *testing.T) { cases := []struct { name string retryMgr *RetryManager - dialer Dialer + dialer types.Dialer url string wantErr string }{ diff --git a/url.go b/transport/url.go similarity index 86% rename from url.go rename to transport/url.go index 05147d4..17b0b54 100644 --- a/url.go +++ b/transport/url.go @@ -1,10 +1,8 @@ -package honeybee +package transport import ( "net/url" "strings" - - "git.wisehodl.dev/jay/go-honeybee/errors" ) func ParseURL(urlStr string) (*url.URL, error) { @@ -14,7 +12,7 @@ func ParseURL(urlStr string) (*url.URL, error) { } if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" { - return nil, errors.InvalidProtocol + return nil, InvalidProtocol } return parsedURL, nil diff --git a/url_test.go b/transport/url_test.go similarity index 93% rename from url_test.go rename to transport/url_test.go index 933d327..39364a4 100644 --- a/url_test.go +++ b/transport/url_test.go @@ -1,7 +1,6 @@ -package honeybee +package transport import ( - "git.wisehodl.dev/jay/go-honeybee/errors" "github.com/stretchr/testify/assert" "testing" ) @@ -41,17 +40,17 @@ func TestParseURL(t *testing.T) { { name: "http scheme rejected", url: "http://example.com", - wantErr: errors.InvalidProtocol, + wantErr: InvalidProtocol, }, { name: "missing scheme", url: "example.com:8080", - wantErr: errors.InvalidProtocol, + wantErr: InvalidProtocol, }, { name: "empty string", url: "", - wantErr: errors.InvalidProtocol, + wantErr: InvalidProtocol, }, { name: "malformed url", @@ -161,5 +160,5 @@ func TestNormalizeURL(t *testing.T) { func TestNormalizeURLError(t *testing.T) { _, err := NormalizeURL("http://relay.example.com") - assert.ErrorIs(t, err, errors.InvalidProtocol) + assert.ErrorIs(t, err, InvalidProtocol) } diff --git a/types/types.go b/types/types.go new file mode 100644 index 0000000..f1e6fc4 --- /dev/null +++ b/types/types.go @@ -0,0 +1,20 @@ +package types + +import ( + "net/http" + "time" +) + +type Dialer interface { + Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error) +} + +type Socket interface { + WriteMessage(messageType int, data []byte) error + ReadMessage() (messageType int, p []byte, err error) + Close() error + + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error + SetCloseHandler(h func(code int, text string) error) +} diff --git a/worker_test.go b/worker_test.go deleted file mode 100644 index 53eb7a8..0000000 --- a/worker_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package honeybee - -import ( -// "github.com/stretchr/testify/assert" -// "testing" -// "time" -)