From 2c7b22b3d994cf7ddb5e35def3b24d5eb9a2a30d Mon Sep 17 00:00:00 2001 From: Jay Date: Mon, 20 Apr 2026 17:54:03 -0400 Subject: [PATCH] Removed wg from worker start. Fixed various races. --- inbound/pool.go | 3 ++- inbound/worker.go | 17 ++++++++--------- inbound/worker_test.go | 17 ++++++++++------- outbound/pool.go | 5 ++++- outbound/worker.go | 22 +++++++++------------ outbound/worker_dialer_test.go | 6 ++++-- outbound/worker_send_test.go | 10 +++++++--- outbound/worker_start_test.go | 35 +++++++++++++++++++++++++++------- 8 files changed, 72 insertions(+), 43 deletions(-) diff --git a/inbound/pool.go b/inbound/pool.go index 1596b56..efe9026 100644 --- a/inbound/pool.go +++ b/inbound/pool.go @@ -278,7 +278,8 @@ func (p *Pool) addLocked(id string, socket types.Socket) error { go func() { defer cancel() defer close(peer.done) - worker.Start(pool, &p.wg) + worker.Start(pool) + p.wg.Done() }() p.peers[id] = peer diff --git a/inbound/worker.go b/inbound/worker.go index 82d01be..d5c10fa 100644 --- a/inbound/worker.go +++ b/inbound/worker.go @@ -10,7 +10,7 @@ import ( ) type Worker interface { - Start(pool PoolPlugin, wg *sync.WaitGroup) + Start(pool PoolPlugin) Stop() Send(data []byte) error } @@ -61,29 +61,28 @@ func NewWorker( }, nil } -func (w *DefaultWorker) Start(pool PoolPlugin, wg *sync.WaitGroup) { +func (w *DefaultWorker) Start(pool PoolPlugin) { messages := make(chan ReceivedMessage, 256) - var owg sync.WaitGroup - owg.Add(3) + var wg sync.WaitGroup + wg.Add(3) go func() { - defer owg.Done() + defer wg.Done() RunReader(w.ctx, pool.OnExit, w.conn, messages, w.heartbeat) }() go func() { - defer owg.Done() + defer wg.Done() RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize) }() go func() { - defer owg.Done() + defer wg.Done() RunWatchdog(w.ctx, pool.OnExit, w.heartbeat, w.config.InactivityTimeout) }() - owg.Wait() - wg.Done() + wg.Wait() } func (w *DefaultWorker) Stop() { diff --git a/inbound/worker_test.go b/inbound/worker_test.go index 040d10b..1fe4ef1 100644 --- a/inbound/worker_test.go +++ b/inbound/worker_test.go @@ -70,7 +70,7 @@ func TestWorkerStart(t *testing.T) { v := setupWorkerTest(t) defer v.worker.Stop() - go v.worker.Start(v.pool, v.wg) + go v.worker.Start(v.pool) v.incoming <- honeybeetest.MockIncomingData{ MsgType: websocket.TextMessage, @@ -91,7 +91,7 @@ func TestWorkerStart(t *testing.T) { v := setupWorkerTest(t) defer v.worker.Stop() - go v.worker.Start(v.pool, v.wg) + go v.worker.Start(v.pool) v.incoming <- honeybeetest.MockIncomingData{ Err: &websocket.CloseError{Code: websocket.CloseNormalClosure}, @@ -107,7 +107,7 @@ func TestWorkerStart(t *testing.T) { v := setupWorkerTest(t) defer v.worker.Stop() - go v.worker.Start(v.pool, v.wg) + go v.worker.Start(v.pool) v.incoming <- honeybeetest.MockIncomingData{ Err: &websocket.CloseError{Code: websocket.CloseProtocolError}, @@ -142,7 +142,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go worker.Start(pool, &wg) + go func() { + worker.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { val := exitKind.Load() @@ -154,7 +157,7 @@ func TestWorkerStart(t *testing.T) { func TestWorkerStop(t *testing.T) { v := setupWorkerTest(t) - go v.worker.Start(v.pool, v.wg) + go func() { v.worker.Start(v.pool); v.wg.Done() }() v.worker.Stop() @@ -179,7 +182,7 @@ func TestWorkerSend(t *testing.T) { v := setupWorkerTest(t) defer v.worker.Stop() - go v.worker.Start(v.pool, v.wg) + go v.worker.Start(v.pool) err := v.worker.Send([]byte("hello")) assert.NoError(t, err) @@ -214,7 +217,7 @@ func TestWorkerSend(t *testing.T) { v := setupWorkerTest(t) defer v.worker.Stop() - go v.worker.Start(v.pool, v.wg) + go v.worker.Start(v.pool) v.conn.Close() diff --git a/outbound/pool.go b/outbound/pool.go index c057279..424273c 100644 --- a/outbound/pool.go +++ b/outbound/pool.go @@ -189,7 +189,10 @@ func (p *Pool) Connect(id string) error { } p.wg.Add(1) - go worker.Start(pool, &p.wg) + go func() { + worker.Start(pool) + p.wg.Done() + }() p.peers[id] = &Peer{id: id, worker: worker} diff --git a/outbound/worker.go b/outbound/worker.go index 99d0e49..b11bf3b 100644 --- a/outbound/worker.go +++ b/outbound/worker.go @@ -12,7 +12,7 @@ import ( // Worker type Worker interface { - Start(pool PoolPlugin, wg *sync.WaitGroup) + Start(pool PoolPlugin) Stop() Send(data []byte) error } @@ -56,35 +56,32 @@ func NewWorker( return w, nil } -func (w *DefaultWorker) Start( - pool PoolPlugin, - wg *sync.WaitGroup, -) { +func (w *DefaultWorker) Start(pool PoolPlugin) { dial := make(chan struct{}, 1) newConn := make(chan *transport.Connection, 1) messages := make(chan ReceivedMessage, 256) keepalive := make(chan struct{}, 1) - var owg sync.WaitGroup - owg.Add(4) + var wg sync.WaitGroup + wg.Add(4) go func() { - defer owg.Done() + defer wg.Done() RunDialer(w.id, w.ctx, pool, dial, newConn) }() go func() { - defer owg.Done() + defer wg.Done() RunKeepalive(w.ctx, w.heartbeat, keepalive, w.config.KeepaliveTimeout) }() go func() { - defer owg.Done() + defer wg.Done() RunForwarder(w.id, w.ctx, messages, pool.Inbox, w.config.MaxQueueSize) }() go func() { - defer owg.Done() + defer wg.Done() session := &Session{ id: w.id, connPtr: &w.conn, @@ -97,8 +94,7 @@ func (w *DefaultWorker) Start( session.Start(w.ctx, pool) }() - owg.Wait() - wg.Done() + wg.Wait() } func (w *DefaultWorker) Stop() { diff --git a/outbound/worker_dialer_test.go b/outbound/worker_dialer_test.go index 29e901c..b41351a 100644 --- a/outbound/worker_dialer_test.go +++ b/outbound/worker_dialer_test.go @@ -99,8 +99,10 @@ func TestRunDialer(t *testing.T) { } }, "expected new connection") - // connection was only dialed once - assert.Equal(t, int32(1), dialCount.Load()) + // number of dials < number of dial requests + honeybeetest.Never(t, func() bool { + return dialCount.Load() >= 5 + }, "expected fewer dials than requests") // dial channel still writable select { diff --git a/outbound/worker_send_test.go b/outbound/worker_send_test.go index 32587ff..5058b0b 100644 --- a/outbound/worker_send_test.go +++ b/outbound/worker_send_test.go @@ -38,8 +38,10 @@ func TestWorkerSend(t *testing.T) { err := w.Send(testData) assert.NoError(t, err) - // one heartbeat was sent - assert.Equal(t, 1, int(heartbeatCount.Load())) + // at least one heartbeat was sent + honeybeetest.Eventually(t, func() bool { + return heartbeatCount.Load() >= 1 + }, "expected heartbeats") // message was sent by the socket honeybeetest.Eventually(t, func() bool { @@ -82,7 +84,9 @@ func TestWorkerSend(t *testing.T) { assert.NoError(t, err) } - assert.Equal(t, count, int(heartbeatCount.Load())) + honeybeetest.Eventually(t, func() bool { + return heartbeatCount.Load() == count + }, "expected heartbeats") }) t.Run("returns error if connection is unavailable", func(t *testing.T) { diff --git a/outbound/worker_start_test.go b/outbound/worker_start_test.go index 0e86f38..e057d82 100644 --- a/outbound/worker_start_test.go +++ b/outbound/worker_start_test.go @@ -62,7 +62,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(pool, &wg) + go func() { + w.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { select { @@ -85,7 +88,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(pool, &wg) + go func() { + w.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { select { @@ -135,7 +141,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(pool, &wg) + go func() { + w.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { select { @@ -172,7 +181,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(pool, &wg) + go func() { + w.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { select { @@ -215,7 +227,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(pool, &wg) + go func() { + w.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { select { @@ -260,7 +275,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(pool, &wg) + go func() { + w.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { select { @@ -302,7 +320,10 @@ func TestWorkerStart(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - go w.Start(pool, &wg) + go func() { + w.Start(pool) + wg.Done() + }() honeybeetest.Eventually(t, func() bool { select {