diff --git a/request.go b/request.go index 90310a7..3f9171c 100644 --- a/request.go +++ b/request.go @@ -29,8 +29,7 @@ type ReqClosed struct { } type RequestManager struct { - reqs map[string]*request - sessions map[string]*session + reqs map[string]*request envoy *Envoy events <-chan OutboundPoolEvent @@ -45,19 +44,17 @@ type RequestManager struct { } type request struct { - id string - filters [][]byte - buffer chan ReqEvent - events chan ReqEvent - closed chan ReqClosed - deregisterOnce sync.Once -} - -type session struct { id string - req []byte + filters [][]byte + isQuery bool - request *request + active bool + + buffer chan ReqEvent + events chan ReqEvent + closed chan ReqClosed + + once sync.Once } // ---------------------------------------------------------------------------- @@ -84,8 +81,7 @@ func NewRequestManager(e *Envoy) *RequestManager { component.MustExtend(e.Context(), "request_manager")) m := &RequestManager{ - reqs: make(map[string]*request), - sessions: make(map[string]*session), + reqs: make(map[string]*request), envoy: e, events: e.SubscribeEvents(), @@ -122,6 +118,7 @@ func (m *RequestManager) Stream( buffer: buffer, events: events, closed: closed, + isQuery: false, } m.mu.Lock() @@ -131,7 +128,7 @@ func (m *RequestManager) Stream( close(events) }() if m.envoy.IsConnected() { - m.spawnSession(req, false) + m.activateLock(req) } m.mu.Unlock() @@ -155,11 +152,12 @@ func (m *RequestManager) Query( filters: filters, buffer: eventsCh, closed: closedCh, + isQuery: true, } m.mu.Lock() m.reqs[id] = req - m.spawnSession(req, true) + m.activateLock(req) m.mu.Unlock() ctx, cancel := context.WithTimeout(m.ctx, timeout) @@ -193,12 +191,12 @@ func (m *RequestManager) Cancel(id string) error { return fmt.Errorf("Cancel: unknown id %q", id) } - if _, ok := m.sessions[id]; ok { + if req.active { go m.envoy.Send(envelope.EncloseClose(id)) - delete(m.sessions, id) + req.active = false } - req.deregisterOnce.Do(func() { + req.once.Do(func() { close(req.buffer) close(req.closed) }) @@ -215,54 +213,29 @@ func (m *RequestManager) Close() { defer m.mu.Unlock() for id, req := range m.reqs { - if _, ok := m.sessions[id]; ok { + if req.active { go m.envoy.Send(envelope.EncloseClose(id)) } - req.deregisterOnce.Do(func() { + req.once.Do(func() { close(req.buffer) close(req.closed) }) delete(m.reqs, id) } - for id := range m.sessions { - delete(m.sessions, id) - } } -func (m *RequestManager) spawnSession(req *request, query bool) { - sess := &session{ - id: req.id, - req: envelope.EncloseReq(req.id, req.filters), - isQuery: query, - request: req, - } - m.sessions[req.id] = sess - go m.envoy.Send(sess.req) +func (m *RequestManager) activateLock(req *request) { + req.active = true + go m.envoy.Send(envelope.EncloseReq(req.id, req.filters)) } -func (m *RequestManager) deregister(req *request) { - req.deregisterOnce.Do(func() { +func (m *RequestManager) removeLock(req *request) { + req.active = false + req.once.Do(func() { close(req.buffer) close(req.closed) }) delete(m.reqs, req.id) - delete(m.sessions, req.id) -} - -func (m *RequestManager) start() { - m.mu.Lock() - defer m.mu.Unlock() - for _, req := range m.reqs { - m.spawnSession(req, false) - } -} - -func (m *RequestManager) stop() { - m.mu.Lock() - defer m.mu.Unlock() - for id := range m.sessions { - delete(m.sessions, id) - } } func (m *RequestManager) handleEvents() { @@ -276,12 +249,20 @@ func (m *RequestManager) handleEvents() { if !ok { return } + m.mu.Lock() switch ev.Kind { case EventConnected: - m.start() + for _, req := range m.reqs { + if !req.isQuery { + m.activateLock(req) + } + } case EventDisconnected: - m.stop() + for _, req := range m.reqs { + req.active = false + } } + m.mu.Unlock() } } } @@ -332,13 +313,13 @@ func (m *RequestManager) dispatchInbox(msg InboxMessage) { return } m.mu.Lock() - sess, ok := m.sessions[subID] + req, ok := m.reqs[subID] if !ok { m.mu.Unlock() return } - if sess.isQuery { - m.deregister(sess.request) + if req.active && req.isQuery { + m.removeLock(req) go m.envoy.Send(envelope.EncloseClose(subID)) } m.mu.Unlock() @@ -359,7 +340,7 @@ func (m *RequestManager) dispatchInbox(msg InboxMessage) { ReceivedAt: msg.ReceivedAt, Data: message, } - m.deregister(req) + m.removeLock(req) m.mu.Unlock() } } diff --git a/request_test.go b/request_test.go index 79aa58c..d07a789 100644 --- a/request_test.go +++ b/request_test.go @@ -8,7 +8,7 @@ import ( ) func TestRequestManager_Stream(t *testing.T) { - t.Run("spawns session and sends req when connected", func(t *testing.T) { + t.Run("sends req when connected", func(t *testing.T) { p, envoy := newMockEnvoy(t) p.connect() @@ -36,7 +36,7 @@ func TestRequestManager_Stream(t *testing.T) { assert.Equal(t, []byte(envelope.EncloseReq(id, filters)), got) }) - t.Run("registers but does not spawn session when disconnected", func(t *testing.T) { + t.Run("does not send req when disconnected", func(t *testing.T) { p, envoy := newMockEnvoy(t) m := NewRequestManager(envoy) @@ -126,10 +126,10 @@ func TestRequestManager_Stream(t *testing.T) { Never(t, func() bool { m.mu.RLock() - _, ok := m.sessions[id] + req, ok := m.reqs[id] m.mu.RUnlock() - return !ok - }, "session should not be removed after eose") + return !ok || !req.active + }, "request should remain registered and active after eose for stream") Never(t, func() bool { select { @@ -213,7 +213,7 @@ func TestRequestManager_Stream(t *testing.T) { } func TestRequestManager_Cancel(t *testing.T) { - t.Run("sends close, terminates session, deregisters", func(t *testing.T) { + t.Run("sends close and deregisters", func(t *testing.T) { p, envoy := newMockEnvoy(t) p.connect() Eventually(t, envoy.IsConnected, "envoy should be connected") @@ -248,10 +248,8 @@ func TestRequestManager_Cancel(t *testing.T) { assert.Equal(t, []byte(envelope.EncloseClose(id)), got) m.mu.RLock() - _, sessOk := m.sessions[id] _, reqOk := m.reqs[id] m.mu.RUnlock() - assert.False(t, sessOk, "session should be removed") assert.False(t, reqOk, "registration should be removed from reqs") Eventually(t, func() bool { @@ -264,9 +262,9 @@ func TestRequestManager_Cancel(t *testing.T) { }, "events channel should close after cancel") }) - t.Run("deregisters when no session is active", func(t *testing.T) { + t.Run("deregisters when inactive", func(t *testing.T) { _, envoy := newMockEnvoy(t) - // do not connect — no session will be spawned + // do not connect — request will not be active m := NewRequestManager(envoy) t.Cleanup(func() { m.Close() }) @@ -422,11 +420,39 @@ func TestRequestManager_Query(t *testing.T) { }) } -func _TestRequestManager_Reconnect(t *testing.T) { - t.Run("sessions terminate on disconnect", func(t *testing.T) { - // connect, open two streams - // send a disconnect event into the mock events channel - // assert both sessions are removed from sessions map +func TestRequestManager_Reconnect(t *testing.T) { + t.Run("requests deactivate on disconnect", func(t *testing.T) { + p, envoy := newMockEnvoy(t) + p.connect() + Eventually(t, envoy.IsConnected, "envoy should be connected") + + m := NewRequestManager(envoy) + t.Cleanup(func() { m.Close() }) + filters := [][]byte{[]byte(`{}`)} + idA, _, _ := m.Stream(filters) + idB, _, _ := m.Stream(filters) + + // drain both REQ sends + for range 2 { + Eventually(t, func() bool { + select { + case <-p.sent: + return true + default: + return false + } + }, "expected REQ send") + } + + p.disconnect() + + Eventually(t, func() bool { + m.mu.RLock() + defer m.mu.RUnlock() + reqA, okA := m.reqs[idA] + reqB, okB := m.reqs[idB] + return okA && okB && !reqA.active && !reqB.active + }, "both requests should be inactive after disconnect") }) t.Run("registrations survive disconnect", func(t *testing.T) { @@ -454,7 +480,7 @@ func _TestRequestManager_Reconnect(t *testing.T) { } func TestRequestManager_Close(t *testing.T) { - t.Run("terminates all sessions without deadlock", func(t *testing.T) { + t.Run("deactivates all requests without deadlock", func(t *testing.T) { p, envoy := newMockEnvoy(t) p.connect() Eventually(t, envoy.IsConnected, "envoy should be connected") @@ -493,9 +519,14 @@ func TestRequestManager_Close(t *testing.T) { }, "Close should return without deadlock") m.mu.RLock() - count := len(m.sessions) + activeCount := 0 + for _, req := range m.reqs { + if req.active { + activeCount++ + } + } m.mu.RUnlock() - assert.Equal(t, 0, count, "all sessions should be terminated") + assert.Equal(t, 0, activeCount, "all requests should be inactive after close") }) t.Run("deregisters all requests on close", func(t *testing.T) {