Refactored package structure.

This commit is contained in:
Jay
2026-04-17 14:53:29 -04:00
parent c14d04f7b3
commit 3af3696d86
29 changed files with 1210 additions and 1259 deletions

422
config.go
View File

@@ -1,422 +0,0 @@
package honeybee
import (
"git.wisehodl.dev/jay/go-honeybee/errors"
"time"
)
// Types
type CloseHandler func(code int, text string) error
type WorkerFactory func(
id string,
conn *Connection,
onReconnect func() (*Connection, error),
) Worker
// Initiator Pool Config
type InitiatorPoolConfig struct {
ConnectionConfig *ConnectionConfig
WorkerFactory WorkerFactory
WorkerConfig *InitiatorWorkerConfig
}
type InitiatorPoolOption func(*InitiatorPoolConfig) error
func NewInitiatorPoolConfig(options ...InitiatorPoolOption) (*InitiatorPoolConfig, error) {
conf := GetDefaultInitiatorPoolConfig()
if err := applyInitiatorPoolOptions(conf, options...); err != nil {
return nil, err
}
if err := validateInitiatorPoolConfig(conf); err != nil {
return nil, err
}
return conf, nil
}
func GetDefaultInitiatorPoolConfig() *InitiatorPoolConfig {
return &InitiatorPoolConfig{
ConnectionConfig: nil,
WorkerFactory: nil,
WorkerConfig: nil,
}
}
func applyInitiatorPoolOptions(config *InitiatorPoolConfig, options ...InitiatorPoolOption) error {
for _, option := range options {
if err := option(config); err != nil {
return err
}
}
return nil
}
func validateInitiatorPoolConfig(config *InitiatorPoolConfig) error {
var err error
if config.ConnectionConfig != nil {
err = validateConnectionConfig(config.ConnectionConfig)
if err != nil {
return err
}
}
if config.WorkerConfig != nil {
err = validateInitiatorWorkerConfig(config.WorkerConfig)
if err != nil {
return err
}
}
return nil
}
func WithInitiatorConnectionConfig(cc *ConnectionConfig) InitiatorPoolOption {
return func(c *InitiatorPoolConfig) error {
err := validateConnectionConfig(cc)
if err != nil {
return err
}
c.ConnectionConfig = cc
return nil
}
}
func WithInitiatorWorkerConfig(wc *InitiatorWorkerConfig) InitiatorPoolOption {
return func(c *InitiatorPoolConfig) error {
err := validateInitiatorWorkerConfig(wc)
if err != nil {
return err
}
c.WorkerConfig = wc
return nil
}
}
func WithInitiatorWorkerFactory(wf WorkerFactory) InitiatorPoolOption {
return func(c *InitiatorPoolConfig) error {
c.WorkerFactory = wf
return nil
}
}
// Responder Pool Config
type ResponderPoolConfig struct {
ConnectionConfig *ConnectionConfig
WorkerFactory WorkerFactory
WorkerConfig *ResponderWorkerConfig
}
// Connection Config
type ConnectionConfig struct {
CloseHandler CloseHandler
WriteTimeout time.Duration
Retry *RetryConfig
}
type RetryConfig struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
JitterFactor float64
}
type ConnectionOption func(*ConnectionConfig) error
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
conf := GetDefaultConnectionConfig()
if err := applyConnectionOptions(conf, options...); err != nil {
return nil, err
}
if err := validateConnectionConfig(conf); err != nil {
return nil, err
}
return conf, nil
}
func GetDefaultConnectionConfig() *ConnectionConfig {
return &ConnectionConfig{
CloseHandler: nil,
WriteTimeout: 30 * time.Second,
Retry: GetDefaultRetryConfig(),
}
}
func GetDefaultRetryConfig() *RetryConfig {
return &RetryConfig{
MaxRetries: 0, // Infinite retries
InitialDelay: 1 * time.Second,
MaxDelay: 5 * time.Second,
JitterFactor: 0.5,
}
}
func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error {
for _, option := range options {
if err := option(config); err != nil {
return err
}
}
return nil
}
func validateConnectionConfig(config *ConnectionConfig) error {
err := validateWriteTimeout(config.WriteTimeout)
if err != nil {
return err
}
if config.Retry != nil {
err = validateMaxRetries(config.Retry.MaxRetries)
if err != nil {
return err
}
err = validateInitialDelay(config.Retry.InitialDelay)
if err != nil {
return err
}
err = validateMaxDelay(config.Retry.MaxDelay)
if err != nil {
return err
}
err = validateJitterFactor(config.Retry.JitterFactor)
if err != nil {
return err
}
if config.Retry.InitialDelay > config.Retry.MaxDelay {
return errors.NewConfigError("initial delay may not exceed maximum delay")
}
}
return nil
}
func validateWriteTimeout(value time.Duration) error {
if value < 0 {
return errors.InvalidWriteTimeout
}
return nil
}
func validateMaxRetries(value int) error {
if value < 0 {
return errors.InvalidRetryMaxRetries
}
return nil
}
func validateInitialDelay(value time.Duration) error {
if value <= 0 {
return errors.InvalidRetryInitialDelay
}
return nil
}
func validateMaxDelay(value time.Duration) error {
if value <= 0 {
return errors.InvalidRetryMaxDelay
}
return nil
}
func validateJitterFactor(value float64) error {
if value < 0.0 || value > 1.0 {
return errors.InvalidRetryJitterFactor
}
return nil
}
func WithCloseHandler(handler CloseHandler) ConnectionOption {
return func(c *ConnectionConfig) error {
c.CloseHandler = handler
return nil
}
}
// When WriteTimeout is set to zero, read timeouts are disabled.
func WithWriteTimeout(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
err := validateWriteTimeout(value)
if err != nil {
return err
}
c.WriteTimeout = value
return nil
}
}
// WithRetry enables retry with default parameters (infinite retries,
// 1s initial delay, 5s max delay, 0.5 jitter factor).
//
// If passed after granular retry options (WithRetryMaxRetries, etc.),
// it will overwrite them. Use either WithRetry alone or the granular
// options; not both.
func WithRetry() ConnectionOption {
return func(c *ConnectionConfig) error {
c.Retry = GetDefaultRetryConfig()
return nil
}
}
func WithRetryMaxRetries(value int) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateMaxRetries(value)
if err != nil {
return err
}
c.Retry.MaxRetries = value
return nil
}
}
func WithRetryInitialDelay(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateInitialDelay(value)
if err != nil {
return err
}
c.Retry.InitialDelay = value
return nil
}
}
func WithRetryMaxDelay(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateMaxDelay(value)
if err != nil {
return err
}
c.Retry.MaxDelay = value
return nil
}
}
func WithRetryJitterFactor(value float64) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateJitterFactor(value)
if err != nil {
return err
}
c.Retry.JitterFactor = value
return nil
}
}
// Initiator Worker Config
type InitiatorWorkerConfig struct {
IdleTimeout time.Duration
MaxQueueSize int
}
type InitiatorWorkerOption func(*InitiatorWorkerConfig) error
func NewInitiatorWorkerConfig(options ...InitiatorWorkerOption) (*InitiatorWorkerConfig, error) {
conf := GetDefaultInitiatorWorkerConfig()
if err := applyInitiatorWorkerOptions(conf, options...); err != nil {
return nil, err
}
if err := validateInitiatorWorkerConfig(conf); err != nil {
return nil, err
}
return conf, nil
}
func GetDefaultInitiatorWorkerConfig() *InitiatorWorkerConfig {
return &InitiatorWorkerConfig{
IdleTimeout: 20 * time.Second,
MaxQueueSize: 0, // disabled by default
}
}
func applyInitiatorWorkerOptions(config *InitiatorWorkerConfig, options ...InitiatorWorkerOption) error {
for _, option := range options {
if err := option(config); err != nil {
return err
}
}
return nil
}
func validateInitiatorWorkerConfig(config *InitiatorWorkerConfig) error {
err := validateIdleTimeout(config.IdleTimeout)
if err != nil {
return err
}
err = validateMaxQueueSize(config.MaxQueueSize)
if err != nil {
return err
}
return nil
}
func validateMaxQueueSize(value int) error {
if value < 0 {
return errors.InvalidMaxQueueSize
}
return nil
}
func validateIdleTimeout(value time.Duration) error {
if value < 0 {
return errors.InvalidIdleTimeout
}
return nil
}
// When IdleTimeout is set to zero, idle timeouts are disabled.
func WithIdleTimeout(value time.Duration) InitiatorWorkerOption {
return func(c *InitiatorWorkerConfig) error {
err := validateIdleTimeout(value)
if err != nil {
return err
}
c.IdleTimeout = value
return nil
}
}
// When MaxQueueSize is set to zero, queue limits are disabled.
func WithMaxQueueSize(value int) InitiatorWorkerOption {
return func(c *InitiatorWorkerConfig) error {
err := validateMaxQueueSize(value)
if err != nil {
return err
}
c.MaxQueueSize = value
return nil
}
}
// Responder Worker Config
type ResponderWorkerConfig struct{}

64
honeybeetest/helpers.go Normal file
View File

@@ -0,0 +1,64 @@
package honeybeetest
import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
// Constants
const (
TestTimeout = 2 * time.Second
TestTick = 10 * time.Millisecond
NegativeTestTimeout = 100 * time.Millisecond
)
// Types
type MockIncomingData struct {
MsgType int
Data []byte
Err error
}
type MockOutgoingData struct {
MsgType int
Data []byte
}
// Helpers
func ExpectIncoming(t *testing.T, incoming <-chan []byte, expected []byte) {
t.Helper()
assert.Eventually(t, func() bool {
select {
case received := <-incoming:
return bytes.Equal(received, expected)
default:
return false
}
}, TestTimeout, TestTick)
}
func ExpectWrite(t *testing.T, outgoingData chan MockOutgoingData, msgType int, expected []byte) {
t.Helper()
var call MockOutgoingData
found := assert.Eventually(t, func() bool {
select {
case received := <-outgoingData:
call = received
return true
default:
return false
}
}, TestTimeout, TestTick)
if found {
assert.Equal(t, msgType, call.MsgType)
assert.Equal(t, expected, call.Data)
}
}

119
honeybeetest/mocks.go Normal file
View File

@@ -0,0 +1,119 @@
package honeybeetest
import (
"context"
"git.wisehodl.dev/jay/go-honeybee/types"
"log/slog"
"net/http"
"sync"
"time"
)
// Dialer Mocks
type MockDialer struct {
DialFunc func(string, http.Header) (types.Socket, *http.Response, error)
}
func (m *MockDialer) Dial(url string, h http.Header) (types.Socket, *http.Response, error) {
return m.DialFunc(url, h)
}
// Socket Mocks
type MockSocket struct {
WriteMessageFunc func(int, []byte) error
SetReadDeadlineFunc func(t time.Time) error
SetWriteDeadlineFunc func(t time.Time) error
ReadMessageFunc func() (int, []byte, error)
CloseFunc func() error
SetCloseHandlerFunc func(func(int, string) error)
Closed chan struct{}
Once sync.Once
Mu sync.Mutex
}
func NewMockSocket() *MockSocket {
return &MockSocket{
WriteMessageFunc: func(int, []byte) error { return nil },
ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil },
CloseFunc: func() error { return nil },
SetReadDeadlineFunc: func(time.Time) error { return nil },
SetWriteDeadlineFunc: func(time.Time) error { return nil },
SetCloseHandlerFunc: func(func(int, string) error) {},
Closed: make(chan struct{}),
}
}
func (m *MockSocket) WriteMessage(t int, d []byte) error {
return m.WriteMessageFunc(t, d)
}
func (m *MockSocket) ReadMessage() (int, []byte, error) {
return m.ReadMessageFunc()
}
func (m *MockSocket) Close() error {
return m.CloseFunc()
}
func (m *MockSocket) SetReadDeadline(t time.Time) error {
return m.SetReadDeadlineFunc(t)
}
func (m *MockSocket) SetWriteDeadline(t time.Time) error {
return m.SetWriteDeadlineFunc(t)
}
func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) {
m.SetCloseHandlerFunc(h)
}
// Logging mocks
type MockSlogHandler struct {
records []slog.Record
mu sync.RWMutex
}
func NewMockSlogHandler() *MockSlogHandler {
return &MockSlogHandler{
records: make([]slog.Record, 0),
}
}
func (m *MockSlogHandler) Handle(ctx context.Context, record slog.Record) error {
m.mu.Lock()
defer m.mu.Unlock()
m.records = append(m.records, record)
return nil
}
func (m *MockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool {
return true
}
func (m *MockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return m
}
func (m *MockSlogHandler) WithGroup(name string) slog.Handler {
return m
}
func (m *MockSlogHandler) GetRecords() []slog.Record {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]slog.Record, len(m.records))
copy(result, m.records)
return result
}
func (m *MockSlogHandler) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.records = make([]slog.Record, 0)
}

189
initiator/config.go Normal file
View File

@@ -0,0 +1,189 @@
package initiator
import (
"git.wisehodl.dev/jay/go-honeybee/transport"
"time"
)
// Types
type WorkerFactory func(
id string,
conn *transport.Connection,
onReconnect func() (*transport.Connection, error),
) Worker
// Pool Config
type PoolConfig struct {
ConnectionConfig *transport.ConnectionConfig
WorkerFactory WorkerFactory
WorkerConfig *WorkerConfig
}
type PoolOption func(*PoolConfig) error
func NewPoolConfig(options ...PoolOption) (*PoolConfig, error) {
conf := GetDefaultPoolConfig()
if err := applyPoolOptions(conf, options...); err != nil {
return nil, err
}
if err := ValidatePoolConfig(conf); err != nil {
return nil, err
}
return conf, nil
}
func GetDefaultPoolConfig() *PoolConfig {
return &PoolConfig{
ConnectionConfig: nil,
WorkerFactory: nil,
WorkerConfig: nil,
}
}
func applyPoolOptions(config *PoolConfig, options ...PoolOption) error {
for _, option := range options {
if err := option(config); err != nil {
return err
}
}
return nil
}
func ValidatePoolConfig(config *PoolConfig) error {
var err error
if config.ConnectionConfig != nil {
err = transport.ValidateConnectionConfig(config.ConnectionConfig)
if err != nil {
return err
}
}
if config.WorkerConfig != nil {
err = ValidateWorkerConfig(config.WorkerConfig)
if err != nil {
return err
}
}
return nil
}
func WithConnectionConfig(cc *transport.ConnectionConfig) PoolOption {
return func(c *PoolConfig) error {
err := transport.ValidateConnectionConfig(cc)
if err != nil {
return err
}
c.ConnectionConfig = cc
return nil
}
}
func WithWorkerConfig(wc *WorkerConfig) PoolOption {
return func(c *PoolConfig) error {
err := ValidateWorkerConfig(wc)
if err != nil {
return err
}
c.WorkerConfig = wc
return nil
}
}
func WithWorkerFactory(wf WorkerFactory) PoolOption {
return func(c *PoolConfig) error {
c.WorkerFactory = wf
return nil
}
}
// Worker Config
type WorkerConfig struct {
IdleTimeout time.Duration
MaxQueueSize int
}
type WorkerOption func(*WorkerConfig) error
func NewWorkerConfig(options ...WorkerOption) (*WorkerConfig, error) {
conf := GetDefaultWorkerConfig()
if err := applyWorkerOptions(conf, options...); err != nil {
return nil, err
}
if err := ValidateWorkerConfig(conf); err != nil {
return nil, err
}
return conf, nil
}
func GetDefaultWorkerConfig() *WorkerConfig {
return &WorkerConfig{
IdleTimeout: 20 * time.Second,
MaxQueueSize: 0, // disabled by default
}
}
func applyWorkerOptions(config *WorkerConfig, options ...WorkerOption) error {
for _, option := range options {
if err := option(config); err != nil {
return err
}
}
return nil
}
func ValidateWorkerConfig(config *WorkerConfig) error {
err := validateIdleTimeout(config.IdleTimeout)
if err != nil {
return err
}
err = validateMaxQueueSize(config.MaxQueueSize)
if err != nil {
return err
}
return nil
}
func validateMaxQueueSize(value int) error {
if value < 0 {
return InvalidMaxQueueSize
}
return nil
}
func validateIdleTimeout(value time.Duration) error {
if value < 0 {
return InvalidIdleTimeout
}
return nil
}
// When IdleTimeout is set to zero, idle timeouts are disabled.
func WithIdleTimeout(value time.Duration) WorkerOption {
return func(c *WorkerConfig) error {
err := validateIdleTimeout(value)
if err != nil {
return err
}
c.IdleTimeout = value
return nil
}
}
// When MaxQueueSize is set to zero, queue limits are disabled.
func WithMaxQueueSize(value int) WorkerOption {
return func(c *WorkerConfig) error {
err := validateMaxQueueSize(value)
if err != nil {
return err
}
c.MaxQueueSize = value
return nil
}
}

View File

@@ -1,16 +1,17 @@
package honeybee
package initiator
import (
"git.wisehodl.dev/jay/go-honeybee/transport"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestNewPoolConfig(t *testing.T) {
conf, err := NewInitiatorPoolConfig()
conf, err := NewPoolConfig()
assert.NoError(t, err)
assert.Equal(t, conf, &InitiatorPoolConfig{
assert.Equal(t, conf, &PoolConfig{
ConnectionConfig: nil,
WorkerConfig: nil,
WorkerFactory: nil,
@@ -18,9 +19,9 @@ func TestNewPoolConfig(t *testing.T) {
}
func TestDefaultPoolConfig(t *testing.T) {
conf := GetDefaultInitiatorPoolConfig()
conf := GetDefaultPoolConfig()
assert.Equal(t, conf, &InitiatorPoolConfig{
assert.Equal(t, conf, &PoolConfig{
ConnectionConfig: nil,
WorkerConfig: nil,
WorkerFactory: nil,
@@ -28,10 +29,10 @@ func TestDefaultPoolConfig(t *testing.T) {
}
func TestApplyPoolOptions(t *testing.T) {
conf := &InitiatorPoolConfig{}
err := applyInitiatorPoolOptions(
conf := &PoolConfig{}
err := applyPoolOptions(
conf,
WithInitiatorConnectionConfig(&ConnectionConfig{}),
WithConnectionConfig(&transport.ConnectionConfig{}),
)
assert.NoError(t, err)
@@ -39,46 +40,46 @@ func TestApplyPoolOptions(t *testing.T) {
}
func TestWithConnectionConfig(t *testing.T) {
conf := &InitiatorPoolConfig{}
opt := WithInitiatorConnectionConfig(&ConnectionConfig{WriteTimeout: 1 * time.Second})
err := applyInitiatorPoolOptions(conf, opt)
conf := &PoolConfig{}
opt := WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: 1 * time.Second})
err := applyPoolOptions(conf, opt)
assert.NoError(t, err)
assert.NotNil(t, conf.ConnectionConfig)
assert.Equal(t, 1*time.Second, conf.ConnectionConfig.WriteTimeout)
// invalid config is rejected
conf = &InitiatorPoolConfig{}
opt = WithInitiatorConnectionConfig(&ConnectionConfig{WriteTimeout: -1 * time.Second})
err = applyInitiatorPoolOptions(conf, opt)
conf = &PoolConfig{}
opt = WithConnectionConfig(&transport.ConnectionConfig{WriteTimeout: -1 * time.Second})
err = applyPoolOptions(conf, opt)
assert.Error(t, err)
}
func TestValidatePoolConfig(t *testing.T) {
cases := []struct {
name string
conf InitiatorPoolConfig
conf PoolConfig
wantErr error
wantErrText string
}{
{
name: "valid empty",
conf: *&InitiatorPoolConfig{},
conf: *&PoolConfig{},
},
{
name: "valid defaults",
conf: *GetDefaultInitiatorPoolConfig(),
conf: *GetDefaultPoolConfig(),
},
{
name: "valid complete",
conf: InitiatorPoolConfig{
ConnectionConfig: &ConnectionConfig{},
conf: PoolConfig{
ConnectionConfig: &transport.ConnectionConfig{},
},
},
{
name: "invalid connection config",
conf: InitiatorPoolConfig{
ConnectionConfig: &ConnectionConfig{
Retry: &RetryConfig{
conf: PoolConfig{
ConnectionConfig: &transport.ConnectionConfig{
Retry: &transport.RetryConfig{
InitialDelay: 10 * time.Second,
MaxDelay: 1 * time.Second,
},
@@ -90,7 +91,7 @@ func TestValidatePoolConfig(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := validateInitiatorPoolConfig(&tc.conf)
err := ValidatePoolConfig(&tc.conf)
if tc.wantErr != nil || tc.wantErrText != "" {
if tc.wantErr != nil {

17
initiator/errors.go Normal file
View File

@@ -0,0 +1,17 @@
package initiator
import "errors"
import "fmt"
var (
InvalidIdleTimeout = errors.New("idle timeout cannot be negative")
InvalidMaxQueueSize = errors.New("maximum queue size cannot be negative")
)
func NewConfigError(text string) error {
return fmt.Errorf("configuration error: %s", text)
}
func NewPoolError(text string) error {
return fmt.Errorf("pool error: %s", text)
}

248
initiator/pool.go Normal file
View File

@@ -0,0 +1,248 @@
package initiator
import (
"git.wisehodl.dev/jay/go-honeybee/transport"
"git.wisehodl.dev/jay/go-honeybee/types"
"log/slog"
"sync"
"time"
)
// Types
type peer struct {
conn *transport.Connection
stop chan struct{}
}
type InboxMessage struct {
ID string
Data []byte
ReceivedAt time.Time
}
type PoolEventKind int
const (
EventConnected PoolEventKind = iota
EventDisconnected
)
func (s PoolEventKind) String() string {
switch s {
case EventConnected:
return "connected"
case EventDisconnected:
return "disconnected"
default:
return "unknown"
}
}
type PoolEvent struct {
ID string
Kind PoolEventKind
}
// Pool
type Pool struct {
peers map[string]*peer
inbox chan InboxMessage
events chan PoolEvent
errors chan error
done chan struct{}
dialer types.Dialer
config *PoolConfig
logger *slog.Logger
mu sync.RWMutex
wg sync.WaitGroup
closed bool
}
func NewPool(config *PoolConfig, logger *slog.Logger) (*Pool, error) {
if config == nil {
config = GetDefaultPoolConfig()
}
if err := ValidatePoolConfig(config); err != nil {
return nil, err
}
p := &Pool{
peers: make(map[string]*peer),
inbox: make(chan InboxMessage, 256),
events: make(chan PoolEvent, 10),
errors: make(chan error, 10),
done: make(chan struct{}),
dialer: transport.NewDialer(),
config: config,
logger: logger,
}
return p, nil
}
func (p *Pool) Peers() map[string]*peer {
return p.peers
}
func (p *Pool) Inbox() chan InboxMessage {
return p.inbox
}
func (p *Pool) Events() chan PoolEvent {
return p.events
}
func (p *Pool) Errors() chan error {
return p.errors
}
func (p *Pool) Close() {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return
}
p.closed = true
close(p.done)
peers := p.peers
p.peers = make(map[string]*peer)
p.mu.Unlock()
for _, conn := range peers {
conn.conn.Close()
}
go func() {
p.wg.Wait()
close(p.inbox)
close(p.events)
close(p.errors)
}()
}
func (p *Pool) Connect(id string) error {
id, err := transport.NormalizeURL(id)
if err != nil {
return err
}
// Check for existing connection in pool
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return NewPoolError("pool is closed")
}
_, exists := p.peers[id]
p.mu.Unlock()
if exists {
return NewPoolError("connection already exists")
}
// Create new connection
var logger *slog.Logger
if p.logger != nil {
logger = p.logger.With("id", id)
}
conn, err := transport.NewConnection(id, p.config.ConnectionConfig, logger)
if err != nil {
return err
}
conn.SetDialer(p.dialer)
// Attempt to connect
if err := conn.Connect(); err != nil {
return err
}
p.mu.Lock()
if p.closed {
// The pool closed while this connection was established.
p.mu.Unlock()
conn.Close()
return NewPoolError("pool is closed")
}
// Add connection to pool
stop := make(chan struct{})
if _, exists := p.peers[id]; exists {
// Another process connected to this url while this one was connecting
// Discard this connection and retain the existing one
p.mu.Unlock()
conn.Close()
return NewPoolError("connection already exists")
}
p.peers[id] = &peer{conn: conn, stop: stop}
p.mu.Unlock()
// TODO: start this connection's incoming message forwarder
select {
case p.events <- PoolEvent{ID: id, Kind: EventConnected}:
case <-p.done:
return nil
}
return nil
}
func (p *Pool) Remove(id string) error {
id, err := transport.NormalizeURL(id)
if err != nil {
return err
}
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return NewPoolError("pool is closed")
}
peer, exists := p.peers[id]
if !exists {
p.mu.Unlock()
return NewPoolError("connection not found")
}
delete(p.peers, id)
p.mu.Unlock()
close(peer.stop)
peer.conn.Close()
select {
case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}:
case <-p.done:
return nil
}
return nil
}
func (p *Pool) Send(id string, data []byte) error {
id, err := transport.NormalizeURL(id)
if err != nil {
return err
}
p.mu.RLock()
defer p.mu.RUnlock()
if p.closed {
return NewPoolError("pool is closed")
}
peer, exists := p.peers[id]
if !exists {
return NewPoolError("connection not found")
}
return peer.conn.Send(data)
}

View File

@@ -1,7 +1,10 @@
package honeybee
package initiator
import (
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/transport"
"git.wisehodl.dev/jay/go-honeybee/types"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"net/http"
@@ -11,14 +14,14 @@ import (
func TestPoolConnect(t *testing.T) {
t.Run("successfully adds connection", func(t *testing.T) {
mockSocket := NewMockSocket()
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
pool, err := NewInitiatorPool(nil, nil)
pool, err := NewPool(nil, nil)
assert.NoError(t, err)
pool.dialer = mockDialer
@@ -33,7 +36,7 @@ func TestPoolConnect(t *testing.T) {
default:
return false
}
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
_, exists := pool.peers["wss://test"]
assert.True(t, exists)
@@ -42,14 +45,14 @@ func TestPoolConnect(t *testing.T) {
})
t.Run("does not add duplicate", func(t *testing.T) {
mockSocket := NewMockSocket()
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
pool, err := NewInitiatorPool(nil, nil)
pool, err := NewPool(nil, nil)
assert.NoError(t, err)
pool.dialer = mockDialer
@@ -69,18 +72,18 @@ func TestPoolConnect(t *testing.T) {
})
t.Run("fails to add connection", func(t *testing.T) {
pool, err := NewInitiatorPool(
&InitiatorPoolConfig{
ConnectionConfig: &ConnectionConfig{
Retry: &RetryConfig{
pool, err := NewPool(
&PoolConfig{
ConnectionConfig: &transport.ConnectionConfig{
Retry: &transport.RetryConfig{
MaxRetries: 1,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 5 * time.Millisecond,
}},
}, nil)
assert.NoError(t, err)
pool.dialer = &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
pool.dialer = &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, fmt.Errorf("dial failed")
},
}
@@ -104,14 +107,14 @@ func TestPoolConnect(t *testing.T) {
func TestPoolRemove(t *testing.T) {
t.Run("removes known url", func(t *testing.T) {
mockSocket := NewMockSocket()
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
pool, err := NewInitiatorPool(nil, nil)
pool, err := NewPool(nil, nil)
assert.NoError(t, err)
pool.dialer = mockDialer
@@ -132,14 +135,14 @@ func TestPoolRemove(t *testing.T) {
})
t.Run("unknown url returns error", func(t *testing.T) {
mockSocket := NewMockSocket()
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
pool, err := NewInitiatorPool(nil, nil)
pool, err := NewPool(nil, nil)
assert.NoError(t, err)
pool.dialer = mockDialer
@@ -149,14 +152,14 @@ func TestPoolRemove(t *testing.T) {
})
t.Run("closed pool returns error", func(t *testing.T) {
mockSocket := NewMockSocket()
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
pool, err := NewInitiatorPool(nil, nil)
pool, err := NewPool(nil, nil)
assert.NoError(t, err)
pool.dialer = mockDialer
@@ -171,19 +174,19 @@ func TestPoolRemove(t *testing.T) {
}
func TestPoolSend(t *testing.T) {
mockSocket := NewMockSocket()
outgoingData := make(chan mockOutgoingData, 10)
mockSocket := honeybeetest.NewMockSocket()
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
outgoingData <- mockOutgoingData{msgType: msgType, data: data}
outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}
return nil
}
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
pool, err := NewInitiatorPool(nil, nil)
pool, err := NewPool(nil, nil)
assert.NoError(t, err)
pool.dialer = mockDialer
@@ -194,7 +197,7 @@ func TestPoolSend(t *testing.T) {
err = pool.Send("wss://test", []byte("hello"))
assert.NoError(t, err)
expectWrite(t, outgoingData, websocket.TextMessage, []byte("hello"))
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, []byte("hello"))
pool.Close()
}
@@ -213,7 +216,7 @@ func expectEvent(
default:
return false
}
}, testTimeout, testTick,
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
fmt.Sprintf("expected event: URL=%q, Kind=%q",
expectedURL, expectedKind.String()))
}

View File

@@ -1,6 +1,7 @@
package honeybee
package initiator
import (
"git.wisehodl.dev/jay/go-honeybee/transport"
"log/slog"
"sync"
"time"
@@ -8,15 +9,6 @@ import (
// Types
// Worker Implementation
type Worker interface {
Start(
ctx *WorkerContext,
wg *sync.WaitGroup,
)
}
type WorkerContext struct {
Inbox chan<- InboxMessage
Events chan<- PoolEvent
@@ -26,40 +18,23 @@ type WorkerContext struct {
Logger *slog.Logger
}
// Base Struct
// Worker
type worker struct {
type Worker struct {
id string
config *WorkerConfig
onReconnect func() (*transport.Connection, error)
}
func (w *worker) runForwarder(
messages <-chan []byte,
inbox chan<- []byte,
stop <-chan struct{},
poolDone <-chan struct{},
maxQueueSize int,
) {
}
// Initiator Worker
type InitiatorWorker struct {
*worker
config *InitiatorWorkerConfig
onReconnect func() (*Connection, error)
}
func newInitiatorWorker(
func NewWorker(
id string,
config *InitiatorWorkerConfig,
onReconnect func() (*Connection, error),
config *WorkerConfig,
onReconnect func() (*transport.Connection, error),
logger *slog.Logger,
) (*InitiatorWorker, error) {
w := &InitiatorWorker{
worker: &worker{
) (*Worker, error) {
w := &Worker{
id: id,
},
config: config,
onReconnect: onReconnect,
}
@@ -67,7 +42,7 @@ func newInitiatorWorker(
return w, nil
}
func (w *InitiatorWorker) Start(
func (w *Worker) Start(
inbox chan<- InboxMessage,
events chan<- PoolEvent,
stop <-chan struct{},
@@ -76,32 +51,37 @@ func (w *InitiatorWorker) Start(
) {
}
func runReader(conn *Connection,
func (w *Worker) runReader(conn *transport.Connection,
messages chan<- []byte,
heartbeat chan<- time.Time,
reconnect chan<- struct{},
newConn <-chan *Connection,
newConn <-chan *transport.Connection,
stop <-chan struct{},
poolDone <-chan struct{},
) {
}
func runHealthMonitor(
func (w *Worker) runForwarder(
messages <-chan []byte,
inbox chan<- []byte,
stop <-chan struct{},
poolDone <-chan struct{},
maxQueueSize int,
) {
}
func (w *Worker) runHealthMonitor(
heartbeat <-chan time.Time,
stop <-chan struct{},
poolDone <-chan struct{},
) {
}
func runReconnector(
func (w *Worker) runReconnector(
reconnect <-chan struct{},
newConn chan<- *Connection,
newConn chan<- *transport.Connection,
stop <-chan struct{},
poolDone <-chan struct{},
) {
}
// Responder Worker
type ResponderWorker struct{}

1
initiator/worker_test.go Normal file
View File

@@ -0,0 +1 @@
package initiator

View File

@@ -1,192 +0,0 @@
package honeybee
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"io"
"log/slog"
"net/http"
"sync"
"testing"
"time"
)
// Test Constants
const (
testTimeout = 2 * time.Second
testTick = 10 * time.Millisecond
negativeTestTimeout = 100 * time.Millisecond
)
// Dialer Mocks
type MockDialer struct {
DialFunc func(string, http.Header) (Socket, *http.Response, error)
}
func (m *MockDialer) Dial(url string, h http.Header) (Socket, *http.Response, error) {
return m.DialFunc(url, h)
}
// Socket Mocks
type MockSocket struct {
WriteMessageFunc func(int, []byte) error
SetReadDeadlineFunc func(t time.Time) error
SetWriteDeadlineFunc func(t time.Time) error
ReadMessageFunc func() (int, []byte, error)
CloseFunc func() error
SetCloseHandlerFunc func(func(int, string) error)
closed chan struct{}
once sync.Once
mu sync.Mutex
}
func NewMockSocket() *MockSocket {
return &MockSocket{
WriteMessageFunc: func(int, []byte) error { return nil },
ReadMessageFunc: func() (int, []byte, error) { return 0, []byte("message"), nil },
CloseFunc: func() error { return nil },
SetReadDeadlineFunc: func(time.Time) error { return nil },
SetWriteDeadlineFunc: func(time.Time) error { return nil },
SetCloseHandlerFunc: func(func(int, string) error) {},
closed: make(chan struct{}),
}
}
func (m *MockSocket) WriteMessage(t int, d []byte) error {
return m.WriteMessageFunc(t, d)
}
func (m *MockSocket) ReadMessage() (int, []byte, error) {
return m.ReadMessageFunc()
}
func (m *MockSocket) Close() error {
return m.CloseFunc()
}
func (m *MockSocket) SetReadDeadline(t time.Time) error {
return m.SetReadDeadlineFunc(t)
}
func (m *MockSocket) SetWriteDeadline(t time.Time) error {
return m.SetWriteDeadlineFunc(t)
}
func (m *MockSocket) SetCloseHandler(h func(code int, text string) error) {
m.SetCloseHandlerFunc(h)
}
// Connection Mocks
type mockIncomingData struct {
msgType int
data []byte
err error
}
type mockOutgoingData struct {
msgType int
data []byte
}
func setupTestConnection(t *testing.T, config *ConnectionConfig) (
conn *Connection,
mockSocket *MockSocket,
incomingData chan mockIncomingData,
outgoingData chan mockOutgoingData,
) {
t.Helper()
incomingData = make(chan mockIncomingData, 10)
outgoingData = make(chan mockOutgoingData, 10)
mockSocket = NewMockSocket()
mockSocket.CloseFunc = func() error {
mockSocket.once.Do(func() {
close(mockSocket.closed)
})
return nil
}
// Wire ReadMessage to pull from incomingData channel
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
select {
case data := <-incomingData:
return data.msgType, data.data, data.err
case <-mockSocket.closed:
return 0, nil, io.EOF
}
}
// Wire WriteMessage to push to outgoingData channel
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
return nil
case <-mockSocket.closed:
return io.EOF
default:
return fmt.Errorf("mock outgoing chanel unavailable")
}
}
var err error
conn, err = NewConnectionFromSocket(mockSocket, config, nil)
assert.NoError(t, err)
return conn, mockSocket, incomingData, outgoingData
}
// Logging mocks
type mockSlogHandler struct {
records []slog.Record
mu sync.RWMutex
}
func newMockSlogHandler() *mockSlogHandler {
return &mockSlogHandler{
records: make([]slog.Record, 0),
}
}
func (m *mockSlogHandler) Handle(ctx context.Context, record slog.Record) error {
m.mu.Lock()
defer m.mu.Unlock()
m.records = append(m.records, record)
return nil
}
func (m *mockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool {
return true
}
func (m *mockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return m
}
func (m *mockSlogHandler) WithGroup(name string) slog.Handler {
return m
}
func (m *mockSlogHandler) GetRecords() []slog.Record {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]slog.Record, len(m.records))
copy(result, m.records)
return result
}
func (m *mockSlogHandler) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.records = make([]slog.Record, 0)
}

304
pool.go
View File

@@ -1,304 +0,0 @@
package honeybee
import (
"git.wisehodl.dev/jay/go-honeybee/errors"
"log/slog"
"sync"
"time"
)
// Types
type peer struct {
conn *Connection
stop chan struct{}
}
type InboxMessage struct {
ID string
Data []byte
ReceivedAt time.Time
}
type PoolEventKind int
const (
EventConnected PoolEventKind = iota
EventDisconnected
)
func (s PoolEventKind) String() string {
switch s {
case EventConnected:
return "connected"
case EventDisconnected:
return "disconnected"
default:
return "unknown"
}
}
type PoolEvent struct {
ID string
Kind PoolEventKind
}
// Pool Implementation
type Pool interface {
Send(id string, data []byte) error
Inbox() <-chan InboxMessage
Events() <-chan PoolEvent
Errors() <-chan error
Close()
}
// Base Struct
type pool struct {
peers map[string]*peer
inbox chan InboxMessage
events chan PoolEvent
errors chan error
done chan struct{}
config *InitiatorPoolConfig
logger *slog.Logger
mu sync.RWMutex
wg sync.WaitGroup
closed bool
}
func (p *pool) closeAll() {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return
}
p.closed = true
close(p.done)
peers := p.peers
p.peers = make(map[string]*peer)
p.mu.Unlock()
for _, conn := range peers {
conn.conn.Close()
}
go func() {
p.wg.Wait()
close(p.inbox)
close(p.events)
close(p.errors)
}()
}
func (p *pool) removePeer(id string) error {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return errors.NewPoolError("pool is closed")
}
peer, exists := p.peers[id]
if !exists {
p.mu.Unlock()
return errors.NewPoolError("connection not found")
}
delete(p.peers, id)
p.mu.Unlock()
close(peer.stop)
peer.conn.Close()
select {
case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}:
case <-p.done:
return nil
}
return nil
}
func (p *pool) send(id string, data []byte) error {
p.mu.RLock()
defer p.mu.RUnlock()
if p.closed {
return errors.NewPoolError("pool is closed")
}
peer, exists := p.peers[id]
if !exists {
return errors.NewPoolError("connection not found")
}
return peer.conn.Send(data)
}
// Initiator Pool
type InitiatorPool struct {
*pool
dialer Dialer
}
func NewInitiatorPool(config *InitiatorPoolConfig, logger *slog.Logger) (*InitiatorPool, error) {
if config == nil {
config = GetDefaultInitiatorPoolConfig()
}
if err := validateInitiatorPoolConfig(config); err != nil {
return nil, err
}
p := &InitiatorPool{
pool: &pool{
peers: make(map[string]*peer),
inbox: make(chan InboxMessage, 256),
events: make(chan PoolEvent, 10),
errors: make(chan error, 10),
done: make(chan struct{}),
config: config,
logger: logger,
},
dialer: NewDialer(),
}
return p, nil
}
func (p *InitiatorPool) Peers() map[string]*peer {
return p.peers
}
func (p *InitiatorPool) Inbox() chan InboxMessage {
return p.inbox
}
func (p *InitiatorPool) Events() chan PoolEvent {
return p.events
}
func (p *InitiatorPool) Errors() chan error {
return p.errors
}
func (p *InitiatorPool) Close() {
p.closeAll()
}
func (p *InitiatorPool) Connect(url string) error {
url, err := NormalizeURL(url)
if err != nil {
return err
}
// Check for existing connection in pool
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return errors.NewPoolError("pool is closed")
}
_, exists := p.peers[url]
p.mu.Unlock()
if exists {
return errors.NewPoolError("connection already exists")
}
// Create new connection
var logger *slog.Logger
if p.logger != nil {
logger = p.logger.With("url", url)
}
conn, err := NewConnection(url, p.config.ConnectionConfig, logger)
if err != nil {
return err
}
conn.dialer = p.dialer
// Attempt to connect
if err := conn.Connect(); err != nil {
return err
}
p.mu.Lock()
if p.closed {
// The pool closed while this connection was established.
p.mu.Unlock()
conn.Close()
return errors.NewPoolError("pool is closed")
}
// Add connection to pool
stop := make(chan struct{})
if _, exists := p.peers[url]; exists {
// Another process connected to this url while this one was connecting
// Discard this connection and retain the existing one
p.mu.Unlock()
conn.Close()
return errors.NewPoolError("connection already exists")
}
p.peers[url] = &peer{conn: conn, stop: stop}
p.mu.Unlock()
// TODO: start this connection's incoming message forwarder
select {
case p.events <- PoolEvent{ID: url, Kind: EventConnected}:
case <-p.done:
return nil
}
return nil
}
func (p *InitiatorPool) Remove(url string) error {
url, err := NormalizeURL(url)
if err != nil {
return err
}
return p.removePeer(url)
}
func (p *InitiatorPool) Send(url string, data []byte) error {
url, err := NormalizeURL(url)
if err != nil {
return err
}
return p.send(url, data)
}
// Responder Pool
type ResponderPool struct {
*pool
idGenerator func() string
}
func (p *ResponderPool) Peers() map[string]*peer {
return p.peers
}
func (p *ResponderPool) Inbox() chan InboxMessage {
return p.inbox
}
func (p *ResponderPool) Events() chan PoolEvent {
return p.events
}
func (p *ResponderPool) Errors() chan error {
return p.errors
}
func (p *ResponderPool) Close() {
p.closeAll()
}

225
transport/config.go Normal file
View File

@@ -0,0 +1,225 @@
package transport
import (
"time"
)
type CloseHandler func(code int, text string) error
type ConnectionConfig struct {
CloseHandler CloseHandler
WriteTimeout time.Duration
Retry *RetryConfig
}
type RetryConfig struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
JitterFactor float64
}
type ConnectionOption func(*ConnectionConfig) error
func NewConnectionConfig(options ...ConnectionOption) (*ConnectionConfig, error) {
conf := GetDefaultConnectionConfig()
if err := applyConnectionOptions(conf, options...); err != nil {
return nil, err
}
if err := ValidateConnectionConfig(conf); err != nil {
return nil, err
}
return conf, nil
}
func GetDefaultConnectionConfig() *ConnectionConfig {
return &ConnectionConfig{
CloseHandler: nil,
WriteTimeout: 30 * time.Second,
Retry: GetDefaultRetryConfig(),
}
}
func GetDefaultRetryConfig() *RetryConfig {
return &RetryConfig{
MaxRetries: 0, // Infinite retries
InitialDelay: 1 * time.Second,
MaxDelay: 5 * time.Second,
JitterFactor: 0.5,
}
}
func applyConnectionOptions(config *ConnectionConfig, options ...ConnectionOption) error {
for _, option := range options {
if err := option(config); err != nil {
return err
}
}
return nil
}
func ValidateConnectionConfig(config *ConnectionConfig) error {
err := validateWriteTimeout(config.WriteTimeout)
if err != nil {
return err
}
if config.Retry != nil {
err = validateMaxRetries(config.Retry.MaxRetries)
if err != nil {
return err
}
err = validateInitialDelay(config.Retry.InitialDelay)
if err != nil {
return err
}
err = validateMaxDelay(config.Retry.MaxDelay)
if err != nil {
return err
}
err = validateJitterFactor(config.Retry.JitterFactor)
if err != nil {
return err
}
if config.Retry.InitialDelay > config.Retry.MaxDelay {
return NewConfigError("initial delay may not exceed maximum delay")
}
}
return nil
}
func validateWriteTimeout(value time.Duration) error {
if value < 0 {
return InvalidWriteTimeout
}
return nil
}
func validateMaxRetries(value int) error {
if value < 0 {
return InvalidRetryMaxRetries
}
return nil
}
func validateInitialDelay(value time.Duration) error {
if value <= 0 {
return InvalidRetryInitialDelay
}
return nil
}
func validateMaxDelay(value time.Duration) error {
if value <= 0 {
return InvalidRetryMaxDelay
}
return nil
}
func validateJitterFactor(value float64) error {
if value < 0.0 || value > 1.0 {
return InvalidRetryJitterFactor
}
return nil
}
func WithCloseHandler(handler CloseHandler) ConnectionOption {
return func(c *ConnectionConfig) error {
c.CloseHandler = handler
return nil
}
}
// When WriteTimeout is set to zero, read timeouts are disabled.
func WithWriteTimeout(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
err := validateWriteTimeout(value)
if err != nil {
return err
}
c.WriteTimeout = value
return nil
}
}
// WithRetry enables retry with default parameters (infinite retries,
// 1s initial delay, 5s max delay, 0.5 jitter factor).
//
// If passed after granular retry options (WithRetryMaxRetries, etc.),
// it will overwrite them. Use either WithRetry alone or the granular
// options; not both.
func WithRetry() ConnectionOption {
return func(c *ConnectionConfig) error {
c.Retry = GetDefaultRetryConfig()
return nil
}
}
func WithRetryMaxRetries(value int) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateMaxRetries(value)
if err != nil {
return err
}
c.Retry.MaxRetries = value
return nil
}
}
func WithRetryInitialDelay(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateInitialDelay(value)
if err != nil {
return err
}
c.Retry.InitialDelay = value
return nil
}
}
func WithRetryMaxDelay(value time.Duration) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateMaxDelay(value)
if err != nil {
return err
}
c.Retry.MaxDelay = value
return nil
}
}
func WithRetryJitterFactor(value float64) ConnectionOption {
return func(c *ConnectionConfig) error {
if c.Retry == nil {
c.Retry = GetDefaultRetryConfig()
}
err := validateJitterFactor(value)
if err != nil {
return err
}
c.Retry.JitterFactor = value
return nil
}
}

View File

@@ -1,7 +1,6 @@
package honeybee
package transport
import (
"git.wisehodl.dev/jay/go-honeybee/errors"
"github.com/stretchr/testify/assert"
"testing"
"time"
@@ -72,7 +71,7 @@ func TestApplyConnectionOptions(t *testing.T) {
WithRetryMaxRetries(-10),
)
assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries)
assert.ErrorIs(t, err, InvalidRetryMaxRetries)
}
// Option Tests
@@ -103,7 +102,7 @@ func TestWithWriteTimeout(t *testing.T) {
conf = &ConnectionConfig{}
opt = WithWriteTimeout(-30)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidWriteTimeout)
assert.ErrorIs(t, err, InvalidWriteTimeout)
assert.ErrorContains(t, err, "write timeout cannot be negative")
}
@@ -132,7 +131,7 @@ func TestWithRetry(t *testing.T) {
// negative disallowed
opt = WithRetryMaxRetries(-10)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidRetryMaxRetries)
assert.ErrorIs(t, err, InvalidRetryMaxRetries)
assert.ErrorContains(t, err, "max retry count cannot be negative")
})
@@ -146,13 +145,13 @@ func TestWithRetry(t *testing.T) {
// zero disallowed
opt = WithRetryInitialDelay(0 * time.Second)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay)
assert.ErrorIs(t, err, InvalidRetryInitialDelay)
assert.ErrorContains(t, err, "initial delay must be positive")
// negative disallowed
opt = WithRetryInitialDelay(-10 * time.Second)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidRetryInitialDelay)
assert.ErrorIs(t, err, InvalidRetryInitialDelay)
})
t.Run("with max delay", func(t *testing.T) {
@@ -165,13 +164,13 @@ func TestWithRetry(t *testing.T) {
// zero disallowed
opt = WithRetryMaxDelay(0 * time.Second)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay)
assert.ErrorIs(t, err, InvalidRetryMaxDelay)
assert.ErrorContains(t, err, "max delay must be positive")
// negative disallowed
opt = WithRetryMaxDelay(-10 * time.Second)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidRetryMaxDelay)
assert.ErrorIs(t, err, InvalidRetryMaxDelay)
})
t.Run("with jitter factor", func(t *testing.T) {
@@ -185,13 +184,13 @@ func TestWithRetry(t *testing.T) {
// negative disallowed
opt = WithRetryJitterFactor(-1)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor)
assert.ErrorIs(t, err, InvalidRetryJitterFactor)
assert.ErrorContains(t, err, "jitter factor must be between 0.0 and 1.0")
// >1 disallowed
opt = WithRetryJitterFactor(1.1)
err = applyConnectionOptions(conf, opt)
assert.ErrorIs(t, err, errors.InvalidRetryJitterFactor)
assert.ErrorIs(t, err, InvalidRetryJitterFactor)
})
}
@@ -239,7 +238,7 @@ func TestValidateConnectionConfig(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := validateConnectionConfig(&tc.conf)
err := ValidateConnectionConfig(&tc.conf)
if tc.wantErr != nil || tc.wantErrText != "" {
if tc.wantErr != nil {

View File

@@ -1,14 +1,14 @@
package honeybee
package transport
import (
stderrors "errors"
"errors"
"fmt"
"log/slog"
"net/url"
"sync"
"time"
"git.wisehodl.dev/jay/go-honeybee/errors"
"git.wisehodl.dev/jay/go-honeybee/types"
"github.com/gorilla/websocket"
)
@@ -38,8 +38,8 @@ func (s ConnectionState) String() string {
type Connection struct {
url *url.URL
dialer Dialer
socket Socket
dialer types.Dialer
socket types.Socket
config *ConnectionConfig
logger *slog.Logger
@@ -60,7 +60,7 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger)
config = GetDefaultConnectionConfig()
}
if err := validateConnectionConfig(config); err != nil {
if err := ValidateConnectionConfig(config); err != nil {
return nil, err
}
@@ -85,16 +85,16 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger)
return conn, nil
}
func NewConnectionFromSocket(socket Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
func NewConnectionFromSocket(socket types.Socket, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) {
if socket == nil {
return nil, errors.NewConnectionError("socket cannot be nil")
return nil, NewConnectionError("socket cannot be nil")
}
if config == nil {
config = GetDefaultConnectionConfig()
}
if err := validateConnectionConfig(config); err != nil {
if err := ValidateConnectionConfig(config); err != nil {
return nil, err
}
@@ -126,11 +126,11 @@ func (c *Connection) Connect() error {
defer c.mu.Unlock()
if c.socket != nil {
return errors.NewConnectionError("connection already has socket")
return NewConnectionError("connection already has socket")
}
if c.closed {
return errors.NewConnectionError("connection is closed")
return NewConnectionError("connection is closed")
}
if c.logger != nil {
@@ -177,7 +177,7 @@ func (c *Connection) startReader() {
if err != nil {
if c.logger != nil {
var closeErr *websocket.CloseError
if stderrors.As(err, &closeErr) {
if errors.As(err, &closeErr) {
switch closeErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway:
c.logger.Info("connection closed by peer",
@@ -263,16 +263,16 @@ func (c *Connection) Send(data []byte) error {
defer c.mu.RUnlock()
if c.closed {
return errors.NewConnectionError("connection closed")
return NewConnectionError("connection closed")
}
select {
case c.outgoing <- data:
return nil
case <-c.done:
return errors.NewConnectionError("connection closing")
return NewConnectionError("connection closing")
default:
return errors.NewConnectionError("outgoing queue full")
return NewConnectionError("outgoing queue full")
}
}
@@ -337,3 +337,7 @@ func (c *Connection) State() ConnectionState {
defer c.mu.RUnlock()
return c.state
}
func (c *Connection) SetDialer(d types.Dialer) {
c.dialer = d
}

View File

@@ -1,8 +1,9 @@
package honeybee
package transport
import (
"bytes"
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"testing"
@@ -38,7 +39,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
t.Run("socket close error does not propagate", func(t *testing.T) {
expectedErr := fmt.Errorf("socket close failed")
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.CloseFunc = func() error {
return expectedErr
}
@@ -64,7 +65,8 @@ func TestDisconnectedConnectionClose(t *testing.T) {
default:
return false
}
}, testTimeout, testTick, "errors channel should close")
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
"errors channel should close")
})
t.Run("send fails after close", func(t *testing.T) {
@@ -86,7 +88,8 @@ func TestConnectedConnectionClose(t *testing.T) {
// Send a message to ensure reader loop is blocking
canary := []byte("canary")
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: canary}
incomingData <- honeybeetest.MockIncomingData{
MsgType: websocket.TextMessage, Data: canary}
assert.Eventually(t, func() bool {
select {
@@ -95,7 +98,7 @@ func TestConnectedConnectionClose(t *testing.T) {
default:
return false
}
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
conn.Close()
assert.Equal(t, StateClosed, conn.State())
@@ -119,9 +122,9 @@ func TestConnectedConnectionClose(t *testing.T) {
conn, _, incomingData, _ := setupTestConnection(t, nil)
for i := 0; i < 10; i++ {
incomingData <- mockIncomingData{
msgType: websocket.TextMessage,
data: []byte(fmt.Sprintf("in-%d", i)),
incomingData <- honeybeetest.MockIncomingData{
MsgType: websocket.TextMessage,
Data: []byte(fmt.Sprintf("in-%d", i)),
}
conn.Send([]byte(fmt.Sprintf("out-%d", i)))
}

View File

@@ -1,8 +1,8 @@
package honeybee
package transport
import (
"bytes"
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"io"
@@ -17,13 +17,13 @@ func TestStartReader(t *testing.T) {
defer conn.Close()
testData := []byte("hello")
incomingData <- mockIncomingData{
msgType: websocket.TextMessage,
data: testData,
err: nil,
incomingData <- honeybeetest.MockIncomingData{
MsgType: websocket.TextMessage,
Data: testData,
Err: nil,
}
expectIncoming(t, conn, testData)
honeybeetest.ExpectIncoming(t, conn.Incoming(), testData)
})
t.Run("binary messages route to incoming channel", func(t *testing.T) {
@@ -31,13 +31,13 @@ func TestStartReader(t *testing.T) {
defer conn.Close()
testData := []byte{0x00, 0x01, 0x02}
incomingData <- mockIncomingData{
msgType: websocket.BinaryMessage,
data: testData,
err: nil,
incomingData <- honeybeetest.MockIncomingData{
MsgType: websocket.BinaryMessage,
Data: testData,
Err: nil,
}
expectIncoming(t, conn, testData)
honeybeetest.ExpectIncoming(t, conn.Incoming(), testData)
})
t.Run("multiple messages processed sequentially", func(t *testing.T) {
@@ -46,20 +46,21 @@ func TestStartReader(t *testing.T) {
messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")}
for _, msg := range messages {
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: msg, err: nil}
incomingData <- honeybeetest.MockIncomingData{
MsgType: websocket.TextMessage, Data: msg, Err: nil}
}
for _, expected := range messages {
expectIncoming(t, conn, expected)
honeybeetest.ExpectIncoming(t, conn.Incoming(), expected)
}
})
t.Run("reader exits on socket read error", func(t *testing.T) {
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.CloseFunc = func() error {
mockSocket.once.Do(func() {
close(mockSocket.closed)
mockSocket.Once.Do(func() {
close(mockSocket.Closed)
})
return nil
}
@@ -80,11 +81,11 @@ func TestStartReader(t *testing.T) {
default:
return false
}
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.Eventually(t, func() bool {
return conn.State() == StateClosed
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
}
@@ -97,7 +98,7 @@ func TestStartWriter(t *testing.T) {
err := conn.Send(testData)
assert.NoError(t, err)
expectWrite(t, outgoingData, websocket.TextMessage, testData)
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, testData)
})
t.Run("multiple messages processed sequentially", func(t *testing.T) {
@@ -111,7 +112,7 @@ func TestStartWriter(t *testing.T) {
}
for _, expected := range messages {
expectWrite(t, outgoingData, websocket.TextMessage, expected)
honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, expected)
}
})
@@ -122,12 +123,12 @@ func TestStartWriter(t *testing.T) {
config := &ConnectionConfig{WriteTimeout: 0}
outgoingData := make(chan mockOutgoingData, 10)
mockSocket := NewMockSocket()
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket := honeybeetest.NewMockSocket()
mockSocket.CloseFunc = func() error {
mockSocket.once.Do(func() {
close(mockSocket.closed)
mockSocket.Once.Do(func() {
close(mockSocket.Closed)
})
return nil
}
@@ -140,8 +141,9 @@ func TestStartWriter(t *testing.T) {
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
case <-mockSocket.closed:
case outgoingData <- honeybeetest.MockOutgoingData{
MsgType: msgType, Data: data}:
case <-mockSocket.Closed:
return io.EOF
}
return nil
@@ -161,19 +163,19 @@ func TestStartWriter(t *testing.T) {
default:
return false
}
}, negativeTestTimeout, testTick,
}, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick,
"SetWriteDeadline should not be called when timeout is zero")
})
t.Run("write timeout sets deadline when positive", func(t *testing.T) {
config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond}
outgoingData := make(chan mockOutgoingData, 10)
mockSocket := NewMockSocket()
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket := honeybeetest.NewMockSocket()
mockSocket.CloseFunc = func() error {
mockSocket.once.Do(func() {
close(mockSocket.closed)
mockSocket.Once.Do(func() {
close(mockSocket.Closed)
})
return nil
}
@@ -186,8 +188,9 @@ func TestStartWriter(t *testing.T) {
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
select {
case outgoingData <- mockOutgoingData{msgType: msgType, data: data}:
case <-mockSocket.closed:
case outgoingData <- honeybeetest.MockOutgoingData{
MsgType: msgType, Data: data}:
case <-mockSocket.Closed:
return io.EOF
}
return nil
@@ -207,18 +210,18 @@ func TestStartWriter(t *testing.T) {
default:
return false
}
}, testTimeout, testTick,
}, honeybeetest.TestTimeout, honeybeetest.TestTick,
"SetWriteDeadline should be called when timeout is positive")
})
t.Run("writer exits on deadline error", func(t *testing.T) {
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond}
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.CloseFunc = func() error {
mockSocket.once.Do(func() {
close(mockSocket.closed)
mockSocket.Once.Do(func() {
close(mockSocket.Closed)
})
return nil
}
@@ -242,15 +245,15 @@ func TestStartWriter(t *testing.T) {
default:
return false
}
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.Eventually(t, func() bool {
return conn.State() == StateClosed
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
t.Run("writer exits on socket write error", func(t *testing.T) {
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
writeErr := fmt.Errorf("write failed")
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
@@ -271,45 +274,12 @@ func TestStartWriter(t *testing.T) {
default:
return false
}
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
assert.Eventually(t, func() bool {
return conn.State() == StateClosed
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
}
// Helpers
func expectIncoming(t *testing.T, conn *Connection, expected []byte) {
t.Helper()
assert.Eventually(t, func() bool {
select {
case received := <-conn.Incoming():
return bytes.Equal(received, expected)
default:
return false
}
}, testTimeout, testTick)
}
func expectWrite(t *testing.T, outgoingData chan mockOutgoingData, msgType int, expected []byte) {
t.Helper()
var call mockOutgoingData
found := assert.Eventually(t, func() bool {
select {
case received := <-outgoingData:
call = received
return true
default:
return false
}
}, testTimeout, testTick)
if found {
assert.Equal(t, msgType, call.msgType)
assert.Equal(t, expected, call.data)
}
}

View File

@@ -1,4 +1,4 @@
package honeybee
package transport
import (
"fmt"

View File

@@ -1,9 +1,12 @@
package honeybee
package transport
import (
"bytes"
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/types"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"testing"
"time"
@@ -36,7 +39,7 @@ func TestConnectionState(t *testing.T) {
assert.Equal(t, StateDisconnected, conn.State())
// Test state after FromSocket (should be Connected)
conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil, nil)
conn2, _ := NewConnectionFromSocket(honeybeetest.NewMockSocket(), nil, nil)
assert.Equal(t, StateConnected, conn2.State())
// Test state after close
@@ -126,7 +129,7 @@ func TestNewConnection(t *testing.T) {
func TestNewConnectionFromSocket(t *testing.T) {
cases := []struct {
name string
socket Socket
socket types.Socket
config *ConnectionConfig
wantErr bool
wantErrText string
@@ -140,17 +143,17 @@ func TestNewConnectionFromSocket(t *testing.T) {
},
{
name: "valid socket with nil config",
socket: NewMockSocket(),
socket: honeybeetest.NewMockSocket(),
config: nil,
},
{
name: "valid socket with valid config",
socket: NewMockSocket(),
socket: honeybeetest.NewMockSocket(),
config: &ConnectionConfig{WriteTimeout: 30 * time.Second},
},
{
name: "invalid config",
socket: NewMockSocket(),
socket: honeybeetest.NewMockSocket(),
config: &ConnectionConfig{
Retry: &RetryConfig{
InitialDelay: 10 * time.Second,
@@ -162,7 +165,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
},
{
name: "close handler set when provided",
socket: NewMockSocket(),
socket: honeybeetest.NewMockSocket(),
config: &ConnectionConfig{
CloseHandler: func(code int, text string) error {
return nil
@@ -176,7 +179,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
// track if SetCloseHandler was called
closeHandlerSet := false
if tc.socket != nil {
mockSocket := tc.socket.(*MockSocket)
mockSocket := tc.socket.(*honeybeetest.MockSocket)
originalSetCloseHandler := mockSocket.SetCloseHandlerFunc
// wrapper around the original handler function
@@ -234,7 +237,7 @@ func TestConnect(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err)
conn.socket = NewMockSocket()
conn.socket = honeybeetest.NewMockSocket()
err = conn.Connect()
assert.Error(t, err)
@@ -258,16 +261,16 @@ func TestConnect(t *testing.T) {
conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err)
outgoingData := make(chan mockOutgoingData, 10)
outgoingData := make(chan honeybeetest.MockOutgoingData, 10)
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
outgoingData <- mockOutgoingData{msgType: msgType, data: data}
outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}
return nil
}
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
@@ -283,11 +286,11 @@ func TestConnect(t *testing.T) {
assert.Eventually(t, func() bool {
select {
case msg := <-outgoingData:
return bytes.Equal(msg.data, testData)
return bytes.Equal(msg.Data, testData)
default:
return false
}
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
conn.Close()
})
@@ -305,13 +308,13 @@ func TestConnect(t *testing.T) {
assert.NoError(t, err)
attemptCount := 0
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
attemptCount++
if attemptCount < 3 {
return nil, nil, fmt.Errorf("dial failed")
}
return NewMockSocket(), nil, nil
return honeybeetest.NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
@@ -336,8 +339,8 @@ func TestConnect(t *testing.T) {
conn, err := NewConnection("ws://test", config, nil)
assert.NoError(t, err)
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, fmt.Errorf("dial failed")
},
}
@@ -355,10 +358,10 @@ func TestConnect(t *testing.T) {
assert.Equal(t, StateDisconnected, conn.State())
stateDuringDial := StateDisconnected
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
stateDuringDial = conn.state
return NewMockSocket(), nil, nil
return honeybeetest.NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
@@ -381,13 +384,13 @@ func TestConnect(t *testing.T) {
conn, err := NewConnection("ws://test", config, nil)
assert.NoError(t, err)
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.SetCloseHandlerFunc = func(h func(int, string) error) {
handlerSet = true
}
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
@@ -431,4 +434,53 @@ func TestConnectionErrors(t *testing.T) {
assert.Equal(t, testErr, received)
}
// Connect() tests
// Test helpers
func setupTestConnection(t *testing.T, config *ConnectionConfig) (
conn *Connection,
mockSocket *honeybeetest.MockSocket,
incomingData chan honeybeetest.MockIncomingData,
outgoingData chan honeybeetest.MockOutgoingData,
) {
t.Helper()
incomingData = make(chan honeybeetest.MockIncomingData, 10)
outgoingData = make(chan honeybeetest.MockOutgoingData, 10)
mockSocket = honeybeetest.NewMockSocket()
mockSocket.CloseFunc = func() error {
mockSocket.Once.Do(func() {
close(mockSocket.Closed)
})
return nil
}
// Wire ReadMessage to pull from incomingData channel
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
select {
case data := <-incomingData:
return data.MsgType, data.Data, data.Err
case <-mockSocket.Closed:
return 0, nil, io.EOF
}
}
// Wire WriteMessage to push to outgoingData channel
mockSocket.WriteMessageFunc = func(msgType int, data []byte) error {
select {
case outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}:
return nil
case <-mockSocket.Closed:
return io.EOF
default:
return fmt.Errorf("mock outgoing chanel unavailable")
}
}
var err error
conn, err = NewConnectionFromSocket(mockSocket, config, nil)
assert.NoError(t, err)
return conn, mockSocket, incomingData, outgoingData
}

View File

@@ -1,4 +1,4 @@
package errors
package transport
import "errors"
import "fmt"
@@ -8,13 +8,11 @@ var (
InvalidProtocol = errors.New("URL must use ws:// or wss:// scheme")
// Configuration Errors
InvalidIdleTimeout = errors.New("idle timeout cannot be negative")
InvalidWriteTimeout = errors.New("write timeout cannot be negative")
InvalidRetryMaxRetries = errors.New("max retry count cannot be negative")
InvalidRetryInitialDelay = errors.New("initial delay must be positive")
InvalidRetryMaxDelay = errors.New("max delay must be positive")
InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0")
InvalidMaxQueueSize = errors.New("maximum queue size cannot be negative")
)
func NewConfigError(text string) error {
@@ -24,7 +22,3 @@ func NewConfigError(text string) error {
func NewConnectionError(text string) error {
return fmt.Errorf("connection error: %s", text)
}
func NewPoolError(text string) error {
return fmt.Errorf("pool error: %s", text)
}

View File

@@ -1,4 +1,4 @@
package honeybee
package transport
import (
"fmt"
@@ -9,6 +9,8 @@ import (
"testing"
"time"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/types"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)
@@ -136,15 +138,15 @@ func toInt64(v any) (int64, bool) {
func TestConnectLogging(t *testing.T) {
t.Run("success", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
conn, err := NewConnection("ws://test", nil, logger)
assert.NoError(t, err)
mockSocket := NewMockSocket()
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}
@@ -167,7 +169,7 @@ func TestConnectLogging(t *testing.T) {
})
t.Run("max retries failure", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
config := &ConnectionConfig{
@@ -183,8 +185,8 @@ func TestConnectLogging(t *testing.T) {
assert.NoError(t, err)
dialErr := fmt.Errorf("dial error")
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return nil, nil, dialErr
},
}
@@ -210,7 +212,7 @@ func TestConnectLogging(t *testing.T) {
})
t.Run("success after retry", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
config := &ConnectionConfig{
@@ -227,13 +229,13 @@ func TestConnectLogging(t *testing.T) {
attemptCount := 0
dialErr := fmt.Errorf("dial error")
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
attemptCount++
if attemptCount < 3 {
return nil, nil, dialErr
}
return NewMockSocket(), nil, nil
return honeybeetest.NewMockSocket(), nil, nil
},
}
conn.dialer = mockDialer
@@ -261,10 +263,10 @@ func TestConnectLogging(t *testing.T) {
func TestCloseLogging(t *testing.T) {
t.Run("normal close", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
assert.NoError(t, err)
@@ -273,7 +275,7 @@ func TestCloseLogging(t *testing.T) {
assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelInfo, "closed") != nil
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
records := mockHandler.GetRecords()
@@ -286,11 +288,11 @@ func TestCloseLogging(t *testing.T) {
})
t.Run("close with socket error", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
closeErr := fmt.Errorf("close error")
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.CloseFunc = func() error {
return closeErr
}
@@ -303,7 +305,7 @@ func TestCloseLogging(t *testing.T) {
assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "socket close failed") != nil
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
records := mockHandler.GetRecords()
@@ -318,10 +320,10 @@ func TestCloseLogging(t *testing.T) {
func TestReaderLogging(t *testing.T) {
t.Run("clean close by peer", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
return 0, nil, &websocket.CloseError{
Code: websocket.CloseNormalClosure,
@@ -336,7 +338,7 @@ func TestReaderLogging(t *testing.T) {
assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer") != nil
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
record := findLogRecord(mockHandler.GetRecords(), slog.LevelInfo, "connection closed by peer")
assert.NotNil(t, record)
@@ -346,10 +348,10 @@ func TestReaderLogging(t *testing.T) {
})
t.Run("unexpected close", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
return 0, nil, &websocket.CloseError{
Code: websocket.CloseProtocolError,
@@ -364,7 +366,7 @@ func TestReaderLogging(t *testing.T) {
assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "unexpected close") != nil
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
record := findLogRecord(mockHandler.GetRecords(), slog.LevelError, "unexpected close")
assert.NotNil(t, record)
@@ -374,10 +376,10 @@ func TestReaderLogging(t *testing.T) {
})
t.Run("read error", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
return 0, nil, io.EOF
}
@@ -389,19 +391,19 @@ func TestReaderLogging(t *testing.T) {
assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "read error") != nil
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
}
func TestWriterLogging(t *testing.T) {
t.Run("write deadline error", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond}
deadlineErr := fmt.Errorf("deadline error")
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.SetWriteDeadlineFunc = func(time.Time) error {
return deadlineErr
}
@@ -415,7 +417,7 @@ func TestWriterLogging(t *testing.T) {
assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "write deadline error") != nil
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
records := mockHandler.GetRecords()
@@ -427,11 +429,11 @@ func TestWriterLogging(t *testing.T) {
})
t.Run("write message error", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
logger := slog.New(mockHandler)
writeErr := fmt.Errorf("write error")
mockSocket := NewMockSocket()
mockSocket := honeybeetest.NewMockSocket()
mockSocket.WriteMessageFunc = func(int, []byte) error {
return writeErr
}
@@ -445,7 +447,7 @@ func TestWriterLogging(t *testing.T) {
assert.Eventually(t, func() bool {
return findLogRecord(
mockHandler.GetRecords(), slog.LevelError, "write error") != nil
}, testTimeout, testTick)
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
records := mockHandler.GetRecords()
@@ -459,14 +461,14 @@ func TestWriterLogging(t *testing.T) {
func TestLoggingDisabled(t *testing.T) {
t.Run("nil logger produces no logs", func(t *testing.T) {
mockHandler := newMockSlogHandler()
mockHandler := honeybeetest.NewMockSlogHandler()
conn, err := NewConnection("ws://test", nil, nil)
assert.NoError(t, err)
mockSocket := NewMockSocket()
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockSocket := honeybeetest.NewMockSocket()
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
}

View File

@@ -1,4 +1,4 @@
package honeybee
package transport
import (
"math"

View File

@@ -1,4 +1,4 @@
package honeybee
package transport
import (
"github.com/stretchr/testify/assert"

View File

@@ -1,19 +1,15 @@
package honeybee
package transport
import (
"log/slog"
"net/http"
"time"
"git.wisehodl.dev/jay/go-honeybee/errors"
"git.wisehodl.dev/jay/go-honeybee/types"
"github.com/gorilla/websocket"
)
type Dialer interface {
Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error)
}
func NewDialer() Dialer {
func NewDialer() types.Dialer {
return NewGorillaDialer()
}
@@ -35,36 +31,26 @@ func NewGorillaDialer() *GorillaDialer {
func (d *GorillaDialer) Dial(
urlStr string, requestHeader http.Header,
) (
Socket, *http.Response, error,
types.Socket, *http.Response, error,
) {
conn, resp, err := d.Dialer.Dial(urlStr, requestHeader)
return conn, resp, err
}
type Socket interface {
WriteMessage(messageType int, data []byte) error
ReadMessage() (messageType int, p []byte, err error)
Close() error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
SetCloseHandler(h func(code int, text string) error)
}
func AcquireSocket(
retryMgr *RetryManager,
dialer Dialer,
dialer types.Dialer,
urlStr string,
logger *slog.Logger,
) (Socket, *http.Response, error) {
) (types.Socket, *http.Response, error) {
if retryMgr == nil {
return nil, nil, errors.NewConnectionError("retry manager cannot be nil")
return nil, nil, NewConnectionError("retry manager cannot be nil")
}
if dialer == nil {
return nil, nil, errors.NewConnectionError("dialer cannot be nil")
return nil, nil, NewConnectionError("dialer cannot be nil")
}
if urlStr == "" {
return nil, nil, errors.NewConnectionError("URL cannot be empty")
return nil, nil, NewConnectionError("URL cannot be empty")
}
for {

View File

@@ -1,7 +1,9 @@
package honeybee
package transport
import (
"errors"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/types"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
@@ -60,14 +62,14 @@ func TestAcquireSocket(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
attemptIndex := 0
mockDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
mockDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
err := tc.mockRuns[attemptIndex]
attemptIndex++
if err != nil {
return nil, nil, err
}
return NewMockSocket(), nil, nil
return honeybeetest.NewMockSocket(), nil, nil
},
}
@@ -93,9 +95,9 @@ func TestAcquireSocket(t *testing.T) {
}
func TestAcquireSocketGuards(t *testing.T) {
validDialer := &MockDialer{
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
return NewMockSocket(), nil, nil
validDialer := &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
return honeybeetest.NewMockSocket(), nil, nil
},
}
validRetryMgr := NewRetryManager(GetDefaultRetryConfig())
@@ -103,7 +105,7 @@ func TestAcquireSocketGuards(t *testing.T) {
cases := []struct {
name string
retryMgr *RetryManager
dialer Dialer
dialer types.Dialer
url string
wantErr string
}{

View File

@@ -1,10 +1,8 @@
package honeybee
package transport
import (
"net/url"
"strings"
"git.wisehodl.dev/jay/go-honeybee/errors"
)
func ParseURL(urlStr string) (*url.URL, error) {
@@ -14,7 +12,7 @@ func ParseURL(urlStr string) (*url.URL, error) {
}
if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
return nil, errors.InvalidProtocol
return nil, InvalidProtocol
}
return parsedURL, nil

View File

@@ -1,7 +1,6 @@
package honeybee
package transport
import (
"git.wisehodl.dev/jay/go-honeybee/errors"
"github.com/stretchr/testify/assert"
"testing"
)
@@ -41,17 +40,17 @@ func TestParseURL(t *testing.T) {
{
name: "http scheme rejected",
url: "http://example.com",
wantErr: errors.InvalidProtocol,
wantErr: InvalidProtocol,
},
{
name: "missing scheme",
url: "example.com:8080",
wantErr: errors.InvalidProtocol,
wantErr: InvalidProtocol,
},
{
name: "empty string",
url: "",
wantErr: errors.InvalidProtocol,
wantErr: InvalidProtocol,
},
{
name: "malformed url",
@@ -161,5 +160,5 @@ func TestNormalizeURL(t *testing.T) {
func TestNormalizeURLError(t *testing.T) {
_, err := NormalizeURL("http://relay.example.com")
assert.ErrorIs(t, err, errors.InvalidProtocol)
assert.ErrorIs(t, err, InvalidProtocol)
}

20
types/types.go Normal file
View File

@@ -0,0 +1,20 @@
package types
import (
"net/http"
"time"
)
type Dialer interface {
Dial(urlStr string, requestHeader http.Header) (Socket, *http.Response, error)
}
type Socket interface {
WriteMessage(messageType int, data []byte) error
ReadMessage() (messageType int, p []byte, err error)
Close() error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
SetCloseHandler(h func(code int, text string) error)
}

View File

@@ -1,7 +0,0 @@
package honeybee
import (
// "github.com/stretchr/testify/assert"
// "testing"
// "time"
)