diff --git a/honeybee.go b/honeybee.go new file mode 100644 index 0000000..294b5aa --- /dev/null +++ b/honeybee.go @@ -0,0 +1,85 @@ +package honeybee + +import ( + "context" + "log/slog" + + "git.wisehodl.dev/jay/go-honeybee/initiatorpool" + "git.wisehodl.dev/jay/go-honeybee/transport" +) + +// Connection types + +type Connection = transport.Connection +type ConnectionConfig = transport.ConnectionConfig +type RetryConfig = transport.RetryConfig +type ConnectionOption = transport.ConnectionOption + +// Initator Pool types + +type InitiatorPool = initiatorpool.Pool +type InitiatorPoolConfig = initiatorpool.PoolConfig +type InitiatorPoolOption = initiatorpool.PoolOption +type InitiatorWorkerConfig = initiatorpool.WorkerConfig +type InitiatorWorkerOption = initiatorpool.WorkerOption +type InitiatorInboxMessage = initiatorpool.InboxMessage +type InitiatorPoolEvent = initiatorpool.PoolEvent +type InitiatorPoolEventKind = initiatorpool.PoolEventKind + +// Pool event constants + +const ( + EventConnected = initiatorpool.EventConnected + EventDisconnected = initiatorpool.EventDisconnected +) + +// Connection constructors + +func NewConnection(url string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { + return transport.NewConnection(url, config, logger) +} + +func NewConnectionConfig(opts ...ConnectionOption) (*ConnectionConfig, error) { + return transport.NewConnectionConfig(opts...) +} + +// Connection options + +var ( + WithoutRetry = transport.WithoutRetry + WithRetryMaxRetries = transport.WithRetryMaxRetries + WithRetryInitialDelay = transport.WithRetryInitialDelay + WithRetryMaxDelay = transport.WithRetryMaxDelay + WithRetryJitterFactor = transport.WithRetryJitterFactor + WithWriteTimeout = transport.WithWriteTimeout + WithCloseHandler = transport.WithCloseHandler +) + +// Initiator Pool constructors + +func NewInitiatorPool(ctx context.Context, config *InitiatorPoolConfig, logger *slog.Logger) (*InitiatorPool, error) { + return initiatorpool.NewPool(ctx, config, logger) +} + +func NewInitiatorPoolConfig(opts ...InitiatorPoolOption) (*InitiatorPoolConfig, error) { + return initiatorpool.NewPoolConfig(opts...) +} + +func NewInitiatorWorkerConfig(opts ...InitiatorWorkerOption) (*InitiatorWorkerConfig, error) { + return initiatorpool.NewWorkerConfig(opts...) +} + +// Initiator Pool options + +var ( + WithConnectionConfig = initiatorpool.WithConnectionConfig + WithWorkerConfig = initiatorpool.WithWorkerConfig + WithWorkerFactory = initiatorpool.WithWorkerFactory +) + +// Initiator Worker options + +var ( + WithKeepaliveTimeout = initiatorpool.WithKeepaliveTimeout + WithMaxQueueSize = initiatorpool.WithMaxQueueSize +) diff --git a/initiatorpool/pool.go b/initiatorpool/pool.go index 0ca1232..cec20aa 100644 --- a/initiatorpool/pool.go +++ b/initiatorpool/pool.go @@ -100,8 +100,15 @@ func NewPool(ctx context.Context, config *PoolConfig, logger *slog.Logger, return p, nil } -func (p *Pool) Peers() map[string]*Peer { - return p.peers +func (p *Pool) Peers() []string { + p.mu.RLock() + defer p.mu.RUnlock() + + ids := make([]string, 0, len(p.peers)) + for i, _ := range p.peers { + ids = append(ids, i) + } + return ids } func (p *Pool) Inbox() chan InboxMessage { @@ -131,8 +138,9 @@ func (p *Pool) Close() { } p.closed = true - p.cancel() + p.cancel() // closes all workers + // remove all peers p.peers = make(map[string]*Peer) p.mu.Unlock() diff --git a/initiatorpool/pool_test.go b/initiatorpool/pool_test.go index f6b18da..206509c 100644 --- a/initiatorpool/pool_test.go +++ b/initiatorpool/pool_test.go @@ -4,31 +4,31 @@ import ( "context" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - "git.wisehodl.dev/jay/go-honeybee/transport" "git.wisehodl.dev/jay/go-honeybee/types" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "net/http" "testing" - "time" ) -// TODO: Worker must connect and emit events. -func _TestPoolConnect(t *testing.T) { +func setupPool(t *testing.T) (*Pool, *honeybeetest.MockDialer) { + t.Helper() + pool, err := NewPool(context.Background(), nil, nil) + assert.NoError(t, err) + dialer := &honeybeetest.MockDialer{ + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { + return honeybeetest.NewMockSocket(), nil, nil + }, + } + pool.dialer = dialer + return pool, dialer +} + +func TestPoolConnect(t *testing.T) { t.Run("successfully adds connection", func(t *testing.T) { - mockSocket := honeybeetest.NewMockSocket() - mockDialer := &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return mockSocket, nil, nil - }, - } + pool, _ := setupPool(t) - pool, err := NewPool(context.Background(), nil, nil) - assert.NoError(t, err) - - pool.dialer = mockDialer - - err = pool.Connect("wss://test") + err := pool.Connect("wss://test") assert.NoError(t, err) honeybeetest.Eventually(t, func() bool { @@ -40,25 +40,15 @@ func _TestPoolConnect(t *testing.T) { } }, "expected event") - _, exists := pool.peers["wss://test"] - assert.True(t, exists) + assert.Contains(t, pool.Peers(), "wss://test") pool.Close() }) t.Run("does not add duplicate", func(t *testing.T) { - mockSocket := honeybeetest.NewMockSocket() - mockDialer := &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return mockSocket, nil, nil - }, - } + pool, _ := setupPool(t) - pool, err := NewPool(context.Background(), nil, nil) - assert.NoError(t, err) - pool.dialer = mockDialer - - err = pool.Connect("wss://test") + err := pool.Connect("wss://test") assert.NoError(t, err) // trailing slash normalizes to same key @@ -66,119 +56,71 @@ func _TestPoolConnect(t *testing.T) { assert.Error(t, err) assert.ErrorIs(t, err, ErrPeerExists) - pool.mu.RLock() - assert.Len(t, pool.peers, 1) - pool.mu.RUnlock() - - pool.Close() - }) - - t.Run("fails to add connection", func(t *testing.T) { - pool, err := NewPool( - context.Background(), - &PoolConfig{ - ConnectionConfig: &transport.ConnectionConfig{ - Retry: &transport.RetryConfig{ - MaxRetries: 1, - InitialDelay: 1 * time.Millisecond, - MaxDelay: 5 * time.Millisecond, - }}, - }, nil) - assert.NoError(t, err) - pool.dialer = &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return nil, nil, fmt.Errorf("dial failed") - }, - } - - err = pool.Connect("wss://test") - assert.Error(t, err) - - pool.mu.RLock() - assert.Len(t, pool.peers, 0) - pool.mu.RUnlock() - - select { - case event := <-pool.events: - t.Fatalf("unexpected event: %+v", event) - default: - } + assert.Len(t, pool.Peers(), 1) pool.Close() }) } -// TODO: Worker must stop connection and emit events -func _TestPoolRemove(t *testing.T) { - t.Run("removes known url", func(t *testing.T) { - mockSocket := honeybeetest.NewMockSocket() - mockDialer := &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return mockSocket, nil, nil - }, - } +func TestPoolClose(t *testing.T) { + t.Run("channels close after pool close", func(t *testing.T) { + pool, _ := NewPool(context.Background(), nil, nil) + pool.Close() + _, ok := <-pool.Inbox() + assert.False(t, ok) + _, ok = <-pool.Events() + assert.False(t, ok) + _, ok = <-pool.Errors() + assert.False(t, ok) + }) - pool, err := NewPool(context.Background(), nil, nil) - assert.NoError(t, err) - pool.dialer = mockDialer + t.Run("connect after close returns error", func(t *testing.T) { + pool, _ := NewPool(context.Background(), nil, nil) + pool.Close() + err := pool.Connect("wss://test") + assert.ErrorIs(t, err, ErrPoolClosed) + }) +} + +func TestPoolRemove(t *testing.T) { + t.Run("removes known url", func(t *testing.T) { + pool, _ := setupPool(t) pool.Connect("wss://test") expectEvent(t, pool.events, "wss://test", EventConnected) - err = pool.Remove("wss://test/") + err := pool.Remove("wss://test/") assert.NoError(t, err) // expect a disconnected event expectEvent(t, pool.events, "wss://test", EventDisconnected) // connection no longer in pool - pool.mu.Lock() - defer pool.mu.Unlock() - _, ok := pool.peers["wss://peer2"] - assert.False(t, ok, "connection is still in pool") + assert.NotContains(t, pool.Peers(), "wss://test") }) t.Run("unknown url returns error", func(t *testing.T) { - mockSocket := honeybeetest.NewMockSocket() - mockDialer := &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return mockSocket, nil, nil - }, - } - - pool, err := NewPool(context.Background(), nil, nil) - assert.NoError(t, err) - pool.dialer = mockDialer + pool, _ := setupPool(t) // remove unknown connection - err = pool.Remove("wss://unknown") + err := pool.Remove("wss://unknown") assert.ErrorIs(t, err, ErrPeerNotFound) }) t.Run("closed pool returns error", func(t *testing.T) { - mockSocket := honeybeetest.NewMockSocket() - mockDialer := &honeybeetest.MockDialer{ - DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { - return mockSocket, nil, nil - }, - } - - pool, err := NewPool(context.Background(), nil, nil) - assert.NoError(t, err) - pool.dialer = mockDialer + pool, _ := setupPool(t) // close pool pool.Close() // attempt to remove connection - err = pool.Remove("wss://test") + err := pool.Remove("wss://test") assert.ErrorIs(t, err, ErrPoolClosed) }) } -// TODO: update worker to be responsible for send -func _TestPoolSend(t *testing.T) { +func TestPoolSend(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() outgoingData := make(chan honeybeetest.MockOutgoingData, 10) mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { diff --git a/initiatorpool/worker_dialer_test.go b/initiatorpool/worker_dialer_test.go index 73d6fe9..6f0480e 100644 --- a/initiatorpool/worker_dialer_test.go +++ b/initiatorpool/worker_dialer_test.go @@ -8,6 +8,7 @@ import ( "git.wisehodl.dev/jay/go-honeybee/types" "github.com/stretchr/testify/assert" "net/http" + "sync" "sync/atomic" "testing" "time" @@ -57,11 +58,14 @@ func TestRunDialer(t *testing.T) { mockSocket := honeybeetest.NewMockSocket() connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry + started := make(chan struct{}) + startOnce := sync.Once{} wctx := WorkerContext{ Errors: make(chan error, 1), Dialer: &honeybeetest.MockDialer{ DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { dialCount.Add(1) + startOnce.Do(func() { close(started) }) <-gate return mockSocket, nil, nil }, @@ -73,7 +77,7 @@ func TestRunDialer(t *testing.T) { dial <- struct{}{} // wait for dial to start blocking on gate - time.Sleep(20 * time.Millisecond) + <-started // flood dial while dialer is blocked for i := 0; i < 5; i++ { diff --git a/initiatorpool/worker_keepalive_test.go b/initiatorpool/worker_keepalive_test.go index b5a069a..14b5b36 100644 --- a/initiatorpool/worker_keepalive_test.go +++ b/initiatorpool/worker_keepalive_test.go @@ -15,14 +15,14 @@ func TestRunKeepalive(t *testing.T) { defer cancel() w := &DefaultWorker{ - Config: &WorkerConfig{KeepaliveTimeout: 100 * time.Millisecond}, + Config: &WorkerConfig{KeepaliveTimeout: 200 * time.Millisecond}, Heartbeat: heartbeat, } go w.RunKeepalive(ctx, keepalive) // send heartbeats faster than the timeout for i := 0; i < 5; i++ { - time.Sleep(30 * time.Millisecond) + time.Sleep(20 * time.Millisecond) w.Heartbeat <- struct{}{} } diff --git a/transport/config.go b/transport/config.go index d3dd4a7..0ff5be2 100644 --- a/transport/config.go +++ b/transport/config.go @@ -147,15 +147,9 @@ func WithWriteTimeout(value time.Duration) ConnectionOption { } } -// WithRetry enables retry with default parameters (infinite retries, -// 1s initial delay, 5s max delay, 0.5 jitter factor). -// -// If passed after granular retry options (WithRetryMaxRetries, etc.), -// it will overwrite them. Use either WithRetry alone or the granular -// options; not both. -func WithRetry() ConnectionOption { +func WithoutRetry() ConnectionOption { return func(c *ConnectionConfig) error { - c.Retry = GetDefaultRetryConfig() + c.Retry = nil return nil } } diff --git a/transport/config_test.go b/transport/config_test.go index 1fc259a..179be1b 100644 --- a/transport/config_test.go +++ b/transport/config_test.go @@ -107,13 +107,12 @@ func TestWithWriteTimeout(t *testing.T) { } func TestWithRetry(t *testing.T) { - t.Run("default", func(t *testing.T) { - conf := &ConnectionConfig{} - opt := WithRetry() + t.Run("without retry", func(t *testing.T) { + conf := GetDefaultConnectionConfig() + opt := WithoutRetry() err := applyConnectionOptions(conf, opt) assert.NoError(t, err) - assert.NotNil(t, conf.Retry) - assert.Equal(t, conf.Retry, GetDefaultRetryConfig()) + assert.Nil(t, conf.Retry) }) t.Run("with attempts", func(t *testing.T) {