From 8c1371e3a06032b7cfddd9c7a0465ffa0132d67c Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 26 May 2026 14:59:03 -0400 Subject: [PATCH] pool: add ConnectOption/WithDialer for per-call dialer override on Connect --- pool.go | 28 ++++++++++++++++++++++++++-- pool_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/pool.go b/pool.go index ad2d8b2..2004a52 100644 --- a/pool.go +++ b/pool.go @@ -231,7 +231,22 @@ func (p *Pool) Close() { }() } -func (p *Pool) Connect(id string) error { +// ConnectOption configures a single Connect call. +type ConnectOption func(*connectOptions) + +type connectOptions struct { + dialer types.Dialer +} + +// WithDialer returns a ConnectOption that overrides the pool dialer for this +// connection only. +func WithDialer(d types.Dialer) ConnectOption { + return func(o *connectOptions) { + o.dialer = d + } +} + +func (p *Pool) Connect(id string, opts ...ConnectOption) error { if p.logger != nil { p.logger.Info("connecting", "peer", id) } @@ -258,8 +273,17 @@ func (p *Pool) Connect(id string) error { return err } + o := &connectOptions{} + for _, opt := range opts { + opt(o) + } + effectiveDialer := p.dialer + if o.dialer != nil { + effectiveDialer = o.dialer + } + cc := p.config.ConnectionConfig.Clone() - cc.Dialer = p.dialer + cc.Dialer = effectiveDialer pool := PoolPlugin{ Inbox: p.inbox, diff --git a/pool_test.go b/pool_test.go index 297b245..7e51780 100644 --- a/pool_test.go +++ b/pool_test.go @@ -90,6 +90,51 @@ func TestPoolConnect(t *testing.T) { }) } +func TestPoolConnectWithDialer(t *testing.T) { + t.Run("per-call dialer is used instead of pool dialer", func(t *testing.T) { + perCallUsed := false + perCallDialer := &honeybeetest.MockDialer{ + DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) { + perCallUsed = true + return honeybeetest.NewMockSocket(), nil, nil + }, + } + + // pool dialer should NOT be called + poolDialer := &honeybeetest.MockDialer{ + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { + t.Error("pool dialer should not be called when per-call dialer is provided") + return nil, nil, fmt.Errorf("unexpected call") + }, + } + + cc := *transport.GetDefaultConnectionConfig() + cc.Dialer = poolDialer + pool, err := NewPool(context.Background(), &PoolConfig{ + InboxBufferSize: 256, + EventsBufferSize: 10, + ConnectionConfig: cc, + WorkerConfig: *GetDefaultWorkerConfig(), + }, nil) + assert.NoError(t, err) + + err = pool.Connect("wss://test", WithDialer(perCallDialer)) + assert.NoError(t, err) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-pool.events: + return e.ID == "wss://test" && e.Kind == EventConnected + default: + return false + } + }, "expected connected event") + + assert.True(t, perCallUsed, "per-call dialer was not used") + pool.Close() + }) +} + func TestPoolClose(t *testing.T) { t.Run("channels close after pool close", func(t *testing.T) { pool, _ := NewPool(context.Background(), nil, nil)