pool: add ConnectOption/WithDialer for per-call dialer override on Connect
This commit is contained in:
@@ -231,7 +231,22 @@ func (p *Pool) Close() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pool) Connect(id string) error {
|
// ConnectOption configures a single Connect call.
|
||||||
|
type ConnectOption func(*connectOptions)
|
||||||
|
|
||||||
|
type connectOptions struct {
|
||||||
|
dialer types.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialer returns a ConnectOption that overrides the pool dialer for this
|
||||||
|
// connection only.
|
||||||
|
func WithDialer(d types.Dialer) ConnectOption {
|
||||||
|
return func(o *connectOptions) {
|
||||||
|
o.dialer = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) Connect(id string, opts ...ConnectOption) error {
|
||||||
if p.logger != nil {
|
if p.logger != nil {
|
||||||
p.logger.Info("connecting", "peer", id)
|
p.logger.Info("connecting", "peer", id)
|
||||||
}
|
}
|
||||||
@@ -258,8 +273,17 @@ func (p *Pool) Connect(id string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
o := &connectOptions{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
effectiveDialer := p.dialer
|
||||||
|
if o.dialer != nil {
|
||||||
|
effectiveDialer = o.dialer
|
||||||
|
}
|
||||||
|
|
||||||
cc := p.config.ConnectionConfig.Clone()
|
cc := p.config.ConnectionConfig.Clone()
|
||||||
cc.Dialer = p.dialer
|
cc.Dialer = effectiveDialer
|
||||||
|
|
||||||
pool := PoolPlugin{
|
pool := PoolPlugin{
|
||||||
Inbox: p.inbox,
|
Inbox: p.inbox,
|
||||||
|
|||||||
@@ -90,6 +90,51 @@ func TestPoolConnect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPoolConnectWithDialer(t *testing.T) {
|
||||||
|
t.Run("per-call dialer is used instead of pool dialer", func(t *testing.T) {
|
||||||
|
perCallUsed := false
|
||||||
|
perCallDialer := &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(ctx context.Context, url string, h http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
perCallUsed = true
|
||||||
|
return honeybeetest.NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// pool dialer should NOT be called
|
||||||
|
poolDialer := &honeybeetest.MockDialer{
|
||||||
|
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
|
||||||
|
t.Error("pool dialer should not be called when per-call dialer is provided")
|
||||||
|
return nil, nil, fmt.Errorf("unexpected call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cc := *transport.GetDefaultConnectionConfig()
|
||||||
|
cc.Dialer = poolDialer
|
||||||
|
pool, err := NewPool(context.Background(), &PoolConfig{
|
||||||
|
InboxBufferSize: 256,
|
||||||
|
EventsBufferSize: 10,
|
||||||
|
ConnectionConfig: cc,
|
||||||
|
WorkerConfig: *GetDefaultWorkerConfig(),
|
||||||
|
}, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = pool.Connect("wss://test", WithDialer(perCallDialer))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
honeybeetest.Eventually(t, func() bool {
|
||||||
|
select {
|
||||||
|
case e := <-pool.events:
|
||||||
|
return e.ID == "wss://test" && e.Kind == EventConnected
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}, "expected connected event")
|
||||||
|
|
||||||
|
assert.True(t, perCallUsed, "per-call dialer was not used")
|
||||||
|
pool.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestPoolClose(t *testing.T) {
|
func TestPoolClose(t *testing.T) {
|
||||||
t.Run("channels close after pool close", func(t *testing.T) {
|
t.Run("channels close after pool close", func(t *testing.T) {
|
||||||
pool, _ := NewPool(context.Background(), nil, nil)
|
pool, _ := NewPool(context.Background(), nil, nil)
|
||||||
|
|||||||
Reference in New Issue
Block a user