diff --git a/courier_test.go b/courier_test.go index e8c4c49..2900da0 100644 --- a/courier_test.go +++ b/courier_test.go @@ -2,10 +2,12 @@ package prism import ( "context" + "fmt" "git.wisehodl.dev/jay/go-mana-component" "github.com/stretchr/testify/assert" + "sync/atomic" "testing" - // "time" + "time" ) // Helpers @@ -27,9 +29,9 @@ func newTestLetter(ctx context.Context, id uint64) OutboundLetter { func TestCourierSendsAfterConnect(t *testing.T) { ctx := component.MustNew(context.Background(), "prism", "test") - sent := make(chan []byte, 1) + var sendCount atomic.Uint32 sendFunc := func(data Envelope) error { - sent <- data + sendCount.Add(1) return nil } @@ -37,12 +39,12 @@ func TestCourierSendsAfterConnect(t *testing.T) { called := make(chan LetterOutcome, 1) c.Enqueue(newTestLetter(ctx, 1), func(o LetterOutcome) { called <- o }) - Never(t, func() bool { return len(sent) > 0 }, + Never(t, func() bool { return sendCount.Load() > 0 }, "should not have sent while disconnected") c.HandleConnect() - Eventually(t, func() bool { return len(sent) > 0 }, + Eventually(t, func() bool { return sendCount.Load() > 0 }, "should have sent after connect") var outcome LetterOutcome @@ -57,28 +59,193 @@ func TestCourierSendsAfterConnect(t *testing.T) { assert.Equal(t, uint64(1), outcome.LetterID) assert.Equal(t, "wss://test", outcome.PeerID) - assert.Equal(t, "sent", outcome.Kind.String()) + assert.Equal(t, OutcomeSent, outcome.Kind) assert.False(t, outcome.SentAt.IsZero()) assert.True(t, outcome.MissedAt.IsZero()) assert.Equal(t, 0, outcome.Retries) } -func TestCourierSequentialSends(t *testing.T) { +func TestCourierMultipleSends(t *testing.T) { + ctx := component.MustNew(context.Background(), "prism", "test") + var sendCount atomic.Uint32 + sendFunc := func(data Envelope) error { + sendCount.Add(1) + return nil + } + + c := NewCourier(ctx, sendFunc, nil) + c.HandleConnect() + + outcomes := make([]LetterOutcome, 0, 2) + called := make(chan LetterOutcome, 4) + c.Enqueue(newTestLetter(ctx, 1), func(o LetterOutcome) { called <- o }) + c.Enqueue(newTestLetter(ctx, 2), func(o LetterOutcome) { called <- o }) + + Eventually(t, func() bool { return sendCount.Load() == 2 }, + "should have sent letters") + + Eventually(t, func() bool { + select { + default: + return false + case o := <-called: + outcomes = append(outcomes, o) + return len(outcomes) == 2 + } + }, "should have returned 2 outcomes") + + // callbacks are called in goroutines and may arrive out of order + assert.Equal(t, OutcomeSent, outcomes[0].Kind) + assert.Equal(t, OutcomeSent, outcomes[1].Kind) } func TestCourierSkipsCancelledLetter(t *testing.T) { + ctx := component.MustNew(context.Background(), "prism", "test") + var sendCount atomic.Uint32 + sendFunc := func(data Envelope) error { + sendCount.Add(1) + return nil + } + + c := NewCourier(ctx, sendFunc, nil) + c.HandleConnect() + + l := newTestLetter(ctx, 1) + l.cancel() + + called := make(chan LetterOutcome, 1) + c.Enqueue(l, func(o LetterOutcome) { called <- o }) + + var outcome LetterOutcome + Eventually(t, func() bool { + select { + default: + return false + case outcome = <-called: + return true + } + }, "should have returned outcome") + + assert.Equal(t, OutcomeCancelled, outcome.Kind) } func TestCourierRetryOnFailure(t *testing.T) { + ctx := component.MustNew(context.Background(), "prism", "test") + var sendCount atomic.Uint32 + sendFunc := func(data Envelope) error { + sendCount.Add(1) + if sendCount.Load() < 3 { + return fmt.Errorf("transient failure") + } + return nil + } + + c := NewCourier(ctx, sendFunc, nil) + c.HandleConnect() + + called := make(chan LetterOutcome, 1) + c.Enqueue(newTestLetter(ctx, 1), func(o LetterOutcome) { called <- o }) + + Eventually(t, func() bool { return sendCount.Load() > 0 }, + "should send eventually") + + var outcome LetterOutcome + Eventually(t, func() bool { + select { + default: + return false + case outcome = <-called: + return true + } + }, "should have returned outcome") + + assert.Equal(t, OutcomeSent, outcome.Kind) + assert.Equal(t, 2, outcome.Retries) } func TestCourierPauseOnDisconnect(t *testing.T) { + ctx := component.MustNew(context.Background(), "prism", "test") + var sendCount atomic.Uint32 + var gate atomic.Bool + gate.Store(false) + sendFunc := func(data Envelope) error { + // gated send + if gate.Load() { + sendCount.Add(1) + return nil + } + + return fmt.Errorf("gate is closed") + } + + c := NewCourier(ctx, sendFunc, nil) + c.HandleConnect() + + // queue a letter + called := make(chan LetterOutcome, 1) + c.Enqueue(newTestLetter(ctx, 1), func(o LetterOutcome) { called <- o }) + + // manually wait for letters to queue + time.Sleep(100 * time.Millisecond) + + // manually wait for disconnect toggle + c.HandleDisconnect() + time.Sleep(100 * time.Millisecond) + + // open gate + gate.Store(true) + + // should never have sent in this time + Never(t, func() bool { return sendCount.Load() > 0 }, + "should not have sent while disconnected") + + // reconnect, gate is open, letter should send + c.HandleConnect() + Eventually(t, func() bool { return sendCount.Load() > 0 }, + "should have sent") } func TestCourierDrainOnClose(t *testing.T) { + ctx := component.MustNew(context.Background(), "prism", "test") + var sendCount atomic.Uint32 + sendFunc := func(data Envelope) error { + sendCount.Add(1) + return nil + } + + c := NewCourier(ctx, sendFunc, nil) + + // do not connect, queue some letters + outcomes := make([]LetterOutcome, 0, 2) + called := make(chan LetterOutcome, 4) + c.Enqueue(newTestLetter(ctx, 1), func(o LetterOutcome) { called <- o }) + c.Enqueue(newTestLetter(ctx, 2), func(o LetterOutcome) { called <- o }) + + // should not send any letters + Never(t, func() bool { return sendCount.Load() > 0 }, + "should not have sent letters") + + // close the courier + c.Close() + + // expect each letter to return cancelled + Eventually(t, func() bool { + select { + default: + return false + case o := <-called: + outcomes = append(outcomes, o) + return len(outcomes) == 2 + } + }, "should have returned 2 outcomes") + + if len(outcomes) >= 2 { + assert.Equal(t, OutcomeCancelled, outcomes[0].Kind) + assert.Equal(t, OutcomeCancelled, outcomes[1].Kind) + } } diff --git a/post.go b/post.go index 464031c..d9c18d8 100644 --- a/post.go +++ b/post.go @@ -3,6 +3,7 @@ package prism import ( "container/list" "context" + "fmt" "git.wisehodl.dev/jay/go-mana-component" "log/slog" "sync" @@ -88,6 +89,7 @@ type Courier struct { ctx context.Context cancel context.CancelFunc + mu sync.Mutex wg sync.WaitGroup logger *slog.Logger } @@ -199,6 +201,7 @@ func (c *Courier) HandleDisconnect() { } func (c *Courier) Close() { + c.command(&cmdCloseCourier{}) c.cancel() c.wg.Wait() } @@ -208,6 +211,7 @@ func (c *Courier) Close() { func (c *Courier) command(cmd courierCommand) { select { case <-c.ctx.Done(): + fmt.Println("here") case c.cmd <- cmd: } } @@ -364,3 +368,18 @@ func (cmd cmdHandleSendResult) apply(c *Courier) { c.doneOnce(cmd.traveller) } } + +type cmdCloseCourier struct{} + +func (cmd cmdCloseCourier) apply(c *Courier) { + // cancel remaining letters + for { + t, ok := c.pop() + if !ok { + break + } + t.letter.cancel() + t.setMissedAt(time.Now()) + c.doneOnce(t) + } +}