diff --git a/envelope/find.go b/envelope/find.go index 2333383..99bd183 100644 --- a/envelope/find.go +++ b/envelope/find.go @@ -33,7 +33,7 @@ func ParseElement(element json.RawMessage, value interface{}, position string) e return nil } -// FindEvent extracts an event from an EVENT envelope with no subscription ID. +// FindEvent extracts an event from an EVENT envelope // Expected Format: ["EVENT", event] func FindEvent(env Envelope) ([]byte, error) { var arr []json.RawMessage @@ -57,6 +57,35 @@ func FindEvent(env Envelope) ([]byte, error) { return arr[1], nil } +// FindEventWithReq extracts an event from an EVENT envelope with a subscription ID. +// Expected Format: ["EVENT", "SUBID", event] +func FindEventWithReq(env Envelope) (string, []byte, error) { + var arr []json.RawMessage + if err := json.Unmarshal(env, &arr); err != nil { + return "", nil, fmt.Errorf("%w: %v", errors.InvalidJSON, err) + } + + if err := CheckArrayLength(arr, 3); err != nil { + return "", nil, err + } + + var label string + if err := ParseElement(arr[0], &label, "envelope label"); err != nil { + return "", nil, err + } + + if err := CheckLabel(label, "EVENT"); err != nil { + return "", nil, err + } + + var req string + if err := ParseElement(arr[1], &req, "request id"); err != nil { + return "", nil, err + } + + return req, arr[2], nil +} + // FindSubscriptionEvent extracts an event and subscription ID from an EVENT envelope. // Expected Format: ["EVENT", subID, event] func FindSubscriptionEvent(env Envelope) (subID string, event []byte, err error) { diff --git a/envelope/find_test.go b/envelope/find_test.go index 2160bd3..1c47095 100644 --- a/envelope/find_test.go +++ b/envelope/find_test.go @@ -65,6 +65,68 @@ func TestFindEvent(t *testing.T) { } } +func TestFindEventWithReq(t *testing.T) { + cases := []struct { + name string + env Envelope + wantReq string + wantEvent []byte + wantErr error + wantErrText string + }{ + { + name: "valid event", + env: []byte(`["EVENT","SUBID",{"id":"abc123","kind":1}]`), + wantReq: "SUBID", + wantEvent: []byte(`{"id":"abc123","kind":1}`), + }, + { + name: "wrong label", + env: []byte(`["REQ","SUBID",{"id":"abc123","kind":1}]`), + wantErr: errors.WrongEnvelopeLabel, + wantErrText: "expected EVENT, got REQ", + }, + { + name: "invalid json", + env: []byte(`invalid`), + wantErr: errors.InvalidJSON, + }, + { + name: "missing elements", + env: []byte(`["EVENT","SUBID"]`), + wantErr: errors.InvalidEnvelope, + wantErrText: "expected 3 elements, got 2", + }, + { + name: "extraneous elements", + env: []byte(`["EVENT","SUBID",{"id":"abc123"},"extra"]`), + wantReq: "SUBID", + wantEvent: []byte(`{"id":"abc123"}`), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + gotReq, gotEvent, err := FindEventWithReq(tc.env) + + if tc.wantErr != nil || tc.wantErrText != "" { + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + } + + if tc.wantErrText != "" { + assert.ErrorContains(t, err, tc.wantErrText) + } + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.wantReq, gotReq) + assert.Equal(t, tc.wantEvent, gotEvent) + }) + } +} + func TestFindSubscriptionEvent(t *testing.T) { cases := []struct { name string