diff --git a/initiatorpool/errors.go b/initiatorpool/errors.go index 8dfb8b4..a20fcff 100644 --- a/initiatorpool/errors.go +++ b/initiatorpool/errors.go @@ -17,8 +17,8 @@ var ( ErrConnectionUnavailable = errors.New("connection unavailable") ) -func NewConfigError(text string) error { - return fmt.Errorf("configuration error: %s", text) +func NewConfigError(err error) error { + return fmt.Errorf("configuration error: %w", err) } func NewPoolError(err error) error { diff --git a/initiatorpool/worker_session_inner_test.go b/initiatorpool/worker_session_inner_test.go index f86fea1..bcb6f01 100644 --- a/initiatorpool/worker_session_inner_test.go +++ b/initiatorpool/worker_session_inner_test.go @@ -125,7 +125,7 @@ func TestRunReader(t *testing.T) { incomingData <- honeybeetest.MockIncomingData{Err: io.EOF} err := <-conn.Errors() - assert.Equal(t, io.EOF, err) + assert.ErrorIs(t, err, io.EOF) honeybeetest.Eventually(t, func() bool { return conn.State() == transport.StateClosed diff --git a/transport/config.go b/transport/config.go index 0ff5be2..13c3a9b 100644 --- a/transport/config.go +++ b/transport/config.go @@ -86,7 +86,7 @@ func ValidateConnectionConfig(config *ConnectionConfig) error { } if config.Retry.InitialDelay > config.Retry.MaxDelay { - return NewConfigError("initial delay may not exceed maximum delay") + return NewConfigError(InvalidDelays) } } diff --git a/transport/connection.go b/transport/connection.go index c2e0c75..41a5273 100644 --- a/transport/connection.go +++ b/transport/connection.go @@ -91,7 +91,7 @@ func NewConnectionFromSocket( socket types.Socket, config *ConnectionConfig, logger *slog.Logger, ) (*Connection, error) { if socket == nil { - return nil, NewConnectionError("socket cannot be nil") + return nil, NewConnectionError(ErrNilSocket) } if config == nil { @@ -128,11 +128,11 @@ func (c *Connection) Connect(ctx context.Context) error { defer c.mu.Unlock() if c.socket != nil { - return NewConnectionError("connection already has socket") + return NewConnectionError(ErrSocketExists) } if c.closed { - return NewConnectionError("connection is closed") + return NewConnectionError(ErrConnectionClosed) } if c.logger != nil { @@ -150,7 +150,7 @@ func (c *Connection) Connect(ctx context.Context) error { if c.logger != nil { c.logger.Error("connection failed", "error", err) } - return err + return NewConnectionError(err) } c.socket = socket @@ -217,7 +217,7 @@ func (c *Connection) shutdownSetClosed(wait bool) error { c.mu.Lock() if c.closed { c.mu.Unlock() - return ErrConnectionClosed + return NewConnectionError(ErrConnectionClosed) } c.closed = true c.state = StateClosed @@ -277,29 +277,37 @@ func (c *Connection) startReader() { default: messageType, data, err := c.socket.ReadMessage() if err != nil { - if c.logger != nil { - var closeErr *websocket.CloseError - if errors.As(err, &closeErr) { - switch closeErr.Code { - case websocket.CloseNormalClosure, websocket.CloseGoingAway: + var wrappedErr error + var closeErr *websocket.CloseError + if errors.As(err, &closeErr) { + switch closeErr.Code { + case websocket.CloseNormalClosure, websocket.CloseGoingAway: + if c.logger != nil { c.logger.Info("connection closed by peer", "code", closeErr.Code, "text", closeErr.Text, ) - default: + } + wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedClean, err) + default: + if c.logger != nil { c.logger.Error("unexpected close", "code", closeErr.Code, "text", closeErr.Text, ) } - } else { + wrappedErr = fmt.Errorf("%w: %w", ErrPeerClosedUnexpected, err) + } + } else { + if c.logger != nil { c.logger.Error("read error", "error", err) } + wrappedErr = fmt.Errorf("%w: %w", ErrReadError, err) } select { case <-c.done: - case c.errors <- err: + case c.errors <- wrappedErr: } return } @@ -316,7 +324,6 @@ func (c *Connection) startReader() { } } }() - } func (c *Connection) Send(data []byte) error { @@ -324,7 +331,7 @@ func (c *Connection) Send(data []byte) error { defer c.writeMu.Unlock() if c.closed { - return ErrConnectionClosed + return NewConnectionError(ErrConnectionClosed) } if c.config.WriteTimeout > 0 { @@ -333,7 +340,7 @@ func (c *Connection) Send(data []byte) error { c.logger.Error("write deadline error", "error", err) } c.shutdownExternal() - return fmt.Errorf("failed to set write deadline: %w", err) + return NewConnectionError(fmt.Errorf("%w: %w", ErrFailedWriteDeadline, err)) } } @@ -341,7 +348,7 @@ func (c *Connection) Send(data []byte) error { if c.logger != nil { c.logger.Error("write error", "error", err) } - return fmt.Errorf("%w: %w", ErrWriteFailed, err) + return NewConnectionError(fmt.Errorf("%w: %w", ErrWriteFailed, err)) } return nil diff --git a/transport/connection_close_test.go b/transport/connection_close_test.go index e6e5405..fbd2206 100644 --- a/transport/connection_close_test.go +++ b/transport/connection_close_test.go @@ -73,7 +73,7 @@ func TestDisconnectedConnectionClose(t *testing.T) { err = conn.Send([]byte("test")) assert.Error(t, err) - assert.ErrorContains(t, err, "connection closed") + assert.ErrorIs(t, err, ErrConnectionClosed) }) } diff --git a/transport/connection_send_test.go b/transport/connection_send_test.go index 7cf64f1..d556eec 100644 --- a/transport/connection_send_test.go +++ b/transport/connection_send_test.go @@ -213,7 +213,7 @@ func TestConnectionSend(t *testing.T) { defer conn.Close() err = conn.Send([]byte("test")) - assert.ErrorContains(t, err, "failed to set write deadline: test error") + assert.ErrorIs(t, err, ErrFailedWriteDeadline) honeybeetest.Eventually(t, func() bool { return conn.State() == StateClosed diff --git a/transport/connection_test.go b/transport/connection_test.go index 8c7f8e6..95cc1b9 100644 --- a/transport/connection_test.go +++ b/transport/connection_test.go @@ -3,9 +3,11 @@ package transport import ( "bytes" "context" + "errors" "fmt" "git.wisehodl.dev/jay/go-honeybee/honeybeetest" "git.wisehodl.dev/jay/go-honeybee/types" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "io" "net/http" @@ -241,7 +243,7 @@ func TestConnect(t *testing.T) { err = conn.Connect(context.Background()) assert.Error(t, err) - assert.ErrorContains(t, err, "already has socket") + assert.ErrorIs(t, err, ErrSocketExists) assert.Equal(t, StateDisconnected, conn.State()) }) @@ -253,7 +255,7 @@ func TestConnect(t *testing.T) { err = conn.Connect(context.Background()) assert.Error(t, err) - assert.ErrorContains(t, err, "connection is closed") + assert.ErrorIs(t, err, ErrConnectionClosed) assert.Equal(t, StateClosed, conn.State()) }) @@ -467,17 +469,72 @@ func TestConnectionIncoming(t *testing.T) { } func TestConnectionErrors(t *testing.T) { - conn, err := NewConnection("ws://test", nil, nil) - assert.NoError(t, err) + t.Run("clean close by peer", func(t *testing.T) { + mockSocket := honeybeetest.NewMockSocket() + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + return 0, nil, &websocket.CloseError{ + Code: websocket.CloseNormalClosure, + Text: "goodbye", + } + } - errors := conn.Errors() - assert.NotNil(t, errors) + conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + assert.NoError(t, err) + defer conn.Close() - // send data through the channel to verify they are the same - testErr := fmt.Errorf("test error") - conn.errors <- testErr - received := <-errors - assert.Equal(t, testErr, received) + honeybeetest.Eventually(t, func() bool { + select { + case err := <-conn.Errors(): + return errors.Is(err, ErrPeerClosedClean) + default: + return false + } + }, "expected clean close error") + }) + + t.Run("unexpected close", func(t *testing.T) { + mockSocket := honeybeetest.NewMockSocket() + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + return 0, nil, &websocket.CloseError{ + Code: websocket.CloseProtocolError, + Text: "bad protocol", + } + } + + conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + assert.NoError(t, err) + defer conn.Close() + + honeybeetest.Eventually(t, func() bool { + select { + case err := <-conn.Errors(): + return errors.Is(err, ErrPeerClosedUnexpected) + default: + return false + } + }, "expected unexpected close error") + }) + + t.Run("read error", func(t *testing.T) { + mockSocket := honeybeetest.NewMockSocket() + mockSocket.ReadMessageFunc = func() (int, []byte, error) { + return 0, nil, io.EOF + } + + conn, err := NewConnectionFromSocket(mockSocket, nil, nil) + assert.NoError(t, err) + defer conn.Close() + + honeybeetest.Eventually(t, func() bool { + select { + case err := <-conn.Errors(): + return errors.Is(err, ErrReadError) + default: + return false + } + }, "expected read error") + + }) } // Test helpers diff --git a/transport/errors.go b/transport/errors.go index b176568..84c1bd0 100644 --- a/transport/errors.go +++ b/transport/errors.go @@ -13,16 +13,28 @@ var ( InvalidRetryInitialDelay = errors.New("initial delay must be positive") InvalidRetryMaxDelay = errors.New("max delay must be positive") InvalidRetryJitterFactor = errors.New("jitter factor must be between 0.0 and 1.0") + InvalidDelays = errors.New("initial delay may not exceed maximum delay") + + // Socket Errors + ErrNilRetryManager = errors.New("retry manager cannot be nil") + ErrNilDialer = errors.New("dialer cannot be nil") + ErrEmptyURL = errors.New("URL cannot be empty") // Connection Errors - ErrConnectionClosed = errors.New("connection closed") - ErrWriteFailed = errors.New("write failed") + ErrConnectionClosed = errors.New("connection closed") + ErrWriteFailed = errors.New("write failed") + ErrNilSocket = errors.New("socket cannot be nil") + ErrSocketExists = errors.New("socket already exists") + ErrFailedWriteDeadline = errors.New("failed to set write deadline") + ErrPeerClosedClean = errors.New("peer closed connection cleanly") + ErrPeerClosedUnexpected = errors.New("peer closed connection unexpectedly") + ErrReadError = errors.New("read error") ) -func NewConfigError(text string) error { - return fmt.Errorf("configuration error: %s", text) +func NewConfigError(err error) error { + return fmt.Errorf("configuration error: %w", err) } -func NewConnectionError(text string) error { - return fmt.Errorf("connection error: %s", text) +func NewConnectionError(err error) error { + return fmt.Errorf("connection error: %w", err) } diff --git a/transport/socket.go b/transport/socket.go index e13debf..7eaf327 100644 --- a/transport/socket.go +++ b/transport/socket.go @@ -54,13 +54,13 @@ func AcquireSocket( } if retryMgr == nil { - return nil, nil, NewConnectionError("retry manager cannot be nil") + return nil, nil, NewConnectionError(ErrNilRetryManager) } if dialer == nil { - return nil, nil, NewConnectionError("dialer cannot be nil") + return nil, nil, NewConnectionError(ErrNilDialer) } if url == "" { - return nil, nil, NewConnectionError("URL cannot be empty") + return nil, nil, NewConnectionError(ErrEmptyURL) } for {