Removed wg from worker start. Fixed various races.

This commit is contained in:
Jay
2026-04-20 17:54:03 -04:00
parent aaf8571b9f
commit 2c7b22b3d9
8 changed files with 72 additions and 43 deletions
+2 -1
View File
@@ -278,7 +278,8 @@ func (p *Pool) addLocked(id string, socket types.Socket) error {
go func() {
defer cancel()
defer close(peer.done)
worker.Start(pool, &p.wg)
worker.Start(pool)
p.wg.Done()
}()
p.peers[id] = peer
+8 -9
View File
@@ -10,7 +10,7 @@ import (
)
type Worker interface {
Start(pool PoolPlugin, wg *sync.WaitGroup)
Start(pool PoolPlugin)
Stop()
Send(data []byte) error
}
@@ -61,29 +61,28 @@ func NewWorker(
}, nil
}
func (w *DefaultWorker) Start(pool PoolPlugin, wg *sync.WaitGroup) {
func (w *DefaultWorker) Start(pool PoolPlugin) {
messages := make(chan ReceivedMessage, 256)
var owg sync.WaitGroup
owg.Add(3)
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer owg.Done()
defer wg.Done()
RunReader(w.ctx, pool.OnExit, w.conn, messages, w.heartbeat)
}()
go func() {
defer owg.Done()
defer wg.Done()
RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize)
}()
go func() {
defer owg.Done()
defer wg.Done()
RunWatchdog(w.ctx, pool.OnExit, w.heartbeat, w.config.InactivityTimeout)
}()
owg.Wait()
wg.Done()
wg.Wait()
}
func (w *DefaultWorker) Stop() {
+10 -7
View File
@@ -70,7 +70,7 @@ func TestWorkerStart(t *testing.T) {
v := setupWorkerTest(t)
defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg)
go v.worker.Start(v.pool)
v.incoming <- honeybeetest.MockIncomingData{
MsgType: websocket.TextMessage,
@@ -91,7 +91,7 @@ func TestWorkerStart(t *testing.T) {
v := setupWorkerTest(t)
defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg)
go v.worker.Start(v.pool)
v.incoming <- honeybeetest.MockIncomingData{
Err: &websocket.CloseError{Code: websocket.CloseNormalClosure},
@@ -107,7 +107,7 @@ func TestWorkerStart(t *testing.T) {
v := setupWorkerTest(t)
defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg)
go v.worker.Start(v.pool)
v.incoming <- honeybeetest.MockIncomingData{
Err: &websocket.CloseError{Code: websocket.CloseProtocolError},
@@ -142,7 +142,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go worker.Start(pool, &wg)
go func() {
worker.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
val := exitKind.Load()
@@ -154,7 +157,7 @@ func TestWorkerStart(t *testing.T) {
func TestWorkerStop(t *testing.T) {
v := setupWorkerTest(t)
go v.worker.Start(v.pool, v.wg)
go func() { v.worker.Start(v.pool); v.wg.Done() }()
v.worker.Stop()
@@ -179,7 +182,7 @@ func TestWorkerSend(t *testing.T) {
v := setupWorkerTest(t)
defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg)
go v.worker.Start(v.pool)
err := v.worker.Send([]byte("hello"))
assert.NoError(t, err)
@@ -214,7 +217,7 @@ func TestWorkerSend(t *testing.T) {
v := setupWorkerTest(t)
defer v.worker.Stop()
go v.worker.Start(v.pool, v.wg)
go v.worker.Start(v.pool)
v.conn.Close()
+4 -1
View File
@@ -189,7 +189,10 @@ func (p *Pool) Connect(id string) error {
}
p.wg.Add(1)
go worker.Start(pool, &p.wg)
go func() {
worker.Start(pool)
p.wg.Done()
}()
p.peers[id] = &Peer{id: id, worker: worker}
+9 -13
View File
@@ -12,7 +12,7 @@ import (
// Worker
type Worker interface {
Start(pool PoolPlugin, wg *sync.WaitGroup)
Start(pool PoolPlugin)
Stop()
Send(data []byte) error
}
@@ -56,35 +56,32 @@ func NewWorker(
return w, nil
}
func (w *DefaultWorker) Start(
pool PoolPlugin,
wg *sync.WaitGroup,
) {
func (w *DefaultWorker) Start(pool PoolPlugin) {
dial := make(chan struct{}, 1)
newConn := make(chan *transport.Connection, 1)
messages := make(chan ReceivedMessage, 256)
keepalive := make(chan struct{}, 1)
var owg sync.WaitGroup
owg.Add(4)
var wg sync.WaitGroup
wg.Add(4)
go func() {
defer owg.Done()
defer wg.Done()
RunDialer(w.id, w.ctx, pool, dial, newConn)
}()
go func() {
defer owg.Done()
defer wg.Done()
RunKeepalive(w.ctx, w.heartbeat, keepalive, w.config.KeepaliveTimeout)
}()
go func() {
defer owg.Done()
defer wg.Done()
RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize)
}()
go func() {
defer owg.Done()
defer wg.Done()
session := &Session{
id: w.id,
connPtr: &w.conn,
@@ -97,8 +94,7 @@ func (w *DefaultWorker) Start(
session.Start(w.ctx, pool)
}()
owg.Wait()
wg.Done()
wg.Wait()
}
func (w *DefaultWorker) Stop() {
+4 -2
View File
@@ -99,8 +99,10 @@ func TestRunDialer(t *testing.T) {
}
}, "expected new connection")
// connection was only dialed once
assert.Equal(t, int32(1), dialCount.Load())
// number of dials < number of dial requests
honeybeetest.Never(t, func() bool {
return dialCount.Load() >= 5
}, "expected fewer dials than requests")
// dial channel still writable
select {
+7 -3
View File
@@ -38,8 +38,10 @@ func TestWorkerSend(t *testing.T) {
err := w.Send(testData)
assert.NoError(t, err)
// one heartbeat was sent
assert.Equal(t, 1, int(heartbeatCount.Load()))
// at least one heartbeat was sent
honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() >= 1
}, "expected heartbeats")
// message was sent by the socket
honeybeetest.Eventually(t, func() bool {
@@ -82,7 +84,9 @@ func TestWorkerSend(t *testing.T) {
assert.NoError(t, err)
}
assert.Equal(t, count, int(heartbeatCount.Load()))
honeybeetest.Eventually(t, func() bool {
return heartbeatCount.Load() == count
}, "expected heartbeats")
})
t.Run("returns error if connection is unavailable", func(t *testing.T) {
+28 -7
View File
@@ -62,7 +62,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go w.Start(pool, &wg)
go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
select {
@@ -85,7 +88,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go w.Start(pool, &wg)
go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
select {
@@ -135,7 +141,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go w.Start(pool, &wg)
go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
select {
@@ -172,7 +181,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go w.Start(pool, &wg)
go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
select {
@@ -215,7 +227,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go w.Start(pool, &wg)
go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
select {
@@ -260,7 +275,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go w.Start(pool, &wg)
go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
select {
@@ -302,7 +320,10 @@ func TestWorkerStart(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go w.Start(pool, &wg)
go func() {
w.Start(pool)
wg.Done()
}()
honeybeetest.Eventually(t, func() bool {
select {