Removed wg from worker start. Fixed various races.

This commit is contained in:
Jay
2026-04-20 17:54:03 -04:00
parent aaf8571b9f
commit 2c7b22b3d9
8 changed files with 72 additions and 43 deletions
+2 -1
View File
@@ -278,7 +278,8 @@ func (p *Pool) addLocked(id string, socket types.Socket) error {
go func() { go func() {
defer cancel() defer cancel()
defer close(peer.done) defer close(peer.done)
worker.Start(pool, &p.wg) worker.Start(pool)
p.wg.Done()
}() }()
p.peers[id] = peer p.peers[id] = peer
+8 -9
View File
@@ -10,7 +10,7 @@ import (
) )
type Worker interface { type Worker interface {
Start(pool PoolPlugin, wg *sync.WaitGroup) Start(pool PoolPlugin)
Stop() Stop()
Send(data []byte) error Send(data []byte) error
} }
@@ -61,29 +61,28 @@ func NewWorker(
}, nil }, nil
} }
func (w *DefaultWorker) Start(pool PoolPlugin, wg *sync.WaitGroup) { func (w *DefaultWorker) Start(pool PoolPlugin) {
messages := make(chan ReceivedMessage, 256) messages := make(chan ReceivedMessage, 256)
var owg sync.WaitGroup var wg sync.WaitGroup
owg.Add(3) wg.Add(3)
go func() { go func() {
defer owg.Done() defer wg.Done()
RunReader(w.ctx, pool.OnExit, w.conn, messages, w.heartbeat) RunReader(w.ctx, pool.OnExit, w.conn, messages, w.heartbeat)
}() }()
go func() { go func() {
defer owg.Done() defer wg.Done()
RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize) RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize)
}() }()
go func() { go func() {
defer owg.Done() defer wg.Done()
RunWatchdog(w.ctx, pool.OnExit, w.heartbeat, w.config.InactivityTimeout) RunWatchdog(w.ctx, pool.OnExit, w.heartbeat, w.config.InactivityTimeout)
}() }()
owg.Wait() wg.Wait()
wg.Done()
} }
func (w *DefaultWorker) Stop() { func (w *DefaultWorker) Stop() {
+10 -7
View File
@@ -70,7 +70,7 @@ func TestWorkerStart(t *testing.T) {
v := setupWorkerTest(t) v := setupWorkerTest(t)
defer v.worker.Stop() defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg) go v.worker.Start(v.pool)
v.incoming <- honeybeetest.MockIncomingData{ v.incoming <- honeybeetest.MockIncomingData{
MsgType: websocket.TextMessage, MsgType: websocket.TextMessage,
@@ -91,7 +91,7 @@ func TestWorkerStart(t *testing.T) {
v := setupWorkerTest(t) v := setupWorkerTest(t)
defer v.worker.Stop() defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg) go v.worker.Start(v.pool)
v.incoming <- honeybeetest.MockIncomingData{ v.incoming <- honeybeetest.MockIncomingData{
Err: &websocket.CloseError{Code: websocket.CloseNormalClosure}, Err: &websocket.CloseError{Code: websocket.CloseNormalClosure},
@@ -107,7 +107,7 @@ func TestWorkerStart(t *testing.T) {
v := setupWorkerTest(t) v := setupWorkerTest(t)
defer v.worker.Stop() defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg) go v.worker.Start(v.pool)
v.incoming <- honeybeetest.MockIncomingData{ v.incoming <- honeybeetest.MockIncomingData{
Err: &websocket.CloseError{Code: websocket.CloseProtocolError}, Err: &websocket.CloseError{Code: websocket.CloseProtocolError},
@@ -142,7 +142,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go worker.Start(pool, &wg) go func() {
worker.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
val := exitKind.Load() val := exitKind.Load()
@@ -154,7 +157,7 @@ func TestWorkerStart(t *testing.T) {
func TestWorkerStop(t *testing.T) { func TestWorkerStop(t *testing.T) {
v := setupWorkerTest(t) v := setupWorkerTest(t)
go v.worker.Start(v.pool, v.wg) go func() { v.worker.Start(v.pool); v.wg.Done() }()
v.worker.Stop() v.worker.Stop()
@@ -179,7 +182,7 @@ func TestWorkerSend(t *testing.T) {
v := setupWorkerTest(t) v := setupWorkerTest(t)
defer v.worker.Stop() defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg) go v.worker.Start(v.pool)
err := v.worker.Send([]byte("hello")) err := v.worker.Send([]byte("hello"))
assert.NoError(t, err) assert.NoError(t, err)
@@ -214,7 +217,7 @@ func TestWorkerSend(t *testing.T) {
v := setupWorkerTest(t) v := setupWorkerTest(t)
defer v.worker.Stop() defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg) go v.worker.Start(v.pool)
v.conn.Close() v.conn.Close()
+4 -1
View File
@@ -189,7 +189,10 @@ func (p *Pool) Connect(id string) error {
} }
p.wg.Add(1) p.wg.Add(1)
go worker.Start(pool, &p.wg) go func() {
worker.Start(pool)
p.wg.Done()
}()
p.peers[id] = &Peer{id: id, worker: worker} p.peers[id] = &Peer{id: id, worker: worker}
+9 -13
View File
@@ -12,7 +12,7 @@ import (
// Worker // Worker
type Worker interface { type Worker interface {
Start(pool PoolPlugin, wg *sync.WaitGroup) Start(pool PoolPlugin)
Stop() Stop()
Send(data []byte) error Send(data []byte) error
} }
@@ -56,35 +56,32 @@ func NewWorker(
return w, nil return w, nil
} }
func (w *DefaultWorker) Start( func (w *DefaultWorker) Start(pool PoolPlugin) {
pool PoolPlugin,
wg *sync.WaitGroup,
) {
dial := make(chan struct{}, 1) dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1) newConn := make(chan *transport.Connection, 1)
messages := make(chan ReceivedMessage, 256) messages := make(chan ReceivedMessage, 256)
keepalive := make(chan struct{}, 1) keepalive := make(chan struct{}, 1)
var owg sync.WaitGroup var wg sync.WaitGroup
owg.Add(4) wg.Add(4)
go func() { go func() {
defer owg.Done() defer wg.Done()
RunDialer(w.id, w.ctx, pool, dial, newConn) RunDialer(w.id, w.ctx, pool, dial, newConn)
}() }()
go func() { go func() {
defer owg.Done() defer wg.Done()
RunKeepalive(w.ctx, w.heartbeat, keepalive, w.config.KeepaliveTimeout) RunKeepalive(w.ctx, w.heartbeat, keepalive, w.config.KeepaliveTimeout)
}() }()
go func() { go func() {
defer owg.Done() defer wg.Done()
RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize) RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize)
}() }()
go func() { go func() {
defer owg.Done() defer wg.Done()
session := &Session{ session := &Session{
id: w.id, id: w.id,
connPtr: &w.conn, connPtr: &w.conn,
@@ -97,8 +94,7 @@ func (w *DefaultWorker) Start(
session.Start(w.ctx, pool) session.Start(w.ctx, pool)
}() }()
owg.Wait() wg.Wait()
wg.Done()
} }
func (w *DefaultWorker) Stop() { func (w *DefaultWorker) Stop() {
+4 -2
View File
@@ -99,8 +99,10 @@ func TestRunDialer(t *testing.T) {
} }
}, "expected new connection") }, "expected new connection")
// connection was only dialed once // number of dials < number of dial requests
assert.Equal(t, int32(1), dialCount.Load()) honeybeetest.Never(t, func() bool {
return dialCount.Load() >= 5
}, "expected fewer dials than requests")
// dial channel still writable // dial channel still writable
select { select {
+7 -3
View File
@@ -38,8 +38,10 @@ func TestWorkerSend(t *testing.T) {
err := w.Send(testData) err := w.Send(testData)
assert.NoError(t, err) assert.NoError(t, err)
// one heartbeat was sent // at least one heartbeat was sent
assert.Equal(t, 1, int(heartbeatCount.Load())) honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() >= 1
}, "expected heartbeats")
// message was sent by the socket // message was sent by the socket
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
@@ -82,7 +84,9 @@ func TestWorkerSend(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
assert.Equal(t, count, int(heartbeatCount.Load())) honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() == count
}, "expected heartbeats")
}) })
t.Run("returns error if connection is unavailable", func(t *testing.T) { t.Run("returns error if connection is unavailable", func(t *testing.T) {
+28 -7
View File
@@ -62,7 +62,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go w.Start(pool, &wg) go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
select { select {
@@ -85,7 +88,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go w.Start(pool, &wg) go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
select { select {
@@ -135,7 +141,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go w.Start(pool, &wg) go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
select { select {
@@ -172,7 +181,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go w.Start(pool, &wg) go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
select { select {
@@ -215,7 +227,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go w.Start(pool, &wg) go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
select { select {
@@ -260,7 +275,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go w.Start(pool, &wg) go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
select { select {
@@ -302,7 +320,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go w.Start(pool, &wg) go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool { honeybeetest.Eventually(t, func() bool {
select { select {