From 031df8c98d04b0986bb9015a6dce62d10275ba13 Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 15 Apr 2026 17:03:49 -0400 Subject: [PATCH] Refactored pool into an implementation-driven module with inbound and outbound variants. --- pool.go | 223 ++++++++++++++++++++++++++++++++++----------------- pool_test.go | 37 +++++---- 2 files changed, 169 insertions(+), 91 deletions(-) diff --git a/pool.go b/pool.go index e98b898..55bdb36 100644 --- a/pool.go +++ b/pool.go @@ -7,13 +7,15 @@ import ( "time" ) -type poolConnection struct { - inner *Connection - stop chan struct{} +// Types + +type peer struct { + conn *Connection + stop chan struct{} } -type InboundMessage struct { - URL string +type InboxMessage struct { + ID string Data []byte ReceivedAt time.Time } @@ -37,48 +39,38 @@ func (s PoolEventKind) String() string { } type PoolEvent struct { - URL string + ID string Kind PoolEventKind } -type Pool struct { - mu sync.RWMutex - wg sync.WaitGroup - closed bool - connections map[string]*poolConnection - inbound chan InboundMessage - events chan PoolEvent - errors chan error - done chan struct{} - config *Config - dialer Dialer - logger *slog.Logger +// Pool Implementation + +type Pool interface { + Send(id string, data []byte) error + Inbox() <-chan InboxMessage + Events() <-chan PoolEvent + Errors() <-chan error + Close() } -func NewPool(config *Config, logger *slog.Logger) (*Pool, error) { - if config == nil { - config = GetDefaultConfig() - } +// Base Struct - if err := ValidateConfig(config); err != nil { - return nil, err - } +type pool struct { + peers map[string]*peer + inbox chan InboxMessage + events chan PoolEvent + errors chan error + done chan struct{} - pool := &Pool{ - connections: make(map[string]*poolConnection), - inbound: make(chan InboundMessage, 256), - events: make(chan PoolEvent, 10), - errors: make(chan error, 10), - done: make(chan struct{}), - config: config, - dialer: NewDialer(), - logger: logger, - } + config *Config + logger *slog.Logger - return pool, nil + mu sync.RWMutex + wg sync.WaitGroup + closed bool } -func (p *Pool) Close() { +func (p *pool) closeAll() { p.mu.Lock() if p.closed { p.mu.Unlock() @@ -88,25 +80,104 @@ func (p *Pool) Close() { p.closed = true close(p.done) - connections := p.connections - p.connections = make(map[string]*poolConnection) + peers := p.peers + p.peers = make(map[string]*peer) p.mu.Unlock() - for _, conn := range connections { - conn.inner.Close() + for _, conn := range peers { + conn.conn.Close() } go func() { p.wg.Wait() - close(p.inbound) + close(p.inbox) close(p.events) close(p.errors) }() } -func (p *Pool) Add(rawURL string) error { - url, err := NormalizeURL(rawURL) +func (p *pool) removePeer(id string) error { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return errors.NewPoolError("pool is closed") + } + + peer, exists := p.peers[id] + if !exists { + p.mu.Unlock() + return errors.NewPoolError("connection not found") + } + delete(p.peers, id) + p.mu.Unlock() + + close(peer.stop) + peer.conn.Close() + + select { + case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}: + case <-p.done: + return nil + } + + return nil +} + +// Outbound Pool + +type OutboundPool struct { + *pool + dialer Dialer +} + +func NewOutboundPool(config *Config, logger *slog.Logger) (*OutboundPool, error) { + if config == nil { + config = GetDefaultConfig() + } + + if err := ValidateConfig(config); err != nil { + return nil, err + } + + p := &OutboundPool{ + pool: &pool{ + peers: make(map[string]*peer), + inbox: make(chan InboxMessage, 256), + events: make(chan PoolEvent, 10), + errors: make(chan error, 10), + done: make(chan struct{}), + config: config, + logger: logger, + }, + dialer: NewDialer(), + } + + return p, nil +} + +func (p *OutboundPool) Peers() map[string]*peer { + return p.peers +} + +func (p *OutboundPool) Inbox() chan InboxMessage { + return p.inbox +} + +func (p *OutboundPool) Events() chan PoolEvent { + return p.events +} + +func (p *OutboundPool) Errors() chan error { + return p.errors +} + +func (p *OutboundPool) Close() { + p.closeAll() +} + +func (p *OutboundPool) Connect(url string) error { + url, err := NormalizeURL(url) if err != nil { return err } @@ -117,7 +188,7 @@ func (p *Pool) Add(rawURL string) error { p.mu.Unlock() return errors.NewPoolError("pool is closed") } - _, exists := p.connections[url] + _, exists := p.peers[url] p.mu.Unlock() if exists { @@ -150,20 +221,20 @@ func (p *Pool) Add(rawURL string) error { // Add connection to pool stop := make(chan struct{}) - if _, exists := p.connections[url]; exists { + if _, exists := p.peers[url]; exists { // Another process connected to this url while this one was connecting // Discard this connection and retain the existing one p.mu.Unlock() conn.Close() return errors.NewPoolError("connection already exists") } - p.connections[url] = &poolConnection{inner: conn, stop: stop} + p.peers[url] = &peer{conn: conn, stop: stop} p.mu.Unlock() // TODO: start this connection's incoming message forwarder select { - case p.events <- PoolEvent{URL: url, Kind: EventConnected}: + case p.events <- PoolEvent{ID: url, Kind: EventConnected}: case <-p.done: return nil } @@ -171,34 +242,38 @@ func (p *Pool) Add(rawURL string) error { return nil } -func (p *Pool) Remove(rawURL string) error { - url, err := NormalizeURL(rawURL) +func (p *OutboundPool) Remove(url string) error { + url, err := NormalizeURL(url) if err != nil { return err } - p.mu.Lock() - if p.closed { - p.mu.Unlock() - return errors.NewPoolError("pool is closed") - } - - conn, exists := p.connections[url] - if !exists { - p.mu.Unlock() - return errors.NewPoolError("connection not found") - } - delete(p.connections, url) - p.mu.Unlock() - - close(conn.stop) - conn.inner.Close() - - select { - case p.events <- PoolEvent{URL: url, Kind: EventDisconnected}: - case <-p.done: - return nil - } - - return nil + return p.removePeer(url) +} + +// Inbound Pool + +type InboundPool struct { + *pool + idGenerator func() string +} + +func (p *InboundPool) Peers() map[string]*peer { + return p.peers +} + +func (p *InboundPool) Inbox() chan InboxMessage { + return p.inbox +} + +func (p *InboundPool) Events() chan PoolEvent { + return p.events +} + +func (p *InboundPool) Errors() chan error { + return p.errors +} + +func (p *InboundPool) Close() { + p.closeAll() } diff --git a/pool_test.go b/pool_test.go index 9893da2..7289df3 100644 --- a/pool_test.go +++ b/pool_test.go @@ -8,7 +8,7 @@ import ( "time" ) -func TestPoolAdd(t *testing.T) { +func TestPoolConnect(t *testing.T) { t.Run("successfully adds connection", func(t *testing.T) { mockSocket := NewMockSocket() mockDialer := &MockDialer{ @@ -17,23 +17,26 @@ func TestPoolAdd(t *testing.T) { }, } - pool, err := NewPool(nil, nil) + pool, err := NewOutboundPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer - err = pool.Add("wss://test") + err = pool.Connect("wss://test") assert.NoError(t, err) assert.Eventually(t, func() bool { select { case event := <-pool.events: - return event.URL == "wss://test" && event.Kind == EventConnected + return event.ID == "wss://test" && event.Kind == EventConnected default: return false } }, testTimeout, testTick) + _, exists := pool.peers["wss://test"] + assert.True(t, exists) + pool.Close() }) @@ -45,27 +48,27 @@ func TestPoolAdd(t *testing.T) { }, } - pool, err := NewPool(nil, nil) + pool, err := NewOutboundPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer - err = pool.Add("wss://test") + err = pool.Connect("wss://test") assert.NoError(t, err) // trailing slash normalizes to same key - err = pool.Add("wss://test/") + err = pool.Connect("wss://test/") assert.Error(t, err) assert.ErrorContains(t, err, "already exists") pool.mu.RLock() - assert.Len(t, pool.connections, 1) + assert.Len(t, pool.peers, 1) pool.mu.RUnlock() pool.Close() }) t.Run("fails to add connection", func(t *testing.T) { - pool, err := NewPool(&Config{ + pool, err := NewOutboundPool(&Config{ Retry: &RetryConfig{ MaxRetries: 1, InitialDelay: 1 * time.Millisecond, @@ -79,11 +82,11 @@ func TestPoolAdd(t *testing.T) { }, } - err = pool.Add("wss://test") + err = pool.Connect("wss://test") assert.Error(t, err) pool.mu.RLock() - assert.Len(t, pool.connections, 0) + assert.Len(t, pool.peers, 0) pool.mu.RUnlock() select { @@ -105,11 +108,11 @@ func TestPoolRemove(t *testing.T) { }, } - pool, err := NewPool(nil, nil) + pool, err := NewOutboundPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer - pool.Add("wss://test") + pool.Connect("wss://test") expectEvent(t, pool.events, "wss://test", EventConnected) err = pool.Remove("wss://test/") @@ -121,7 +124,7 @@ func TestPoolRemove(t *testing.T) { // connection no longer in pool pool.mu.Lock() defer pool.mu.Unlock() - _, ok := pool.connections["wss://peer2"] + _, ok := pool.peers["wss://peer2"] assert.False(t, ok, "connection is still in pool") }) @@ -133,7 +136,7 @@ func TestPoolRemove(t *testing.T) { }, } - pool, err := NewPool(nil, nil) + pool, err := NewOutboundPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -150,7 +153,7 @@ func TestPoolRemove(t *testing.T) { }, } - pool, err := NewPool(nil, nil) + pool, err := NewOutboundPool(nil, nil) assert.NoError(t, err) pool.dialer = mockDialer @@ -174,7 +177,7 @@ func expectEvent( assert.Eventually(t, func() bool { select { case e := <-events: - return e.URL == expectedURL && e.Kind == expectedKind + return e.ID == expectedURL && e.Kind == expectedKind default: return false }