diff --git a/initiatorpool/errors.go b/initiatorpool/errors.go index 2044cce..8dfb8b4 100644 --- a/initiatorpool/errors.go +++ b/initiatorpool/errors.go @@ -4,18 +4,27 @@ import "errors" import "fmt" var ( + // Config errors InvalidKeepaliveTimeout = errors.New("keepalive timeout cannot be negative") InvalidMaxQueueSize = errors.New("maximum queue size cannot be negative") + + // Pool errors + ErrPoolClosed = errors.New("pool is closed") + ErrPeerNotFound = errors.New("peer not found") + ErrPeerExists = errors.New("peer already exists") + + // Worker errors + ErrConnectionUnavailable = errors.New("connection unavailable") ) func NewConfigError(text string) error { return fmt.Errorf("configuration error: %s", text) } -func NewPoolError(text string) error { - return fmt.Errorf("pool error: %s", text) +func NewPoolError(err error) error { + return fmt.Errorf("pool error: %w", err) } -func NewWorkerError(id string, text string) error { - return fmt.Errorf("worker %q error: %s", id, text) +func NewWorkerError(id string, err error) error { + return fmt.Errorf("worker %q error: %w", id, err) } diff --git a/initiatorpool/helper_test.go b/initiatorpool/helper_test.go new file mode 100644 index 0000000..a275bae --- /dev/null +++ b/initiatorpool/helper_test.go @@ -0,0 +1,62 @@ +package initiatorpool + +import ( + "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "git.wisehodl.dev/jay/go-honeybee/transport" + "github.com/stretchr/testify/assert" + "io" + "testing" +) + +func setupWorkerTestConnection(t *testing.T) ( + conn *transport.Connection, + mockSocket *honeybeetest.MockSocket, + incomingData chan honeybeetest.MockIncomingData, + outgoingData chan honeybeetest.MockOutgoingData, +) { + t.Helper() + + incomingData = make(chan honeybeetest.MockIncomingData, 100) + outgoingData = make(chan honeybeetest.MockOutgoingData, 100) + mockSocket = honeybeetest.NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.Once.Do(func() { close(mockSocket.Closed) }) + return nil + } + + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + select { + case data := <-incomingData: + return data.MsgType, data.Data, data.Err + case <-mockSocket.Closed: + return 0, nil, io.EOF + } + } + + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + select { + case outgoingData <- honeybeetest.MockOutgoingData{MsgType: msgType, Data: data}: + return nil + case <-mockSocket.Closed: + return io.EOF + default: + return fmt.Errorf("mock outgoing channel unavailable") + } + } + + var err error + conn, err = transport.NewConnectionFromSocket(mockSocket, nil, nil) + assert.NoError(t, err) + return +} + +func connClosed(conn *transport.Connection) bool { + select { + case _, ok := <-conn.Errors(): + return !ok + default: + return false + } +} diff --git a/initiatorpool/pool.go b/initiatorpool/pool.go index faecbc6..ed85c78 100644 --- a/initiatorpool/pool.go +++ b/initiatorpool/pool.go @@ -155,12 +155,12 @@ func (p *Pool) Connect(id string) error { defer p.mu.Unlock() if p.closed { - return NewPoolError("pool is closed") + return NewPoolError(ErrPoolClosed) } _, exists := p.peers[id] if exists { - return NewPoolError("connection already exists") + return NewPoolError(ErrPeerExists) } // The worker factory must be non-blocking to avoid deadlocks @@ -199,13 +199,13 @@ func (p *Pool) Remove(id string) error { p.mu.Lock() if p.closed { p.mu.Unlock() - return NewPoolError("pool is closed") + return NewPoolError(ErrPoolClosed) } peer, exists := p.peers[id] if !exists { p.mu.Unlock() - return NewPoolError("connection not found") + return NewPoolError(ErrPeerNotFound) } delete(p.peers, id) p.mu.Unlock() @@ -225,12 +225,12 @@ func (p *Pool) Send(id string, data []byte) error { defer p.mu.RUnlock() if p.closed { - return NewPoolError("pool is closed") + return NewPoolError(ErrPoolClosed) } peer, exists := p.peers[id] if !exists { - return NewPoolError("connection not found") + return NewPoolError(ErrPeerNotFound) } return peer.worker.Send(data) diff --git a/initiatorpool/pool_test.go b/initiatorpool/pool_test.go index ccc62fa..b887a0b 100644 --- a/initiatorpool/pool_test.go +++ b/initiatorpool/pool_test.go @@ -64,7 +64,7 @@ func _TestPoolConnect(t *testing.T) { // trailing slash normalizes to same key err = pool.Connect("wss://test/") assert.Error(t, err) - assert.ErrorContains(t, err, "already exists") + assert.ErrorIs(t, err, ErrPeerExists) pool.mu.RLock() assert.Len(t, pool.peers, 1) @@ -152,7 +152,7 @@ func _TestPoolRemove(t *testing.T) { // remove unknown connection err = pool.Remove("wss://unknown") - assert.ErrorContains(t, err, "connection not found") + assert.ErrorIs(t, err, ErrPeerNotFound) }) t.Run("closed pool returns error", func(t *testing.T) { @@ -172,7 +172,7 @@ func _TestPoolRemove(t *testing.T) { // attempt to remove connection err = pool.Remove("wss://test") - assert.ErrorContains(t, err, "pool is closed") + assert.ErrorIs(t, err, ErrPoolClosed) }) } diff --git a/initiatorpool/worker.go b/initiatorpool/worker.go index ed495df..fb58e08 100644 --- a/initiatorpool/worker.go +++ b/initiatorpool/worker.go @@ -5,6 +5,7 @@ import ( "context" "git.wisehodl.dev/jay/go-honeybee/transport" "sync" + "sync/atomic" "time" ) @@ -16,11 +17,14 @@ type receivedMessage struct { } type Worker struct { - ctx context.Context - cancel context.CancelFunc - id string - config *WorkerConfig - outbound chan []byte + ctx context.Context + cancel context.CancelFunc + + id string + config *WorkerConfig + + conn atomic.Pointer[transport.Connection] + heartbeat chan struct{} } func NewWorker( @@ -40,25 +44,34 @@ func NewWorker( wctx, cancel := context.WithCancel(ctx) w := &Worker{ - ctx: wctx, - cancel: cancel, - id: id, - outbound: make(chan []byte, 64), - config: config, + ctx: wctx, + cancel: cancel, + id: id, + config: config, + heartbeat: make(chan struct{}), } return w, nil } func (w *Worker) Send(data []byte) error { - select { - case w.outbound <- data: - return nil - case <-w.ctx.Done(): - return NewWorkerError(w.id, "worker is stopped") - default: - return NewWorkerError(w.id, "outbound queue full") + conn := w.conn.Load() + if conn == nil { + return NewWorkerError(w.id, ErrConnectionUnavailable) } + + err := conn.Send(data) + + if err != nil { + return NewWorkerError(w.id, err) + } + + select { + case w.heartbeat <- struct{}{}: + case <-w.ctx.Done(): + } + + return nil } func (w *Worker) Start( @@ -76,7 +89,6 @@ func (w *Worker) runSession( wctx WorkerContext, messages chan<- receivedMessage, - heartbeat chan<- struct{}, dial chan<- struct{}, keepalive <-chan struct{}, @@ -88,19 +100,38 @@ func (w *Worker) runSession( func (w *Worker) runReader( conn *transport.Connection, messages chan<- receivedMessage, - heartbeat chan<- struct{}, sessionDone <-chan struct{}, onStop func(), ) { -} + defer func() { + conn.Close() + onStop() + }() -func (w *Worker) runWriter( - conn *transport.Connection, - outbound <-chan []byte, - heartbeat chan<- struct{}, - sessionDone <-chan struct{}, - onStop func(), -) { + for { + select { + case <-sessionDone: + return + case data, ok := <-conn.Incoming(): + if !ok { + // connection has closed + return + } + + // send message forward + messages <- receivedMessage{ + data: data, + receivedAt: time.Now(), + } + + // send heartbeat + select { + case w.heartbeat <- struct{}{}: + case <-sessionDone: + return + } + } + } } func (w *Worker) runStopMonitor( @@ -110,6 +141,16 @@ func (w *Worker) runStopMonitor( sessionDone <-chan struct{}, onStop func(), ) { + defer func() { + conn.Close() + onStop() + }() + + select { + case <-ctx.Done(): + case <-keepalive: + case <-sessionDone: + } } func (w *Worker) runForwarder( @@ -157,7 +198,6 @@ func (w *Worker) runForwarder( func (w *Worker) runKeepalive( ctx context.Context, - heartbeat <-chan struct{}, keepalive chan<- struct{}, ) { // disable keepalive timeout if not configured @@ -176,7 +216,7 @@ func (w *Worker) runKeepalive( select { case <-ctx.Done(): return - case <-heartbeat: + case <-w.heartbeat: // drain the timer channel and reset if !timer.Stop() { select { diff --git a/initiatorpool/worker_test.go b/initiatorpool/worker_test.go index 8e31119..829a734 100644 --- a/initiatorpool/worker_test.go +++ b/initiatorpool/worker_test.go @@ -8,6 +8,7 @@ import ( "git.wisehodl.dev/jay/go-honeybee/types" "github.com/stretchr/testify/assert" "net/http" + "sync" "sync/atomic" "testing" "time" @@ -109,18 +110,21 @@ func TestRunForwarder(t *testing.T) { func TestRunKeepalive(t *testing.T) { t.Run("heartbeat resets timer, no keepalive signal fired", func(t *testing.T) { - heartbeat := make(chan struct{}, 3) + heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 100 * time.Millisecond}} - go w.runKeepalive(ctx, heartbeat, keepalive) + w := &Worker{ + config: &WorkerConfig{KeepaliveTimeout: 100 * time.Millisecond}, + heartbeat: heartbeat, + } + go w.runKeepalive(ctx, keepalive) // send heartbeats faster than the timeout for i := 0; i < 5; i++ { time.Sleep(30 * time.Millisecond) - heartbeat <- struct{}{} + w.heartbeat <- struct{}{} } // because the timer is being reset, keepalive signal should not be sent @@ -135,13 +139,12 @@ func TestRunKeepalive(t *testing.T) { }) t.Run("keepalive timeout fires signal", func(t *testing.T) { - heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) ctx, cancel := context.WithCancel(context.Background()) defer cancel() w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Millisecond}} - go w.runKeepalive(ctx, heartbeat, keepalive) + go w.runKeepalive(ctx, keepalive) // send no heartbeats, wait for timeout and keepalive signal assert.Eventually(t, func() bool { @@ -155,14 +158,13 @@ func TestRunKeepalive(t *testing.T) { }) t.Run("exits on context cancellation", func(t *testing.T) { - heartbeat := make(chan struct{}) keepalive := make(chan struct{}, 1) ctx, cancel := context.WithCancel(context.Background()) w := &Worker{config: &WorkerConfig{KeepaliveTimeout: 20 * time.Second}} done := make(chan struct{}) go func() { - w.runKeepalive(ctx, heartbeat, keepalive) + w.runKeepalive(ctx, keepalive) close(done) }() @@ -178,6 +180,73 @@ func TestRunKeepalive(t *testing.T) { }) } +func TestRunStopMonitor(t *testing.T) { + t.Run("keepalive signal calls conn.Close and onStop", func(t *testing.T) { + conn, _, _, _ := setupWorkerTestConnection(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + keepalive := make(chan struct{}, 1) + sessionDone := make(chan struct{}) + onStopCalled := atomic.Bool{} + onStop := func() { onStopCalled.Store(true) } + + w := &Worker{id: "wss://test"} + go w.runStopMonitor(ctx, conn, keepalive, sessionDone, onStop) + + keepalive <- struct{}{} + + assert.Eventually(t, func() bool { + return connClosed(conn) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + + assert.True(t, onStopCalled.Load()) + }) + + t.Run("ctx.Done calls conn.Close and onStop", func(t *testing.T) { + conn, _, _, _ := setupWorkerTestConnection(t) + ctx, cancel := context.WithCancel(context.Background()) + + keepalive := make(chan struct{}) + sessionDone := make(chan struct{}) + onStopCalled := atomic.Bool{} + onStop := func() { onStopCalled.Store(true) } + + w := &Worker{id: "wss://test"} + go w.runStopMonitor(ctx, conn, keepalive, sessionDone, onStop) + + cancel() + + assert.Eventually(t, func() bool { + return connClosed(conn) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + + assert.True(t, onStopCalled.Load()) + }) + + t.Run("sessionDone close calls conn.Close and onStop", func(t *testing.T) { + conn, _, _, _ := setupWorkerTestConnection(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + keepalive := make(chan struct{}) + sessionDone := make(chan struct{}) + onStopCalled := atomic.Bool{} + onStop := func() { onStopCalled.Store(true) } + + w := &Worker{id: "wss://test"} + go w.runStopMonitor(ctx, conn, keepalive, sessionDone, onStop) + + close(sessionDone) + + assert.Eventually(t, func() bool { + return connClosed(conn) + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + + assert.True(t, onStopCalled.Load()) + }) +} + func TestRunDialer(t *testing.T) { t.Run("successful dial delivers connection to newConn", func(t *testing.T) { w := &Worker{id: "wss://test"} @@ -395,3 +464,104 @@ func TestRunDialer(t *testing.T) { assert.Empty(t, newConn) }) } + +func TestWorkerSend(t *testing.T) { + t.Run("data sent to mock socket", func(t *testing.T) { + conn, _, _, outgoingData := setupWorkerTestConnection(t) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + heartbeat := make(chan struct{}) + heartbeatCount := atomic.Int32{} + + w := &Worker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + heartbeat: heartbeat, + } + w.conn.Store(conn) + defer w.cancel() + + go func() { + for range heartbeat { + heartbeatCount.Add(1) + } + }() + + testData := []byte("hello") + err := w.Send(testData) + assert.NoError(t, err) + + // one heartbeat was sent + assert.Equal(t, 1, int(heartbeatCount.Load())) + + // message was sent by the socket + assert.Eventually(t, func() bool { + select { + case msg := <-outgoingData: + return string(msg.Data) == "hello" + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) + + t.Run("sends one heartbeat per successful send", func(t *testing.T) { + conn, _, _, _ := setupWorkerTestConnection(t) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + heartbeat := make(chan struct{}) + heartbeatCount := atomic.Int32{} + + w := &Worker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + heartbeat: heartbeat, + } + w.conn.Store(conn) + defer w.cancel() + + go func() { + for range heartbeat { + heartbeatCount.Add(1) + } + }() + + const count = 3 + for i := 0; i < count; i++ { + err := w.Send([]byte(fmt.Sprintf("msg-%d", i))) + assert.NoError(t, err) + } + + assert.Equal(t, count, int(heartbeatCount.Load())) + }) + + t.Run("returns error if connection is unavailable", func(t *testing.T) { + // no connection available to worker + + ctx, cancel := context.WithCancel(context.Background()) + + heartbeat := make(chan struct{}) + + w := &Worker{ + ctx: ctx, + cancel: cancel, + id: "wss://test", + heartbeat: heartbeat, + } + defer w.cancel() + + go func() { + for range heartbeat { + } + }() + + err := w.Send([]byte("hello")) + assert.ErrorIs(t, err, ErrConnectionUnavailable) + }) +} diff --git a/transport/connection.go b/transport/connection.go index 4867ad2..4ede803 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -45,15 +45,15 @@ type Connection struct { logger *slog.Logger incoming chan []byte - outgoing chan []byte errors chan error done chan struct{} state ConnectionState - wg sync.WaitGroup - closed bool - mu sync.RWMutex + wg sync.WaitGroup + closed bool + mu sync.RWMutex + writeMu sync.Mutex } func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) (*Connection, error) { @@ -77,7 +77,6 @@ func NewConnection(urlStr string, config *ConnectionConfig, logger *slog.Logger) config: config, logger: logger, incoming: make(chan []byte, 100), - outgoing: make(chan []byte, 100), errors: make(chan error, 10), state: StateDisconnected, done: make(chan struct{}), @@ -108,7 +107,6 @@ func NewConnectionFromSocket( config: config, logger: logger, incoming: make(chan []byte, 100), - outgoing: make(chan []byte, 100), errors: make(chan error, 10), state: StateConnected, done: make(chan struct{}), @@ -119,7 +117,6 @@ func NewConnectionFromSocket( } conn.startReader() - conn.startWriter() return conn, nil } @@ -166,7 +163,6 @@ func (c *Connection) Connect(ctx context.Context) error { } c.startReader() - c.startWriter() return nil } @@ -221,63 +217,32 @@ func (c *Connection) startReader() { } -func (c *Connection) startWriter() { - c.wg.Add(1) - go func() { - defer c.wg.Done() - - for { - select { - case <-c.done: - return - case data := <-c.outgoing: - if c.config.WriteTimeout > 0 { - if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { - if c.logger != nil { - c.logger.Error("write deadline error", "error", err) - } - select { - case c.errors <- fmt.Errorf("failed to set write deadline: %w", err): - case <-c.done: - } - c.shutdown() - return - } - } - - if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil { - if c.logger != nil { - c.logger.Error("write error", "error", err) - } - select { - case c.errors <- err: - case <-c.done: - } - c.shutdown() - return - } - } - } - }() - -} - func (c *Connection) Send(data []byte) error { - c.mu.RLock() - defer c.mu.RUnlock() + c.writeMu.Lock() + defer c.writeMu.Unlock() if c.closed { - return NewConnectionError("connection closed") + return ErrConnectionClosed } - select { - case c.outgoing <- data: - return nil - case <-c.done: - return NewConnectionError("connection closing") - default: - return NewConnectionError("outgoing queue full") + if c.config.WriteTimeout > 0 { + if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { + if c.logger != nil { + c.logger.Error("write deadline error", "error", err) + } + c.shutdown() + return fmt.Errorf("failed to set write deadline: %w", err) + } } + + if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil { + if c.logger != nil { + c.logger.Error("write error", "error", err) + } + return fmt.Errorf("%w: %w", ErrWriteFailed, err) + } + + return nil } func (c *Connection) Incoming() <-chan []byte { @@ -326,7 +291,6 @@ func (c *Connection) shutdown() { c.wg.Wait() close(c.incoming) - close(c.outgoing) close(c.errors) }() diff --git a/transport/connection_goroutine_test.go b/transport/connection_goroutine_test.go index df54e0c..8203c63 100644 --- a/transport/connection_goroutine_test.go +++ b/transport/connection_goroutine_test.go @@ -5,10 +5,7 @@ import ( "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" - "io" - "strings" "testing" - "time" ) func TestStartReader(t *testing.T) { @@ -88,198 +85,3 @@ func TestStartReader(t *testing.T) { }, honeybeetest.TestTimeout, honeybeetest.TestTick) }) } - -func TestStartWriter(t *testing.T) { - t.Run("data from outgoing triggers write", func(t *testing.T) { - conn, _, _, outgoingData := setupTestConnection(t, nil) - defer conn.Close() - - testData := []byte("test message") - err := conn.Send(testData) - assert.NoError(t, err) - - honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, testData) - }) - - t.Run("multiple messages processed sequentially", func(t *testing.T) { - conn, _, _, outgoingData := setupTestConnection(t, nil) - defer conn.Close() - - messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")} - for _, msg := range messages { - err := conn.Send(msg) - assert.NoError(t, err) - } - - for _, expected := range messages { - honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, expected) - } - }) - - t.Run("write timeout disabled when zero", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode") - } - - config := &ConnectionConfig{WriteTimeout: 0} - - outgoingData := make(chan honeybeetest.MockOutgoingData, 10) - mockSocket := honeybeetest.NewMockSocket() - - mockSocket.CloseFunc = func() error { - mockSocket.Once.Do(func() { - close(mockSocket.Closed) - }) - return nil - } - - deadlineCalled := make(chan struct{}, 1) - mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { - deadlineCalled <- struct{}{} - return nil - } - - mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { - select { - case outgoingData <- honeybeetest.MockOutgoingData{ - MsgType: msgType, Data: data}: - case <-mockSocket.Closed: - return io.EOF - } - return nil - } - - conn, err := NewConnectionFromSocket(mockSocket, config, nil) - assert.NoError(t, err) - defer conn.Close() - - err = conn.Send([]byte("test")) - assert.NoError(t, err) - - assert.Never(t, func() bool { - select { - case <-deadlineCalled: - return true - default: - return false - } - }, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick, - "SetWriteDeadline should not be called when timeout is zero") - }) - - t.Run("write timeout sets deadline when positive", func(t *testing.T) { - config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond} - - outgoingData := make(chan honeybeetest.MockOutgoingData, 10) - mockSocket := honeybeetest.NewMockSocket() - - mockSocket.CloseFunc = func() error { - mockSocket.Once.Do(func() { - close(mockSocket.Closed) - }) - return nil - } - - deadlineCalled := make(chan struct{}, 1) - mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { - deadlineCalled <- struct{}{} - return nil - } - - mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { - select { - case outgoingData <- honeybeetest.MockOutgoingData{ - MsgType: msgType, Data: data}: - case <-mockSocket.Closed: - return io.EOF - } - return nil - } - - conn, err := NewConnectionFromSocket(mockSocket, config, nil) - assert.NoError(t, err) - defer conn.Close() - - err = conn.Send([]byte("test")) - assert.NoError(t, err) - - assert.Eventually(t, func() bool { - select { - case <-deadlineCalled: - return true - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick, - "SetWriteDeadline should be called when timeout is positive") - }) - - t.Run("writer exits on deadline error", func(t *testing.T) { - config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond} - - mockSocket := honeybeetest.NewMockSocket() - - mockSocket.CloseFunc = func() error { - mockSocket.Once.Do(func() { - close(mockSocket.Closed) - }) - return nil - } - - mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { - return fmt.Errorf("test error") - } - - conn, err := NewConnectionFromSocket(mockSocket, config, nil) - assert.NoError(t, err) - - err = conn.Send([]byte("test")) - assert.NoError(t, err) - defer conn.Close() - - assert.Eventually(t, func() bool { - select { - case err := <-conn.Errors(): - return err != nil && - strings.Contains(err.Error(), "failed to set write deadline") - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) - - assert.Eventually(t, func() bool { - return conn.State() == StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) - }) - - t.Run("writer exits on socket write error", func(t *testing.T) { - mockSocket := honeybeetest.NewMockSocket() - - writeErr := fmt.Errorf("write failed") - mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { - return writeErr - } - - conn, err := NewConnectionFromSocket(mockSocket, nil, nil) - assert.NoError(t, err) - defer conn.Close() - - err = conn.Send([]byte("test")) - assert.NoError(t, err) - - assert.Eventually(t, func() bool { - select { - case err := <-conn.Errors(): - return err == writeErr - default: - return false - } - }, honeybeetest.TestTimeout, honeybeetest.TestTick) - - assert.Eventually(t, func() bool { - return conn.State() == StateClosed - }, honeybeetest.TestTimeout, honeybeetest.TestTick) - }) -} - -// Helpers diff --git a/transport/connection_send_test.go b/transport/connection_send_test.go index e5869d9..211a024 100644 --- a/transport/connection_send_test.go +++ b/transport/connection_send_test.go @@ -2,110 +2,243 @@ package transport import ( "fmt" + "git.wisehodl.dev/jay/go-honeybee/honeybeetest" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "io" "sync" "testing" + "time" ) func TestConnectionSend(t *testing.T) { - cases := []struct { - name string - setup func(*Connection) - data []byte - wantErr bool - wantErrText string - }{ - { - name: "send succeeds when open", - setup: func(c *Connection) {}, - data: []byte("test message"), - }, - { - name: "send fails when closed", - setup: func(c *Connection) { - c.Close() - }, - data: []byte("test"), - wantErr: true, - wantErrText: "connection closed", - }, - { - name: "send fails when queue full", - setup: func(c *Connection) { - // Fill outgoing channel - for i := 0; i < 100; i++ { - c.outgoing <- []byte("filler") - } - }, - data: []byte("overflow"), - wantErr: true, - wantErrText: "outgoing queue full", - }, - } + t.Run("writes message to socket", func(t *testing.T) { + conn, _, _, outgoingData := setupTestConnection(t, nil) + defer conn.Close() - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) + testData := []byte("test message") + err := conn.Send(testData) + assert.NoError(t, err) + + honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, testData) + }) + + t.Run("writes multiple message to socket", func(t *testing.T) { + conn, _, _, outgoingData := setupTestConnection(t, nil) + defer conn.Close() + + messages := [][]byte{[]byte("first"), []byte("second"), []byte("third")} + for _, msg := range messages { + err := conn.Send(msg) assert.NoError(t, err) - - tc.setup(conn) - - err = conn.Send(tc.data) - - if tc.wantErr { - assert.Error(t, err) - if tc.wantErrText != "" { - assert.ErrorContains(t, err, tc.wantErrText) - } - return - } - - assert.NoError(t, err) - - select { - case sent := <-conn.outgoing: - assert.Equal(t, tc.data, sent) - default: - t.Fatal("data not sent to outgoing channel") - } - }) - } -} - -// Run with `go test -race` to ensure no race conditions occur -func TestConnectionSendConcurrent(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) - assert.NoError(t, err) - - // continuously consume outgoing channel in background - done := make(chan struct{}) - go func() { - for { - select { - case <-conn.outgoing: - case <-done: - return - } } - }() - defer close(done) - // Send from multiple goroutines concurrently - const goroutines = 5 - const messagesPerGoroutine = 10 - var wg sync.WaitGroup + for _, expected := range messages { + honeybeetest.ExpectWrite(t, outgoingData, websocket.TextMessage, expected) + } + }) - for i := 0; i < goroutines; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < messagesPerGoroutine; j++ { - data := []byte(fmt.Sprintf("msg-%d-%d", id, j)) - err := conn.Send(data) - assert.NoError(t, err) + t.Run("concurrent sends write messages to socket", func(t *testing.T) { + conn, _, _, outgoingData := setupTestConnection(t, nil) + defer conn.Close() + + mu := sync.Mutex{} + messages := []string{} + done := make(chan struct{}) + + go func() { + for { + select { + case msg := <-outgoingData: + fmt.Printf("got message %s\n", string(msg.Data)) + mu.Lock() + messages = append(messages, string(msg.Data)) + mu.Unlock() + case <-done: + return + } } - }(i) - } + }() - wg.Wait() + defer close(done) + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 10; j++ { + data := []byte(fmt.Sprintf("msg-%d-%d", id, j)) + fmt.Printf("sending message %s\n", string(data)) + for { + // send and retry until success + err := conn.Send(data) + if err != nil { + continue + } else { + break + } + } + } + }(i) + } + + wg.Wait() + + assert.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return len(messages) == 50 + }, honeybeetest.TestTimeout, honeybeetest.TestTick, + "should have received 50 messages") + + }) + + t.Run("send fails when connection is closed", func(t *testing.T) { + conn, _, _, _ := setupTestConnection(t, nil) + conn.Close() + + testData := []byte("test message") + err := conn.Send(testData) + assert.ErrorIs(t, err, ErrConnectionClosed) + }) + + t.Run("write timeout disabled when zero", func(t *testing.T) { + config := &ConnectionConfig{WriteTimeout: 0} + + outgoingData := make(chan honeybeetest.MockOutgoingData, 10) + mockSocket := honeybeetest.NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.Once.Do(func() { + close(mockSocket.Closed) + }) + return nil + } + + deadlineCalled := make(chan struct{}, 1) + mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { + deadlineCalled <- struct{}{} + return nil + } + + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + select { + case outgoingData <- honeybeetest.MockOutgoingData{ + MsgType: msgType, Data: data}: + case <-mockSocket.Closed: + return io.EOF + } + return nil + } + + conn, err := NewConnectionFromSocket(mockSocket, config, nil) + assert.NoError(t, err) + defer conn.Close() + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + + assert.Never(t, func() bool { + select { + case <-deadlineCalled: + return true + default: + return false + } + }, honeybeetest.NegativeTestTimeout, honeybeetest.TestTick, + "SetWriteDeadline should not be called when timeout is zero") + }) + + t.Run("write timeout sets deadline when positive", func(t *testing.T) { + config := &ConnectionConfig{WriteTimeout: 30 * time.Millisecond} + + outgoingData := make(chan honeybeetest.MockOutgoingData, 10) + mockSocket := honeybeetest.NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.Once.Do(func() { + close(mockSocket.Closed) + }) + return nil + } + + deadlineCalled := make(chan struct{}, 1) + mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { + deadlineCalled <- struct{}{} + return nil + } + + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + select { + case outgoingData <- honeybeetest.MockOutgoingData{ + MsgType: msgType, Data: data}: + case <-mockSocket.Closed: + return io.EOF + } + return nil + } + + conn, err := NewConnectionFromSocket(mockSocket, config, nil) + assert.NoError(t, err) + defer conn.Close() + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + select { + case <-deadlineCalled: + return true + default: + return false + } + }, honeybeetest.TestTimeout, honeybeetest.TestTick, + "SetWriteDeadline should be called when timeout is positive") + }) + + t.Run("send fails on deadline error", func(t *testing.T) { + config := &ConnectionConfig{WriteTimeout: 1 * time.Millisecond} + + mockSocket := honeybeetest.NewMockSocket() + + mockSocket.CloseFunc = func() error { + mockSocket.Once.Do(func() { + close(mockSocket.Closed) + }) + return nil + } + + mockSocket.SetWriteDeadlineFunc = func(t time.Time) error { + return fmt.Errorf("test error") + } + + conn, err := NewConnectionFromSocket(mockSocket, config, nil) + assert.NoError(t, err) + defer conn.Close() + + err = conn.Send([]byte("test")) + assert.ErrorContains(t, err, "failed to set write deadline: test error") + + assert.Eventually(t, func() bool { + return conn.State() == StateClosed + }, honeybeetest.TestTimeout, honeybeetest.TestTick) + }) + + t.Run("send fails on socket write error", func(t *testing.T) { + mockSocket := honeybeetest.NewMockSocket() + + writeErr := fmt.Errorf("test error") + mockSocket.WriteMessageFunc = func(msgType int, data []byte) error { + return writeErr + } + + conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + assert.NoError(t, err) + defer conn.Close() + + err = conn.Send([]byte("test")) + assert.ErrorIs(t, err, ErrWriteFailed) + assert.ErrorContains(t, err, "test error") + }) } diff --git a/transport/connection_test.go b/transport/connection_test.go index d02f4b0..25d0ad9 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -112,7 +112,6 @@ func TestNewConnection(t *testing.T) { assert.Nil(t, conn.socket) assert.NotNil(t, conn.config) assert.NotNil(t, conn.incoming) - assert.NotNil(t, conn.outgoing) assert.NotNil(t, conn.errors) assert.NotNil(t, conn.done) assert.Equal(t, StateDisconnected, conn.state) @@ -213,7 +212,6 @@ func TestNewConnectionFromSocket(t *testing.T) { assert.Equal(t, tc.socket, conn.socket) assert.NotNil(t, conn.config) assert.NotNil(t, conn.incoming) - assert.NotNil(t, conn.outgoing) assert.NotNil(t, conn.errors) assert.NotNil(t, conn.done) assert.Equal(t, StateConnected, conn.state) diff --git a/transport/errors.go b/transport/errors.go index bcc01dc..b176568 100644 --- a/transport/errors.go +++ b/transport/errors.go @@ -13,6 +13,10 @@ var ( InvalidRetryInitialDelay = errors.New("initial delay must be positive") InvalidRetryMaxDelay = errors.New("max delay must be positive") InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0") + + // Connection Errors + ErrConnectionClosed = errors.New("connection closed") + ErrWriteFailed = errors.New("write failed") ) func NewConfigError(text string) error { diff --git a/transport/logging_test.go b/transport/logging_test.go index e55f926..a103668 100644 --- a/transport/logging_test.go +++ b/transport/logging_test.go @@ -413,7 +413,7 @@ func TestWriterLogging(t *testing.T) { assert.NoError(t, err) err = conn.Send([]byte("test")) - assert.NoError(t, err) + assert.ErrorContains(t, err, "failed to set write deadline: deadline error") assert.Eventually(t, func() bool { return findLogRecord( @@ -443,7 +443,7 @@ func TestWriterLogging(t *testing.T) { assert.NoError(t, err) err = conn.Send([]byte("test")) - assert.NoError(t, err) + assert.ErrorContains(t, err, "write error") assert.Eventually(t, func() bool { return findLogRecord( diff --git a/transport/socket_test.go b/transport/socket_test.go index ffa2553..1c5dc95 100644 --- a/transport/socket_test.go +++ b/transport/socket_test.go @@ -199,9 +199,9 @@ func TestAcquireSocketContextCancellation(t *testing.T) { done <- err }() - // wait for first dial to complete, then cancel during sleep + // wait for first two dials to complete, then cancel during sleep assert.Eventually(t, func() bool { - return dialCount.Load() >= 1 + return dialCount.Load() > 1 }, honeybeetest.TestTimeout, honeybeetest.TestTick) cancel()