From eea208738a123e72d7de35c5c659b3f5fba3883c Mon Sep 17 00:00:00 2001 From: Jay Date: Mon, 20 Apr 2026 14:19:50 -0400 Subject: [PATCH] completed inbound pool. Refactored to inbound/outbound semantics. --- honeybee.go | 117 ++++-- {responderpool => inbound}/config.go | 19 +- {responderpool => inbound}/config_test.go | 2 +- {responderpool => inbound}/errors.go | 3 +- {responderpool => inbound}/helpers_test.go | 39 +- inbound/pool.go | 297 +++++++++++++ inbound/pool_test.go | 395 ++++++++++++++++++ {responderpool => inbound}/worker.go | 112 ++++- .../worker_forwarder_test.go | 2 +- .../worker_reader_test.go | 66 ++- inbound/worker_test.go | 228 ++++++++++ .../worker_watchdog_test.go | 43 +- {initiatorpool => outbound}/config.go | 2 +- .../config_pool_test.go | 2 +- {initiatorpool => outbound}/errors.go | 2 +- {initiatorpool => outbound}/helper_test.go | 4 +- {initiatorpool => outbound}/pool.go | 67 ++- {initiatorpool => outbound}/pool_test.go | 40 +- {initiatorpool => outbound}/worker.go | 72 ++-- .../worker_dialer_test.go | 2 +- .../worker_forwarder_test.go | 2 +- .../worker_keepalive_test.go | 25 +- .../worker_send_test.go | 40 +- .../worker_session_inner_test.go | 14 +- .../worker_session_test.go | 6 +- .../worker_start_test.go | 20 +- responderpool/pool.go | 24 -- 27 files changed, 1401 insertions(+), 244 deletions(-) rename {responderpool => inbound}/config.go (90%) rename {responderpool => inbound}/config_test.go (99%) rename {responderpool => inbound}/errors.go (85%) rename {responderpool => inbound}/helpers_test.go (52%) create mode 100644 inbound/pool.go create mode 100644 inbound/pool_test.go rename {responderpool => inbound}/worker.go (55%) rename {initiatorpool => inbound}/worker_forwarder_test.go (99%) rename {responderpool => inbound}/worker_reader_test.go (68%) create mode 100644 inbound/worker_test.go rename {responderpool => inbound}/worker_watchdog_test.go (64%) rename {initiatorpool => outbound}/config.go (99%) rename {initiatorpool => outbound}/config_pool_test.go (99%) rename {initiatorpool => outbound}/errors.go (97%) rename {initiatorpool => outbound}/helper_test.go (94%) rename {initiatorpool => outbound}/pool.go (92%) rename {initiatorpool => outbound}/pool_test.go (99%) rename {initiatorpool => outbound}/worker.go (86%) rename {initiatorpool => outbound}/worker_dialer_test.go (99%) rename {responderpool => outbound}/worker_forwarder_test.go (99%) rename {initiatorpool => outbound}/worker_keepalive_test.go (76%) rename {initiatorpool => outbound}/worker_send_test.go (77%) rename {initiatorpool => outbound}/worker_session_inner_test.go (92%) rename {initiatorpool => outbound}/worker_session_test.go (98%) rename {initiatorpool => outbound}/worker_start_test.go (94%) delete mode 100644 responderpool/pool.go diff --git a/honeybee.go b/honeybee.go index 294b5aa..c6b0933 100644 --- a/honeybee.go +++ b/honeybee.go @@ -4,8 +4,10 @@ import ( "context" "log/slog" - "git.wisehodl.dev/jay/go-honeybee/initiatorpool" + "git.wisehodl.dev/jay/go-honeybee/inbound" + "git.wisehodl.dev/jay/go-honeybee/outbound" "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" ) // Connection types @@ -15,22 +17,52 @@ type ConnectionConfig = transport.ConnectionConfig type RetryConfig = transport.RetryConfig type ConnectionOption = transport.ConnectionOption -// Initator Pool types +// Outbound Pool types -type InitiatorPool = initiatorpool.Pool -type InitiatorPoolConfig = initiatorpool.PoolConfig -type InitiatorPoolOption = initiatorpool.PoolOption -type InitiatorWorkerConfig = initiatorpool.WorkerConfig -type InitiatorWorkerOption = initiatorpool.WorkerOption -type InitiatorInboxMessage = initiatorpool.InboxMessage -type InitiatorPoolEvent = initiatorpool.PoolEvent -type InitiatorPoolEventKind = initiatorpool.PoolEventKind +type OutboundPool = outbound.Pool +type OutboundPoolConfig = outbound.PoolConfig +type OutboundPoolOption = outbound.PoolOption +type OutboundWorkerConfig = outbound.WorkerConfig +type OutboundWorkerOption = outbound.WorkerOption +type OutboundInboxMessage = outbound.InboxMessage +type OutboundPoolEvent = outbound.PoolEvent +type OutboundPoolEventKind = outbound.PoolEventKind // Pool event constants const ( - EventConnected = initiatorpool.EventConnected - EventDisconnected = initiatorpool.EventDisconnected + EventConnected = outbound.EventConnected + EventDisconnected = outbound.EventDisconnected +) + +// Inbound Pool types + +type InboundPool = inbound.Pool +type InboundPoolConfig = inbound.PoolConfig +type InboundPoolOption = inbound.PoolOption +type InboundWorkerConfig = inbound.WorkerConfig +type InboundWorkerOption = inbound.WorkerOption +type InboundWorkerFactory = inbound.WorkerFactory +type InboundWorker = inbound.Worker +type InboundWorkerExitKind = inbound.WorkerExitKind +type InboundInboxMessage = inbound.InboxMessage +type InboundPoolEvent = inbound.PoolEvent +type InboundPoolEventKind = inbound.PoolEventKind + +// Inbound Pool event constants + +const ( + EventPeerDisconnected = inbound.EventPeerDisconnected + EventPeerDropped = inbound.EventPeerDropped + EventPeerEvicted = inbound.EventPeerEvicted +) + +// Inbound Worker exit kinds + +const ( + ExitCleanDisconnect = inbound.ExitCleanDisconnect + ExitUnexpectedDrop = inbound.ExitUnexpectedDrop + ExitInactive = inbound.ExitInactive ) // Connection constructors @@ -55,31 +87,64 @@ var ( WithCloseHandler = transport.WithCloseHandler ) -// Initiator Pool constructors +// Outbound Pool constructors -func NewInitiatorPool(ctx context.Context, config *InitiatorPoolConfig, logger *slog.Logger) (*InitiatorPool, error) { - return initiatorpool.NewPool(ctx, config, logger) +func NewOutboundPool(ctx context.Context, config *OutboundPoolConfig, logger *slog.Logger) (*OutboundPool, error) { + return outbound.NewPool(ctx, config, logger) } -func NewInitiatorPoolConfig(opts ...InitiatorPoolOption) (*InitiatorPoolConfig, error) { - return initiatorpool.NewPoolConfig(opts...) +func NewOutboundPoolConfig(opts ...OutboundPoolOption) (*OutboundPoolConfig, error) { + return outbound.NewPoolConfig(opts...) } -func NewInitiatorWorkerConfig(opts ...InitiatorWorkerOption) (*InitiatorWorkerConfig, error) { - return initiatorpool.NewWorkerConfig(opts...) +func NewOutboundWorkerConfig(opts ...OutboundWorkerOption) (*OutboundWorkerConfig, error) { + return outbound.NewWorkerConfig(opts...) } -// Initiator Pool options +// Outbound Pool options var ( - WithConnectionConfig = initiatorpool.WithConnectionConfig - WithWorkerConfig = initiatorpool.WithWorkerConfig - WithWorkerFactory = initiatorpool.WithWorkerFactory + WithOutboundConnectionConfig = outbound.WithConnectionConfig + WithOutboundWorkerConfig = outbound.WithWorkerConfig + WithOutboundWorkerFactory = outbound.WithWorkerFactory ) -// Initiator Worker options +// Outbound Worker options var ( - WithKeepaliveTimeout = initiatorpool.WithKeepaliveTimeout - WithMaxQueueSize = initiatorpool.WithMaxQueueSize + WithOutboundKeepaliveTimeout = outbound.WithKeepaliveTimeout + WithOutboundMaxQueueSize = outbound.WithMaxQueueSize ) + +// Inbound Pool constructors + +func NewInboundPool(ctx context.Context, config *InboundPoolConfig, logger *slog.Logger) (*InboundPool, error) { + return inbound.NewPool(ctx, config, logger) +} + +func NewInboundPoolConfig(opts ...InboundPoolOption) (*InboundPoolConfig, error) { + return inbound.NewPoolConfig(opts...) +} + +func NewInboundWorkerConfig(opts ...InboundWorkerOption) (*InboundWorkerConfig, error) { + return inbound.NewWorkerConfig(opts...) +} + +// Inbound Pool options + +var ( + WithInboundConnectionConfig = inbound.WithConnectionConfig + WithInboundWorkerConfig = inbound.WithWorkerConfig + WithInboundWorkerFactory = inbound.WithWorkerFactory +) + +// Inbound Worker options + +var ( + WithInboundDeadTimeout = inbound.WithDeadTimeout + WithInboundMaxQueueSize = inbound.WithMaxQueueSize +) + +// Socket type — needed for inbound pool.Add and pool.Replace + +type Socket = types.Socket diff --git a/responderpool/config.go b/inbound/config.go similarity index 90% rename from responderpool/config.go rename to inbound/config.go index 7ba5f71..3e9a7ad 100644 --- a/responderpool/config.go +++ b/inbound/config.go @@ -1,7 +1,8 @@ // responderpool/config.go -package responderpool +package inbound import ( + "context" "git.wisehodl.dev/jay/go-honeybee/transport" "time" ) @@ -90,9 +91,17 @@ func WithDeadTimeout(value time.Duration) WorkerOption { // Pool Config +type WorkerFactory func( + ctx context.Context, + id string, + conn *transport.Connection, + config *WorkerConfig, +) (Worker, error) + type PoolConfig struct { ConnectionConfig *transport.ConnectionConfig WorkerConfig *WorkerConfig + WorkerFactory WorkerFactory } type PoolOption func(*PoolConfig) error @@ -112,6 +121,7 @@ func GetDefaultPoolConfig() *PoolConfig { return &PoolConfig{ ConnectionConfig: nil, WorkerConfig: nil, + WorkerFactory: nil, } } @@ -157,3 +167,10 @@ func WithWorkerConfig(wc *WorkerConfig) PoolOption { return nil } } + +func WithWorkerFactory(wf WorkerFactory) PoolOption { + return func(c *PoolConfig) error { + c.WorkerFactory = wf + return nil + } +} diff --git a/responderpool/config_test.go b/inbound/config_test.go similarity index 99% rename from responderpool/config_test.go rename to inbound/config_test.go index f24139a..d4d5c22 100644 --- a/responderpool/config_test.go +++ b/inbound/config_test.go @@ -1,5 +1,5 @@ // responderpool/config_test.go -package responderpool +package inbound import ( "git.wisehodl.dev/jay/go-honeybee/transport" diff --git a/responderpool/errors.go b/inbound/errors.go similarity index 85% rename from responderpool/errors.go rename to inbound/errors.go index c702df1..a6b7c58 100644 --- a/responderpool/errors.go +++ b/inbound/errors.go @@ -1,9 +1,10 @@ -package responderpool +package inbound import "errors" var ( // Pool errors + PoolError = errors.New("pool error") ErrPoolClosed = errors.New("pool is closed") ErrPeerNotFound = errors.New("peer not found") ErrPeerExists = errors.New("peer already exists") diff --git a/responderpool/helpers_test.go b/inbound/helpers_test.go similarity index 52% rename from responderpool/helpers_test.go rename to inbound/helpers_test.go index 0c15dd8..1b6c0da 100644 --- a/responderpool/helpers_test.go +++ b/inbound/helpers_test.go @@ -1,7 +1,6 @@ -package responderpool +package inbound import ( - "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/transport" "github.com/stretchr/testify/assert" @@ -9,9 +8,8 @@ import ( "testing" ) -func setupReaderTestConnection(t *testing.T) ( - conn *transport.Connection, - mock *honeybeetest.MockSocket, +func setupTestSocket(t *testing.T) ( + socket *honeybeetest.MockSocket, incoming chan honeybeetest.MockIncomingData, outgoing chan honeybeetest.MockOutgoingData, ) { @@ -19,38 +17,51 @@ func setupReaderTestConnection(t *testing.T) ( incoming = make(chan honeybeetest.MockIncomingData, 10) outgoing = make(chan honeybeetest.MockOutgoingData, 10) - mock = honeybeetest.NewMockSocket() + socket = honeybeetest.NewMockSocket() - mock.CloseFunc = func() error { - mock.Once.Do(func() { close(mock.Closed) }) + socket.CloseFunc = func() error { + socket.Once.Do(func() { close(socket.Closed) }) return nil } - mock.ReadMessageFunc = func() (int, []byte, error) { + socket.ReadMessageFunc = func() (int, []byte, error) { select { case data, ok := <-incoming: if !ok { return 0, nil, io.EOF } return data.MsgType, data.Data, data.Err - case <-mock.Closed: + case <-socket.Closed: return 0, nil, io.EOF } } - mock.WriteMessageFunc = func(msgType int, data []byte) error { + socket.WriteMessageFunc = func(msgType int, data []byte) error { select { case outgoing <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}: return nil - case <-mock.Closed: + case <-socket.Closed: return io.EOF default: - return fmt.Errorf("mock outgoing channel unavailable") + return io.EOF } } + return +} + +func setupTestConnection(t *testing.T) ( + conn *transport.Connection, + socket *honeybeetest.MockSocket, + incoming chan honeybeetest.MockIncomingData, + outgoing chan honeybeetest.MockOutgoingData, +) { + t.Helper() + + socket, incoming, outgoing = setupTestSocket(t) + var err error - conn, err = transport.NewConnectionFromSocket(mock, nil, nil) + conn, err = transport.NewConnectionFromSocket(socket, nil, nil) assert.NoError(t, err) return } diff --git a/inbound/pool.go b/inbound/pool.go new file mode 100644 index 0000000..35b96c0 --- /dev/null +++ b/inbound/pool.go @@ -0,0 +1,297 @@ +package inbound + +import ( + "context" + "fmt" + "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" + "log/slog" + "sync" + "time" +) + +// Types + +type PoolEventKind string + +const ( + EventPeerDisconnected PoolEventKind = "disconnected" + EventPeerDropped PoolEventKind = "dropped" + EventPeerEvicted PoolEventKind = "evicted" +) + +var workerToPoolEvent = map[WorkerExitKind]PoolEventKind{ + ExitCleanDisconnect: EventPeerDisconnected, + ExitUnexpectedDrop: EventPeerDropped, + ExitInactive: EventPeerEvicted, +} + +type OnExitFunction func(kind WorkerExitKind) + +type PoolEvent struct { + ID string + Kind PoolEventKind +} + +type InboxMessage struct { + ID string + Data []byte + ReceivedAt time.Time +} + +type PoolPlugin struct { + Inbox chan<- InboxMessage + Events chan<- PoolEvent + Logger *slog.Logger + OnExit OnExitFunction +} + +// Pool + +type Peer struct { + id string + conn *transport.Connection + worker Worker + done chan struct{} +} + +type Pool struct { + ctx context.Context + cancel context.CancelFunc + + peers map[string]*Peer + inbox chan InboxMessage + events chan PoolEvent + + config *PoolConfig + logger *slog.Logger + + mu sync.RWMutex + wg sync.WaitGroup + closed bool +} + +func NewPool(ctx context.Context, config *PoolConfig, logger *slog.Logger) (*Pool, error) { + if config == nil { + config = GetDefaultPoolConfig() + } + + // If a custom factory is supplied, config.WorkerConfig is not used. + // The factory function should be non-blocking or else Connect() may cause + // deadlocks. + if config.WorkerFactory == nil { + config.WorkerFactory = func( + ctx context.Context, + id string, + conn *transport.Connection, + config *WorkerConfig, + ) (Worker, error) { + return NewWorker(ctx, id, conn, config) + } + } + + if err := ValidatePoolConfig(config); err != nil { + return nil, err + } + + pctx, cancel := context.WithCancel(ctx) + + return &Pool{ + ctx: pctx, + cancel: cancel, + peers: make(map[string]*Peer), + inbox: make(chan InboxMessage, 256), + events: make(chan PoolEvent, 10), + config: config, + logger: logger, + }, nil +} + +func (p *Pool) Peers() []string { + p.mu.RLock() + defer p.mu.RUnlock() + + ids := make([]string, 0, len(p.peers)) + for id := range p.peers { + ids = append(ids, id) + } + + return ids +} + +func (p *Pool) Inbox() <-chan InboxMessage { + return p.inbox +} + +func (p *Pool) Events() <-chan PoolEvent { + return p.events +} + +func (p *Pool) Close() { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return + } + + p.closed = true + p.cancel() + + // remove all peers + p.peers = make(map[string]*Peer) + + // close all connections + for _, peer := range p.peers { + peer.worker.Stop() + peer.conn.Close() + } + + p.mu.Unlock() + + go func() { + p.wg.Wait() + close(p.inbox) + close(p.events) + }() +} + +func (p *Pool) Add(id string, socket types.Socket) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return ErrPoolClosed + } + + if _, exists := p.peers[id]; exists { + return ErrPeerExists + } + + return p.addLocked(id, socket) +} + +func (p *Pool) Replace(id string, socket types.Socket) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return ErrPoolClosed + } + + if peer, exists := p.peers[id]; exists { + p.removeLocked(peer) + } else { + return ErrPeerNotFound + } + + return p.addLocked(id, socket) +} + +func (p *Pool) Remove(id string) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return ErrPoolClosed + } + + peer, exists := p.peers[id] + if !exists { + return ErrPeerNotFound + } + + p.removeLocked(peer) + + return nil +} + +func (p *Pool) Send(id string, data []byte) error { + p.mu.RLock() + defer p.mu.RUnlock() + + if p.closed { + return ErrPoolClosed + } + + peer, exists := p.peers[id] + if !exists { + return ErrPeerNotFound + } + + return peer.worker.Send(data) +} + +// addLocked constructs and registers a peer. Caller must hold p.mu write lock. +func (p *Pool) addLocked(id string, socket types.Socket) error { + conn, err := transport.NewConnectionFromSocket( + socket, p.config.ConnectionConfig, p.logger) + if err != nil { + return err + } + + // The worker factory must be non-blocking to avoid deadlocks + wctx, cancel := context.WithCancel(p.ctx) + worker, err := p.config.WorkerFactory(wctx, id, conn, p.config.WorkerConfig) + if err != nil { + cancel() + conn.Close() + return fmt.Errorf("%w: %w", PoolError, err) + } + + var once sync.Once + onExit := func(kind WorkerExitKind) { + once.Do(func() { + p.mu.Lock() + delete(p.peers, id) + p.mu.Unlock() + + conn.Close() + + select { + case p.events <- PoolEvent{ID: id, Kind: workerToPoolEvent[kind]}: + case <-p.ctx.Done(): + return + } + }) + } + + var logger *slog.Logger + if p.logger != nil { + logger = p.logger.With("id", id) + } + + pool := PoolPlugin{ + Inbox: p.inbox, + Events: p.events, + Logger: logger, + OnExit: onExit, + } + + peer := &Peer{ + id: id, + conn: conn, + worker: worker, + done: make(chan struct{}), + } + + p.wg.Add(1) + go func() { + defer cancel() + defer close(peer.done) + worker.Start(pool, &p.wg) + }() + + p.peers[id] = peer + + return nil +} + +// removeLocked closes and unregisters a peer. Caller must hold p.mu write lock. +func (p *Pool) removeLocked(peer *Peer) { + delete(p.peers, peer.id) + peer.worker.Stop() + go func() { + <-peer.done + peer.conn.Close() + }() +} diff --git a/inbound/pool_test.go b/inbound/pool_test.go new file mode 100644 index 0000000..02deb83 --- /dev/null +++ b/inbound/pool_test.go @@ -0,0 +1,395 @@ +package inbound + +import ( + "context" + "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "slices" + "testing" + "time" +) + +// Helpers + +func setupPool(t *testing.T) *Pool { + t.Helper() + pool, err := NewPool(context.Background(), nil, nil) + assert.NoError(t, err) + return pool +} + +func expectEvent( + t *testing.T, + events <-chan PoolEvent, + expectedURL string, + expectedKind PoolEventKind, +) { + t.Helper() + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.ID == expectedURL && e.Kind == expectedKind + default: + return false + } + }, fmt.Sprintf("expected event: URL=%q, Kind=%q", expectedURL, expectedKind)) +} + +// Tests + +func TestPoolAdd(t *testing.T) { + t.Run("successfully adds peer", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Add("peer-1", socket) + assert.NoError(t, err) + }) + + t.Run("peer appears in Peers after add", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Add("peer-1", socket) + assert.NoError(t, err) + + assert.Contains(t, pool.Peers(), "peer-1") + }) + + t.Run("duplicate id returns ErrPeerExists", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket1, _, _ := setupTestSocket(t) + socket2, _, _ := setupTestSocket(t) + + err := pool.Add("peer-1", socket1) + assert.NoError(t, err) + + err = pool.Add("peer-1", socket2) + assert.ErrorIs(t, err, ErrPeerExists) + }) + + t.Run("closed pool returns ErrPoolClosed", func(t *testing.T) { + pool := setupPool(t) + pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Add("peer-1", socket) + assert.ErrorIs(t, err, ErrPoolClosed) + }) +} + +func TestPoolReplace(t *testing.T) { + t.Run("replaces existing peer", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket1, _, _ := setupTestSocket(t) + socket2, _, _ := setupTestSocket(t) + + err := pool.Add("peer-1", socket1) + assert.NoError(t, err) + + err = pool.Replace("peer-1", socket2) + assert.NoError(t, err) + + assert.Contains(t, pool.Peers(), "peer-1") + }) + + t.Run("unknown id returns ErrPeerNotFound", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Replace("unknown", socket) + assert.ErrorIs(t, err, ErrPeerNotFound) + }) + + t.Run("closed pool returns ErrPoolClosed", func(t *testing.T) { + pool := setupPool(t) + pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Replace("peer-1", socket) + assert.ErrorIs(t, err, ErrPoolClosed) + }) + + t.Run("no event emitted for replaced peer", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket1, _, _ := setupTestSocket(t) + socket2, _, _ := setupTestSocket(t) + + err := pool.Add("peer-1", socket1) + assert.NoError(t, err) + + err = pool.Replace("peer-1", socket2) + assert.NoError(t, err) + + honeybeetest.Never(t, func() bool { + select { + case <-pool.Events(): + return true + default: + return false + } + }, "no event expected on replace") + }) +} + +func TestPoolRemove(t *testing.T) { + t.Run("removes known peer", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Add("peer-1", socket) + assert.NoError(t, err) + + err = pool.Remove("peer-1") + assert.NoError(t, err) + + assert.NotContains(t, pool.Peers(), "peer-1") + }) + + t.Run("unknown id returns ErrPeerNotFound", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + err := pool.Remove("unknown") + assert.ErrorIs(t, err, ErrPeerNotFound) + }) + + t.Run("closed pool returns ErrPoolClosed", func(t *testing.T) { + pool := setupPool(t) + pool.Close() + + err := pool.Remove("peer-1") + assert.ErrorIs(t, err, ErrPoolClosed) + }) + + t.Run("no event emitted on remove", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Add("peer-1", socket) + assert.NoError(t, err) + + err = pool.Remove("peer-1") + assert.NoError(t, err) + + honeybeetest.Never(t, func() bool { + select { + case e := <-pool.Events(): + fmt.Printf("got event: %v", e) + return true + default: + return false + } + }, "no event expected on remove") + }) +} + +func TestPoolSend(t *testing.T) { + t.Run("data reaches socket", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, outgoing := setupTestSocket(t) + err := pool.Add("peer-1", socket) + assert.NoError(t, err) + + err = pool.Send("peer-1", []byte("hello")) + assert.NoError(t, err) + + honeybeetest.ExpectWrite(t, outgoing, websocket.TextMessage, []byte("hello")) + }) + + t.Run("unknown id returns ErrPeerNotFound", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + err := pool.Send("unknown", []byte("hello")) + assert.ErrorIs(t, err, ErrPeerNotFound) + }) + + t.Run("closed pool returns ErrPoolClosed", func(t *testing.T) { + pool := setupPool(t) + pool.Close() + + err := pool.Send("peer-1", []byte("hello")) + assert.ErrorIs(t, err, ErrPoolClosed) + }) +} + +func TestPoolClose(t *testing.T) { + t.Run("inbox and events channels close after pool close", func(t *testing.T) { + pool := setupPool(t) + pool.Close() + + _, ok := <-pool.Inbox() + assert.False(t, ok) + _, ok = <-pool.Events() + assert.False(t, ok) + }) + + t.Run("add after close returns ErrPoolClosed", func(t *testing.T) { + pool := setupPool(t) + pool.Close() + + socket, _, _ := setupTestSocket(t) + err := pool.Add("peer-1", socket) + assert.ErrorIs(t, err, ErrPoolClosed) + }) + + t.Run("close is idempotent", func(t *testing.T) { + pool := setupPool(t) + pool.Close() + pool.Close() + }) +} + +func TestPoolPeers(t *testing.T) { + t.Run("reflects active peers after add", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket1, _, _ := setupTestSocket(t) + socket2, _, _ := setupTestSocket(t) + + pool.Add("peer-1", socket1) + pool.Add("peer-2", socket2) + + peers := pool.Peers() + assert.Contains(t, peers, "peer-1") + assert.Contains(t, peers, "peer-2") + }) + + t.Run("loses entry after remove", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + pool.Add("peer-1", socket) + pool.Remove("peer-1") + + assert.NotContains(t, pool.Peers(), "peer-1") + }) + + t.Run("loses entry after peer self-disconnects", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, incoming, _ := setupTestSocket(t) + pool.Add("peer-1", socket) + + close(incoming) + + honeybeetest.Eventually(t, func() bool { + return !slices.Contains(pool.Peers(), "peer-1") + }, "expected peer to be removed after self-disconnect") + }) +} + +func TestPoolEvents(t *testing.T) { + t.Run("EventPeerDisconnected emitted on clean close", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, incoming, _ := setupTestSocket(t) + pool.Add("peer-1", socket) + + incoming <- honeybeetest.MockIncomingData{ + Err: &websocket.CloseError{Code: websocket.CloseNormalClosure}, + } + + expectEvent(t, pool.Events(), "peer-1", EventPeerDisconnected) + + honeybeetest.Eventually(t, func() bool { + return !slices.Contains(pool.Peers(), "peer-1") + }, "expected peer auto-removed") + }) + + t.Run("EventPeerDropped emitted on unexpected close", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, incoming, _ := setupTestSocket(t) + pool.Add("peer-1", socket) + + incoming <- honeybeetest.MockIncomingData{ + Err: &websocket.CloseError{Code: websocket.CloseProtocolError}, + } + + expectEvent(t, pool.Events(), "peer-1", EventPeerDropped) + + honeybeetest.Eventually(t, func() bool { + return !slices.Contains(pool.Peers(), "peer-1") + }, "expected peer auto-removed") + }) + + t.Run("EventPeerEvicted emitted on watchdog timeout", func(t *testing.T) { + config, err := NewPoolConfig( + WithWorkerConfig(&WorkerConfig{DeadTimeout: 20 * time.Millisecond}), + ) + assert.NoError(t, err) + + pool, err := NewPool(context.Background(), config, nil) + assert.NoError(t, err) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + pool.Add("peer-1", socket) + + expectEvent(t, pool.Events(), "peer-1", EventPeerEvicted) + + honeybeetest.Eventually(t, func() bool { + return !slices.Contains(pool.Peers(), "peer-1") + }, "expected peer auto-removed") + }) + + t.Run("no event emitted on Remove", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket, _, _ := setupTestSocket(t) + pool.Add("peer-1", socket) + pool.Remove("peer-1") + + honeybeetest.Never(t, func() bool { + select { + case <-pool.Events(): + return true + default: + return false + } + }, "no event expected on Remove") + }) + + t.Run("no event emitted on Replace of old peer", func(t *testing.T) { + pool := setupPool(t) + defer pool.Close() + + socket1, _, _ := setupTestSocket(t) + socket2, _, _ := setupTestSocket(t) + + pool.Add("peer-1", socket1) + pool.Replace("peer-1", socket2) + + honeybeetest.Never(t, func() bool { + select { + case <-pool.Events(): + return true + default: + return false + } + }, "no event expected on Replace") + }) +} diff --git a/responderpool/worker.go b/inbound/worker.go similarity index 55% rename from responderpool/worker.go rename to inbound/worker.go index 29cc450..45fa208 100644 --- a/responderpool/worker.go +++ b/inbound/worker.go @@ -1,23 +1,111 @@ -package responderpool +package inbound import ( "container/list" "context" "errors" "git.wisehodl.dev/jay/go-honeybee/transport" + "sync" "time" ) -type onEventFunc func(kind PoolEventKind) +type Worker interface { + Start(pool PoolPlugin, wg *sync.WaitGroup) + Stop() + Send(data []byte) error +} + +type WorkerExitKind string + +const ( + ExitCleanDisconnect WorkerExitKind = "disconnected" + ExitUnexpectedDrop WorkerExitKind = "dropped" + ExitInactive WorkerExitKind = "inactive" +) type ReceivedMessage struct { data []byte receivedAt time.Time } +type DefaultWorker struct { + id string + conn *transport.Connection + heartbeat chan struct{} + config *WorkerConfig + ctx context.Context + cancel context.CancelFunc +} + +func NewWorker( + ctx context.Context, + id string, + conn *transport.Connection, + config *WorkerConfig, +) (*DefaultWorker, error) { + if config == nil { + config = GetDefaultWorkerConfig() + } + if err := ValidateWorkerConfig(config); err != nil { + return nil, err + } + + wctx, cancel := context.WithCancel(ctx) + return &DefaultWorker{ + id: id, + conn: conn, + heartbeat: make(chan struct{}), + config: config, + ctx: wctx, + cancel: cancel, + }, nil +} + +func (w *DefaultWorker) Start(pool PoolPlugin, wg *sync.WaitGroup) { + messages := make(chan ReceivedMessage, 256) + + var owg sync.WaitGroup + owg.Add(3) + + go func() { + defer owg.Done() + RunReader(w.ctx, pool.OnExit, w.conn, messages, w.heartbeat) + }() + + go func() { + defer owg.Done() + RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize) + }() + + go func() { + defer owg.Done() + RunWatchdog(w.ctx, pool.OnExit, w.heartbeat, w.config.DeadTimeout) + }() + + owg.Wait() + wg.Done() +} + +func (w *DefaultWorker) Stop() { + w.cancel() +} + +func (w *DefaultWorker) Send(data []byte) error { + if err := w.conn.Send(data); err != nil { + return err + } + + select { + case w.heartbeat <- struct{}{}: + case <-w.ctx.Done(): + } + + return nil +} + func RunReader( ctx context.Context, - onPeerClose onEventFunc, + onPeerClose OnExitFunction, conn *transport.Connection, messages chan<- ReceivedMessage, @@ -31,14 +119,14 @@ func RunReader( if !ok { // determine exit kind // by default, the peer dropped unexpectedly - kind := EventPeerDropped + kind := ExitUnexpectedDrop select { // the peer-side error is sent before the connection is closed, // so a non-blocking call here is correct // if an error is not sent, then assume the default event kind case err := <-conn.Errors(): if errors.Is(err, transport.ErrPeerClosedClean) { - kind = EventPeerDisconnected + kind = ExitCleanDisconnect } default: } @@ -104,17 +192,21 @@ func RunForwarder( func RunWatchdog( ctx context.Context, - onInactive func(), + onInactive OnExitFunction, heartbeat <-chan struct{}, timeout time.Duration, ) { // disable watchdog timeout if not configured if timeout <= 0 { + // drain heartbeats // wait for cancel and exit - select { - case <-ctx.Done(): + for { + select { + case <-heartbeat: + case <-ctx.Done(): + return + } } - return } timer := time.NewTimer(timeout) @@ -136,7 +228,7 @@ func RunWatchdog( // timer completed case <-timer.C: // signal peer is inactive - onInactive() + onInactive(ExitInactive) return } } diff --git a/initiatorpool/worker_forwarder_test.go b/inbound/worker_forwarder_test.go similarity index 99% rename from initiatorpool/worker_forwarder_test.go rename to inbound/worker_forwarder_test.go index 2a8ab48..18b05ca 100644 --- a/initiatorpool/worker_forwarder_test.go +++ b/inbound/worker_forwarder_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package inbound import ( "context" diff --git a/responderpool/worker_reader_test.go b/inbound/worker_reader_test.go similarity index 68% rename from responderpool/worker_reader_test.go rename to inbound/worker_reader_test.go index 3e5cc1a..0be4ed4 100644 --- a/responderpool/worker_reader_test.go +++ b/inbound/worker_reader_test.go @@ -1,8 +1,9 @@ -package responderpool +package inbound import ( "context" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/transport" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "io" @@ -13,7 +14,7 @@ import ( func TestRunReader(t *testing.T) { t.Run("message forwarded with correct data and non-zero receivedAt", func(t *testing.T) { - conn, _, incoming, _ := setupReaderTestConnection(t) + conn, _, incoming, _ := setupTestConnection(t) defer conn.Close() messages := make(chan ReceivedMessage, 1) @@ -21,7 +22,7 @@ func TestRunReader(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go RunReader(ctx, func(PoolEventKind) {}, conn, messages, heartbeat) + go RunReader(ctx, func(WorkerExitKind) {}, conn, messages, heartbeat) before := time.Now() incoming <- honeybeetest.MockIncomingData{MsgType: websocket.TextMessage, Data: []byte("hello")} @@ -37,7 +38,7 @@ func TestRunReader(t *testing.T) { }) t.Run("heartbeat sent per forwarded message", func(t *testing.T) { - conn, _, incoming, _ := setupReaderTestConnection(t) + conn, _, incoming, _ := setupTestConnection(t) defer conn.Close() messages := make(chan ReceivedMessage, 10) @@ -55,7 +56,7 @@ func TestRunReader(t *testing.T) { for range messages { } }() - go RunReader(ctx, func(PoolEventKind) {}, conn, messages, heartbeat) + go RunReader(ctx, func(WorkerExitKind) {}, conn, messages, heartbeat) const n = 3 for i := 0; i < n; i++ { @@ -67,20 +68,27 @@ func TestRunReader(t *testing.T) { }, "expected heartbeats") }) - t.Run("clean close calls onPeerClose with EventPeerDisconnected", func(t *testing.T) { - conn, mock, _, _ := setupReaderTestConnection(t) + t.Run("clean close calls onPeerClose with ExitCleanDisconnect", func(t *testing.T) { + mock := honeybeetest.NewMockSocket() + mock.CloseFunc = func() error { + mock.Once.Do(func() { close(mock.Closed) }) + return nil + } mock.ReadMessageFunc = func() (int, []byte, error) { return 0, nil, &websocket.CloseError{Code: websocket.CloseNormalClosure} } + conn, err := transport.NewConnectionFromSocket(mock, nil, nil) + assert.NoError(t, err) + messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var gotKind PoolEventKind + var gotKind WorkerExitKind done := make(chan struct{}) - go RunReader(ctx, func(kind PoolEventKind) { + go RunReader(ctx, func(kind WorkerExitKind) { gotKind = kind close(done) }, conn, messages, heartbeat) @@ -94,23 +102,30 @@ func TestRunReader(t *testing.T) { } }, "expected onPeerClose") - assert.Equal(t, EventPeerDisconnected, gotKind) + assert.Equal(t, ExitCleanDisconnect, gotKind) }) - t.Run("unexpected close calls onPeerClose with EventPeerDropped", func(t *testing.T) { - conn, mock, _, _ := setupReaderTestConnection(t) + t.Run("unexpected close calls onPeerClose with ExitUnexpectedDrop", func(t *testing.T) { + mock := honeybeetest.NewMockSocket() + mock.CloseFunc = func() error { + mock.Once.Do(func() { close(mock.Closed) }) + return nil + } mock.ReadMessageFunc = func() (int, []byte, error) { return 0, nil, &websocket.CloseError{Code: websocket.CloseProtocolError} } + conn, err := transport.NewConnectionFromSocket(mock, nil, nil) + assert.NoError(t, err) + messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var gotKind PoolEventKind + var gotKind WorkerExitKind done := make(chan struct{}) - go RunReader(ctx, func(kind PoolEventKind) { + go RunReader(ctx, func(kind WorkerExitKind) { gotKind = kind close(done) }, conn, messages, heartbeat) @@ -124,23 +139,30 @@ func TestRunReader(t *testing.T) { } }, "expected onPeerClose") - assert.Equal(t, EventPeerDropped, gotKind) + assert.Equal(t, ExitUnexpectedDrop, gotKind) }) - t.Run("read error calls onPeerClose with EventPeerDropped", func(t *testing.T) { - conn, mock, _, _ := setupReaderTestConnection(t) + t.Run("read error calls onPeerClose with ExitUnexpectedDrop", func(t *testing.T) { + mock := honeybeetest.NewMockSocket() + mock.CloseFunc = func() error { + mock.Once.Do(func() { close(mock.Closed) }) + return nil + } mock.ReadMessageFunc = func() (int, []byte, error) { return 0, nil, io.EOF } + conn, err := transport.NewConnectionFromSocket(mock, nil, nil) + assert.NoError(t, err) + messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var gotKind PoolEventKind + var gotKind WorkerExitKind done := make(chan struct{}) - go RunReader(ctx, func(kind PoolEventKind) { + go RunReader(ctx, func(kind WorkerExitKind) { gotKind = kind close(done) }, conn, messages, heartbeat) @@ -154,11 +176,11 @@ func TestRunReader(t *testing.T) { } }, "expected onPeerClose") - assert.Equal(t, EventPeerDropped, gotKind) + assert.Equal(t, ExitUnexpectedDrop, gotKind) }) t.Run("ctx.Done exits without calling onPeerClose", func(t *testing.T) { - conn, _, _, _ := setupReaderTestConnection(t) + conn, _, _, _ := setupTestConnection(t) defer conn.Close() messages := make(chan ReceivedMessage, 1) @@ -168,7 +190,7 @@ func TestRunReader(t *testing.T) { called := atomic.Bool{} done := make(chan struct{}) go func() { - RunReader(ctx, func(PoolEventKind) { + RunReader(ctx, func(WorkerExitKind) { called.Store(true) }, conn, messages, heartbeat) close(done) diff --git a/inbound/worker_test.go b/inbound/worker_test.go new file mode 100644 index 0000000..10129b9 --- /dev/null +++ b/inbound/worker_test.go @@ -0,0 +1,228 @@ +package inbound + +import ( + "context" + "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/transport" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "sync" + "sync/atomic" + "testing" + "time" +) + +type workerTestVars struct { + worker *DefaultWorker + conn *transport.Connection + incoming chan honeybeetest.MockIncomingData + outgoing chan honeybeetest.MockOutgoingData + pool PoolPlugin + inbox chan InboxMessage + events chan PoolEvent + exitKind *atomic.Value + wg *sync.WaitGroup +} + +func setupWorkerTest(t *testing.T) workerTestVars { + t.Helper() + + conn, _, incoming, outgoing := setupTestConnection(t) + + ctx, cancel := context.WithCancel(context.Background()) + var err error + worker, err := NewWorker(ctx, "peer-1", conn, nil) + assert.NoError(t, err) + worker.cancel = cancel + + inbox := make(chan InboxMessage, 256) + events := make(chan PoolEvent, 10) + exitKind := &atomic.Value{} + + var once sync.Once + pool := PoolPlugin{ + Inbox: inbox, + Events: events, + OnExit: func(kind WorkerExitKind) { + once.Do(func() { exitKind.Store(kind) }) + }, + } + + wg := &sync.WaitGroup{} + wg.Add(1) + + return workerTestVars{ + worker: worker, + conn: conn, + incoming: incoming, + outgoing: outgoing, + pool: pool, + inbox: inbox, + events: events, + exitKind: exitKind, + wg: wg, + } +} + +func TestWorkerStart(t *testing.T) { + t.Run("socket data arrives on inbox", func(t *testing.T) { + v := setupWorkerTest(t) + defer v.worker.Stop() + + go v.worker.Start(v.pool, v.wg) + + v.incoming <- honeybeetest.MockIncomingData{ + MsgType: websocket.TextMessage, + Data: []byte("hello"), + } + + honeybeetest.Eventually(t, func() bool { + select { + case msg := <-v.inbox: + return msg.ID == "peer-1" && string(msg.Data) == "hello" + default: + return false + } + }, "expected message on inbox") + }) + + t.Run("clean peer close calls OnExit with ExitCleanDisconnect", func(t *testing.T) { + v := setupWorkerTest(t) + defer v.worker.Stop() + + go v.worker.Start(v.pool, v.wg) + + v.incoming <- honeybeetest.MockIncomingData{ + Err: &websocket.CloseError{Code: websocket.CloseNormalClosure}, + } + + honeybeetest.Eventually(t, func() bool { + val := v.exitKind.Load() + return val != nil && val.(WorkerExitKind) == ExitCleanDisconnect + }, "expected ExitCleanDisconnect") + }) + + t.Run("unexpected peer close calls OnExit with ExitUnexpectedDrop", func(t *testing.T) { + v := setupWorkerTest(t) + defer v.worker.Stop() + + go v.worker.Start(v.pool, v.wg) + + v.incoming <- honeybeetest.MockIncomingData{ + Err: &websocket.CloseError{Code: websocket.CloseProtocolError}, + } + + honeybeetest.Eventually(t, func() bool { + val := v.exitKind.Load() + return val != nil && val.(WorkerExitKind) == ExitUnexpectedDrop + }, "expected ExitUnexpectedDrop") + }) + + t.Run("watchdog timeout calls OnExit with ExitInactive", func(t *testing.T) { + conn, _, _, _ := setupTestConnection(t) + + ctx, cancel := context.WithCancel(context.Background()) + worker, err := NewWorker(ctx, "peer-1", conn, &WorkerConfig{ + DeadTimeout: 20 * time.Millisecond, + }) + assert.NoError(t, err) + worker.cancel = cancel + defer worker.Stop() + + exitKind := &atomic.Value{} + var once sync.Once + pool := PoolPlugin{ + Inbox: make(chan InboxMessage, 256), + Events: make(chan PoolEvent, 10), + OnExit: func(kind WorkerExitKind) { + once.Do(func() { exitKind.Store(kind) }) + }, + } + + var wg sync.WaitGroup + wg.Add(1) + go worker.Start(pool, &wg) + + honeybeetest.Eventually(t, func() bool { + val := exitKind.Load() + return val != nil && val.(WorkerExitKind) == ExitInactive + }, "expected ExitInactive") + }) +} + +func TestWorkerStop(t *testing.T) { + v := setupWorkerTest(t) + + go v.worker.Start(v.pool, v.wg) + + v.worker.Stop() + + done := make(chan struct{}) + go func() { v.wg.Wait(); close(done) }() + + honeybeetest.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, "expected wg to drain") + + // does not call onExit + assert.Nil(t, v.exitKind.Load()) +} + +func TestWorkerSend(t *testing.T) { + t.Run("Send delivers data to socket", func(t *testing.T) { + v := setupWorkerTest(t) + defer v.worker.Stop() + + go v.worker.Start(v.pool, v.wg) + + err := v.worker.Send([]byte("hello")) + assert.NoError(t, err) + + honeybeetest.ExpectWrite(t, v.outgoing, websocket.TextMessage, []byte("hello")) + }) + + t.Run("Send produces heartbeats", func(t *testing.T) { + v := setupWorkerTest(t) + defer v.worker.Stop() + + count := atomic.Int32{} + go func() { + for range v.worker.heartbeat { + count.Add(1) + } + }() + + // do not start the worker, allow heartbeats to be drained manually + + for i := 0; i < 3; i++ { + err := v.worker.Send([]byte(fmt.Sprintf("msg-%d", i))) + assert.NoError(t, err) + } + + honeybeetest.Eventually(t, func() bool { + return count.Load() == 3 + }, "expected heartbeats") + }) + + t.Run("Send returns error after connection closed", func(t *testing.T) { + v := setupWorkerTest(t) + defer v.worker.Stop() + + go v.worker.Start(v.pool, v.wg) + + v.conn.Close() + + honeybeetest.Eventually(t, func() bool { + return v.conn.State() == transport.StateClosed + }, "expected connection closed") + + err := v.worker.Send([]byte("hello")) + assert.Error(t, err) + }) +} diff --git a/responderpool/worker_watchdog_test.go b/inbound/worker_watchdog_test.go similarity index 64% rename from responderpool/worker_watchdog_test.go rename to inbound/worker_watchdog_test.go index a612607..614270e 100644 --- a/responderpool/worker_watchdog_test.go +++ b/inbound/worker_watchdog_test.go @@ -1,4 +1,4 @@ -package responderpool +package inbound import ( "context" @@ -16,7 +16,7 @@ func TestRunWatchdog(t *testing.T) { defer cancel() called := atomic.Bool{} - go RunWatchdog(ctx, func() { called.Store(true) }, heartbeat, 200*time.Millisecond) + go RunWatchdog(ctx, func(WorkerExitKind) { called.Store(true) }, heartbeat, 200*time.Millisecond) for i := 0; i < 5; i++ { time.Sleep(20 * time.Millisecond) @@ -33,10 +33,12 @@ func TestRunWatchdog(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + var gotKind WorkerExitKind count := atomic.Int32{} done := make(chan struct{}) - go RunWatchdog(ctx, func() { + go RunWatchdog(ctx, func(kind WorkerExitKind) { count.Add(1) + gotKind = kind close(done) }, heartbeat, 20*time.Millisecond) @@ -50,6 +52,7 @@ func TestRunWatchdog(t *testing.T) { }, "expected onInactive") assert.Equal(t, int32(1), count.Load()) + assert.Equal(t, ExitInactive, gotKind) }) t.Run("ctx.Done exits without calling onInactive", func(t *testing.T) { @@ -59,7 +62,7 @@ func TestRunWatchdog(t *testing.T) { called := atomic.Bool{} done := make(chan struct{}) go func() { - RunWatchdog(ctx, func() { called.Store(true) }, heartbeat, 20*time.Second) + RunWatchdog(ctx, func(WorkerExitKind) { called.Store(true) }, heartbeat, 20*time.Second) close(done) }() @@ -77,14 +80,14 @@ func TestRunWatchdog(t *testing.T) { assert.False(t, called.Load()) }) - t.Run("zero timeout exits on ctx.Done without firing", func(t *testing.T) { + t.Run("zero timeout exits on ctx.Done without firing onInactive", func(t *testing.T) { heartbeat := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) called := atomic.Bool{} done := make(chan struct{}) go func() { - RunWatchdog(ctx, func() { called.Store(true) }, heartbeat, 0) + RunWatchdog(ctx, func(WorkerExitKind) { called.Store(true) }, heartbeat, 0) close(done) }() @@ -101,4 +104,32 @@ func TestRunWatchdog(t *testing.T) { assert.False(t, called.Load()) }) + + t.Run("disabled keepalive drains heartbeats without blocking", func(t *testing.T) { + heartbeat := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + go func() { + RunWatchdog(ctx, func(WorkerExitKind) {}, heartbeat, 0) + close(done) + }() + + // these must not block + for i := 0; i < 5; i++ { + heartbeat <- struct{}{} + } + + cancel() + + honeybeetest.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, "expected RunWatchdog to exit") + }) } diff --git a/initiatorpool/config.go b/outbound/config.go similarity index 99% rename from initiatorpool/config.go rename to outbound/config.go index 7816668..43cf587 100644 --- a/initiatorpool/config.go +++ b/outbound/config.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" diff --git a/initiatorpool/config_pool_test.go b/outbound/config_pool_test.go similarity index 99% rename from initiatorpool/config_pool_test.go rename to outbound/config_pool_test.go index 740e73f..4bbe9f1 100644 --- a/initiatorpool/config_pool_test.go +++ b/outbound/config_pool_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "git.wisehodl.dev/jay/go-honeybee/transport" diff --git a/initiatorpool/errors.go b/outbound/errors.go similarity index 97% rename from initiatorpool/errors.go rename to outbound/errors.go index a20fcff..9282aed 100644 --- a/initiatorpool/errors.go +++ b/outbound/errors.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import "errors" import "fmt" diff --git a/initiatorpool/helper_test.go b/outbound/helper_test.go similarity index 94% rename from initiatorpool/helper_test.go rename to outbound/helper_test.go index c359449..7fa555d 100644 --- a/initiatorpool/helper_test.go +++ b/outbound/helper_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "fmt" @@ -9,7 +9,7 @@ import ( "testing" ) -func setupWorkerTestConnection(t *testing.T) ( +func setupTestConnection(t *testing.T) ( conn *transport.Connection, mockSocket *honeybeetest.MockSocket, incomingData chan honeybeetest.MockIncomingData, diff --git a/initiatorpool/pool.go b/outbound/pool.go similarity index 92% rename from initiatorpool/pool.go rename to outbound/pool.go index a86d63a..c057279 100644 --- a/initiatorpool/pool.go +++ b/outbound/pool.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" @@ -11,26 +11,6 @@ import ( // Types -type Peer struct { - id string - worker Worker -} - -type PoolPlugin struct { - Inbox chan<- InboxMessage - Events chan<- PoolEvent - Errors chan<- error - Logger *slog.Logger - Dialer types.Dialer - ConnectionConfig *transport.ConnectionConfig -} - -type InboxMessage struct { - ID string - Data []byte - ReceivedAt time.Time -} - type PoolEventKind string const ( @@ -43,8 +23,28 @@ type PoolEvent struct { Kind PoolEventKind } +type InboxMessage struct { + ID string + Data []byte + ReceivedAt time.Time +} + +type PoolPlugin struct { + Inbox chan<- InboxMessage + Events chan<- PoolEvent + Errors chan<- error + Logger *slog.Logger + Dialer types.Dialer + ConnectionConfig *transport.ConnectionConfig +} + // Pool +type Peer struct { + id string + worker Worker +} + type Pool struct { ctx context.Context cancel context.CancelFunc @@ -85,7 +85,7 @@ func NewPool(ctx context.Context, config *PoolConfig, logger *slog.Logger, pctx, cancel := context.WithCancel(ctx) - p := &Pool{ + return &Pool{ ctx: pctx, cancel: cancel, peers: make(map[string]*Peer), @@ -95,9 +95,7 @@ func NewPool(ctx context.Context, config *PoolConfig, logger *slog.Logger, dialer: transport.NewDialer(), config: config, logger: logger, - } - - return p, nil + }, nil } func (p *Pool) Peers() []string { @@ -111,15 +109,15 @@ func (p *Pool) Peers() []string { return ids } -func (p *Pool) Inbox() chan InboxMessage { +func (p *Pool) Inbox() <-chan InboxMessage { return p.inbox } -func (p *Pool) Events() chan PoolEvent { +func (p *Pool) Events() <-chan PoolEvent { return p.events } -func (p *Pool) Errors() chan error { +func (p *Pool) Errors() <-chan error { return p.errors } @@ -165,9 +163,8 @@ func (p *Pool) Connect(id string) error { if p.closed { return NewPoolError(ErrPoolClosed) } - _, exists := p.peers[id] - if exists { + if _, exists := p.peers[id]; exists { return NewPoolError(ErrPeerExists) } @@ -181,7 +178,8 @@ func (p *Pool) Connect(id string) error { if p.logger != nil { logger = p.logger.With("id", id) } - ctx := PoolPlugin{ + + pool := PoolPlugin{ Inbox: p.inbox, Events: p.events, Errors: p.errors, @@ -191,7 +189,7 @@ func (p *Pool) Connect(id string) error { } p.wg.Add(1) - go worker.Start(ctx, &p.wg) + go worker.Start(pool, &p.wg) p.peers[id] = &Peer{id: id, worker: worker} @@ -205,18 +203,17 @@ func (p *Pool) Remove(id string) error { } p.mu.Lock() + defer p.mu.Unlock() + if p.closed { - p.mu.Unlock() return NewPoolError(ErrPoolClosed) } peer, exists := p.peers[id] if !exists { - p.mu.Unlock() return NewPoolError(ErrPeerNotFound) } delete(p.peers, id) - p.mu.Unlock() peer.worker.Stop() diff --git a/initiatorpool/pool_test.go b/outbound/pool_test.go similarity index 99% rename from initiatorpool/pool_test.go rename to outbound/pool_test.go index 206509c..92e8fc2 100644 --- a/initiatorpool/pool_test.go +++ b/outbound/pool_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" @@ -11,6 +11,8 @@ import ( "testing" ) +// Helpers + func setupPool(t *testing.T) (*Pool, *honeybeetest.MockDialer) { t.Helper() pool, err := NewPool(context.Background(), nil, nil) @@ -24,6 +26,25 @@ func setupPool(t *testing.T) (*Pool, *honeybeetest.MockDialer) { return pool, dialer } +func expectEvent( + t *testing.T, + events chan PoolEvent, + expectedURL string, + expectedKind PoolEventKind, +) { + t.Helper() + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.ID == expectedURL && e.Kind == expectedKind + default: + return false + } + }, fmt.Sprintf("expected event: URL=%q, Kind=%q", expectedURL, expectedKind)) +} + +// Tests + func TestPoolConnect(t *testing.T) { t.Run("successfully adds connection", func(t *testing.T) { pool, _ := setupPool(t) @@ -148,20 +169,3 @@ func TestPoolSend(t *testing.T) { pool.Close() } - -func expectEvent( - t *testing.T, - events chan PoolEvent, - expectedURL string, - expectedKind PoolEventKind, -) { - t.Helper() - honeybeetest.Eventually(t, func() bool { - select { - case e := <-events: - return e.ID == expectedURL && e.Kind == expectedKind - default: - return false - } - }, fmt.Sprintf("expected event: URL=%q, Kind=%q", expectedURL, expectedKind)) -} diff --git a/initiatorpool/worker.go b/outbound/worker.go similarity index 86% rename from initiatorpool/worker.go rename to outbound/worker.go index 9d6b69f..99d0e49 100644 --- a/initiatorpool/worker.go +++ b/outbound/worker.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "container/list" @@ -23,14 +23,12 @@ type ReceivedMessage struct { } type DefaultWorker struct { - Ctx context.Context - Cancel context.CancelFunc - - Id string - Config *WorkerConfig - - Conn atomic.Pointer[transport.Connection] - Heartbeat chan struct{} + id string + conn atomic.Pointer[transport.Connection] + heartbeat chan struct{} + config *WorkerConfig + ctx context.Context + cancel context.CancelFunc } func NewWorker( @@ -42,19 +40,17 @@ func NewWorker( if config == nil { config = GetDefaultWorkerConfig() } - - err := ValidateWorkerConfig(config) - if err != nil { + if err := ValidateWorkerConfig(config); err != nil { return nil, err } - pool, cancel := context.WithCancel(ctx) + wctx, wcancel := context.WithCancel(ctx) w := &DefaultWorker{ - Ctx: pool, - Cancel: cancel, - Id: id, - Config: config, - Heartbeat: make(chan struct{}), + id: id, + config: config, + heartbeat: make(chan struct{}), + ctx: wctx, + cancel: wcancel, } return w, nil @@ -74,31 +70,31 @@ func (w *DefaultWorker) Start( go func() { defer owg.Done() - RunDialer(w.Id, w.Ctx, pool, dial, newConn) + RunDialer(w.id, w.ctx, pool, dial, newConn) }() go func() { defer owg.Done() - RunKeepalive(w.Ctx, w.Heartbeat, keepalive, w.Config.KeepaliveTimeout) + RunKeepalive(w.ctx, w.heartbeat, keepalive, w.config.KeepaliveTimeout) }() go func() { defer owg.Done() - RunForwarder(w.Id, w.Ctx, messages, pool.Inbox, w.Config.MaxQueueSize) + RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize) }() go func() { defer owg.Done() session := &Session{ - id: w.Id, - connPtr: &w.Conn, + id: w.id, + connPtr: &w.conn, messages: messages, - heartbeat: w.Heartbeat, + heartbeat: w.heartbeat, dial: dial, keepalive: keepalive, newConn: newConn, } - session.Start(w.Ctx, pool) + session.Start(w.ctx, pool) }() owg.Wait() @@ -106,25 +102,23 @@ func (w *DefaultWorker) Start( } func (w *DefaultWorker) Stop() { - w.Cancel() + w.cancel() } func (w *DefaultWorker) Send(data []byte) error { - conn := w.Conn.Load() + conn := w.conn.Load() if conn == nil { // connection not established by session - return NewWorkerError(w.Id, ErrConnectionUnavailable) + return NewWorkerError(w.id, ErrConnectionUnavailable) } - err := conn.Send(data) - - if err != nil { - return NewWorkerError(w.Id, err) + if err := conn.Send(data); err != nil { + return NewWorkerError(w.id, err) } select { - case w.Heartbeat <- struct{}{}: - case <-w.Ctx.Done(): + case w.heartbeat <- struct{}{}: + case <-w.ctx.Done(): } return nil @@ -313,11 +307,15 @@ func RunKeepalive( ) { // disable keepalive timeout if not configured if timeout <= 0 { + // drain heartbeats // wait for cancel and exit - select { - case <-ctx.Done(): + for { + select { + case <-heartbeat: + case <-ctx.Done(): + return + } } - return } timer := time.NewTimer(timeout) diff --git a/initiatorpool/worker_dialer_test.go b/outbound/worker_dialer_test.go similarity index 99% rename from initiatorpool/worker_dialer_test.go rename to outbound/worker_dialer_test.go index c8e22dc..29e901c 100644 --- a/initiatorpool/worker_dialer_test.go +++ b/outbound/worker_dialer_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" diff --git a/responderpool/worker_forwarder_test.go b/outbound/worker_forwarder_test.go similarity index 99% rename from responderpool/worker_forwarder_test.go rename to outbound/worker_forwarder_test.go index c0459f8..c16288f 100644 --- a/responderpool/worker_forwarder_test.go +++ b/outbound/worker_forwarder_test.go @@ -1,4 +1,4 @@ -package responderpool +package outbound import ( "context" diff --git a/initiatorpool/worker_keepalive_test.go b/outbound/worker_keepalive_test.go similarity index 76% rename from initiatorpool/worker_keepalive_test.go rename to outbound/worker_keepalive_test.go index f633e04..c94cabb 100644 --- a/initiatorpool/worker_keepalive_test.go +++ b/outbound/worker_keepalive_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" @@ -76,4 +76,27 @@ func TestRunKeepalive(t *testing.T) { } }, "expected done signal") }) + + t.Run("disabled keepalive drains heartbeats without blocking", func(t *testing.T) { + heartbeat := make(chan struct{}) + keepalive := make(chan struct{}, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go RunKeepalive(ctx, heartbeat, keepalive, 0) + + // these must not block + for i := 0; i < 5; i++ { + heartbeat <- struct{}{} + } + + honeybeetest.Never(t, func() bool { + select { + case <-keepalive: + return true + default: + return false + } + }, "keepalive signal should not fire when disabled") + }) } diff --git a/initiatorpool/worker_send_test.go b/outbound/worker_send_test.go similarity index 77% rename from initiatorpool/worker_send_test.go rename to outbound/worker_send_test.go index aded7da..32587ff 100644 --- a/initiatorpool/worker_send_test.go +++ b/outbound/worker_send_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" @@ -11,7 +11,7 @@ import ( func TestWorkerSend(t *testing.T) { t.Run("data sent to mock socket", func(t *testing.T) { - conn, _, _, outgoingData := setupWorkerTestConnection(t) + conn, _, _, outgoingData := setupTestConnection(t) defer conn.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -20,13 +20,13 @@ func TestWorkerSend(t *testing.T) { heartbeatCount := atomic.Int32{} w := &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Heartbeat: heartbeat, + ctx: ctx, + cancel: cancel, + id: "wss://test", + heartbeat: heartbeat, } - w.Conn.Store(conn) - defer w.Cancel() + w.conn.Store(conn) + defer w.cancel() go func() { for range heartbeat { @@ -53,7 +53,7 @@ func TestWorkerSend(t *testing.T) { }) t.Run("sends one heartbeat per successful send", func(t *testing.T) { - conn, _, _, _ := setupWorkerTestConnection(t) + conn, _, _, _ := setupTestConnection(t) defer conn.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -62,13 +62,13 @@ func TestWorkerSend(t *testing.T) { heartbeatCount := atomic.Int32{} w := &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Heartbeat: heartbeat, + ctx: ctx, + cancel: cancel, + id: "wss://test", + heartbeat: heartbeat, } - w.Conn.Store(conn) - defer w.Cancel() + w.conn.Store(conn) + defer w.cancel() go func() { for range heartbeat { @@ -93,12 +93,12 @@ func TestWorkerSend(t *testing.T) { heartbeat := make(chan struct{}) w := &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Heartbeat: heartbeat, + ctx: ctx, + cancel: cancel, + id: "wss://test", + heartbeat: heartbeat, } - defer w.Cancel() + defer w.cancel() go func() { for range heartbeat { diff --git a/initiatorpool/worker_session_inner_test.go b/outbound/worker_session_inner_test.go similarity index 92% rename from initiatorpool/worker_session_inner_test.go rename to outbound/worker_session_inner_test.go index b4696d9..adeb160 100644 --- a/initiatorpool/worker_session_inner_test.go +++ b/outbound/worker_session_inner_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" @@ -15,7 +15,7 @@ import ( func TestRunReader(t *testing.T) { t.Run("message arrives with correct data and non-zero receivedAt", func(t *testing.T) { - conn, _, incomingData, _ := setupWorkerTestConnection(t) + conn, _, incomingData, _ := setupTestConnection(t) defer conn.Close() messages := make(chan ReceivedMessage, 1) @@ -46,7 +46,7 @@ func TestRunReader(t *testing.T) { }) t.Run("heartbeat receives one signal per message", func(t *testing.T) { - conn, _, incomingData, _ := setupWorkerTestConnection(t) + conn, _, incomingData, _ := setupTestConnection(t) defer conn.Close() messages := make(chan ReceivedMessage, 10) @@ -80,7 +80,7 @@ func TestRunReader(t *testing.T) { }) t.Run("incoming channel close calls conn.Close and onStop", func(t *testing.T) { - conn, _, incomingData, _ := setupWorkerTestConnection(t) + conn, _, incomingData, _ := setupTestConnection(t) messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}) @@ -118,7 +118,7 @@ func TestRunReader(t *testing.T) { }) t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { - conn, _, _, _ := setupWorkerTestConnection(t) + conn, _, _, _ := setupTestConnection(t) messages := make(chan ReceivedMessage, 1) heartbeat := make(chan struct{}) @@ -145,7 +145,7 @@ func TestRunReader(t *testing.T) { func TestRunStopMonitor(t *testing.T) { t.Run("keepalive signal calls conn.Close and cancel", func(t *testing.T) { - conn, _, _, _ := setupWorkerTestConnection(t) + conn, _, _, _ := setupTestConnection(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -170,7 +170,7 @@ func TestRunStopMonitor(t *testing.T) { }) t.Run("ctx.Done calls conn.Close and cancel", func(t *testing.T) { - conn, _, _, _ := setupWorkerTestConnection(t) + conn, _, _, _ := setupTestConnection(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/initiatorpool/worker_session_test.go b/outbound/worker_session_test.go similarity index 98% rename from initiatorpool/worker_session_test.go rename to outbound/worker_session_test.go index 209aabe..ad21468 100644 --- a/initiatorpool/worker_session_test.go +++ b/outbound/worker_session_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" @@ -45,7 +45,7 @@ func setup(t *testing.T) ( ) { t.Helper() ctx, cancel = context.WithCancel(context.Background()) - conn, mockSocket, incomingData, outgoingData := setupWorkerTestConnection(t) + conn, mockSocket, incomingData, outgoingData := setupTestConnection(t) vars = testVars{ id: "wss://test", dial: make(chan struct{}, 1), @@ -319,7 +319,7 @@ func TestRunSessionDisconnect(t *testing.T) { close(v.incomingData) drainEvent(t, events, EventDisconnected) - conn2, _, _, _ := setupWorkerTestConnection(t) + conn2, _, _, _ := setupTestConnection(t) v.newConn <- conn2 drainEvent(t, events, EventConnected) }) diff --git a/initiatorpool/worker_start_test.go b/outbound/worker_start_test.go similarity index 94% rename from initiatorpool/worker_start_test.go rename to outbound/worker_start_test.go index 6e748f1..0e86f38 100644 --- a/initiatorpool/worker_start_test.go +++ b/outbound/worker_start_test.go @@ -1,4 +1,4 @@ -package initiatorpool +package outbound import ( "context" @@ -34,11 +34,11 @@ func makeWorkerContext(t *testing.T) ( func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *DefaultWorker { t.Helper() return &DefaultWorker{ - Ctx: ctx, - Cancel: cancel, - Id: "wss://test", - Config: GetDefaultWorkerConfig(), - Heartbeat: make(chan struct{}), + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: GetDefaultWorkerConfig(), + heartbeat: make(chan struct{}), } } @@ -67,7 +67,7 @@ func TestWorkerStart(t *testing.T) { honeybeetest.Eventually(t, func() bool { select { case e := <-events: - return e.ID == w.Id && e.Kind == EventConnected + return e.ID == w.id && e.Kind == EventConnected default: return false } @@ -80,7 +80,7 @@ func TestWorkerStart(t *testing.T) { w := makeWorker(t, ctx, cancel) _, events, _, pool := makeWorkerContext(t) - _, mockSocket, _, outgoingData := setupWorkerTestConnection(t) + _, mockSocket, _, outgoingData := setupTestConnection(t) pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup @@ -154,7 +154,7 @@ func TestWorkerStart(t *testing.T) { honeybeetest.Eventually(t, func() bool { select { case msg := <-inbox: - return msg.ID == w.Id && string(msg.Data) == "hello" + return msg.ID == w.id && string(msg.Data) == "hello" default: return false } @@ -167,7 +167,7 @@ func TestWorkerStart(t *testing.T) { w := makeWorker(t, ctx, cancel) _, events, _, pool := makeWorkerContext(t) - _, mockSocket, incomingData, _ := setupWorkerTestConnection(t) + _, mockSocket, incomingData, _ := setupTestConnection(t) pool.Dialer = mockDialer(mockSocket) var wg sync.WaitGroup diff --git a/responderpool/pool.go b/responderpool/pool.go deleted file mode 100644 index 72faedc..0000000 --- a/responderpool/pool.go +++ /dev/null @@ -1,24 +0,0 @@ -package responderpool - -import ( - "time" -) - -type PoolEventKind string - -const ( - EventPeerDisconnected PoolEventKind = "disconnected" - EventPeerDropped PoolEventKind = "dropped" - EventPeerEvicted PoolEventKind = "evicted" -) - -type PoolEvent struct { - ID string - Kind PoolEventKind -} - -type InboxMessage struct { - ID string - Data []byte - ReceivedAt time.Time -}