diff --git a/initiatorpool/worker.go b/initiatorpool/worker.go index 87da7bc..81bbecf 100644 --- a/initiatorpool/worker.go +++ b/initiatorpool/worker.go @@ -46,16 +46,6 @@ func NewWorker( return w, nil } -func (w *Worker) dial(ctx WorkerContext) (*transport.Connection, error) { - conn, err := transport.NewConnection(w.id, ctx.ConnectionConfig, ctx.Logger) - if err != nil { - return nil, err - } - - conn.SetDialer(ctx.Dialer) - return conn, conn.Connect() -} - func (w *Worker) Send(data []byte) error { select { case w.outbound <- data: @@ -209,6 +199,16 @@ func (w *Worker) runKeepalive( } } +func (w *Worker) dial(ctx WorkerContext) (*transport.Connection, error) { + conn, err := transport.NewConnection(w.id, ctx.ConnectionConfig, ctx.Logger) + if err != nil { + return nil, err + } + + conn.SetDialer(ctx.Dialer) + return conn, conn.Connect() +} + func (w *Worker) runDialer( dial <-chan struct{}, newConn chan<- *transport.Connection, @@ -216,4 +216,49 @@ func (w *Worker) runDialer( stop <-chan struct{}, poolDone <-chan struct{}, ) { + for { + select { + case <-stop: + return + case <-poolDone: + return + case <-dial: + // drain dial signals while connection is being established + done := make(chan struct{}) + go func() { + for { + select { + case <-dial: + case <-done: + return + } + } + }() + + // dial a new connection + conn, err := w.dial(ctx) + close(done) + + // send error if dial failed and continue + if err != nil { + select { + case ctx.Errors <- err: + case <-stop: + case <-poolDone: + } + continue + } + + // send the new connection or close and exit + select { + case newConn <- conn: + case <-stop: + conn.Close() + return + case <-poolDone: + conn.Close() + return + } + } + } } diff --git a/initiatorpool/worker_test.go b/initiatorpool/worker_test.go index 07b554a..c563394 100644 --- a/initiatorpool/worker_test.go +++ b/initiatorpool/worker_test.go @@ -1,11 +1,13 @@ package initiatorpool import ( + "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" - // "git.wisehodl.dev/jay/go-honeybee/transport" - // "git.wisehodl.dev/jay/go-honeybee/types" + "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" "github.com/stretchr/testify/assert" - // "net/http" + "net/http" + "sync/atomic" "testing" "time" ) @@ -103,6 +105,29 @@ func TestRunForwarder(t *testing.T) { } }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) + + t.Run("exits on pool done", func(t *testing.T) { + messages := make(chan receivedMessage, 1) + inbox := make(chan InboxMessage, 1) + poolDone := make(chan struct{}) + + w := &Worker{id: "wss://test"} + done := make(chan struct{}) + go func() { + w.runForwarder(messages, inbox, nil, poolDone, 0) + close(done) + }() + + close(poolDone) + assert.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) } func TestRunKeepalive(t *testing.T) { @@ -164,7 +189,6 @@ func TestRunKeepalive(t *testing.T) { close(done) }() - // send stop signal close(stop) assert.Eventually(t, func() bool { select { @@ -176,4 +200,224 @@ func TestRunKeepalive(t *testing.T) { }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) + t.Run("exits on stop", func(t *testing.T) { + heartbeat := make(chan struct{}) + keepalive := make(chan struct{}, 1) + poolDone := make(chan struct{}) + + w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Second}} + done := make(chan struct{}) + go func() { + w.runKeepalive(heartbeat, keepalive, nil, poolDone) + close(done) + }() + + close(poolDone) + assert.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) +} + +func TestRunDialer(t *testing.T) { + t.Run("successful dial delivers connection to newConn", func(t *testing.T) { + w := &Worker{id: "wss://test"} + dial := make(chan struct{}, 1) + newConn := make(chan *transport.Connection, 1) + stop := make(chan struct{}) + defer close(stop) + + mockSocket := honeybeetest.NewMockSocket() + ctx := WorkerContext{ + Errors: make(chan error, 1), + Dialer: &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + return mockSocket, nil, nil + }, + }, + } + + go w.runDialer(dial, newConn, ctx, stop, nil) + dial <- struct{}{} + + assert.Eventually(t, func() bool { + select { + case <-newConn: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) + + t.Run("concurrent dial signals are drained; only one connection produced.", + func(t *testing.T) { + w := &Worker{id: "wss://test"} + dial := make(chan struct{}, 1) + newConn := make(chan *transport.Connection, 1) + stop := make(chan struct{}) + defer close(stop) + + gate := make(chan struct{}) + dialCount := atomic.Int32{} + + mockSocket := honeybeetest.NewMockSocket() + connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry + ctx := WorkerContext{ + Errors: make(chan error, 1), + Dialer: &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + dialCount.Add(1) + <-gate + return mockSocket, nil, nil + }, + }, + ConnectionConfig: connConfig, + } + + go w.runDialer(dial, newConn, ctx, stop, nil) + dial <- struct{}{} + + // wait for dial to start blocking on gate + time.Sleep(20 * time.Millisecond) + + // flood dial while dialer is blocked + for i := 0; i < 5; i++ { + select { + case dial <- struct{}{}: + default: + } + } + + close(gate) + + // connection is cleared to connect + assert.Eventually(t, func() bool { + select { + case <-newConn: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + + // connection was only dialed once + assert.Equal(t, int32(1), dialCount.Load()) + + // dial channel still writable + select { + case dial <- struct{}{}: + default: + t.Fatal("dial channel should still accept sends") + } + }) + + t.Run("dial failure emits error, succeeds on next signal", func(t *testing.T) { + w := &Worker{id: "wss://test"} + errors := make(chan error, 1) + dial := make(chan struct{}, 1) + newConn := make(chan *transport.Connection, 1) + stop := make(chan struct{}) + defer close(stop) + + // use atomic counter to fail first dial and pass second + dialCount := atomic.Int32{} + mockSocket := honeybeetest.NewMockSocket() + connConfig := &transport.ConnectionConfig{Retry: nil} // disable retry + ctx := WorkerContext{ + Errors: errors, + Dialer: &honeybeetest.MockDialer{ + DialFunc: func(string, http.Header) (types.Socket, *http.Response, error) { + if dialCount.Add(1) == 1 { + // fail first + return nil, nil, fmt.Errorf("dial failed") + } + // pass second + return mockSocket, nil, nil + }, + }, + ConnectionConfig: connConfig, + } + + go w.runDialer(dial, newConn, ctx, stop, nil) + dial <- struct{}{} + + assert.Eventually(t, func() bool { + select { + case err := <-errors: + return err != nil + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + + dial <- struct{}{} + + assert.Eventually(t, func() bool { + select { + case <-newConn: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) + + t.Run("exits on stop", func(t *testing.T) { + w := &Worker{id: "wss://test"} + dial := make(chan struct{}, 1) + newConn := make(chan *transport.Connection, 1) + stop := make(chan struct{}) + + ctx := WorkerContext{Errors: make(chan error, 1)} + + done := make(chan struct{}) + go func() { + w.runDialer(dial, newConn, ctx, stop, nil) + close(done) + }() + + close(stop) + + assert.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) + + t.Run("exits on pool done", func(t *testing.T) { + w := &Worker{id: "wss://test"} + dial := make(chan struct{}, 1) + newConn := make(chan *transport.Connection, 1) + poolDone := make(chan struct{}) + + ctx := WorkerContext{Errors: make(chan error, 1)} + + done := make(chan struct{}) + go func() { + w.runDialer(dial, newConn, ctx, nil, poolDone) + close(done) + }() + + close(poolDone) + + assert.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) + }