diff --git a/initiatorpool/worker.go b/initiatorpool/worker.go index ac5b761..4c7074b 100644 --- a/initiatorpool/worker.go +++ b/initiatorpool/worker.go @@ -55,9 +55,24 @@ func NewWorker( } func (w *Worker) Start( - ctx WorkerContext, + wctx WorkerContext, wg *sync.WaitGroup, ) { + dial := make(chan struct{}, 1) + newConn := make(chan *transport.Connection, 1) + messages := make(chan receivedMessage, 256) + keepalive := make(chan struct{}, 1) + + var owg sync.WaitGroup + owg.Add(4) + + go func() { defer owg.Done(); w.runDialer(w.ctx, wctx, dial, newConn) }() + go func() { defer owg.Done(); w.runKeepalive(w.ctx, keepalive) }() + go func() { defer owg.Done(); w.runForwarder(w.ctx, messages, wctx.Inbox, w.config.MaxQueueSize) }() + go func() { defer owg.Done(); w.runSession(w.ctx, wctx, messages, dial, keepalive, newConn) }() + + owg.Wait() + wg.Done() } func (w *Worker) Stop() { diff --git a/initiatorpool/worker_start_test.go b/initiatorpool/worker_start_test.go new file mode 100644 index 0000000..4e78389 --- /dev/null +++ b/initiatorpool/worker_start_test.go @@ -0,0 +1,316 @@ +package initiatorpool + +import ( + "context" + "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/transport" + "git.wisehodl.dev/jay/go-honeybee/types" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "net/http" + "sync" + "testing" +) + +func makeWorkerContext(t *testing.T) ( + inbox chan InboxMessage, + events chan PoolEvent, + errors chan error, + wctx WorkerContext, +) { + t.Helper() + inbox = make(chan InboxMessage, 256) + events = make(chan PoolEvent, 10) + errors = make(chan error, 10) + wctx = WorkerContext{ + Inbox: inbox, + Events: events, + Errors: errors, + } + return +} + +func makeWorker(t *testing.T, ctx context.Context, cancel context.CancelFunc) *Worker { + t.Helper() + return &Worker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + config: GetDefaultWorkerConfig(), + heartbeat: make(chan struct{}), + } +} + +func mockDialer(socket *honeybeetest.MockSocket) *honeybeetest.MockDialer { + return &honeybeetest.MockDialer{ + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { + return socket, nil, nil + }, + } +} + +func TestWorkerStart(t *testing.T) { + t.Run("EventConnected emitted after dial succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, _, wctx := makeWorkerContext(t) + mockSocket := honeybeetest.NewMockSocket() + wctx.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Add(1) + go w.Start(wctx, &wg) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.ID == w.id && e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + }) + + t.Run("Send delivers data to socket", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, _, wctx := makeWorkerContext(t) + _, mockSocket, _, outgoingData := setupWorkerTestConnection(t) + wctx.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Add(1) + go w.Start(wctx, &wg) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + err := w.Send([]byte("hello")) + assert.NoError(t, err) + + honeybeetest.Eventually(t, func() bool { + select { + case msg := <-outgoingData: + return string(msg.Data) == "hello" + default: + return false + } + }, "expected data on socket") + }) + + t.Run("socket data arrives on Inbox", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + inbox, events, _, wctx := makeWorkerContext(t) + + incomingData := make(chan honeybeetest.MockIncomingData, 10) + mockSocket := honeybeetest.NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.Once.Do(func() { close(mockSocket.Closed) }) + return nil + } + + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + select { + case data := <-incomingData: + return data.MsgType, data.Data, data.Err + } + } + + wctx.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Add(1) + go w.Start(wctx, &wg) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + incomingData <- honeybeetest.MockIncomingData{ + MsgType: websocket.TextMessage, + Data: []byte("hello"), + } + + honeybeetest.Eventually(t, func() bool { + select { + case msg := <-inbox: + return msg.ID == w.id && string(msg.Data) == "hello" + default: + return false + } + }, "expected message on Inbox") + }) + + t.Run("socket close produces EventDisconnected then EventConnected", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, _, wctx := makeWorkerContext(t) + _, mockSocket, incomingData, _ := setupWorkerTestConnection(t) + wctx.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Add(1) + go w.Start(wctx, &wg) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + close(incomingData) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected EventDisconnected") + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected second EventConnected") + }) + + t.Run("Stop produces EventDisconnected and wg drains", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, events, _, wctx := makeWorkerContext(t) + mockSocket := honeybeetest.NewMockSocket() + wctx.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Add(1) + go w.Start(wctx, &wg) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + w.Stop() + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventDisconnected + default: + return false + } + }, "expected EventDisconnected") + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + honeybeetest.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, "expected wg to drain") + }) + + t.Run("parent context cancel exits cleanly and wg drains", func(t *testing.T) { + parentCtx, parentCancel := context.WithCancel(context.Background()) + workerCtx, workerCancel := context.WithCancel(parentCtx) + + w := makeWorker(t, workerCtx, workerCancel) + _, events, _, wctx := makeWorkerContext(t) + mockSocket := honeybeetest.NewMockSocket() + wctx.Dialer = mockDialer(mockSocket) + + var wg sync.WaitGroup + wg.Add(1) + go w.Start(wctx, &wg) + + honeybeetest.Eventually(t, func() bool { + select { + case e := <-events: + return e.Kind == EventConnected + default: + return false + } + }, "expected EventConnected") + + // drain events after parent cancel — we don't assert what they are, + // only that the worker exits + parentCancel() + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + honeybeetest.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, "expected wg to drain after parent cancel") + }) + + t.Run("dial failure emits to Errors", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := makeWorker(t, ctx, cancel) + _, _, errors, wctx := makeWorkerContext(t) + wctx.ConnectionConfig = &transport.ConnectionConfig{Retry: nil} + wctx.Dialer = &honeybeetest.MockDialer{ + DialContextFunc: func(context.Context, string, http.Header) (types.Socket, *http.Response, error) { + return nil, nil, fmt.Errorf("dial failed") + }, + } + + var wg sync.WaitGroup + wg.Add(1) + go w.Start(wctx, &wg) + + honeybeetest.Eventually(t, func() bool { + select { + case err := <-errors: + return err != nil + default: + return false + } + }, "expected error on Errors channel") + }) +}