Injected context cancellation for dial and retry cancellation.

This commit is contained in:
Jay
2026-04-18 17:11:22 -04:00
parent e1cdc1cf9c
commit b4c5c897e8
12 changed files with 182 additions and 230 deletions

View File

@@ -1,6 +1,7 @@
package initiatorpool
import (
"context"
"fmt"
"git.wisehodl.dev/jay/go-honeybee/honeybeetest"
"git.wisehodl.dev/jay/go-honeybee/transport"
@@ -18,11 +19,11 @@ func TestRunForwarder(t *testing.T) {
t.Run("message passes through to inbox", func(t *testing.T) {
messages := make(chan receivedMessage, 1)
inbox := make(chan InboxMessage, 1)
stop := make(chan struct{})
defer close(stop)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
w := &Worker{id: "wss://test"}
go w.runForwarder(messages, inbox, stop, nil, 0)
go w.runForwarder(ctx, messages, inbox, 0)
messages <- receivedMessage{data: []byte("hello"), receivedAt: time.Now()}
@@ -39,8 +40,8 @@ func TestRunForwarder(t *testing.T) {
t.Run("oldest message dropped when queue is full", func(t *testing.T) {
messages := make(chan receivedMessage, 1)
inbox := make(chan InboxMessage, 1)
stop := make(chan struct{})
defer close(stop)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
gate := make(chan struct{})
gatedInbox := make(chan InboxMessage)
@@ -54,7 +55,7 @@ func TestRunForwarder(t *testing.T) {
}()
w := &Worker{id: "wss://test"}
go w.runForwarder(messages, gatedInbox, stop, nil, 2)
go w.runForwarder(ctx, messages, gatedInbox, 2)
// send three messages while the gated inbox is blocked
messages <- receivedMessage{data: []byte("first"), receivedAt: time.Now()}
@@ -83,42 +84,20 @@ func TestRunForwarder(t *testing.T) {
})
t.Run("exits on stop", func(t *testing.T) {
t.Run("exits on context cancellation", func(t *testing.T) {
messages := make(chan receivedMessage, 1)
inbox := make(chan InboxMessage, 1)
stop := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
w := &Worker{id: "wss://test"}
done := make(chan struct{})
go func() {
w.runForwarder(messages, inbox, stop, nil, 0)
w.runForwarder(ctx, messages, inbox, 0)
close(done)
}()
close(stop)
assert.Eventually(t, func() bool {
select {
case <-done:
return true
default:
return false
}
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
t.Run("exits on pool done", func(t *testing.T) {
messages := make(chan receivedMessage, 1)
inbox := make(chan InboxMessage, 1)
poolDone := make(chan struct{})
w := &Worker{id: "wss://test"}
done := make(chan struct{})
go func() {
w.runForwarder(messages, inbox, nil, poolDone, 0)
close(done)
}()
close(poolDone)
cancel()
assert.Eventually(t, func() bool {
select {
case <-done:
@@ -134,11 +113,11 @@ func TestRunKeepalive(t *testing.T) {
t.Run("heartbeat resets timer, no keepalive signal fired", func(t *testing.T) {
heartbeat := make(chan struct{}, 3)
keepalive := make(chan struct{}, 1)
stop := make(chan struct{})
defer close(stop)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 100 * time.Millisecond}}
go w.runKeepalive(heartbeat, keepalive, stop, nil)
go w.runKeepalive(ctx, heartbeat, keepalive)
// send heartbeats faster than the timeout
for i := 0; i < 5; i++ {
@@ -160,11 +139,11 @@ func TestRunKeepalive(t *testing.T) {
t.Run("keepalive timeout fires signal", func(t *testing.T) {
heartbeat := make(chan struct{})
keepalive := make(chan struct{}, 1)
stop := make(chan struct{})
defer close(stop)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Millisecond}}
go w.runKeepalive(heartbeat, keepalive, stop, nil)
go w.runKeepalive(ctx, heartbeat, keepalive)
// send no heartbeats, wait for timeout and keepalive signal
assert.Eventually(t, func() bool {
@@ -177,42 +156,19 @@ func TestRunKeepalive(t *testing.T) {
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
t.Run("exits on stop", func(t *testing.T) {
t.Run("exits on context cancellation", func(t *testing.T) {
heartbeat := make(chan struct{})
keepalive := make(chan struct{}, 1)
stop := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Second}}
done := make(chan struct{})
go func() {
w.runKeepalive(heartbeat, keepalive, stop, nil)
w.runKeepalive(ctx, heartbeat, keepalive)
close(done)
}()
close(stop)
assert.Eventually(t, func() bool {
select {
case <-done:
return true
default:
return false
}
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
t.Run("exits on stop", func(t *testing.T) {
heartbeat := make(chan struct{})
keepalive := make(chan struct{}, 1)
poolDone := make(chan struct{})
w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Second}}
done := make(chan struct{})
go func() {
w.runKeepalive(heartbeat, keepalive, nil, poolDone)
close(done)
}()
close(poolDone)
cancel()
assert.Eventually(t, func() bool {
select {
case <-done:
@@ -229,20 +185,20 @@ func TestRunDialer(t *testing.T) {
w := &Worker{id: "wss://test"}
dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1)
stop := make(chan struct{})
defer close(stop)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mockSocket := honeybeetest.NewMockSocket()
ctx := WorkerContext{
wctx := WorkerContext{
Errors: make(chan error, 1),
Dialer: &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
return mockSocket, nil, nil
},
},
}
go w.runDialer(dial, newConn, ctx, stop, nil)
go w.runDialer(ctx, wctx, dial, newConn)
dial <- struct{}{}
assert.Eventually(t, func() bool {
@@ -260,18 +216,18 @@ func TestRunDialer(t *testing.T) {
w := &Worker{id: "wss://test"}
dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1)
stop := make(chan struct{})
defer close(stop)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
gate := make(chan struct{})
dialCount := atomic.Int32{}
mockSocket := honeybeetest.NewMockSocket()
connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry
ctx := WorkerContext{
wctx := WorkerContext{
Errors: make(chan error, 1),
Dialer: &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) {
dialCount.Add(1)
<-gate
return mockSocket, nil, nil
@@ -280,7 +236,7 @@ func TestRunDialer(t *testing.T) {
ConnectionConfig: connConfig,
}
go w.runDialer(dial, newConn, ctx, stop, nil)
go w.runDialer(ctx, wctx, dial, newConn)
dial <- struct{}{}
// wait for dial to start blocking on gate
@@ -322,17 +278,19 @@ func TestRunDialer(t *testing.T) {
errors := make(chan error, 1)
dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1)
stop := make(chan struct{})
defer close(stop)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// use atomic counter to fail first dial and pass second
dialCount := atomic.Int32{}
mockSocket := honeybeetest.NewMockSocket()
connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry
ctx := WorkerContext{
wctx := WorkerContext{
Errors: errors,
Dialer: &honeybeetest.MockDialer{
DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) {
DialContextFunc: func(
context.Context, string, http.Header,
) (types.Socket, *http.Response, error) {
if dialCount.Add(1) == 1 {
// fail first
return nil, nil, fmt.Errorf("dial failed")
@@ -344,7 +302,7 @@ func TestRunDialer(t *testing.T) {
ConnectionConfig: connConfig,
}
go w.runDialer(dial, newConn, ctx, stop, nil)
go w.runDialer(ctx, wctx, dial, newConn)
dial <- struct{}{}
assert.Eventually(t, func() bool {
@@ -368,21 +326,21 @@ func TestRunDialer(t *testing.T) {
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
t.Run("exits on stop", func(t *testing.T) {
t.Run("exits on context cancellation", func(t *testing.T) {
w := &Worker{id: "wss://test"}
dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1)
stop := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
ctx := WorkerContext{Errors: make(chan error, 1)}
wctx := WorkerContext{Errors: make(chan error, 1)}
done := make(chan struct{})
go func() {
w.runDialer(dial, newConn, ctx, stop, nil)
w.runDialer(ctx, wctx, dial, newConn)
close(done)
}()
close(stop)
cancel()
assert.Eventually(t, func() bool {
select {
@@ -393,31 +351,4 @@ func TestRunDialer(t *testing.T) {
}
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
t.Run("exits on pool done", func(t *testing.T) {
w := &Worker{id: "wss://test"}
dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1)
poolDone := make(chan struct{})
ctx := WorkerContext{Errors: make(chan error, 1)}
done := make(chan struct{})
go func() {
w.runDialer(dial, newConn, ctx, nil, poolDone)
close(done)
}()
close(poolDone)
assert.Eventually(t, func() bool {
select {
case <-done:
return true
default:
return false
}
}, honeybeetest.TestTimeout, honeybeetest.TestTick)
})
}