diff --git a/request.go b/request.go index 6fab8c9..824d582 100644 --- a/request.go +++ b/request.go @@ -2,13 +2,12 @@ package prism import ( "context" - "crypto/rand" - "encoding/base32" "fmt" "git.wisehodl.dev/jay/go-mana-component" "git.wisehodl.dev/jay/go-roots-ws" "log/slog" "sync" + "sync/atomic" "time" ) @@ -29,7 +28,8 @@ type ReqClosed struct { } type RequestManager struct { - reqs map[string]*request + reqs map[string]*request + counter atomic.Uint64 envoy *Envoy events <-chan OutboundPoolEvent @@ -43,6 +43,29 @@ type RequestManager struct { logger *slog.Logger } +// ---------------------------------------------------------------------------- +// Options +// ---------------------------------------------------------------------------- + +type RequestOption func(*requestOptions) + +type requestOptions struct { + id string + label string +} + +// WithID sets an explicit subscription ID. Returns an error from Stream or +// Query if the ID is already in use. +func WithID(id string) RequestOption { + return func(o *requestOptions) { o.id = id } +} + +// WithLabel sets the prefix for the generated subscription ID. The default +// prefix is "req". The counter is shared across all labels. +func WithLabel(label string) RequestOption { + return func(o *requestOptions) { o.label = label } +} + type request struct { id string filters [][]byte @@ -58,21 +81,6 @@ type request struct { closedOnce sync.Once } -// ---------------------------------------------------------------------------- -// Helpers -// ---------------------------------------------------------------------------- - -var encoder = base32.StdEncoding.WithPadding(base32.NoPadding) - -func generateID() string { - b := make([]byte, 5) - if _, err := rand.Read(b); err != nil { - panic(fmt.Sprintf("generateID: %v", err)) - } - - return encoder.EncodeToString(b) -} - // ---------------------------------------------------------------------------- // Request Manager // ---------------------------------------------------------------------------- @@ -107,16 +115,32 @@ func NewRequestManager(e *Envoy) *RequestManager { func (m *RequestManager) Stream( filters [][]byte, -) (string, <-chan ReqEvent, <-chan ReqClosed) { - id, events, closed := m.newStream(filters, false) - return id, events, closed + opts ...RequestOption, +) (string, <-chan ReqEvent, <-chan ReqClosed, error) { + id, events, closed, err := m.newStream(filters, false, opts...) + return id, events, closed, err } func (m *RequestManager) newStream( filters [][]byte, isQuery bool, -) (string, <-chan ReqEvent, <-chan ReqClosed) { - id := generateID() + opts ...RequestOption, +) (string, <-chan ReqEvent, <-chan ReqClosed, error) { + var o requestOptions + for _, opt := range opts { + opt(&o) + } + + var id string + if o.id != "" { + id = o.id + } else { + label := o.label + if label == "" { + label = "req" + } + id = fmt.Sprintf("%s:%d", label, m.counter.Add(1)) + } buffer := make(chan ReqEvent, 64) closed := make(chan ReqClosed, 1) @@ -145,24 +169,38 @@ func (m *RequestManager) newStream( } m.mu.Lock() + if _, exists := m.reqs[id]; exists { + m.mu.Unlock() + close(buffer) + if !isQuery { + // drain bufferedPipe goroutine + for range events { + } + } + return "", nil, nil, fmt.Errorf("Stream: id %q already in use", id) + } m.reqs[id] = req if m.envoy.IsConnected() { m.activate(req) } m.mu.Unlock() - return id, events, closed + return id, events, closed, nil } func (m *RequestManager) Query( filters [][]byte, timeout time.Duration, -) ([]ReqEvent, *ReqClosed) { + opts ...RequestOption, +) ([]ReqEvent, *ReqClosed, error) { if !m.envoy.IsConnected() { - return nil, nil + return nil, nil, nil } - id, eventsCh, closedCh := m.newStream(filters, true) + id, eventsCh, closedCh, err := m.newStream(filters, true, opts...) + if err != nil { + return nil, nil, err + } ctx, cancel := context.WithTimeout(m.ctx, timeout) defer cancel() @@ -172,17 +210,17 @@ func (m *RequestManager) Query( select { case ev, ok := <-eventsCh: if !ok { - return result, nil + return result, nil, nil } result = append(result, ev) case cl, ok := <-closedCh: if !ok { - return result, nil + return result, nil, nil } - return result, &cl + return result, &cl, nil case <-ctx.Done(): m.Cancel(id) - return result, nil + return result, nil, nil } } } diff --git a/request_test.go b/request_test.go index 03f4320..11a4907 100644 --- a/request_test.go +++ b/request_test.go @@ -7,6 +7,86 @@ import ( "time" ) +func TestRequestManager_Options(t *testing.T) { + t.Run("default id uses req label and monotonic counter", func(t *testing.T) { + _, envoy := newMockEnvoy(t) + m := NewRequestManager(envoy) + t.Cleanup(func() { m.Close() }) + + filters := [][]byte{[]byte(`{}`)} + idA, _, _, err := m.Stream(filters) + assert.NoError(t, err) + idB, _, _, err := m.Stream(filters) + assert.NoError(t, err) + + assert.Equal(t, "req:1", idA) + assert.Equal(t, "req:2", idB) + }) + + t.Run("WithLabel sets prefix", func(t *testing.T) { + _, envoy := newMockEnvoy(t) + m := NewRequestManager(envoy) + t.Cleanup(func() { m.Close() }) + + filters := [][]byte{[]byte(`{}`)} + idA, _, _, err := m.Stream(filters, WithLabel("feed")) + assert.NoError(t, err) + idB, _, _, err := m.Stream(filters, WithLabel("profile")) + assert.NoError(t, err) + + assert.Equal(t, "feed:1", idA) + assert.Equal(t, "profile:2", idB) + }) + + t.Run("WithID uses caller id", func(t *testing.T) { + _, envoy := newMockEnvoy(t) + m := NewRequestManager(envoy) + t.Cleanup(func() { m.Close() }) + + filters := [][]byte{[]byte(`{}`)} + id, _, _, err := m.Stream(filters, WithID("my-custom-id")) + assert.NoError(t, err) + assert.Equal(t, "my-custom-id", id) + }) + + t.Run("WithID wins over WithLabel", func(t *testing.T) { + _, envoy := newMockEnvoy(t) + m := NewRequestManager(envoy) + t.Cleanup(func() { m.Close() }) + + filters := [][]byte{[]byte(`{}`)} + id, _, _, err := m.Stream(filters, WithLabel("feed"), WithID("explicit")) + assert.NoError(t, err) + assert.Equal(t, "explicit", id) + }) + + t.Run("WithID returns error on collision", func(t *testing.T) { + _, envoy := newMockEnvoy(t) + m := NewRequestManager(envoy) + t.Cleanup(func() { m.Close() }) + + filters := [][]byte{[]byte(`{}`)} + _, _, _, err := m.Stream(filters, WithID("dup")) + assert.NoError(t, err) + + _, _, _, err = m.Stream(filters, WithID("dup")) + assert.Error(t, err) + }) + + t.Run("WithID does not advance counter", func(t *testing.T) { + _, envoy := newMockEnvoy(t) + m := NewRequestManager(envoy) + t.Cleanup(func() { m.Close() }) + + filters := [][]byte{[]byte(`{}`)} + _, _, _, err := m.Stream(filters, WithID("explicit")) + assert.NoError(t, err) + id, _, _, err := m.Stream(filters) + assert.NoError(t, err) + assert.Equal(t, "req:1", id) + }) +} + func TestRequestManager_Stream(t *testing.T) { t.Run("sends req when connected", func(t *testing.T) { p, envoy := newMockEnvoy(t) @@ -17,7 +97,8 @@ func TestRequestManager_Stream(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, closed := m.Stream(filters) + id, events, closed, err := m.Stream(filters) + assert.NoError(t, err) assert.NotEmpty(t, id) assert.NotNil(t, events) @@ -42,7 +123,8 @@ func TestRequestManager_Stream(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, closed := m.Stream(filters) + id, events, closed, err := m.Stream(filters) + assert.NoError(t, err) assert.NotEmpty(t, id) assert.NotNil(t, events) @@ -66,7 +148,8 @@ func TestRequestManager_Stream(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, _ := m.Stream(filters) + id, events, _, err := m.Stream(filters) + assert.NoError(t, err) // drain the REQ send Eventually(t, func() bool { @@ -110,7 +193,8 @@ func TestRequestManager_Stream(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, closed := m.Stream(filters) + id, events, closed, err := m.Stream(filters) + assert.NoError(t, err) // drain the REQ send Eventually(t, func() bool { @@ -162,7 +246,8 @@ func TestRequestManager_Stream(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, closed := m.Stream(filters) + id, events, closed, err := m.Stream(filters) + assert.NoError(t, err) // drain the REQ send Eventually(t, func() bool { @@ -221,7 +306,8 @@ func TestRequestManager_Cancel(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, _ := m.Stream(filters) + id, events, _, streamErr := m.Stream(filters) + assert.NoError(t, streamErr) // drain the REQ send Eventually(t, func() bool { @@ -269,7 +355,8 @@ func TestRequestManager_Cancel(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, _ := m.Stream(filters) + id, events, _, streamErr := m.Stream(filters) + assert.NoError(t, streamErr) err := m.Cancel(id) assert.NoError(t, err) @@ -327,7 +414,8 @@ func TestRequestManager_Query(t *testing.T) { p.receive(envelope.EncloseEOSE(subID)) }() - events, closed := m.Query(filters, TestTimeout) + events, closed, err := m.Query(filters, TestTimeout) + assert.NoError(t, err) assert.Len(t, events, 3) assert.Nil(t, closed) @@ -364,7 +452,8 @@ func TestRequestManager_Query(t *testing.T) { p.receive(envelope.EncloseClosed(subID, reason)) }() - events, closed := m.Query(filters, TestTimeout) + events, closed, err := m.Query(filters, TestTimeout) + assert.NoError(t, err) assert.Empty(t, events) if assert.NotNil(t, closed) { @@ -398,8 +487,9 @@ func TestRequestManager_Query(t *testing.T) { }() start := time.Now() - events, closed := m.Query(filters, queryTimeout) + events, closed, err := m.Query(filters, queryTimeout) elapsed := time.Since(start) + assert.NoError(t, err) assert.GreaterOrEqual(t, elapsed, queryTimeout) assert.Len(t, events, 2) @@ -413,7 +503,8 @@ func TestRequestManager_Query(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) - events, closed := m.Query([][]byte{[]byte(`{}`)}, TestTimeout) + events, closed, err := m.Query([][]byte{[]byte(`{}`)}, TestTimeout) + assert.NoError(t, err) assert.Nil(t, events) assert.Nil(t, closed) @@ -429,8 +520,8 @@ func TestRequestManager_Reconnect(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - idA, _, _ := m.Stream(filters) - idB, _, _ := m.Stream(filters) + idA, _, _, _ := m.Stream(filters) + idB, _, _, _ := m.Stream(filters) // drain both REQ sends for range 2 { @@ -463,8 +554,8 @@ func TestRequestManager_Reconnect(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - idA, eventsA, closedA := m.Stream(filters) - idB, eventsB, closedB := m.Stream(filters) + idA, eventsA, closedA, _ := m.Stream(filters) + idB, eventsB, closedB, _ := m.Stream(filters) for range 2 { Eventually(t, func() bool { @@ -539,8 +630,8 @@ func TestRequestManager_Reconnect(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - idA, _, _ := m.Stream(filters) - idB, _, _ := m.Stream(filters) + idA, _, _, _ := m.Stream(filters) + idB, _, _, _ := m.Stream(filters) for range 2 { Eventually(t, func() bool { @@ -599,7 +690,7 @@ func TestRequestManager_Reconnect(t *testing.T) { m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) filters := [][]byte{[]byte(`{}`)} - id, events, _ := m.Stream(filters) + id, events, _, _ := m.Stream(filters) Eventually(t, func() bool { select { @@ -650,9 +741,9 @@ func TestRequestManager_Close(t *testing.T) { m := NewRequestManager(envoy) filters := [][]byte{[]byte(`{}`)} - m.Stream(filters) - m.Stream(filters) - m.Stream(filters) + _, _, _, _ = m.Stream(filters) + _, _, _, _ = m.Stream(filters) + _, _, _, _ = m.Stream(filters) // drain all three REQ sends for range 3 { @@ -699,8 +790,8 @@ func TestRequestManager_Close(t *testing.T) { m := NewRequestManager(envoy) filters := [][]byte{[]byte(`{}`)} - _, eventsA, _ := m.Stream(filters) - _, eventsB, _ := m.Stream(filters) + _, eventsA, _, _ := m.Stream(filters) + _, eventsB, _, _ := m.Stream(filters) for range 2 { Eventually(t, func() bool {