From 04e6c4fc57d68f3d5dc197b39583ca2ca379924a Mon Sep 17 00:00:00 2001 From: Jay Date: Fri, 22 May 2026 16:31:52 -0400 Subject: [PATCH] Phase 4: serialize outer pipeline; remove channels and goroutines from WriteEvents; add TestEnforcePolicyRules --- write.go | 204 +++++++++++------------------------------------ write_test.go | 214 ++++++++++++++++++++------------------------------ 2 files changed, 129 insertions(+), 289 deletions(-) diff --git a/write.go b/write.go index 6d275ee..a4a23aa 100644 --- a/write.go +++ b/write.go @@ -8,7 +8,6 @@ import ( roots "git.wisehodl.dev/jay/go-roots/events" "github.com/boltdb/bolt" "github.com/neo4j/neo4j-go-driver/v6/neo4j" - "sync" "time" ) @@ -60,69 +59,11 @@ func WriteEvents( return WriteReport{Error: fmt.Errorf("error setting up bolt db: %w", err)} } - var wg sync.WaitGroup - - // Create Event Travellers - jsonChan := make(chan []byte) - eventChan := make(chan EventTraveller) - - wg.Add(1) - go createEventTravellers(&wg, jsonChan, eventChan) - - // Parse Event JSON - parsedChan := make(chan EventTraveller) - parseExcludedChan := make(chan EventTraveller) - - wg.Add(1) - go parseEventJSON(&wg, eventChan, parsedChan, parseExcludedChan) - - // Collect Rejected Events - collectedParseExcludedChan := make(chan []EventTraveller, 1) - - wg.Add(1) - go collectTravellers(&wg, parseExcludedChan, collectedParseExcludedChan) - - // Enforce Policy Rules - queuedChan := make(chan EventTraveller) - policyExcludedChan := make(chan EventTraveller) - - wg.Add(1) - go enforcePolicyRules(&wg, driver, boltdb, opts.BoltReadBatchSize, - parsedChan, queuedChan, policyExcludedChan) - - // Collect Skipped Events - collectedPolicyExcludedChan := make(chan []EventTraveller, 1) - - wg.Add(1) - go collectTravellers(&wg, policyExcludedChan, collectedPolicyExcludedChan) - - // Convert Events To Subgraphs - convertedChan := make(chan EventTraveller) - - wg.Add(1) - go convertEventsToSubgraphs(&wg, opts.Expanders, queuedChan, convertedChan) - - // Write Events To Databases - writeResultChan := make(chan WriteResult, 1) - - wg.Add(1) - go writeEventsToDatabases(&wg, driver, boltdb, convertedChan, writeResultChan) - - // Send event jsons into pipeline - go func() { - for _, raw := range events { - jsonChan <- raw - } - close(jsonChan) - }() - - // Wait for pipeline to complete - wg.Wait() - - // Collect results - parseExcluded := <-collectedParseExcludedChan - policyExcluded := <-collectedPolicyExcludedChan - writeResult := <-writeResultChan + travellers := createEventTravellers(events) + parsed, parseExcluded := parseEventJSON(travellers) + queued, policyExcluded := enforcePolicyRules(parsed, boltdb, opts.BoltReadBatchSize) + converted := convertEventsToSubgraphs(queued, opts.Expanders) + writeResult := writeEventsToDatabases(driver, boltdb, converted) excluded := append(parseExcluded, policyExcluded...) @@ -144,129 +85,84 @@ func setDefaultWriteOptions(opts *WriteOptions) { } } -func createEventTravellers(wg *sync.WaitGroup, jsonChan chan []byte, eventChan chan EventTraveller) { - defer wg.Done() - for json := range jsonChan { - eventChan <- EventTraveller{JSON: json} +func createEventTravellers(jsons [][]byte) []EventTraveller { + travellers := make([]EventTraveller, 0, len(jsons)) + for _, j := range jsons { + travellers = append(travellers, EventTraveller{JSON: j}) } - close(eventChan) + return travellers } -func parseEventJSON(wg *sync.WaitGroup, inChan, parsedChan, excludedChan chan EventTraveller) { - defer wg.Done() - for traveller := range inChan { +func parseEventJSON(in []EventTraveller) (parsed []EventTraveller, excluded []EventTraveller) { + for _, traveller := range in { var event roots.Event - jsonBytes := traveller.JSON - err := json.Unmarshal(jsonBytes, &event) + err := json.Unmarshal(traveller.JSON, &event) if err != nil { traveller.Error = fmt.Errorf("rejected: %w: %w", ErrMalformedJSON, err) - excludedChan <- traveller + excluded = append(excluded, traveller) continue } err = roots.Validate(event) if err != nil { traveller.Error = fmt.Errorf("rejected: %w: %w", ErrInvalidEvent, err) - excludedChan <- traveller + excluded = append(excluded, traveller) continue } traveller.ID = event.ID traveller.Event = event - parsedChan <- traveller + parsed = append(parsed, traveller) } - - close(parsedChan) - close(excludedChan) + return parsed, excluded } -func enforcePolicyRules( - wg *sync.WaitGroup, - driver neo4j.Driver, boltdb *bolt.DB, - batchSize int, - inChan, queuedChan, excludedChan chan EventTraveller, -) { - defer wg.Done() - var batch []EventTraveller +func enforcePolicyRules(in []EventTraveller, boltdb *bolt.DB, batchSize int) (queued []EventTraveller, excluded []EventTraveller) { + for i := 0; i < len(in); i += batchSize { + end := i + batchSize + if end > len(in) { + end = len(in) + } + batch := in[i:end] - for traveller := range inChan { - batch = append(batch, traveller) + eventIDs := make([]string, 0, len(batch)) + for _, traveller := range batch { + eventIDs = append(eventIDs, traveller.ID) + } - if len(batch) >= batchSize { - processPolicyRulesBatch(boltdb, batch, queuedChan, excludedChan) - batch = []EventTraveller{} + existsMap := BatchCheckEventsExist(boltdb, eventIDs) + + for _, traveller := range batch { + if existsMap[traveller.ID] { + traveller.Error = fmt.Errorf("skipped: %w", ErrDuplicate) + excluded = append(excluded, traveller) + } else { + queued = append(queued, traveller) + } } } - - if len(batch) > 0 { - processPolicyRulesBatch(boltdb, batch, queuedChan, excludedChan) - } - - close(queuedChan) - close(excludedChan) + return queued, excluded } -func processPolicyRulesBatch( - boltdb *bolt.DB, - batch []EventTraveller, - queuedChan, skippedChan chan EventTraveller, -) { - eventIDs := make([]string, 0, len(batch)) - - for _, traveller := range batch { - eventIDs = append(eventIDs, traveller.ID) - } - - existsMap := BatchCheckEventsExist(boltdb, eventIDs) - - for _, traveller := range batch { - if existsMap[traveller.ID] { - traveller.Error = fmt.Errorf("skipped: %w", ErrDuplicate) - skippedChan <- traveller - } else { - queuedChan <- traveller - } - } -} - -func convertEventsToSubgraphs( - wg *sync.WaitGroup, expanders ExpanderPipeline, - inChan, convertedChan chan EventTraveller, -) { - defer wg.Done() - for traveller := range inChan { +func convertEventsToSubgraphs(in []EventTraveller, expanders ExpanderPipeline) []EventTraveller { + for i, traveller := range in { // TODO: temporary adapter — removed in Phase 5 validated, _ := roots.NewValidatedEvent(traveller.Event) - subgraph := EventToSubgraph(validated, expanders) - traveller.Subgraph = subgraph - convertedChan <- traveller + in[i].Subgraph = EventToSubgraph(validated, expanders) } - close(convertedChan) + return in } -func writeEventsToDatabases( - wg *sync.WaitGroup, - driver neo4j.Driver, boltdb *bolt.DB, - inChan chan EventTraveller, - resultChan chan WriteResult, -) { - defer wg.Done() - - var travellers []EventTraveller - for traveller := range inChan { - travellers = append(travellers, traveller) - } - +func writeEventsToDatabases(driver neo4j.Driver, boltdb *bolt.DB, travellers []EventTraveller) WriteResult { boltErr := writeEventsToBoltDB(boltdb, travellers) if boltErr != nil { - resultChan <- WriteResult{ + return WriteResult{ Error: fmt.Errorf("boltdb write failed, aborting graph write: %w", boltErr), } - return } summaries, err := writeEventsToGraphDB(driver, travellers) - resultChan <- WriteResult{ + return WriteResult{ ResultSummaries: summaries, Error: err, } @@ -297,12 +193,4 @@ func writeEventsToGraphDB(driver neo4j.Driver, travellers []EventTraveller) ([]n return MergeSubgraph(context.Background(), driver, batch) } -func collectTravellers(wg *sync.WaitGroup, inChan chan EventTraveller, resultChan chan []EventTraveller) { - defer wg.Done() - var collected []EventTraveller - for traveller := range inChan { - collected = append(collected, traveller) - } - resultChan <- collected - close(resultChan) -} + diff --git a/write_test.go b/write_test.go index 880c47e..deacd4c 100644 --- a/write_test.go +++ b/write_test.go @@ -1,9 +1,8 @@ package heartwood import ( - roots "git.wisehodl.dev/jay/go-roots/events" "github.com/stretchr/testify/assert" - "sync" + "github.com/stretchr/testify/require" "testing" ) @@ -54,26 +53,7 @@ func TestCreateEventTravellers(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - var wg sync.WaitGroup - jsonChan := make(chan []byte) - eventChan := make(chan EventTraveller) - - wg.Add(1) - go createEventTravellers(&wg, jsonChan, eventChan) - - go func() { - for _, raw := range tc.input { - jsonChan <- raw - } - close(jsonChan) - }() - - var result []EventTraveller - for traveller := range eventChan { - result = append(result, traveller) - } - - wg.Wait() + result := createEventTravellers(tc.input) assert.Equal(t, len(tc.expected), len(result)) for i := range tc.expected { @@ -81,7 +61,6 @@ func TestCreateEventTravellers(t *testing.T) { } }) } - } func TestParseEventJSON(t *testing.T) { @@ -136,43 +115,7 @@ func TestParseEventJSON(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - var wg sync.WaitGroup - inChan := make(chan EventTraveller) - parsedChan := make(chan EventTraveller) - rejectedChan := make(chan EventTraveller) - - wg.Add(1) - go parseEventJSON(&wg, inChan, parsedChan, rejectedChan) - - go func() { - for _, traveller := range tc.input { - inChan <- traveller - } - close(inChan) - }() - - var parsed []EventTraveller - var rejected []EventTraveller - var collectWg sync.WaitGroup - - collectWg.Add(2) - - go func() { - defer collectWg.Done() - for f := range parsedChan { - parsed = append(parsed, f) - } - }() - - go func() { - defer collectWg.Done() - for f := range rejectedChan { - rejected = append(rejected, f) - } - }() - - collectWg.Wait() - wg.Wait() + parsed, rejected := parseEventJSON(tc.input) assert.Equal(t, tc.wantParsed, len(parsed)) assert.Equal(t, tc.wantRejected, len(rejected)) @@ -196,20 +139,92 @@ func TestParseEventJSON(t *testing.T) { } } -// Skip `enforcePolicyRules` -- requires BoltDB +func TestEnforcePolicyRules(t *testing.T) { + db := tempDB(t) + require.NoError(t, SetupBoltDB(db)) + fx := LoadFixtures(t) + + // Pre-write bare and generic_tag as existing events + bareJSON, _ := fx.ValidatedEvent(t, "bare").MarshalJSON() + genericJSON, _ := fx.ValidatedEvent(t, "generic_tag").MarshalJSON() + bareID := fx.ValidatedEvent(t, "bare").ID() + genericID := fx.ValidatedEvent(t, "generic_tag").ID() + + err := BatchWriteEvents(db, []EventBlob{ + {ID: []byte(bareID), JSON: bareJSON}, + {ID: []byte(genericID), JSON: genericJSON}, + }) + assert.NoError(t, err) + + e_tag_id := fx.ValidatedEvent(t, "e_tag_valid").ID() + p_tag_id := fx.ValidatedEvent(t, "p_tag_valid").ID() + + cases := []struct { + name string + input []EventTraveller + wantQueued int + wantExcluded int + }{ + { + name: "empty input", + input: []EventTraveller{}, + wantQueued: 0, + wantExcluded: 0, + }, + { + name: "no duplicates", + input: []EventTraveller{ + {ID: e_tag_id}, + {ID: p_tag_id}, + }, + wantQueued: 2, + wantExcluded: 0, + }, + { + name: "some duplicates", + input: []EventTraveller{ + {ID: bareID}, + {ID: e_tag_id}, + }, + wantQueued: 1, + wantExcluded: 1, + }, + { + name: "all duplicates", + input: []EventTraveller{ + {ID: bareID}, + {ID: genericID}, + }, + wantQueued: 0, + wantExcluded: 2, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + queued, excluded := enforcePolicyRules(tc.input, db, 100) + + assert.Equal(t, tc.wantQueued, len(queued)) + assert.Equal(t, tc.wantExcluded, len(excluded)) + for _, ex := range excluded { + assert.ErrorIs(t, ex.Error, ErrDuplicate) + } + }) + } +} func TestConvertEventsToSubgraphs(t *testing.T) { fx := LoadFixtures(t) cases := []struct { name string - event roots.ValidatedEvent + traveller EventTraveller wantNodeCount int wantRelCount int }{ { name: "event with no tags", - event: fx.ValidatedEvent(t, "bare"), + traveller: EventTraveller{Event: fx.ValidatedEvent(t, "bare").Event()}, wantNodeCount: 2, // event + user wantRelCount: 1, // signed }, @@ -217,78 +232,15 @@ func TestConvertEventsToSubgraphs(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - var wg sync.WaitGroup - inChan := make(chan EventTraveller) - convertedChan := make(chan EventTraveller) - expanders := NewExpanderPipeline(DefaultExpanders()...) + results := convertEventsToSubgraphs([]EventTraveller{tc.traveller}, expanders) - wg.Add(1) - go convertEventsToSubgraphs(&wg, expanders, inChan, convertedChan) - - go func() { - inChan <- EventTraveller{Event: tc.event.Event()} - close(inChan) - }() - - var result EventTraveller - for f := range convertedChan { - result = f - } - - wg.Wait() - - assert.NotNil(t, result.Subgraph) - assert.Equal(t, tc.wantNodeCount, len(result.Subgraph.Nodes())) - assert.Equal(t, tc.wantRelCount, len(result.Subgraph.Rels())) + assert.Len(t, results, 1) + assert.NotNil(t, results[0].Subgraph) + assert.Equal(t, tc.wantNodeCount, len(results[0].Subgraph.Nodes())) + assert.Equal(t, tc.wantRelCount, len(results[0].Subgraph.Rels())) }) } } // Skip `writeEventsToDatabases` tests -- requires BoltDB + Neo4j - -func TestCollectEvents(t *testing.T) { - cases := []struct { - name string - input []EventTraveller - expected int - }{ - { - name: "empty channel", - input: []EventTraveller{}, - expected: 0, - }, - { - name: "multiple travellers", - input: []EventTraveller{ - {ID: "id1"}, - {ID: "id2"}, - {ID: "id3"}, - }, - expected: 3, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - var wg sync.WaitGroup - inChan := make(chan EventTraveller) - resultChan := make(chan []EventTraveller) - - wg.Add(1) - go collectTravellers(&wg, inChan, resultChan) - - go func() { - for _, f := range tc.input { - inChan <- f - } - close(inChan) - }() - - result := <-resultChan - wg.Wait() - - assert.Equal(t, tc.expected, len(result)) - }) - } -}