diff --git a/write.go b/write.go index fb44a48..74e5ea5 100644 --- a/write.go +++ b/write.go @@ -18,7 +18,6 @@ type WriteOptions struct { type EventTraveller struct { ID string - JSON []byte Event roots.ValidatedEvent Subgraph *EventSubgraph Error error @@ -37,12 +36,10 @@ type WriteReport struct { Error error } -var ErrMalformedJSON = errors.New("unrecognized event format") -var ErrInvalidEvent = errors.New("invalid event") var ErrDuplicate = errors.New("event already exists") func WriteEvents( - events [][]byte, + events []roots.ValidatedEvent, driver neo4j.Driver, boltdb *bolt.DB, opts *WriteOptions, ) WriteReport { @@ -59,14 +56,15 @@ func WriteEvents( return WriteReport{Error: fmt.Errorf("error setting up bolt db: %w", err)} } - travellers := createEventTravellers(events) - parsed, parseExcluded := parseEventJSON(travellers) - queued, policyExcluded := enforcePolicyRules(parsed, boltdb, opts.BoltReadBatchSize) + travellers := make([]EventTraveller, 0, len(events)) + for _, e := range events { + travellers = append(travellers, EventTraveller{ID: e.ID(), Event: e}) + } + + queued, excluded := enforcePolicyRules(travellers, boltdb, opts.BoltReadBatchSize) converted := convertEventsToSubgraphs(queued, opts.Expanders) writeResult := writeEventsToDatabases(driver, boltdb, converted) - excluded := append(parseExcluded, policyExcluded...) - return WriteReport{ ExcludedEvents: excluded, CreatedEventCount: len(events) - len(excluded), @@ -85,38 +83,6 @@ func setDefaultWriteOptions(opts *WriteOptions) { } } -func createEventTravellers(jsons [][]byte) []EventTraveller { - travellers := make([]EventTraveller, 0, len(jsons)) - for _, j := range jsons { - travellers = append(travellers, EventTraveller{JSON: j}) - } - return travellers -} - -func parseEventJSON(in []EventTraveller) (parsed []EventTraveller, excluded []EventTraveller) { - for _, traveller := range in { - var raw roots.Event - err := json.Unmarshal(traveller.JSON, &raw) - if err != nil { - traveller.Error = fmt.Errorf("rejected: %w: %w", ErrMalformedJSON, err) - excluded = append(excluded, traveller) - continue - } - - validated, err := roots.NewValidatedEvent(raw) - if err != nil { - traveller.Error = fmt.Errorf("rejected: %w: %w", ErrInvalidEvent, err) - excluded = append(excluded, traveller) - continue - } - - traveller.ID = validated.ID() - traveller.Event = validated - parsed = append(parsed, traveller) - } - return parsed, excluded -} - func enforcePolicyRules(in []EventTraveller, boltdb *bolt.DB, batchSize int) (queued []EventTraveller, excluded []EventTraveller) { for i := 0; i < len(in); i += batchSize { end := i + batchSize @@ -169,8 +135,11 @@ func writeEventsToDatabases(driver neo4j.Driver, boltdb *bolt.DB, travellers []E func writeEventsToBoltDB(boltdb *bolt.DB, travellers []EventTraveller) error { var events []EventBlob for _, traveller := range travellers { - events = append(events, - EventBlob{ID: []byte(traveller.ID), JSON: traveller.JSON}) + j, err := json.Marshal(traveller.Event) + if err != nil { + return fmt.Errorf("failed to serialize event %s: %w", traveller.ID, err) + } + events = append(events, EventBlob{ID: []byte(traveller.ID), JSON: j}) } return BatchWriteEvents(boltdb, events) } @@ -190,5 +159,3 @@ func writeEventsToGraphDB(driver neo4j.Driver, travellers []EventTraveller) ([]n return MergeSubgraph(context.Background(), driver, batch) } - - diff --git a/write_test.go b/write_test.go index f4c568d..2685aca 100644 --- a/write_test.go +++ b/write_test.go @@ -6,139 +6,8 @@ import ( "testing" ) -// Test helpers - -func validEventJSON() []byte { - return []byte(`{"id":"c7a702e6158744ca03508bbb4c90f9dbb0d6e88fefbfaa511d5ab24b4e3c48ad","pubkey":"cfa87f35acbde29ba1ab3ee42de527b2cad33ac487e80cf2d6405ea0042c8fef","created_at":1760740551,"kind":1,"tags":[],"content":"hello world","sig":"83b71e15649c9e9da362c175f988c36404cabf357a976d869102a74451cfb8af486f6088b5631033b4927bd46cad7a0d90d7f624aefc0ac260364aa65c36071a"}`) -} - -func invalidEventJSON() []byte { - return []byte(`{"id":"abc123","pubkey":"xyz789","created_at":1000,"kind":1,"content":"test","tags":[],"sig":"abc"}`) -} - -func malformedEventJSON() []byte { - return []byte(`{malformed json`) -} - // Pipeline stage tests -func TestCreateEventTravellers(t *testing.T) { - cases := []struct { - name string - input [][]byte - expected []EventTraveller - }{ - { - name: "empty input", - input: [][]byte{}, - expected: []EventTraveller{}, - }, - { - name: "single json", - input: [][]byte{[]byte("test1")}, - expected: []EventTraveller{ - {JSON: []byte("test1")}, - }, - }, - { - name: "multiple jsons", - input: [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")}, - expected: []EventTraveller{ - {JSON: []byte("test1")}, - {JSON: []byte("test2")}, - {JSON: []byte("test3")}, - }, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - result := createEventTravellers(tc.input) - - assert.Equal(t, len(tc.expected), len(result)) - for i := range tc.expected { - assert.Equal(t, tc.expected[i].JSON, result[i].JSON) - } - }) - } -} - -func TestParseEventJSON(t *testing.T) { - cases := []struct { - name string - input []EventTraveller - wantParsed int - wantRejected int - checkParsedID bool - expectedID string - wantErrorText string - }{ - { - name: "valid event", - input: []EventTraveller{ - {JSON: validEventJSON()}, - }, - wantParsed: 1, - wantRejected: 0, - checkParsedID: true, - expectedID: "c7a702e6158744ca03508bbb4c90f9dbb0d6e88fefbfaa511d5ab24b4e3c48ad", - }, - { - name: "invalid event", - input: []EventTraveller{ - {JSON: invalidEventJSON()}, - }, - wantParsed: 0, - wantRejected: 1, - wantErrorText: "rejected: invalid event", - }, - { - name: "malformed json", - input: []EventTraveller{ - {JSON: malformedEventJSON()}, - }, - wantParsed: 0, - wantRejected: 1, - wantErrorText: "rejected: unrecognized event format", - }, - { - name: "mixed batch", - input: []EventTraveller{ - {JSON: invalidEventJSON()}, - {JSON: malformedEventJSON()}, - {JSON: validEventJSON()}, - }, - wantParsed: 1, - wantRejected: 2, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - parsed, rejected := parseEventJSON(tc.input) - - assert.Equal(t, tc.wantParsed, len(parsed)) - assert.Equal(t, tc.wantRejected, len(rejected)) - - // Smoke test first parsed id - if tc.checkParsedID && len(parsed) > 0 { - assert.Equal(t, tc.expectedID, parsed[0].ID) - assert.NotEmpty(t, parsed[0].Event.ID) - } - - // Check error text on first rejected event - if tc.wantErrorText != "" { - assert.ErrorContains(t, rejected[0].Error, tc.wantErrorText) - } - - for _, reject := range rejected { - assert.NotNil(t, reject.Error) - assert.Empty(t, reject.Event.ID()) - } - }) - } -} - func TestEnforcePolicyRules(t *testing.T) { db := tempDB(t) require.NoError(t, SetupBoltDB(db))