diff --git a/request.go b/request.go index f5a02d1..6a82415 100644 --- a/request.go +++ b/request.go @@ -84,7 +84,8 @@ const ( termSendFailed terminateReason = iota termCloseSent termReceivedClosed - termExternal + termDone + termCancelled ) // ---------------------------------------------------------------------------- @@ -178,9 +179,33 @@ func (m *RequestManager) Query( return } +func (m *RequestManager) Cancel(id string) error { + m.mu.Lock() + defer m.mu.Unlock() + + req, ok := m.reqs[id] + if !ok { + return fmt.Errorf("Cancel: unknown id %q", id) + } + + if sess, ok := m.sessions[id]; ok { + sess.Close() + } + + req.once.Do(func() { + close(req.buffer) + close(req.closed) + }) + delete(m.reqs, id) + + return nil +} + func (m *RequestManager) Close() { m.cancel() m.wg.Wait() + // call session.Close for each open session + // manually deregister and close each registered request. } func (m *RequestManager) spawnSession(req *request) { @@ -356,10 +381,11 @@ func (s *session) run() { return } case <-s.done: - s.terminate(termExternal) + s.terminate(termDone) return case <-s.ctx.Done(): - s.terminate(termExternal) + s.send(envelope.EncloseClose(s.id)) + s.terminate(termCancelled) return case <-s.eose: if s.closeOnEOSE { diff --git a/request_test.go b/request_test.go index 0db27fe..d39ee65 100644 --- a/request_test.go +++ b/request_test.go @@ -142,11 +142,11 @@ func TestRequestManager_Session(t *testing.T) { Eventually(t, func() bool { select { case r := <-h.terminatedWith: - return r == termExternal + return r == termDone default: return false } - }, "expected termExternal after done closed") + }, "expected termDone after done closed") }) t.Run("terminates on context cancel", func(t *testing.T) { @@ -170,11 +170,11 @@ func TestRequestManager_Session(t *testing.T) { Eventually(t, func() bool { select { case r := <-h.terminatedWith: - return r == termExternal + return r == termCancelled default: return false } - }, "expected termExternal after context cancel") + }, "expected termCancelled after context cancel") }) t.Run("terminates on closed signal", func(t *testing.T) { @@ -408,12 +408,84 @@ func TestRequestManager_Stream(t *testing.T) { func TestRequestManager_Cancel(t *testing.T) { t.Run("sends close, terminates session, deregisters", func(t *testing.T) { - // connect, call Stream, hold the id - // call Cancel(id) - // assert mock send was called with a CLOSE envelope for the id - // assert the session is removed from sessions - // assert the registration is removed from reqs - // assert the caller's events channel eventually closes + p, envoy := newMockEnvoy(t) + p.connect() + Eventually(t, envoy.IsConnected, "envoy should be connected") + + m := NewRequestManager(envoy) + filters := [][]byte{[]byte(`{}`)} + id, events, _ := m.Stream(filters) + + // drain the REQ send + Eventually(t, func() bool { + select { + case <-p.sent: + return true + default: + return false + } + }, "expected REQ send") + + err := m.Cancel(id) + assert.NoError(t, err) + + var got []byte + Eventually(t, func() bool { + select { + case got = <-p.sent: + return true + default: + return false + } + }, "expected CLOSE send") + assert.Equal(t, []byte(envelope.EncloseClose(id)), got) + + Eventually(t, func() bool { + m.mu.RLock() + _, ok := m.sessions[id] + m.mu.RUnlock() + return !ok + }, "session should be removed") + + m.mu.RLock() + _, ok := m.reqs[id] + m.mu.RUnlock() + assert.False(t, ok, "registration should be removed from reqs") + + Eventually(t, func() bool { + select { + case _, ok := <-events: + return !ok + default: + return false + } + }, "events channel should close after cancel") + }) + + t.Run("deregisters when no session is active", func(t *testing.T) { + _, envoy := newMockEnvoy(t) + // do not connect — no session will be spawned + + m := NewRequestManager(envoy) + filters := [][]byte{[]byte(`{}`)} + id, events, _ := m.Stream(filters) + + err := m.Cancel(id) + assert.NoError(t, err) + + m.mu.RLock() + _, ok := m.reqs[id] + m.mu.RUnlock() + assert.False(t, ok, "registration should be removed from reqs") + + Eventually(t, func() bool { + select { + case _, ok := <-events: + return !ok + default: + return false + } + }, "events channel should close after cancel") }) t.Run("returns error for unknown id", func(t *testing.T) { @@ -532,10 +604,13 @@ func TestRequestManager_Close(t *testing.T) { // assert all sessions are terminated (sessions map empty) }) - t.Run("does not deregister requests on close", func(t *testing.T) { + t.Run("deregisters all requests on close", func(t *testing.T) { // connect, open two streams // call manager.Close() - // assert registrations remain in reqs - // termExternal does not deregister; that is the caller's domain via Cancel + // -- calls session.Close for each registration + // -- manually cleans up the rest + // all sessions are stopped + // all request registrations are removed + // all registration channels close }) }