Refactored pool into an implementation-driven module with inbound and outbound variants.

This commit is contained in:
Jay
2026-04-15 17:03:49 -04:00
parent 8402cbfa8d
commit 031df8c98d
2 changed files with 169 additions and 91 deletions

223
pool.go
View File

@@ -7,13 +7,15 @@ import (
"time" "time"
) )
type poolConnection struct { // Types
inner *Connection
stop chan struct{} type peer struct {
conn *Connection
stop chan struct{}
} }
type InboundMessage struct { type InboxMessage struct {
URL string ID string
Data []byte Data []byte
ReceivedAt time.Time ReceivedAt time.Time
} }
@@ -37,48 +39,38 @@ func (s PoolEventKind) String() string {
} }
type PoolEvent struct { type PoolEvent struct {
URL string ID string
Kind PoolEventKind Kind PoolEventKind
} }
type Pool struct { // Pool Implementation
mu sync.RWMutex
wg sync.WaitGroup type Pool interface {
closed bool Send(id string, data []byte) error
connections map[string]*poolConnection Inbox() <-chan InboxMessage
inbound chan InboundMessage Events() <-chan PoolEvent
events chan PoolEvent Errors() <-chan error
errors chan error Close()
done chan struct{}
config *Config
dialer Dialer
logger *slog.Logger
} }
func NewPool(config *Config, logger *slog.Logger) (*Pool, error) { // Base Struct
if config == nil {
config = GetDefaultConfig()
}
if err := ValidateConfig(config); err != nil { type pool struct {
return nil, err peers map[string]*peer
} inbox chan InboxMessage
events chan PoolEvent
errors chan error
done chan struct{}
pool := &Pool{ config *Config
connections: make(map[string]*poolConnection), logger *slog.Logger
inbound: make(chan InboundMessage, 256),
events: make(chan PoolEvent, 10),
errors: make(chan error, 10),
done: make(chan struct{}),
config: config,
dialer: NewDialer(),
logger: logger,
}
return pool, nil mu sync.RWMutex
wg sync.WaitGroup
closed bool
} }
func (p *Pool) Close() { func (p *pool) closeAll() {
p.mu.Lock() p.mu.Lock()
if p.closed { if p.closed {
p.mu.Unlock() p.mu.Unlock()
@@ -88,25 +80,104 @@ func (p *Pool) Close() {
p.closed = true p.closed = true
close(p.done) close(p.done)
connections := p.connections peers := p.peers
p.connections = make(map[string]*poolConnection) p.peers = make(map[string]*peer)
p.mu.Unlock() p.mu.Unlock()
for _, conn := range connections { for _, conn := range peers {
conn.inner.Close() conn.conn.Close()
} }
go func() { go func() {
p.wg.Wait() p.wg.Wait()
close(p.inbound) close(p.inbox)
close(p.events) close(p.events)
close(p.errors) close(p.errors)
}() }()
} }
func (p *Pool) Add(rawURL string) error { func (p *pool) removePeer(id string) error {
url, err := NormalizeURL(rawURL) p.mu.Lock()
if p.closed {
p.mu.Unlock()
return errors.NewPoolError("pool is closed")
}
peer, exists := p.peers[id]
if !exists {
p.mu.Unlock()
return errors.NewPoolError("connection not found")
}
delete(p.peers, id)
p.mu.Unlock()
close(peer.stop)
peer.conn.Close()
select {
case p.events <- PoolEvent{ID: id, Kind: EventDisconnected}:
case <-p.done:
return nil
}
return nil
}
// Outbound Pool
type OutboundPool struct {
*pool
dialer Dialer
}
func NewOutboundPool(config *Config, logger *slog.Logger) (*OutboundPool, error) {
if config == nil {
config = GetDefaultConfig()
}
if err := ValidateConfig(config); err != nil {
return nil, err
}
p := &OutboundPool{
pool: &pool{
peers: make(map[string]*peer),
inbox: make(chan InboxMessage, 256),
events: make(chan PoolEvent, 10),
errors: make(chan error, 10),
done: make(chan struct{}),
config: config,
logger: logger,
},
dialer: NewDialer(),
}
return p, nil
}
func (p *OutboundPool) Peers() map[string]*peer {
return p.peers
}
func (p *OutboundPool) Inbox() chan InboxMessage {
return p.inbox
}
func (p *OutboundPool) Events() chan PoolEvent {
return p.events
}
func (p *OutboundPool) Errors() chan error {
return p.errors
}
func (p *OutboundPool) Close() {
p.closeAll()
}
func (p *OutboundPool) Connect(url string) error {
url, err := NormalizeURL(url)
if err != nil { if err != nil {
return err return err
} }
@@ -117,7 +188,7 @@ func (p *Pool) Add(rawURL string) error {
p.mu.Unlock() p.mu.Unlock()
return errors.NewPoolError("pool is closed") return errors.NewPoolError("pool is closed")
} }
_, exists := p.connections[url] _, exists := p.peers[url]
p.mu.Unlock() p.mu.Unlock()
if exists { if exists {
@@ -150,20 +221,20 @@ func (p *Pool) Add(rawURL string) error {
// Add connection to pool // Add connection to pool
stop := make(chan struct{}) stop := make(chan struct{})
if _, exists := p.connections[url]; exists { if _, exists := p.peers[url]; exists {
// Another process connected to this url while this one was connecting // Another process connected to this url while this one was connecting
// Discard this connection and retain the existing one // Discard this connection and retain the existing one
p.mu.Unlock() p.mu.Unlock()
conn.Close() conn.Close()
return errors.NewPoolError("connection already exists") return errors.NewPoolError("connection already exists")
} }
p.connections[url] = &poolConnection{inner: conn, stop: stop} p.peers[url] = &peer{conn: conn, stop: stop}
p.mu.Unlock() p.mu.Unlock()
// TODO: start this connection's incoming message forwarder // TODO: start this connection's incoming message forwarder
select { select {
case p.events <- PoolEvent{URL: url, Kind: EventConnected}: case p.events <- PoolEvent{ID: url, Kind: EventConnected}:
case <-p.done: case <-p.done:
return nil return nil
} }
@@ -171,34 +242,38 @@ func (p *Pool) Add(rawURL string) error {
return nil return nil
} }
func (p *Pool) Remove(rawURL string) error { func (p *OutboundPool) Remove(url string) error {
url, err := NormalizeURL(rawURL) url, err := NormalizeURL(url)
if err != nil { if err != nil {
return err return err
} }
p.mu.Lock() return p.removePeer(url)
if p.closed { }
p.mu.Unlock()
return errors.NewPoolError("pool is closed") // Inbound Pool
}
type InboundPool struct {
conn, exists := p.connections[url] *pool
if !exists { idGenerator func() string
p.mu.Unlock() }
return errors.NewPoolError("connection not found")
} func (p *InboundPool) Peers() map[string]*peer {
delete(p.connections, url) return p.peers
p.mu.Unlock() }
close(conn.stop) func (p *InboundPool) Inbox() chan InboxMessage {
conn.inner.Close() return p.inbox
}
select {
case p.events <- PoolEvent{URL: url, Kind: EventDisconnected}: func (p *InboundPool) Events() chan PoolEvent {
case <-p.done: return p.events
return nil }
}
func (p *InboundPool) Errors() chan error {
return nil return p.errors
}
func (p *InboundPool) Close() {
p.closeAll()
} }

View File

@@ -8,7 +8,7 @@ import (
"time" "time"
) )
func TestPoolAdd(t *testing.T) { func TestPoolConnect(t *testing.T) {
t.Run("successfully adds connection", func(t *testing.T) { t.Run("successfully adds connection", func(t *testing.T) {
mockSocket := NewMockSocket() mockSocket := NewMockSocket()
mockDialer := &MockDialer{ mockDialer := &MockDialer{
@@ -17,23 +17,26 @@ func TestPoolAdd(t *testing.T) {
}, },
} }
pool, err := NewPool(nil, nil) pool, err := NewOutboundPool(nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
pool.dialer = mockDialer pool.dialer = mockDialer
err = pool.Add("wss://test") err = pool.Connect("wss://test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
select { select {
case event := <-pool.events: case event := <-pool.events:
return event.URL == "wss://test" && event.Kind == EventConnected return event.ID == "wss://test" && event.Kind == EventConnected
default: default:
return false return false
} }
}, testTimeout, testTick) }, testTimeout, testTick)
_, exists := pool.peers["wss://test"]
assert.True(t, exists)
pool.Close() pool.Close()
}) })
@@ -45,27 +48,27 @@ func TestPoolAdd(t *testing.T) {
}, },
} }
pool, err := NewPool(nil, nil) pool, err := NewOutboundPool(nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
pool.dialer = mockDialer pool.dialer = mockDialer
err = pool.Add("wss://test") err = pool.Connect("wss://test")
assert.NoError(t, err) assert.NoError(t, err)
// trailing slash normalizes to same key // trailing slash normalizes to same key
err = pool.Add("wss://test/") err = pool.Connect("wss://test/")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorContains(t, err, "already exists") assert.ErrorContains(t, err, "already exists")
pool.mu.RLock() pool.mu.RLock()
assert.Len(t, pool.connections, 1) assert.Len(t, pool.peers, 1)
pool.mu.RUnlock() pool.mu.RUnlock()
pool.Close() pool.Close()
}) })
t.Run("fails to add connection", func(t *testing.T) { t.Run("fails to add connection", func(t *testing.T) {
pool, err := NewPool(&Config{ pool, err := NewOutboundPool(&Config{
Retry: &RetryConfig{ Retry: &RetryConfig{
MaxRetries: 1, MaxRetries: 1,
InitialDelay: 1 * time.Millisecond, InitialDelay: 1 * time.Millisecond,
@@ -79,11 +82,11 @@ func TestPoolAdd(t *testing.T) {
}, },
} }
err = pool.Add("wss://test") err = pool.Connect("wss://test")
assert.Error(t, err) assert.Error(t, err)
pool.mu.RLock() pool.mu.RLock()
assert.Len(t, pool.connections, 0) assert.Len(t, pool.peers, 0)
pool.mu.RUnlock() pool.mu.RUnlock()
select { select {
@@ -105,11 +108,11 @@ func TestPoolRemove(t *testing.T) {
}, },
} }
pool, err := NewPool(nil, nil) pool, err := NewOutboundPool(nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
pool.dialer = mockDialer pool.dialer = mockDialer
pool.Add("wss://test") pool.Connect("wss://test")
expectEvent(t, pool.events, "wss://test", EventConnected) expectEvent(t, pool.events, "wss://test", EventConnected)
err = pool.Remove("wss://test/") err = pool.Remove("wss://test/")
@@ -121,7 +124,7 @@ func TestPoolRemove(t *testing.T) {
// connection no longer in pool // connection no longer in pool
pool.mu.Lock() pool.mu.Lock()
defer pool.mu.Unlock() defer pool.mu.Unlock()
_, ok := pool.connections["wss://peer2"] _, ok := pool.peers["wss://peer2"]
assert.False(t, ok, "connection is still in pool") assert.False(t, ok, "connection is still in pool")
}) })
@@ -133,7 +136,7 @@ func TestPoolRemove(t *testing.T) {
}, },
} }
pool, err := NewPool(nil, nil) pool, err := NewOutboundPool(nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
pool.dialer = mockDialer pool.dialer = mockDialer
@@ -150,7 +153,7 @@ func TestPoolRemove(t *testing.T) {
}, },
} }
pool, err := NewPool(nil, nil) pool, err := NewOutboundPool(nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
pool.dialer = mockDialer pool.dialer = mockDialer
@@ -174,7 +177,7 @@ func expectEvent(
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
select { select {
case e := <-events: case e := <-events:
return e.URL == expectedURL && e.Kind == expectedKind return e.ID == expectedURL && e.Kind == expectedKind
default: default:
return false return false
} }