From af54f2b41cc973e12d7a0855ab6c911c6097c6ca Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 4 Feb 2026 13:10:44 -0500 Subject: [PATCH] Added logging. --- connection.go | 65 ++++- connection_close_test.go | 12 +- connection_goroutine_test.go | 16 +- connection_send_test.go | 4 +- connection_test.go | 26 +- logging_test.go | 470 +++++++++++++++++++++++++++++++++++ mocks_test.go | 50 +++- socket.go | 22 ++ socket_test.go | 4 +- 9 files changed, 632 insertions(+), 37 deletions(-) create mode 100644 logging_test.go diff --git a/connection.go b/connection.go index b09a17f..1179046 100644 --- a/connection.go +++ b/connection.go @@ -2,6 +2,7 @@ package honeybee import ( "fmt" + "log/slog" "net/url" "sync" "time" @@ -39,6 +40,7 @@ type Connection struct { dialer Dialer socket Socket config *Config + logger *slog.Logger incoming chan []byte outgoing chan []byte @@ -53,7 +55,7 @@ type Connection struct { mu sync.RWMutex } -func NewConnection(urlStr string, config *Config) (*Connection, error) { +func NewConnection(urlStr string, config *Config, logger *slog.Logger) (*Connection, error) { if config == nil { config = GetDefaultConfig() } @@ -67,7 +69,7 @@ func NewConnection(urlStr string, config *Config) (*Connection, error) { return nil, err } - return &Connection{ + conn := &Connection{ url: parsedURL, dialer: NewDialer(), socket: nil, @@ -77,10 +79,20 @@ func NewConnection(urlStr string, config *Config) (*Connection, error) { errors: make(chan error, 10), state: StateDisconnected, done: make(chan struct{}), - }, nil + } + + if logger != nil { + conn.logger = logger.With( + "library", "honeybee", + "component", "Connection", + "url", parsedURL.String(), + ) + } + + return conn, nil } -func NewConnectionFromSocket(socket Socket, config *Config) (*Connection, error) { +func NewConnectionFromSocket(socket Socket, config *Config, logger *slog.Logger) (*Connection, error) { if socket == nil { return nil, errors.NewConnectionError("socket cannot be nil") } @@ -105,6 +117,13 @@ func NewConnectionFromSocket(socket Socket, config *Config) (*Connection, error) done: make(chan struct{}), } + if logger != nil { + conn.logger = logger.With( + "library", "honeybee", + "component", "Connection", + ) + } + if config.CloseHandler != nil { socket.SetCloseHandler(config.CloseHandler) } @@ -127,13 +146,20 @@ func (c *Connection) Connect() error { return errors.NewConnectionError("connection is closed") } + if c.logger != nil { + c.logger.Info("connecting") + } + c.state = StateConnecting retryMgr := NewRetryManager(c.config.Retry) - socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String()) + socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String(), c.logger) if err != nil { c.state = StateDisconnected + if c.logger != nil { + c.logger.Error("connection failed", "error", err) + } return err } @@ -144,6 +170,10 @@ func (c *Connection) Connect() error { c.socket.SetCloseHandler(c.config.CloseHandler) } + if c.logger != nil { + c.logger.Info("connected") + } + c.startReader() c.startWriter() @@ -162,6 +192,9 @@ func (c *Connection) startReader() { default: if c.config.ReadTimeout > 0 { if err := c.socket.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil { + if c.logger != nil { + c.logger.Error("read deadline error", "error", err) + } select { case c.errors <- fmt.Errorf("failed to set read deadline: %w", err): case <-c.done: @@ -172,6 +205,9 @@ func (c *Connection) startReader() { } messageType, data, err := c.socket.ReadMessage() if err != nil { + if c.logger != nil { + c.logger.Error("read error", "error", err) + } select { case c.errors <- err: case <-c.done: @@ -208,6 +244,9 @@ func (c *Connection) startWriter() { case data := <-c.outgoing: if c.config.WriteTimeout > 0 { if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { + if c.logger != nil { + c.logger.Error("write deadline error", "error", err) + } select { case c.errors <- fmt.Errorf("failed to set write deadline: %w", err): case <-c.done: @@ -218,6 +257,9 @@ func (c *Connection) startWriter() { } if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil { + if c.logger != nil { + c.logger.Error("write error", "error", err) + } select { case c.errors <- err: case <-c.done: @@ -263,7 +305,11 @@ func (c *Connection) Close() error { c.mu.Lock() alreadyClosed := c.closed + currentState := c.state if !alreadyClosed { + if c.logger != nil { + c.logger.Info("closing", "state", currentState.String()) + } c.closed = true c.state = StateClosed close(c.done) @@ -279,6 +325,15 @@ func (c *Connection) Close() error { var err error if socket != nil { err = socket.Close() + if err != nil { + if c.logger != nil { + c.logger.Error("socket close failed", "error", err) + } + } else { + if c.logger != nil { + c.logger.Info("closed") + } + } } c.wg.Wait() diff --git a/connection_close_test.go b/connection_close_test.go index e094c6b..b9588a2 100644 --- a/connection_close_test.go +++ b/connection_close_test.go @@ -11,7 +11,7 @@ import ( func TestDisconnectedConnectionClose(t *testing.T) { t.Run("close succeeds on disconnected connection", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) assert.Equal(t, StateDisconnected, conn.State()) @@ -21,7 +21,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("close is idempotent", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) err = conn.Close() @@ -34,7 +34,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("close with nil socket", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) assert.Nil(t, conn.socket) @@ -50,7 +50,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { return expectedErr } - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) conn.socket = mockSocket @@ -60,7 +60,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("channels close after close", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) err = conn.Close() @@ -92,7 +92,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { }) t.Run("send fails after close", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) err = conn.Close() diff --git a/connection_goroutine_test.go b/connection_goroutine_test.go index d59d4f3..a9e52c0 100644 --- a/connection_goroutine_test.go +++ b/connection_goroutine_test.go @@ -74,7 +74,7 @@ func TestStartReader(t *testing.T) { return nil } - conn, err := NewConnectionFromSocket(mockSocket, config) + conn, err := NewConnectionFromSocket(mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -114,7 +114,7 @@ func TestStartReader(t *testing.T) { } } - conn, err := NewConnectionFromSocket(mockSocket, config) + conn, err := NewConnectionFromSocket(mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -149,7 +149,7 @@ func TestStartReader(t *testing.T) { return fmt.Errorf("test error") } - conn, err := NewConnectionFromSocket(mockSocket, config) + conn, err := NewConnectionFromSocket(mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -180,7 +180,7 @@ func TestStartReader(t *testing.T) { return 0, nil, readErr } - conn, err := NewConnectionFromSocket(mockSocket, nil) + conn, err := NewConnectionFromSocket(mockSocket, nil, nil) assert.NoError(t, err) defer conn.Close() @@ -256,7 +256,7 @@ func TestStartWriter(t *testing.T) { return nil } - conn, err := NewConnectionFromSocket(mockSocket, config) + conn, err := NewConnectionFromSocket(mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -300,7 +300,7 @@ func TestStartWriter(t *testing.T) { return nil } - conn, err := NewConnectionFromSocket(mockSocket, config) + conn, err := NewConnectionFromSocket(mockSocket, config, nil) assert.NoError(t, err) defer conn.Close() @@ -333,7 +333,7 @@ func TestStartWriter(t *testing.T) { return fmt.Errorf("test error") } - conn, err := NewConnectionFromSocket(mockSocket, config) + conn, err := NewConnectionFromSocket(mockSocket, config, nil) assert.NoError(t, err) err = conn.Send([]byte("test")) @@ -359,7 +359,7 @@ func TestStartWriter(t *testing.T) { return writeErr } - conn, err := NewConnectionFromSocket(mockSocket, nil) + conn, err := NewConnectionFromSocket(mockSocket, nil, nil) assert.NoError(t, err) defer conn.Close() diff --git a/connection_send_test.go b/connection_send_test.go index 54c0c3a..c475b6b 100644 --- a/connection_send_test.go +++ b/connection_send_test.go @@ -46,7 +46,7 @@ func TestConnectionSend(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) tc.setup(conn) @@ -76,7 +76,7 @@ func TestConnectionSend(t *testing.T) { // Run with `go test -race` to ensure no race conditions occur func TestConnectionSendConcurrent(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) // continuously consume outgoing channel in background diff --git a/connection_test.go b/connection_test.go index 7996046..0e533f1 100644 --- a/connection_test.go +++ b/connection_test.go @@ -31,11 +31,11 @@ func TestConnectionStateString(t *testing.T) { func TestConnectionState(t *testing.T) { // Test initial state - conn, _ := NewConnection("ws://test", nil) + conn, _ := NewConnection("ws://test", nil, nil) assert.Equal(t, StateDisconnected, conn.State()) // Test state after FromSocket (should be Connected) - conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil) + conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil, nil) assert.Equal(t, StateConnected, conn2.State()) // Test state after close @@ -86,7 +86,7 @@ func TestNewConnection(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - conn, err := NewConnection(tc.url, tc.config) + conn, err := NewConnection(tc.url, tc.config, nil) if tc.wantErr { assert.Error(t, err) @@ -187,7 +187,7 @@ func TestNewConnectionFromSocket(t *testing.T) { } } - conn, err := NewConnectionFromSocket(tc.socket, tc.config) + conn, err := NewConnectionFromSocket(tc.socket, tc.config, nil) if tc.wantErr { assert.Error(t, err) @@ -234,7 +234,7 @@ func TestNewConnectionFromSocket(t *testing.T) { func TestConnect(t *testing.T) { t.Run("connect fails when socket already present", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) conn.socket = NewMockSocket() @@ -246,7 +246,7 @@ func TestConnect(t *testing.T) { }) t.Run("connect fails when connection closed", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) conn.Close() @@ -258,7 +258,7 @@ func TestConnect(t *testing.T) { }) t.Run("connect succeeds and starts goroutines", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) outgoingData := make(chan mockOutgoingData, 10) @@ -305,7 +305,7 @@ func TestConnect(t *testing.T) { JitterFactor: 0.0, }, } - conn, err := NewConnection("ws://test", config) + conn, err := NewConnection("ws://test", config, nil) assert.NoError(t, err) attemptCount := 0 @@ -337,7 +337,7 @@ func TestConnect(t *testing.T) { JitterFactor: 0.0, }, } - conn, err := NewConnection("ws://test", config) + conn, err := NewConnection("ws://test", config, nil) assert.NoError(t, err) mockDialer := &MockDialer{ @@ -354,7 +354,7 @@ func TestConnect(t *testing.T) { }) t.Run("state transitions during connect", func(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) assert.Equal(t, StateDisconnected, conn.State()) @@ -382,7 +382,7 @@ func TestConnect(t *testing.T) { return nil }, } - conn, err := NewConnection("ws://test", config) + conn, err := NewConnection("ws://test", config, nil) assert.NoError(t, err) mockSocket := NewMockSocket() @@ -408,7 +408,7 @@ func TestConnect(t *testing.T) { // Connection method tests func TestConnectionIncoming(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) incoming := conn.Incoming() @@ -422,7 +422,7 @@ func TestConnectionIncoming(t *testing.T) { } func TestConnectionErrors(t *testing.T) { - conn, err := NewConnection("ws://test", nil) + conn, err := NewConnection("ws://test", nil, nil) assert.NoError(t, err) errors := conn.Errors() diff --git a/logging_test.go b/logging_test.go new file mode 100644 index 0000000..8e1b888 --- /dev/null +++ b/logging_test.go @@ -0,0 +1,470 @@ +package honeybee + +import ( + "fmt" + "log/slog" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// Helpers + +type expectedLog struct { + level slog.Level + msg string + attrs map[string]any +} + +func assertLogSequence(t *testing.T, records []slog.Record, expected []expectedLog) { + t.Helper() + + recIndex := 0 + for expIndex, exp := range expected { + found := false + + // Search forward through records + for recIndex < len(records) { + rec := records[recIndex] + + if rec.Level == exp.level && strings.Contains(rec.Message, exp.msg) { + allAttrsMatch := true + for key, expectedValue := range exp.attrs { + if !assertAttributePresent(t, rec, key, expectedValue) { + allAttrsMatch = false + break + } + } + + if allAttrsMatch { + found = true + recIndex++ // Consume this record + break + } + } + + recIndex++ // Move to next record + } + + if !found { + t.Fatalf( + "expected log not found: index=%d level=%v msg=%q attrs=%v", + expIndex, exp.level, exp.msg, exp.attrs) + } + } +} + +func assertLogRecord( + t *testing.T, + records []slog.Record, + level slog.Level, + msgSnippet string, + expectedAttrs ...slog.Attr, +) { + t.Helper() + + record := findLogRecord(records, level, msgSnippet) + if record == nil { + t.Fatalf("no log record found with level %v and message containing %q", level, msgSnippet) + } + + for _, expectedAttr := range expectedAttrs { + assertAttributePresent(t, *record, expectedAttr.Key, expectedAttr.Value.Any()) + } +} + +func findLogRecord( + records []slog.Record, level slog.Level, msgSnippet string, +) *slog.Record { + for i := range records { + if records[i].Level == level && strings.Contains(records[i].Message, msgSnippet) { + return &records[i] + } + } + return nil +} + +func assertAttributePresent( + t *testing.T, record slog.Record, key string, expectedValue any, +) bool { + t.Helper() + + var found bool + var actualValue any + + record.Attrs(func(attr slog.Attr) bool { + if attr.Key == key { + found = true + actualValue = attr.Value.Any() + return false + } + return true + }) + + if !found { + t.Fatalf("attribute %q not found in log record", key) + } + + if !valuesEqual(actualValue, expectedValue) { + t.Errorf("attribute %q mismatch: expected=%v actual=%v", key, expectedValue, actualValue) + return false + } + + return true +} + +func valuesEqual(a, b any) bool { + // Direct equality + if a == b { + return true + } + + // Handle int/int64 conversions + aInt, aIsInt := toInt64(a) + bInt, bIsInt := toInt64(b) + if aIsInt && bIsInt { + return aInt == bInt + } + + return false +} + +func toInt64(v any) (int64, bool) { + switch val := v.(type) { + case int: + return int64(val), true + case int64: + return val, true + case int32: + return int64(val), true + case int16: + return int64(val), true + case int8: + return int64(val), true + default: + return 0, false + } +} + +// Tests + +func TestConnectLogging(t *testing.T) { + t.Run("success", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + conn, err := NewConnection("ws://test", nil, logger) + assert.NoError(t, err) + + mockSocket := NewMockSocket() + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return mockSocket, nil, nil + }, + } + conn.dialer = mockDialer + + err = conn.Connect() + assert.NoError(t, err) + defer conn.Close() + + records := mockHandler.GetRecords() + + expected := []expectedLog{ + {slog.LevelInfo, "connecting", map[string]any{}}, + {slog.LevelInfo, "dialing", map[string]any{"attempt": 1}}, + {slog.LevelInfo, "dial successful", map[string]any{"attempt": 1}}, + {slog.LevelInfo, "connected", map[string]any{}}, + } + + assertLogSequence(t, records, expected) + }) + + t.Run("max retries failure", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + config := &Config{ + Retry: &RetryConfig{ + MaxRetries: 2, + InitialDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + JitterFactor: 0.0, + }, + } + + conn, err := NewConnection("ws://test", config, logger) + assert.NoError(t, err) + + dialErr := fmt.Errorf("dial error") + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return nil, nil, dialErr + }, + } + conn.dialer = mockDialer + + err = conn.Connect() + assert.Error(t, err) + + records := mockHandler.GetRecords() + + expected := []expectedLog{ + {slog.LevelInfo, "connecting", map[string]any{}}, + {slog.LevelInfo, "dialing", map[string]any{"attempt": 1}}, + {slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}}, + {slog.LevelInfo, "dialing", map[string]any{"attempt": 2}}, + {slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}}, + {slog.LevelInfo, "dialing", map[string]any{"attempt": 3}}, + {slog.LevelError, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}}, + {slog.LevelError, "connection failed", map[string]any{"error": dialErr}}, + } + + assertLogSequence(t, records, expected) + }) + + t.Run("success after retry", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + config := &Config{ + Retry: &RetryConfig{ + MaxRetries: 3, + InitialDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + JitterFactor: 0.0, + }, + } + + conn, err := NewConnection("ws://test", config, logger) + assert.NoError(t, err) + + attemptCount := 0 + dialErr := fmt.Errorf("dial error") + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + attemptCount++ + if attemptCount < 3 { + return nil, nil, dialErr + } + return NewMockSocket(), nil, nil + }, + } + conn.dialer = mockDialer + + err = conn.Connect() + assert.NoError(t, err) + defer conn.Close() + + records := mockHandler.GetRecords() + + expected := []expectedLog{ + {slog.LevelInfo, "connecting", map[string]any{}}, + {slog.LevelInfo, "dialing", map[string]any{"attempt": 1}}, + {slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}}, + {slog.LevelInfo, "dialing", map[string]any{"attempt": 2}}, + {slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}}, + {slog.LevelInfo, "dialing", map[string]any{"attempt": 3}}, + {slog.LevelInfo, "dial successful", map[string]any{"attempt": 3}}, + {slog.LevelInfo, "connected", map[string]any{}}, + } + + assertLogSequence(t, records, expected) + }) +} + +func TestCloseLogging(t *testing.T) { + t.Run("normal close", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + mockSocket := NewMockSocket() + conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + assert.NoError(t, err) + + err = conn.Close() + assert.NoError(t, err) + + records := mockHandler.GetRecords() + + expected := []expectedLog{ + {slog.LevelInfo, "closing", map[string]any{"state": "connected"}}, + {slog.LevelInfo, "closed", map[string]any{}}, + } + + assertLogSequence(t, records, expected) + }) + + t.Run("close with socket error", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + closeErr := fmt.Errorf("close error") + mockSocket := NewMockSocket() + mockSocket.CloseFunc = func() error { + return closeErr + } + + conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + assert.NoError(t, err) + + err = conn.Close() + assert.Error(t, err) + + records := mockHandler.GetRecords() + + expected := []expectedLog{ + {slog.LevelInfo, "closing", map[string]any{"state": "connected"}}, + {slog.LevelError, "socket close failed", map[string]any{"error": closeErr}}, + } + + assertLogSequence(t, records, expected) + }) +} + +func TestReaderLogging(t *testing.T) { + t.Run("read deadline error", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + config := &Config{ReadTimeout: 1 * time.Millisecond} + + deadlineErr := fmt.Errorf("deadline error") + mockSocket := NewMockSocket() + mockSocket.SetReadDeadlineFunc = func(time.Time) error { + return deadlineErr + } + + conn, err := NewConnectionFromSocket(mockSocket, config, logger) + assert.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + records := mockHandler.GetRecords() + + assertLogRecord(t, records, slog.LevelError, "read deadline error") + + record := findLogRecord(records, slog.LevelError, "read deadline error") + assert.NotNil(t, record) + assertAttributePresent(t, *record, "error", deadlineErr) + + conn.Close() + }) + + t.Run("read message error", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + readErr := fmt.Errorf("read error") + mockSocket := NewMockSocket() + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + return 0, nil, readErr + } + + conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + assert.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + records := mockHandler.GetRecords() + + assertLogRecord(t, records, slog.LevelError, "read error") + + record := findLogRecord(records, slog.LevelError, "read error") + assert.NotNil(t, record) + assertAttributePresent(t, *record, "error", readErr) + + conn.Close() + }) +} + +func TestWriterLogging(t *testing.T) { + t.Run("write deadline error", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + config := &Config{WriteTimeout: 1 * time.Millisecond} + + deadlineErr := fmt.Errorf("deadline error") + mockSocket := NewMockSocket() + mockSocket.SetWriteDeadlineFunc = func(time.Time) error { + return deadlineErr + } + + conn, err := NewConnectionFromSocket(mockSocket, config, logger) + assert.NoError(t, err) + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + records := mockHandler.GetRecords() + + assertLogRecord(t, records, slog.LevelError, "write deadline error") + + record := findLogRecord(records, slog.LevelError, "write deadline error") + assert.NotNil(t, record) + assertAttributePresent(t, *record, "error", deadlineErr) + + conn.Close() + }) + + t.Run("write message error", func(t *testing.T) { + mockHandler := newMockSlogHandler() + logger := slog.New(mockHandler) + + writeErr := fmt.Errorf("write error") + mockSocket := NewMockSocket() + mockSocket.WriteMessageFunc = func(int, []byte) error { + return writeErr + } + + conn, err := NewConnectionFromSocket(mockSocket, nil, logger) + assert.NoError(t, err) + + err = conn.Send([]byte("test")) + assert.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + records := mockHandler.GetRecords() + + assertLogRecord(t, records, slog.LevelError, "write error") + + record := findLogRecord(records, slog.LevelError, "write error") + assert.NotNil(t, record) + assertAttributePresent(t, *record, "error", writeErr) + + conn.Close() + }) +} + +func TestLoggingDisabled(t *testing.T) { + t.Run("nil logger produces no logs", func(t *testing.T) { + mockHandler := newMockSlogHandler() + + conn, err := NewConnection("ws://test", nil, nil) + assert.NoError(t, err) + + mockSocket := NewMockSocket() + mockDialer := &MockDialer{ + DialFunc: func(string, http.Header) (Socket, *http.Response, error) { + return mockSocket, nil, nil + }, + } + conn.dialer = mockDialer + + err = conn.Connect() + assert.NoError(t, err) + + err = conn.Close() + assert.NoError(t, err) + + records := mockHandler.GetRecords() + assert.Empty(t, records) + }) +} diff --git a/mocks_test.go b/mocks_test.go index bc3ddf8..46a8a77 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -1,8 +1,10 @@ package honeybee import ( + "context" "github.com/stretchr/testify/assert" "io" + "log/slog" "net/http" "sync" "testing" @@ -125,8 +127,54 @@ func setupTestConnection(t *testing.T, config *Config) ( } var err error - conn, err = NewConnectionFromSocket(mockSocket, config) + conn, err = NewConnectionFromSocket(mockSocket, config, nil) assert.NoError(t, err) return conn, mockSocket, incomingData, outgoingData } + +// Logging mocks + +type mockSlogHandler struct { + records []slog.Record + mu sync.RWMutex +} + +func newMockSlogHandler() *mockSlogHandler { + return &mockSlogHandler{ + records: make([]slog.Record, 0), + } +} + +func (m *mockSlogHandler) Handle(ctx context.Context, record slog.Record) error { + m.mu.Lock() + defer m.mu.Unlock() + m.records = append(m.records, record) + return nil +} + +func (m *mockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool { + return true +} + +func (m *mockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return m +} + +func (m *mockSlogHandler) WithGroup(name string) slog.Handler { + return m +} + +func (m *mockSlogHandler) GetRecords() []slog.Record { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]slog.Record, len(m.records)) + copy(result, m.records) + return result +} + +func (m *mockSlogHandler) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + m.records = make([]slog.Record, 0) +} diff --git a/socket.go b/socket.go index 94756bc..1f6538c 100644 --- a/socket.go +++ b/socket.go @@ -1,6 +1,7 @@ package honeybee import ( + "log/slog" "net/http" "time" @@ -54,6 +55,7 @@ func AcquireSocket( retryMgr *RetryManager, dialer Dialer, urlStr string, + logger *slog.Logger, ) (Socket, *http.Response, error) { if retryMgr == nil { return nil, nil, errors.NewConnectionError("retry manager cannot be nil") @@ -66,16 +68,36 @@ func AcquireSocket( } for { + if logger != nil { + logger.Info("dialing", "attempt", retryMgr.RetryCount()+1) + } + socket, resp, err := dialer.Dial(urlStr, nil) if err == nil { + if logger != nil { + logger.Info("dial successful", "attempt", retryMgr.RetryCount()+1) + } return socket, resp, nil } if !retryMgr.ShouldRetry() { + if logger != nil { + logger.Error("dial failed, max retries reached", + "error", err, + "attempt", retryMgr.RetryCount()+1) + } return nil, nil, err } delay := retryMgr.CalculateDelay() + + if logger != nil { + logger.Warn("dial failed, retrying", + "error", err, + "attempt", retryMgr.RetryCount()+1, + "next_delay", delay) + } + time.Sleep(delay) retryMgr.RecordRetry() } diff --git a/socket_test.go b/socket_test.go index 5c5e6c4..5fae4e8 100644 --- a/socket_test.go +++ b/socket_test.go @@ -78,7 +78,7 @@ func TestAcquireSocket(t *testing.T) { JitterFactor: 0.0, }) - socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test") + socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test", nil) assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount()) if tc.wantErr { @@ -132,7 +132,7 @@ func TestAcquireSocketGuards(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url) + socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url, nil) assert.Error(t, err) assert.ErrorContains(t, err, tc.wantErr)