feat: add WithID/WithLabel options and monotonic counter for subscription IDs

This commit is contained in:
Jay
2026-05-17 22:37:57 -04:00
parent 715dfa17b0
commit 5bbfd9523b
2 changed files with 183 additions and 54 deletions
+69 -31
View File
@@ -2,13 +2,12 @@ package prism
import ( import (
"context" "context"
"crypto/rand"
"encoding/base32"
"fmt" "fmt"
"git.wisehodl.dev/jay/go-mana-component" "git.wisehodl.dev/jay/go-mana-component"
"git.wisehodl.dev/jay/go-roots-ws" "git.wisehodl.dev/jay/go-roots-ws"
"log/slog" "log/slog"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@@ -29,7 +28,8 @@ type ReqClosed struct {
} }
type RequestManager struct { type RequestManager struct {
reqs map[string]*request reqs map[string]*request
counter atomic.Uint64
envoy *Envoy envoy *Envoy
events <-chan OutboundPoolEvent events <-chan OutboundPoolEvent
@@ -43,6 +43,29 @@ type RequestManager struct {
logger *slog.Logger logger *slog.Logger
} }
// ----------------------------------------------------------------------------
// Options
// ----------------------------------------------------------------------------
type RequestOption func(*requestOptions)
type requestOptions struct {
id string
label string
}
// WithID sets an explicit subscription ID. Returns an error from Stream or
// Query if the ID is already in use.
func WithID(id string) RequestOption {
return func(o *requestOptions) { o.id = id }
}
// WithLabel sets the prefix for the generated subscription ID. The default
// prefix is "req". The counter is shared across all labels.
func WithLabel(label string) RequestOption {
return func(o *requestOptions) { o.label = label }
}
type request struct { type request struct {
id string id string
filters [][]byte filters [][]byte
@@ -58,21 +81,6 @@ type request struct {
closedOnce sync.Once closedOnce sync.Once
} }
// ----------------------------------------------------------------------------
// Helpers
// ----------------------------------------------------------------------------
var encoder = base32.StdEncoding.WithPadding(base32.NoPadding)
func generateID() string {
b := make([]byte, 5)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("generateID: %v", err))
}
return encoder.EncodeToString(b)
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Request Manager // Request Manager
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@@ -107,16 +115,32 @@ func NewRequestManager(e *Envoy) *RequestManager {
func (m *RequestManager) Stream( func (m *RequestManager) Stream(
filters [][]byte, filters [][]byte,
) (string, <-chan ReqEvent, <-chan ReqClosed) { opts ...RequestOption,
id, events, closed := m.newStream(filters, false) ) (string, <-chan ReqEvent, <-chan ReqClosed, error) {
return id, events, closed id, events, closed, err := m.newStream(filters, false, opts...)
return id, events, closed, err
} }
func (m *RequestManager) newStream( func (m *RequestManager) newStream(
filters [][]byte, filters [][]byte,
isQuery bool, isQuery bool,
) (string, <-chan ReqEvent, <-chan ReqClosed) { opts ...RequestOption,
id := generateID() ) (string, <-chan ReqEvent, <-chan ReqClosed, error) {
var o requestOptions
for _, opt := range opts {
opt(&o)
}
var id string
if o.id != "" {
id = o.id
} else {
label := o.label
if label == "" {
label = "req"
}
id = fmt.Sprintf("%s:%d", label, m.counter.Add(1))
}
buffer := make(chan ReqEvent, 64) buffer := make(chan ReqEvent, 64)
closed := make(chan ReqClosed, 1) closed := make(chan ReqClosed, 1)
@@ -145,24 +169,38 @@ func (m *RequestManager) newStream(
} }
m.mu.Lock() m.mu.Lock()
if _, exists := m.reqs[id]; exists {
m.mu.Unlock()
close(buffer)
if !isQuery {
// drain bufferedPipe goroutine
for range events {
}
}
return "", nil, nil, fmt.Errorf("Stream: id %q already in use", id)
}
m.reqs[id] = req m.reqs[id] = req
if m.envoy.IsConnected() { if m.envoy.IsConnected() {
m.activate(req) m.activate(req)
} }
m.mu.Unlock() m.mu.Unlock()
return id, events, closed return id, events, closed, nil
} }
func (m *RequestManager) Query( func (m *RequestManager) Query(
filters [][]byte, filters [][]byte,
timeout time.Duration, timeout time.Duration,
) ([]ReqEvent, *ReqClosed) { opts ...RequestOption,
) ([]ReqEvent, *ReqClosed, error) {
if !m.envoy.IsConnected() { if !m.envoy.IsConnected() {
return nil, nil return nil, nil, nil
} }
id, eventsCh, closedCh := m.newStream(filters, true) id, eventsCh, closedCh, err := m.newStream(filters, true, opts...)
if err != nil {
return nil, nil, err
}
ctx, cancel := context.WithTimeout(m.ctx, timeout) ctx, cancel := context.WithTimeout(m.ctx, timeout)
defer cancel() defer cancel()
@@ -172,17 +210,17 @@ func (m *RequestManager) Query(
select { select {
case ev, ok := <-eventsCh: case ev, ok := <-eventsCh:
if !ok { if !ok {
return result, nil return result, nil, nil
} }
result = append(result, ev) result = append(result, ev)
case cl, ok := <-closedCh: case cl, ok := <-closedCh:
if !ok { if !ok {
return result, nil return result, nil, nil
} }
return result, &cl return result, &cl, nil
case <-ctx.Done(): case <-ctx.Done():
m.Cancel(id) m.Cancel(id)
return result, nil return result, nil, nil
} }
} }
} }
+114 -23
View File
@@ -7,6 +7,86 @@ import (
"time" "time"
) )
func TestRequestManager_Options(t *testing.T) {
t.Run("default id uses req label and monotonic counter", func(t *testing.T) {
_, envoy := newMockEnvoy(t)
m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)}
idA, _, _, err := m.Stream(filters)
assert.NoError(t, err)
idB, _, _, err := m.Stream(filters)
assert.NoError(t, err)
assert.Equal(t, "req:1", idA)
assert.Equal(t, "req:2", idB)
})
t.Run("WithLabel sets prefix", func(t *testing.T) {
_, envoy := newMockEnvoy(t)
m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)}
idA, _, _, err := m.Stream(filters, WithLabel("feed"))
assert.NoError(t, err)
idB, _, _, err := m.Stream(filters, WithLabel("profile"))
assert.NoError(t, err)
assert.Equal(t, "feed:1", idA)
assert.Equal(t, "profile:2", idB)
})
t.Run("WithID uses caller id", func(t *testing.T) {
_, envoy := newMockEnvoy(t)
m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)}
id, _, _, err := m.Stream(filters, WithID("my-custom-id"))
assert.NoError(t, err)
assert.Equal(t, "my-custom-id", id)
})
t.Run("WithID wins over WithLabel", func(t *testing.T) {
_, envoy := newMockEnvoy(t)
m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)}
id, _, _, err := m.Stream(filters, WithLabel("feed"), WithID("explicit"))
assert.NoError(t, err)
assert.Equal(t, "explicit", id)
})
t.Run("WithID returns error on collision", func(t *testing.T) {
_, envoy := newMockEnvoy(t)
m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)}
_, _, _, err := m.Stream(filters, WithID("dup"))
assert.NoError(t, err)
_, _, _, err = m.Stream(filters, WithID("dup"))
assert.Error(t, err)
})
t.Run("WithID does not advance counter", func(t *testing.T) {
_, envoy := newMockEnvoy(t)
m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)}
_, _, _, err := m.Stream(filters, WithID("explicit"))
assert.NoError(t, err)
id, _, _, err := m.Stream(filters)
assert.NoError(t, err)
assert.Equal(t, "req:1", id)
})
}
func TestRequestManager_Stream(t *testing.T) { func TestRequestManager_Stream(t *testing.T) {
t.Run("sends req when connected", func(t *testing.T) { t.Run("sends req when connected", func(t *testing.T) {
p, envoy := newMockEnvoy(t) p, envoy := newMockEnvoy(t)
@@ -17,7 +97,8 @@ func TestRequestManager_Stream(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, closed := m.Stream(filters) id, events, closed, err := m.Stream(filters)
assert.NoError(t, err)
assert.NotEmpty(t, id) assert.NotEmpty(t, id)
assert.NotNil(t, events) assert.NotNil(t, events)
@@ -42,7 +123,8 @@ func TestRequestManager_Stream(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, closed := m.Stream(filters) id, events, closed, err := m.Stream(filters)
assert.NoError(t, err)
assert.NotEmpty(t, id) assert.NotEmpty(t, id)
assert.NotNil(t, events) assert.NotNil(t, events)
@@ -66,7 +148,8 @@ func TestRequestManager_Stream(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, _ := m.Stream(filters) id, events, _, err := m.Stream(filters)
assert.NoError(t, err)
// drain the REQ send // drain the REQ send
Eventually(t, func() bool { Eventually(t, func() bool {
@@ -110,7 +193,8 @@ func TestRequestManager_Stream(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, closed := m.Stream(filters) id, events, closed, err := m.Stream(filters)
assert.NoError(t, err)
// drain the REQ send // drain the REQ send
Eventually(t, func() bool { Eventually(t, func() bool {
@@ -162,7 +246,8 @@ func TestRequestManager_Stream(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, closed := m.Stream(filters) id, events, closed, err := m.Stream(filters)
assert.NoError(t, err)
// drain the REQ send // drain the REQ send
Eventually(t, func() bool { Eventually(t, func() bool {
@@ -221,7 +306,8 @@ func TestRequestManager_Cancel(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, _ := m.Stream(filters) id, events, _, streamErr := m.Stream(filters)
assert.NoError(t, streamErr)
// drain the REQ send // drain the REQ send
Eventually(t, func() bool { Eventually(t, func() bool {
@@ -269,7 +355,8 @@ func TestRequestManager_Cancel(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, _ := m.Stream(filters) id, events, _, streamErr := m.Stream(filters)
assert.NoError(t, streamErr)
err := m.Cancel(id) err := m.Cancel(id)
assert.NoError(t, err) assert.NoError(t, err)
@@ -327,7 +414,8 @@ func TestRequestManager_Query(t *testing.T) {
p.receive(envelope.EncloseEOSE(subID)) p.receive(envelope.EncloseEOSE(subID))
}() }()
events, closed := m.Query(filters, TestTimeout) events, closed, err := m.Query(filters, TestTimeout)
assert.NoError(t, err)
assert.Len(t, events, 3) assert.Len(t, events, 3)
assert.Nil(t, closed) assert.Nil(t, closed)
@@ -364,7 +452,8 @@ func TestRequestManager_Query(t *testing.T) {
p.receive(envelope.EncloseClosed(subID, reason)) p.receive(envelope.EncloseClosed(subID, reason))
}() }()
events, closed := m.Query(filters, TestTimeout) events, closed, err := m.Query(filters, TestTimeout)
assert.NoError(t, err)
assert.Empty(t, events) assert.Empty(t, events)
if assert.NotNil(t, closed) { if assert.NotNil(t, closed) {
@@ -398,8 +487,9 @@ func TestRequestManager_Query(t *testing.T) {
}() }()
start := time.Now() start := time.Now()
events, closed := m.Query(filters, queryTimeout) events, closed, err := m.Query(filters, queryTimeout)
elapsed := time.Since(start) elapsed := time.Since(start)
assert.NoError(t, err)
assert.GreaterOrEqual(t, elapsed, queryTimeout) assert.GreaterOrEqual(t, elapsed, queryTimeout)
assert.Len(t, events, 2) assert.Len(t, events, 2)
@@ -413,7 +503,8 @@ func TestRequestManager_Query(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
events, closed := m.Query([][]byte{[]byte(`{}`)}, TestTimeout) events, closed, err := m.Query([][]byte{[]byte(`{}`)}, TestTimeout)
assert.NoError(t, err)
assert.Nil(t, events) assert.Nil(t, events)
assert.Nil(t, closed) assert.Nil(t, closed)
@@ -429,8 +520,8 @@ func TestRequestManager_Reconnect(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
idA, _, _ := m.Stream(filters) idA, _, _, _ := m.Stream(filters)
idB, _, _ := m.Stream(filters) idB, _, _, _ := m.Stream(filters)
// drain both REQ sends // drain both REQ sends
for range 2 { for range 2 {
@@ -463,8 +554,8 @@ func TestRequestManager_Reconnect(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
idA, eventsA, closedA := m.Stream(filters) idA, eventsA, closedA, _ := m.Stream(filters)
idB, eventsB, closedB := m.Stream(filters) idB, eventsB, closedB, _ := m.Stream(filters)
for range 2 { for range 2 {
Eventually(t, func() bool { Eventually(t, func() bool {
@@ -539,8 +630,8 @@ func TestRequestManager_Reconnect(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
idA, _, _ := m.Stream(filters) idA, _, _, _ := m.Stream(filters)
idB, _, _ := m.Stream(filters) idB, _, _, _ := m.Stream(filters)
for range 2 { for range 2 {
Eventually(t, func() bool { Eventually(t, func() bool {
@@ -599,7 +690,7 @@ func TestRequestManager_Reconnect(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
t.Cleanup(func() { m.Close() }) t.Cleanup(func() { m.Close() })
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
id, events, _ := m.Stream(filters) id, events, _, _ := m.Stream(filters)
Eventually(t, func() bool { Eventually(t, func() bool {
select { select {
@@ -650,9 +741,9 @@ func TestRequestManager_Close(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
m.Stream(filters) _, _, _, _ = m.Stream(filters)
m.Stream(filters) _, _, _, _ = m.Stream(filters)
m.Stream(filters) _, _, _, _ = m.Stream(filters)
// drain all three REQ sends // drain all three REQ sends
for range 3 { for range 3 {
@@ -699,8 +790,8 @@ func TestRequestManager_Close(t *testing.T) {
m := NewRequestManager(envoy) m := NewRequestManager(envoy)
filters := [][]byte{[]byte(`{}`)} filters := [][]byte{[]byte(`{}`)}
_, eventsA, _ := m.Stream(filters) _, eventsA, _, _ := m.Stream(filters)
_, eventsB, _ := m.Stream(filters) _, eventsB, _, _ := m.Stream(filters)
for range 2 { for range 2 {
Eventually(t, func() bool { Eventually(t, func() bool {