Remove read deadlines from the connection.
Prevents the connection from closing from a lack of incoming messages.
This commit is contained in:
13
config.go
13
config.go
@@ -9,7 +9,6 @@ type CloseHandler func(code int, text string) error
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
CloseHandler CloseHandler
|
CloseHandler CloseHandler
|
||||||
ReadTimeout time.Duration
|
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
Retry *RetryConfig
|
Retry *RetryConfig
|
||||||
}
|
}
|
||||||
@@ -37,7 +36,6 @@ func NewConfig(options ...ConfigOption) (*Config, error) {
|
|||||||
func GetDefaultConfig() *Config {
|
func GetDefaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
CloseHandler: nil,
|
CloseHandler: nil,
|
||||||
ReadTimeout: 30 * time.Second,
|
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
Retry: GetDefaultRetryConfig(),
|
Retry: GetDefaultRetryConfig(),
|
||||||
}
|
}
|
||||||
@@ -80,17 +78,6 @@ func WithCloseHandler(handler CloseHandler) ConfigOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// When ReadTimeout is set to zero, read timeouts are disabled.
|
|
||||||
func WithReadTimeout(value time.Duration) ConfigOption {
|
|
||||||
return func(c *Config) error {
|
|
||||||
if value < 0 {
|
|
||||||
return errors.InvalidReadTimeout
|
|
||||||
}
|
|
||||||
c.ReadTimeout = value
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// When WriteTimeout is set to zero, read timeouts are disabled.
|
// When WriteTimeout is set to zero, read timeouts are disabled.
|
||||||
func WithWriteTimeout(value time.Duration) ConfigOption {
|
func WithWriteTimeout(value time.Duration) ConfigOption {
|
||||||
return func(c *Config) error {
|
return func(c *Config) error {
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ func TestNewConfig(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, conf, &Config{
|
assert.Equal(t, conf, &Config{
|
||||||
CloseHandler: nil,
|
CloseHandler: nil,
|
||||||
ReadTimeout: 30 * time.Second,
|
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
Retry: GetDefaultRetryConfig(),
|
Retry: GetDefaultRetryConfig(),
|
||||||
})
|
})
|
||||||
@@ -35,7 +34,6 @@ func TestDefaultConfig(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, conf, &Config{
|
assert.Equal(t, conf, &Config{
|
||||||
CloseHandler: nil,
|
CloseHandler: nil,
|
||||||
ReadTimeout: 30 * time.Second,
|
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
Retry: GetDefaultRetryConfig(),
|
Retry: GetDefaultRetryConfig(),
|
||||||
})
|
})
|
||||||
@@ -87,28 +85,6 @@ func TestWithCloseHandler(t *testing.T) {
|
|||||||
assert.Nil(t, conf.CloseHandler(0, ""))
|
assert.Nil(t, conf.CloseHandler(0, ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWithReadTimeout(t *testing.T) {
|
|
||||||
conf := &Config{}
|
|
||||||
opt := WithReadTimeout(30)
|
|
||||||
err := SetConfig(conf, opt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, conf.ReadTimeout, time.Duration(30))
|
|
||||||
|
|
||||||
// zero allowed
|
|
||||||
conf = &Config{}
|
|
||||||
opt = WithReadTimeout(0)
|
|
||||||
err = SetConfig(conf, opt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, conf.ReadTimeout, time.Duration(0))
|
|
||||||
|
|
||||||
// negative disallowed
|
|
||||||
conf = &Config{}
|
|
||||||
opt = WithReadTimeout(-30)
|
|
||||||
err = SetConfig(conf, opt)
|
|
||||||
assert.ErrorIs(t, err, errors.InvalidReadTimeout)
|
|
||||||
assert.ErrorContains(t, err, "read timeout must be positive")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithWriteTimeout(t *testing.T) {
|
func TestWithWriteTimeout(t *testing.T) {
|
||||||
conf := &Config{}
|
conf := &Config{}
|
||||||
opt := WithWriteTimeout(30)
|
opt := WithWriteTimeout(30)
|
||||||
@@ -238,7 +214,6 @@ func TestValidateConfig(t *testing.T) {
|
|||||||
name: "valid complete",
|
name: "valid complete",
|
||||||
conf: Config{
|
conf: Config{
|
||||||
CloseHandler: (func(code int, text string) error { return nil }),
|
CloseHandler: (func(code int, text string) error { return nil }),
|
||||||
ReadTimeout: time.Duration(30),
|
|
||||||
WriteTimeout: time.Duration(30),
|
WriteTimeout: time.Duration(30),
|
||||||
Retry: &RetryConfig{
|
Retry: &RetryConfig{
|
||||||
MaxRetries: 0,
|
MaxRetries: 0,
|
||||||
|
|||||||
@@ -172,47 +172,29 @@ func (c *Connection) startReader() {
|
|||||||
defer c.wg.Done()
|
defer c.wg.Done()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
messageType, data, err := c.socket.ReadMessage()
|
||||||
case <-c.done:
|
if err != nil {
|
||||||
return
|
if c.logger != nil {
|
||||||
default:
|
c.logger.Error("read error", "error", err)
|
||||||
if c.config.ReadTimeout > 0 {
|
|
||||||
if err := c.socket.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil {
|
|
||||||
if c.logger != nil {
|
|
||||||
c.logger.Error("read deadline error", "error", err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case c.errors <- fmt.Errorf("failed to set read deadline: %w", err):
|
|
||||||
case <-c.done:
|
|
||||||
}
|
|
||||||
c.shutdown()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
messageType, data, err := c.socket.ReadMessage()
|
select {
|
||||||
if err != nil {
|
case c.errors <- err:
|
||||||
if c.logger != nil {
|
case <-c.done:
|
||||||
c.logger.Error("read error", "error", err)
|
}
|
||||||
}
|
c.shutdown()
|
||||||
select {
|
return
|
||||||
case c.errors <- err:
|
}
|
||||||
case <-c.done:
|
|
||||||
}
|
if messageType == websocket.TextMessage ||
|
||||||
|
messageType == websocket.BinaryMessage {
|
||||||
|
select {
|
||||||
|
case c.incoming <- data:
|
||||||
|
case <-c.done:
|
||||||
c.shutdown()
|
c.shutdown()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if messageType == websocket.TextMessage ||
|
|
||||||
messageType == websocket.BinaryMessage {
|
|
||||||
select {
|
|
||||||
case c.incoming <- data:
|
|
||||||
case <-c.done:
|
|
||||||
c.shutdown()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -54,127 +54,6 @@ func TestStartReader(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("read timeout disabled when zero", func(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
config := &Config{ReadTimeout: 0}
|
|
||||||
|
|
||||||
mockSocket := NewMockSocket()
|
|
||||||
|
|
||||||
mockSocket.CloseFunc = func() error {
|
|
||||||
mockSocket.once.Do(func() {
|
|
||||||
close(mockSocket.closed)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
deadlineCalled := make(chan struct{}, 1)
|
|
||||||
mockSocket.SetReadDeadlineFunc = func(t time.Time) error {
|
|
||||||
deadlineCalled <- struct{}{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
assert.Never(t, func() bool {
|
|
||||||
select {
|
|
||||||
case <-deadlineCalled:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, negativeTestTimeout, testTick,
|
|
||||||
"SetReadDeadline should not be called when timeout is zero")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("read timeout sets deadline when positive", func(t *testing.T) {
|
|
||||||
config := &Config{ReadTimeout: 30}
|
|
||||||
|
|
||||||
incomingData := make(chan mockIncomingData, 10)
|
|
||||||
mockSocket := NewMockSocket()
|
|
||||||
|
|
||||||
mockSocket.CloseFunc = func() error {
|
|
||||||
mockSocket.once.Do(func() {
|
|
||||||
close(mockSocket.closed)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
deadlineCalled := make(chan struct{}, 1)
|
|
||||||
mockSocket.SetReadDeadlineFunc = func(t time.Time) error {
|
|
||||||
deadlineCalled <- struct{}{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mockSocket.ReadMessageFunc = func() (int, []byte, error) {
|
|
||||||
select {
|
|
||||||
case data := <-incomingData:
|
|
||||||
return data.msgType, data.data, data.err
|
|
||||||
case <-mockSocket.closed:
|
|
||||||
return 0, nil, io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
incomingData <- mockIncomingData{msgType: websocket.TextMessage, data: []byte("test"), err: nil}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-conn.Incoming():
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
|
||||||
select {
|
|
||||||
case <-deadlineCalled:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, testTimeout, testTick,
|
|
||||||
"SetWriteDeadline should be called when timeout is positive")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("reader exits on deadline error", func(t *testing.T) {
|
|
||||||
config := &Config{ReadTimeout: 1 * time.Millisecond}
|
|
||||||
|
|
||||||
mockSocket := NewMockSocket()
|
|
||||||
|
|
||||||
mockSocket.CloseFunc = func() error {
|
|
||||||
mockSocket.once.Do(func() {
|
|
||||||
close(mockSocket.closed)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mockSocket.SetReadDeadlineFunc = func(t time.Time) error {
|
|
||||||
return fmt.Errorf("test error")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config, nil)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
|
||||||
select {
|
|
||||||
case err := <-conn.Errors():
|
|
||||||
return err != nil &&
|
|
||||||
strings.Contains(err.Error(), "failed to set read deadline")
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}, testTimeout, testTick)
|
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
|
||||||
return conn.State() == StateClosed
|
|
||||||
}, testTimeout, testTick)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("reader exits on socket read error", func(t *testing.T) {
|
t.Run("reader exits on socket read error", func(t *testing.T) {
|
||||||
mockSocket := NewMockSocket()
|
mockSocket := NewMockSocket()
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func TestNewConnection(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "valid url, valid config",
|
name: "valid url, valid config",
|
||||||
url: "wss://relay.example.com:8080/path",
|
url: "wss://relay.example.com:8080/path",
|
||||||
config: &Config{ReadTimeout: 30 * time.Second},
|
config: &Config{WriteTimeout: 30 * time.Second},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid url",
|
name: "invalid url",
|
||||||
@@ -146,7 +146,7 @@ func TestNewConnectionFromSocket(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "valid socket with valid config",
|
name: "valid socket with valid config",
|
||||||
socket: NewMockSocket(),
|
socket: NewMockSocket(),
|
||||||
config: &Config{ReadTimeout: 30 * time.Second},
|
config: &Config{WriteTimeout: 30 * time.Second},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid config",
|
name: "invalid config",
|
||||||
|
|||||||
@@ -315,35 +315,6 @@ func TestCloseLogging(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestReaderLogging(t *testing.T) {
|
func TestReaderLogging(t *testing.T) {
|
||||||
t.Run("read deadline error", func(t *testing.T) {
|
|
||||||
mockHandler := newMockSlogHandler()
|
|
||||||
logger := slog.New(mockHandler)
|
|
||||||
|
|
||||||
config := &Config{ReadTimeout: 1 * time.Millisecond}
|
|
||||||
|
|
||||||
deadlineErr := fmt.Errorf("deadline error")
|
|
||||||
mockSocket := NewMockSocket()
|
|
||||||
mockSocket.SetReadDeadlineFunc = func(time.Time) error {
|
|
||||||
return deadlineErr
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := NewConnectionFromSocket(mockSocket, config, logger)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
|
||||||
return findLogRecord(
|
|
||||||
mockHandler.GetRecords(), slog.LevelError, "read deadline error") != nil
|
|
||||||
}, testTimeout, testTick)
|
|
||||||
|
|
||||||
records := mockHandler.GetRecords()
|
|
||||||
|
|
||||||
record := findLogRecord(records, slog.LevelError, "read deadline error")
|
|
||||||
assert.NotNil(t, record)
|
|
||||||
assertAttributePresent(t, *record, "error", deadlineErr)
|
|
||||||
|
|
||||||
conn.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("read message error", func(t *testing.T) {
|
t.Run("read message error", func(t *testing.T) {
|
||||||
mockHandler := newMockSlogHandler()
|
mockHandler := newMockSlogHandler()
|
||||||
logger := slog.New(mockHandler)
|
logger := slog.New(mockHandler)
|
||||||
|
|||||||
Reference in New Issue
Block a user