From fdae43e715ff599a9f9553f329e310b495671a9b Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 14 Apr 2026 22:12:39 -0400 Subject: [PATCH] Started connection pool. Wrote Close and Add functions. --- errors/errors.go | 4 ++ pool.go | 168 +++++++++++++++++++++++++++++++++++++++++++++++ pool_test.go | 96 +++++++++++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 pool.go create mode 100644 pool_test.go diff --git a/errors/errors.go b/errors/errors.go index 564c8fc..0ba964c 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -23,3 +23,7 @@ func NewConfigError(text string) error { func NewConnectionError(text string) error { return fmt.Errorf("connection error: %s", text) } + +func NewPoolError(text string) error { + return fmt.Errorf("pool error: %s", text) +} diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..0143d90 --- /dev/null +++ b/pool.go @@ -0,0 +1,168 @@ +package honeybee + +import ( + "git.wisehodl.dev/jay/go-honeybee/errors" + "log/slog" + "sync" + "time" +) + +type poolConnection struct { + inner *Connection + stop chan struct{} +} + +type InboundMessage struct { + URL string + Data []byte + ReceivedAt time.Time +} + +type PoolEventKind int + +const ( + EventConnected PoolEventKind = iota + EventDisconnected +) + +func (s PoolEventKind) String() string { + switch s { + case EventConnected: + return "connected" + case EventDisconnected: + return "disconnected" + default: + return "unknown" + } +} + +type PoolEvent struct { + URL string + Kind PoolEventKind +} + +type Pool struct { + mu sync.RWMutex + wg sync.WaitGroup + closed bool + connections map[string]*poolConnection + inbound chan InboundMessage + events chan PoolEvent + errors chan error + done chan struct{} + config *Config + dialer Dialer + logger *slog.Logger +} + +func NewPool(config *Config, logger *slog.Logger) (*Pool, error) { + if config == nil { + config = GetDefaultConfig() + } + + if err := ValidateConfig(config); err != nil { + return nil, err + } + + pool := &Pool{ + connections: make(map[string]*poolConnection), + inbound: make(chan InboundMessage, 256), + events: make(chan PoolEvent, 10), + errors: make(chan error, 10), + done: make(chan struct{}), + config: config, + dialer: NewDialer(), + logger: logger, + } + + return pool, nil +} + +func (p *Pool) Close() { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return + } + + p.closed = true + close(p.done) + + connections := p.connections + p.connections = make(map[string]*poolConnection) + + p.mu.Unlock() + + for _, conn := range connections { + conn.inner.Close() + } + + go func() { + p.wg.Wait() + close(p.inbound) + close(p.events) + close(p.errors) + }() +} + +func (p *Pool) Add(rawURL string) error { + url, err := NormalizeURL(rawURL) + if err != nil { + return err + } + + // Check for existing connection in pool + p.mu.Lock() + _, exists := p.connections[url] + p.mu.Unlock() + + if exists { + return errors.NewPoolError("connection already exists") + } + + // Create new connection + var logger *slog.Logger + if p.logger != nil { + logger = p.logger.With("url", url) + } + conn, err := NewConnection(url, p.config, logger) + if err != nil { + return err + } + conn.dialer = p.dialer + + // Attempt to connect + if err := conn.Connect(); err != nil { + return err + } + + p.mu.Lock() + if p.closed { + // The pool closed while this connection was established. + p.mu.Unlock() + conn.Close() + return errors.NewPoolError("pool is closed") + } + + // Add connection to pool + stop := make(chan struct{}) + if _, exists := p.connections[url]; exists { + // Another process connected to this url while this one was connecting + // Discard this connection and retain the existing one + p.mu.Unlock() + conn.Close() + return errors.NewPoolError("connection already exists") + } + p.connections[url] = &poolConnection{inner: conn, stop: stop} + p.mu.Unlock() + + // TODO: start this connection's incoming message forwarder + + select { + case p.events <- PoolEvent{URL: url, Kind: EventConnected}: + case <-p.done: + return nil + } + + return nil +} diff --git a/pool_test.go b/pool_test.go new file mode 100644 index 0000000..e39c6c2 --- /dev/null +++ b/pool_test.go @@ -0,0 +1,96 @@ +package honeybee + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "net/http" + "testing" + "time" +) + +func TestPoolAdd(t *testing.T) { + t.Run("successfully adds connection", func(t *testing.T) { + mockSocket := NewMockSocket() + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return mockSocket, nil, nil + }, + } + + pool, err := NewPool(nil, nil) + assert.NoError(t, err) + + pool.dialer = mockDialer + + err = pool.Add("wss://test") + assert.NoError(t, err) + + select { + case event := <-pool.events: + assert.Equal(t, "wss://test", event.URL) + assert.Equal(t, EventConnected, event.Kind) + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for Connected event") + } + + pool.Close() + }) + + t.Run("does not add duplicate", func(t *testing.T) { + mockSocket := NewMockSocket() + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return mockSocket, nil, nil + }, + } + + pool, err := NewPool(nil, nil) + assert.NoError(t, err) + pool.dialer = mockDialer + + err = pool.Add("wss://test") + assert.NoError(t, err) + + // trailing slash normalizes to same key + err = pool.Add("wss://test/") + assert.Error(t, err) + assert.ErrorContains(t, err, "already exists") + + pool.mu.RLock() + assert.Len(t, pool.connections, 1) + pool.mu.RUnlock() + + pool.Close() + }) + + t.Run("fails to add connection", func(t *testing.T) { + pool, err := NewPool(&Config{ + Retry: &RetryConfig{ + MaxRetries: 1, + InitialDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + }, + }, nil) + assert.NoError(t, err) + pool.dialer = &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return nil, nil, fmt.Errorf("dial failed") + }, + } + + err = pool.Add("wss://test") + assert.Error(t, err) + + pool.mu.RLock() + assert.Len(t, pool.connections, 0) + pool.mu.RUnlock() + + select { + case event := <-pool.events: + t.Fatalf("unexpected event: %+v", event) + default: + } + + pool.Close() + }) +}