Added logging.
This commit is contained in:
@@ -2,6 +2,7 @@ package honeybee
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -39,6 +40,7 @@ type Connection struct {
|
|||||||
dialer Dialer
|
dialer Dialer
|
||||||
socket Socket
|
socket Socket
|
||||||
config *Config
|
config *Config
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
incoming chan []byte
|
incoming chan []byte
|
||||||
outgoing chan []byte
|
outgoing chan []byte
|
||||||
@@ -53,7 +55,7 @@ type Connection struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnection(urlStr string, config *Config) (*Connection, error) {
|
func NewConnection(urlStr string, config *Config, logger *slog.Logger) (*Connection, error) {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
config = GetDefaultConfig()
|
config = GetDefaultConfig()
|
||||||
}
|
}
|
||||||
@@ -67,7 +69,7 @@ func NewConnection(urlStr string, config *Config) (*Connection, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Connection{
|
conn := &Connection{
|
||||||
url: parsedURL,
|
url: parsedURL,
|
||||||
dialer: NewDialer(),
|
dialer: NewDialer(),
|
||||||
socket: nil,
|
socket: nil,
|
||||||
@@ -77,10 +79,20 @@ func NewConnection(urlStr string, config *Config) (*Connection, error) {
|
|||||||
errors: make(chan error, 10),
|
errors: make(chan error, 10),
|
||||||
state: StateDisconnected,
|
state: StateDisconnected,
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectionFromSocket(socket Socket, config *Config) (*Connection, error) {
|
if logger != nil {
|
||||||
|
conn.logger = logger.With(
|
||||||
|
"library", "honeybee",
|
||||||
|
"component", "Connection",
|
||||||
|
"url", parsedURL.String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionFromSocket(socket Socket, config *Config, logger *slog.Logger) (*Connection, error) {
|
||||||
if socket == nil {
|
if socket == nil {
|
||||||
return nil, errors.NewConnectionError("socket cannot be nil")
|
return nil, errors.NewConnectionError("socket cannot be nil")
|
||||||
}
|
}
|
||||||
@@ -105,6 +117,13 @@ func NewConnectionFromSocket(socket Socket, config *Config) (*Connection, error)
|
|||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if logger != nil {
|
||||||
|
conn.logger = logger.With(
|
||||||
|
"library", "honeybee",
|
||||||
|
"component", "Connection",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if config.CloseHandler != nil {
|
if config.CloseHandler != nil {
|
||||||
socket.SetCloseHandler(config.CloseHandler)
|
socket.SetCloseHandler(config.CloseHandler)
|
||||||
}
|
}
|
||||||
@@ -127,13 +146,20 @@ func (c *Connection) Connect() error {
|
|||||||
return errors.NewConnectionError("connection is closed")
|
return errors.NewConnectionError("connection is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("connecting")
|
||||||
|
}
|
||||||
|
|
||||||
c.state = StateConnecting
|
c.state = StateConnecting
|
||||||
|
|
||||||
retryMgr := NewRetryManager(c.config.Retry)
|
retryMgr := NewRetryManager(c.config.Retry)
|
||||||
socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String())
|
socket, _, err := AcquireSocket(retryMgr, c.dialer, c.url.String(), c.logger)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.state = StateDisconnected
|
c.state = StateDisconnected
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Error("connection failed", "error", err)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,6 +170,10 @@ func (c *Connection) Connect() error {
|
|||||||
c.socket.SetCloseHandler(c.config.CloseHandler)
|
c.socket.SetCloseHandler(c.config.CloseHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("connected")
|
||||||
|
}
|
||||||
|
|
||||||
c.startReader()
|
c.startReader()
|
||||||
c.startWriter()
|
c.startWriter()
|
||||||
|
|
||||||
@@ -162,6 +192,9 @@ func (c *Connection) startReader() {
|
|||||||
default:
|
default:
|
||||||
if c.config.ReadTimeout > 0 {
|
if c.config.ReadTimeout > 0 {
|
||||||
if err := c.socket.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil {
|
if err := c.socket.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Error("read deadline error", "error", err)
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
|
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
@@ -172,6 +205,9 @@ func (c *Connection) startReader() {
|
|||||||
}
|
}
|
||||||
messageType, data, err := c.socket.ReadMessage()
|
messageType, data, err := c.socket.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Error("read error", "error", err)
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case c.errors <- err:
|
case c.errors <- err:
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
@@ -208,6 +244,9 @@ func (c *Connection) startWriter() {
|
|||||||
case data := <-c.outgoing:
|
case data := <-c.outgoing:
|
||||||
if c.config.WriteTimeout > 0 {
|
if c.config.WriteTimeout > 0 {
|
||||||
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
if err := c.socket.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Error("write deadline error", "error", err)
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
|
case c.errors <- fmt.Errorf("failed to set write deadline: %w", err):
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
@@ -218,6 +257,9 @@ func (c *Connection) startWriter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
|
if err := c.socket.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Error("write error", "error", err)
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case c.errors <- err:
|
case c.errors <- err:
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
@@ -263,7 +305,11 @@ func (c *Connection) Close() error {
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
|
||||||
alreadyClosed := c.closed
|
alreadyClosed := c.closed
|
||||||
|
currentState := c.state
|
||||||
if !alreadyClosed {
|
if !alreadyClosed {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("closing", "state", currentState.String())
|
||||||
|
}
|
||||||
c.closed = true
|
c.closed = true
|
||||||
c.state = StateClosed
|
c.state = StateClosed
|
||||||
close(c.done)
|
close(c.done)
|
||||||
@@ -279,6 +325,15 @@ func (c *Connection) Close() error {
|
|||||||
var err error
|
var err error
|
||||||
if socket != nil {
|
if socket != nil {
|
||||||
err = socket.Close()
|
err = socket.Close()
|
||||||
|
if err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Error("socket close failed", "error", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Info("closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.wg.Wait()
|
c.wg.Wait()
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
func TestDisconnectedConnectionClose(t *testing.T) {
|
func TestDisconnectedConnectionClose(t *testing.T) {
|
||||||
t.Run("close succeeds on disconnected connection", func(t *testing.T) {
|
t.Run("close succeeds on disconnected connection", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, StateDisconnected, conn.State())
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("close is idempotent", func(t *testing.T) {
|
t.Run("close is idempotent", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
err = conn.Close()
|
||||||
@@ -34,7 +34,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("close with nil socket", func(t *testing.T) {
|
t.Run("close with nil socket", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Nil(t, conn.socket)
|
assert.Nil(t, conn.socket)
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
return expectedErr
|
return expectedErr
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
conn.socket = mockSocket
|
conn.socket = mockSocket
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("channels close after close", func(t *testing.T) {
|
t.Run("channels close after close", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
err = conn.Close()
|
||||||
@@ -92,7 +92,7 @@ func TestDisconnectedConnectionClose(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("send fails after close", func(t *testing.T) {
|
t.Run("send fails after close", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Close()
|
err = conn.Close()
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func TestStartReader(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config)
|
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ func TestStartReader(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config)
|
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
@@ -149,7 +149,7 @@ func TestStartReader(t *testing.T) {
|
|||||||
return fmt.Errorf("test error")
|
return fmt.Errorf("test error")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config)
|
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ func TestStartReader(t *testing.T) {
|
|||||||
return 0, nil, readErr
|
return 0, nil, readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, nil)
|
conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
@@ -256,7 +256,7 @@ func TestStartWriter(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config)
|
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
@@ -300,7 +300,7 @@ func TestStartWriter(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config)
|
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
@@ -333,7 +333,7 @@ func TestStartWriter(t *testing.T) {
|
|||||||
return fmt.Errorf("test error")
|
return fmt.Errorf("test error")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config)
|
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = conn.Send([]byte("test"))
|
err = conn.Send([]byte("test"))
|
||||||
@@ -359,7 +359,7 @@ func TestStartWriter(t *testing.T) {
|
|||||||
return writeErr
|
return writeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, nil)
|
conn, err := NewConnectionFromSocket(mockSocket, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func TestConnectionSend(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
tc.setup(conn)
|
tc.setup(conn)
|
||||||
@@ -76,7 +76,7 @@ func TestConnectionSend(t *testing.T) {
|
|||||||
|
|
||||||
// Run with `go test -race` to ensure no race conditions occur
|
// Run with `go test -race` to ensure no race conditions occur
|
||||||
func TestConnectionSendConcurrent(t *testing.T) {
|
func TestConnectionSendConcurrent(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// continuously consume outgoing channel in background
|
// continuously consume outgoing channel in background
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ func TestConnectionStateString(t *testing.T) {
|
|||||||
|
|
||||||
func TestConnectionState(t *testing.T) {
|
func TestConnectionState(t *testing.T) {
|
||||||
// Test initial state
|
// Test initial state
|
||||||
conn, _ := NewConnection("ws://test", nil)
|
conn, _ := NewConnection("ws://test", nil, nil)
|
||||||
assert.Equal(t, StateDisconnected, conn.State())
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
|
||||||
// Test state after FromSocket (should be Connected)
|
// Test state after FromSocket (should be Connected)
|
||||||
conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil)
|
conn2, _ := NewConnectionFromSocket(NewMockSocket(), nil, nil)
|
||||||
assert.Equal(t, StateConnected, conn2.State())
|
assert.Equal(t, StateConnected, conn2.State())
|
||||||
|
|
||||||
// Test state after close
|
// Test state after close
|
||||||
@@ -86,7 +86,7 @@ func TestNewConnection(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
conn, err := NewConnection(tc.url, tc.config)
|
conn, err := NewConnection(tc.url, tc.config, nil)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -187,7 +187,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(tc.socket, tc.config)
|
conn, err := NewConnectionFromSocket(tc.socket, tc.config, nil)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -234,7 +234,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
|
|||||||
|
|
||||||
func TestConnect(t *testing.T) {
|
func TestConnect(t *testing.T) {
|
||||||
t.Run("connect fails when socket already present", func(t *testing.T) {
|
t.Run("connect fails when socket already present", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
conn.socket = NewMockSocket()
|
conn.socket = NewMockSocket()
|
||||||
@@ -246,7 +246,7 @@ func TestConnect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("connect fails when connection closed", func(t *testing.T) {
|
t.Run("connect fails when connection closed", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
@@ -258,7 +258,7 @@ func TestConnect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("connect succeeds and starts goroutines", func(t *testing.T) {
|
t.Run("connect succeeds and starts goroutines", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
outgoingData := make(chan mockOutgoingData, 10)
|
outgoingData := make(chan mockOutgoingData, 10)
|
||||||
@@ -305,7 +305,7 @@ func TestConnect(t *testing.T) {
|
|||||||
JitterFactor: 0.0,
|
JitterFactor: 0.0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn, err := NewConnection("ws://test", config)
|
conn, err := NewConnection("ws://test", config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
attemptCount := 0
|
attemptCount := 0
|
||||||
@@ -337,7 +337,7 @@ func TestConnect(t *testing.T) {
|
|||||||
JitterFactor: 0.0,
|
JitterFactor: 0.0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn, err := NewConnection("ws://test", config)
|
conn, err := NewConnection("ws://test", config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
mockDialer := &MockDialer{
|
mockDialer := &MockDialer{
|
||||||
@@ -354,7 +354,7 @@ func TestConnect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("state transitions during connect", func(t *testing.T) {
|
t.Run("state transitions during connect", func(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, StateDisconnected, conn.State())
|
assert.Equal(t, StateDisconnected, conn.State())
|
||||||
|
|
||||||
@@ -382,7 +382,7 @@ func TestConnect(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn, err := NewConnection("ws://test", config)
|
conn, err := NewConnection("ws://test", config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
mockSocket := NewMockSocket()
|
mockSocket := NewMockSocket()
|
||||||
@@ -408,7 +408,7 @@ func TestConnect(t *testing.T) {
|
|||||||
// Connection method tests
|
// Connection method tests
|
||||||
|
|
||||||
func TestConnectionIncoming(t *testing.T) {
|
func TestConnectionIncoming(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
incoming := conn.Incoming()
|
incoming := conn.Incoming()
|
||||||
@@ -422,7 +422,7 @@ func TestConnectionIncoming(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectionErrors(t *testing.T) {
|
func TestConnectionErrors(t *testing.T) {
|
||||||
conn, err := NewConnection("ws://test", nil)
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
errors := conn.Errors()
|
errors := conn.Errors()
|
||||||
|
|||||||
470
logging_test.go
Normal file
470
logging_test.go
Normal file
@@ -0,0 +1,470 @@
|
|||||||
|
package honeybee
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helpers
|
||||||
|
|
||||||
|
type expectedLog struct {
|
||||||
|
level slog.Level
|
||||||
|
msg string
|
||||||
|
attrs map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertLogSequence(t *testing.T, records []slog.Record, expected []expectedLog) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
recIndex := 0
|
||||||
|
for expIndex, exp := range expected {
|
||||||
|
found := false
|
||||||
|
|
||||||
|
// Search forward through records
|
||||||
|
for recIndex < len(records) {
|
||||||
|
rec := records[recIndex]
|
||||||
|
|
||||||
|
if rec.Level == exp.level && strings.Contains(rec.Message, exp.msg) {
|
||||||
|
allAttrsMatch := true
|
||||||
|
for key, expectedValue := range exp.attrs {
|
||||||
|
if !assertAttributePresent(t, rec, key, expectedValue) {
|
||||||
|
allAttrsMatch = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allAttrsMatch {
|
||||||
|
found = true
|
||||||
|
recIndex++ // Consume this record
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recIndex++ // Move to next record
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected log not found: index=%d level=%v msg=%q attrs=%v",
|
||||||
|
expIndex, exp.level, exp.msg, exp.attrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertLogRecord(
|
||||||
|
t *testing.T,
|
||||||
|
records []slog.Record,
|
||||||
|
level slog.Level,
|
||||||
|
msgSnippet string,
|
||||||
|
expectedAttrs ...slog.Attr,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
record := findLogRecord(records, level, msgSnippet)
|
||||||
|
if record == nil {
|
||||||
|
t.Fatalf("no log record found with level %v and message containing %q", level, msgSnippet)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expectedAttr := range expectedAttrs {
|
||||||
|
assertAttributePresent(t, *record, expectedAttr.Key, expectedAttr.Value.Any())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findLogRecord(
|
||||||
|
records []slog.Record, level slog.Level, msgSnippet string,
|
||||||
|
) *slog.Record {
|
||||||
|
for i := range records {
|
||||||
|
if records[i].Level == level && strings.Contains(records[i].Message, msgSnippet) {
|
||||||
|
return &records[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertAttributePresent(
|
||||||
|
t *testing.T, record slog.Record, key string, expectedValue any,
|
||||||
|
) bool {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var found bool
|
||||||
|
var actualValue any
|
||||||
|
|
||||||
|
record.Attrs(func(attr slog.Attr) bool {
|
||||||
|
if attr.Key == key {
|
||||||
|
found = true
|
||||||
|
actualValue = attr.Value.Any()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("attribute %q not found in log record", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valuesEqual(actualValue, expectedValue) {
|
||||||
|
t.Errorf("attribute %q mismatch: expected=%v actual=%v", key, expectedValue, actualValue)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func valuesEqual(a, b any) bool {
|
||||||
|
// Direct equality
|
||||||
|
if a == b {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle int/int64 conversions
|
||||||
|
aInt, aIsInt := toInt64(a)
|
||||||
|
bInt, bIsInt := toInt64(b)
|
||||||
|
if aIsInt && bIsInt {
|
||||||
|
return aInt == bInt
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func toInt64(v any) (int64, bool) {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case int:
|
||||||
|
return int64(val), true
|
||||||
|
case int64:
|
||||||
|
return val, true
|
||||||
|
case int32:
|
||||||
|
return int64(val), true
|
||||||
|
case int16:
|
||||||
|
return int64(val), true
|
||||||
|
case int8:
|
||||||
|
return int64(val), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests
|
||||||
|
|
||||||
|
func TestConnectLogging(t *testing.T) {
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
conn, err := NewConnection("ws://test", nil, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
return mockSocket, nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
expected := []expectedLog{
|
||||||
|
{slog.LevelInfo, "connecting", map[string]any{}},
|
||||||
|
{slog.LevelInfo, "dialing", map[string]any{"attempt": 1}},
|
||||||
|
{slog.LevelInfo, "dial successful", map[string]any{"attempt": 1}},
|
||||||
|
{slog.LevelInfo, "connected", map[string]any{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
assertLogSequence(t, records, expected)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("max retries failure", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
MaxRetries: 2,
|
||||||
|
InitialDelay: 1 * time.Millisecond,
|
||||||
|
MaxDelay: 5 * time.Millisecond,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnection("ws://test", config, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
dialErr := fmt.Errorf("dial error")
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
return nil, nil, dialErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
expected := []expectedLog{
|
||||||
|
{slog.LevelInfo, "connecting", map[string]any{}},
|
||||||
|
{slog.LevelInfo, "dialing", map[string]any{"attempt": 1}},
|
||||||
|
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}},
|
||||||
|
{slog.LevelInfo, "dialing", map[string]any{"attempt": 2}},
|
||||||
|
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}},
|
||||||
|
{slog.LevelInfo, "dialing", map[string]any{"attempt": 3}},
|
||||||
|
{slog.LevelError, "dial failed, max retries reached", map[string]any{"attempt": 3, "error": dialErr}},
|
||||||
|
{slog.LevelError, "connection failed", map[string]any{"error": dialErr}},
|
||||||
|
}
|
||||||
|
|
||||||
|
assertLogSequence(t, records, expected)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("success after retry", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Retry: &RetryConfig{
|
||||||
|
MaxRetries: 3,
|
||||||
|
InitialDelay: 1 * time.Millisecond,
|
||||||
|
MaxDelay: 5 * time.Millisecond,
|
||||||
|
JitterFactor: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnection("ws://test", config, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
attemptCount := 0
|
||||||
|
dialErr := fmt.Errorf("dial error")
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
attemptCount++
|
||||||
|
if attemptCount < 3 {
|
||||||
|
return nil, nil, dialErr
|
||||||
|
}
|
||||||
|
return NewMockSocket(), nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
expected := []expectedLog{
|
||||||
|
{slog.LevelInfo, "connecting", map[string]any{}},
|
||||||
|
{slog.LevelInfo, "dialing", map[string]any{"attempt": 1}},
|
||||||
|
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 1, "error": dialErr}},
|
||||||
|
{slog.LevelInfo, "dialing", map[string]any{"attempt": 2}},
|
||||||
|
{slog.LevelWarn, "dial failed, retrying", map[string]any{"attempt": 2, "error": dialErr}},
|
||||||
|
{slog.LevelInfo, "dialing", map[string]any{"attempt": 3}},
|
||||||
|
{slog.LevelInfo, "dial successful", map[string]any{"attempt": 3}},
|
||||||
|
{slog.LevelInfo, "connected", map[string]any{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
assertLogSequence(t, records, expected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseLogging(t *testing.T) {
|
||||||
|
t.Run("normal close", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
expected := []expectedLog{
|
||||||
|
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
||||||
|
{slog.LevelInfo, "closed", map[string]any{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
assertLogSequence(t, records, expected)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("close with socket error", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
closeErr := fmt.Errorf("close error")
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.CloseFunc = func() error {
|
||||||
|
return closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
expected := []expectedLog{
|
||||||
|
{slog.LevelInfo, "closing", map[string]any{"state": "connected"}},
|
||||||
|
{slog.LevelError, "socket close failed", map[string]any{"error": closeErr}},
|
||||||
|
}
|
||||||
|
|
||||||
|
assertLogSequence(t, records, expected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReaderLogging(t *testing.T) {
|
||||||
|
t.Run("read deadline error", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
config := &Config{ReadTimeout: 1 * time.Millisecond}
|
||||||
|
|
||||||
|
deadlineErr := fmt.Errorf("deadline error")
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.SetReadDeadlineFunc = func(time.Time) error {
|
||||||
|
return deadlineErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
assertLogRecord(t, records, slog.LevelError, "read deadline error")
|
||||||
|
|
||||||
|
record := findLogRecord(records, slog.LevelError, "read deadline error")
|
||||||
|
assert.NotNil(t, record)
|
||||||
|
assertAttributePresent(t, *record, "error", deadlineErr)
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("read message error", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
readErr := fmt.Errorf("read error")
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
||||||
|
return 0, nil, readErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
assertLogRecord(t, records, slog.LevelError, "read error")
|
||||||
|
|
||||||
|
record := findLogRecord(records, slog.LevelError, "read error")
|
||||||
|
assert.NotNil(t, record)
|
||||||
|
assertAttributePresent(t, *record, "error", readErr)
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriterLogging(t *testing.T) {
|
||||||
|
t.Run("write deadline error", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
config := &Config{WriteTimeout: 1 * time.Millisecond}
|
||||||
|
|
||||||
|
deadlineErr := fmt.Errorf("deadline error")
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.SetWriteDeadlineFunc = func(time.Time) error {
|
||||||
|
return deadlineErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, config, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Send([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
assertLogRecord(t, records, slog.LevelError, "write deadline error")
|
||||||
|
|
||||||
|
record := findLogRecord(records, slog.LevelError, "write deadline error")
|
||||||
|
assert.NotNil(t, record)
|
||||||
|
assertAttributePresent(t, *record, "error", deadlineErr)
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write message error", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
logger := slog.New(mockHandler)
|
||||||
|
|
||||||
|
writeErr := fmt.Errorf("write error")
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockSocket.WriteMessageFunc = func(int, []byte) error {
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnectionFromSocket(mockSocket, nil, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Send([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
|
||||||
|
assertLogRecord(t, records, slog.LevelError, "write error")
|
||||||
|
|
||||||
|
record := findLogRecord(records, slog.LevelError, "write error")
|
||||||
|
assert.NotNil(t, record)
|
||||||
|
assertAttributePresent(t, *record, "error", writeErr)
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggingDisabled(t *testing.T) {
|
||||||
|
t.Run("nil logger produces no logs", func(t *testing.T) {
|
||||||
|
mockHandler := newMockSlogHandler()
|
||||||
|
|
||||||
|
conn, err := NewConnection("ws://test", nil, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mockSocket := NewMockSocket()
|
||||||
|
mockDialer := &MockDialer{
|
||||||
|
DialFunc: func(string, http.Header) (Socket, *http.Response, error) {
|
||||||
|
return mockSocket, nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.dialer = mockDialer
|
||||||
|
|
||||||
|
err = conn.Connect()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
records := mockHandler.GetRecords()
|
||||||
|
assert.Empty(t, records)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
package honeybee
|
package honeybee
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -125,8 +127,54 @@ func setupTestConnection(t *testing.T, config *Config) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
conn, err = NewConnectionFromSocket(mockSocket, config)
|
conn, err = NewConnectionFromSocket(mockSocket, config, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
return conn, mockSocket, incomingData, outgoingData
|
return conn, mockSocket, incomingData, outgoingData
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logging mocks
|
||||||
|
|
||||||
|
type mockSlogHandler struct {
|
||||||
|
records []slog.Record
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockSlogHandler() *mockSlogHandler {
|
||||||
|
return &mockSlogHandler{
|
||||||
|
records: make([]slog.Record, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSlogHandler) Handle(ctx context.Context, record slog.Record) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.records = append(m.records, record)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSlogHandler) Enabled(ctx context.Context, level slog.Level) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSlogHandler) WithGroup(name string) slog.Handler {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSlogHandler) GetRecords() []slog.Record {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
result := make([]slog.Record, len(m.records))
|
||||||
|
copy(result, m.records)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSlogHandler) Clear() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.records = make([]slog.Record, 0)
|
||||||
|
}
|
||||||
|
|||||||
22
socket.go
22
socket.go
@@ -1,6 +1,7 @@
|
|||||||
package honeybee
|
package honeybee
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ func AcquireSocket(
|
|||||||
retryMgr *RetryManager,
|
retryMgr *RetryManager,
|
||||||
dialer Dialer,
|
dialer Dialer,
|
||||||
urlStr string,
|
urlStr string,
|
||||||
|
logger *slog.Logger,
|
||||||
) (Socket, *http.Response, error) {
|
) (Socket, *http.Response, error) {
|
||||||
if retryMgr == nil {
|
if retryMgr == nil {
|
||||||
return nil, nil, errors.NewConnectionError("retry manager cannot be nil")
|
return nil, nil, errors.NewConnectionError("retry manager cannot be nil")
|
||||||
@@ -66,16 +68,36 @@ func AcquireSocket(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("dialing", "attempt", retryMgr.RetryCount()+1)
|
||||||
|
}
|
||||||
|
|
||||||
socket, resp, err := dialer.Dial(urlStr, nil)
|
socket, resp, err := dialer.Dial(urlStr, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("dial successful", "attempt", retryMgr.RetryCount()+1)
|
||||||
|
}
|
||||||
return socket, resp, nil
|
return socket, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !retryMgr.ShouldRetry() {
|
if !retryMgr.ShouldRetry() {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Error("dial failed, max retries reached",
|
||||||
|
"error", err,
|
||||||
|
"attempt", retryMgr.RetryCount()+1)
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
delay := retryMgr.CalculateDelay()
|
delay := retryMgr.CalculateDelay()
|
||||||
|
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("dial failed, retrying",
|
||||||
|
"error", err,
|
||||||
|
"attempt", retryMgr.RetryCount()+1,
|
||||||
|
"next_delay", delay)
|
||||||
|
}
|
||||||
|
|
||||||
time.Sleep(delay)
|
time.Sleep(delay)
|
||||||
retryMgr.RecordRetry()
|
retryMgr.RecordRetry()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func TestAcquireSocket(t *testing.T) {
|
|||||||
JitterFactor: 0.0,
|
JitterFactor: 0.0,
|
||||||
})
|
})
|
||||||
|
|
||||||
socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test")
|
socket, _, err := AcquireSocket(retryMgr, mockDialer, "ws://test", nil)
|
||||||
|
|
||||||
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
assert.Equal(t, tc.wantRetryCount, retryMgr.RetryCount())
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
@@ -132,7 +132,7 @@ func TestAcquireSocketGuards(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url)
|
socket, resp, err := AcquireSocket(tc.retryMgr, tc.dialer, tc.url, nil)
|
||||||
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, err, tc.wantErr)
|
assert.ErrorContains(t, err, tc.wantErr)
|
||||||
|
|||||||
Reference in New Issue
Block a user