pool: add ConnectOption/WithDialer for per-call dialer override on Connect

This commit is contained in:
Jay
2026-05-26 14:59:03 -04:00
parent d4da16f82a
commit 8c1371e3a0
2 changed files with 71 additions and 2 deletions
+26 -2
View File
@@ -231,7 +231,22 @@ func (p *Pool) Close() {
}() }()
} }
func (p *Pool) Connect(id string) error { // ConnectOption configures a single Connect call.
type ConnectOption func(*connectOptions)
type connectOptions struct {
dialer types.Dialer
}
// WithDialer returns a ConnectOption that overrides the pool dialer for this
// connection only.
func WithDialer(d types.Dialer) ConnectOption {
return func(o *connectOptions) {
o.dialer = d
}
}
func (p *Pool) Connect(id string, opts ...ConnectOption) error {
if p.logger != nil { if p.logger != nil {
p.logger.Info("connecting", "peer", id) p.logger.Info("connecting", "peer", id)
} }
@@ -258,8 +273,17 @@ func (p *Pool) Connect(id string) error {
return err return err
} }
o := &connectOptions{}
for _, opt := range opts {
opt(o)
}
effectiveDialer := p.dialer
if o.dialer != nil {
effectiveDialer = o.dialer
}
cc := p.config.ConnectionConfig.Clone() cc := p.config.ConnectionConfig.Clone()
cc.Dialer = p.dialer cc.Dialer = effectiveDialer
pool := PoolPlugin{ pool := PoolPlugin{
Inbox: p.inbox, Inbox: p.inbox,
+45
View File
@@ -90,6 +90,51 @@ func TestPoolConnect(t *testing.T) {
}) })
} }
func TestPoolConnectWithDialer(t *testing.T) {
t.Run("per-call dialer is used instead of pool dialer", func(t *testing.T) {
perCallUsed := false
perCallDialer := &honeybeetest.MockDialer{
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
perCallUsed = true
return honeybeetest.NewMockSocket(), nil, nil
},
}
// pool dialer should NOT be called
poolDialer := &honeybeetest.MockDialer{
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
t.Error("pool dialer should not be called when per-call dialer is provided")
return nil, nil, fmt.Errorf("unexpected call")
},
}
cc := *transport.GetDefaultConnectionConfig()
cc.Dialer = poolDialer
pool, err := NewPool(context.Background(), &PoolConfig{
InboxBufferSize: 256,
EventsBufferSize: 10,
ConnectionConfig: cc,
WorkerConfig: *GetDefaultWorkerConfig(),
}, nil)
assert.NoError(t, err)
err = pool.Connect("wss://test", WithDialer(perCallDialer))
assert.NoError(t, err)
honeybeetest.Eventually(t, func() bool {
select {
case e := <-pool.events:
return e.ID == "wss://test" && e.Kind == EventConnected
default:
return false
}
}, "expected connected event")
assert.True(t, perCallUsed, "per-call dialer was not used")
pool.Close()
})
}
func TestPoolClose(t *testing.T) { func TestPoolClose(t *testing.T) {
t.Run("channels close after pool close", func(t *testing.T) { t.Run("channels close after pool close", func(t *testing.T) {
pool, _ := NewPool(context.Background(), nil, nil) pool, _ := NewPool(context.Background(), nil, nil)