From 269e88fe49a65756d38305b2c1ac8517a9c1991c Mon Sep 17 00:00:00 2001 From: Jay Date: Thu, 5 Mar 2026 00:28:40 -0500 Subject: [PATCH] Variety of refactors and optimizations. --- batchMerge.go | 46 +++++------ batchMerge_test.go | 9 +-- boltdb.go | 40 ++-------- cypher.go | 3 +- expanders.go | 82 ------------------- filters/filters.go | 4 +- graph.go | 13 ++- graph_test.go | 2 +- neo4j.go | 56 ++++++++----- schema.go | 44 +---------- schema_test.go | 4 +- set.go | 52 +++++++----- subgraph.go | 87 ++++++++++++++++++-- subgraph_test.go | 2 +- write.go | 192 ++++++++++++++++++--------------------------- 15 files changed, 268 insertions(+), 368 deletions(-) delete mode 100644 expanders.go diff --git a/batchMerge.go b/batchMerge.go index 816318f..2870d74 100644 --- a/batchMerge.go +++ b/batchMerge.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "github.com/neo4j/neo4j-go-driver/v6/neo4j" - "sort" "strings" ) @@ -45,15 +44,11 @@ func (s *BatchSubgraph) AddNode(node *Node) error { // Verify that the node has defined match property values. matchLabel, _, err := node.MatchProps(s.matchProvider) if err != nil { - return fmt.Errorf("invalid node: %s", err) + return fmt.Errorf("invalid node: %w", err) } // Determine the node's batch key. - batchKey := createNodeBatchKey(matchLabel, node.Labels.ToArray()) - - if _, exists := s.nodes[batchKey]; !exists { - s.nodes[batchKey] = []*Node{} - } + batchKey := createNodeBatchKey(matchLabel, node.Labels.AsSortedArray()) // Add the node to the sub s.nodes[batchKey] = append(s.nodes[batchKey], node) @@ -66,22 +61,18 @@ func (s *BatchSubgraph) AddRel(rel *Relationship) error { // Verify that the start node has defined match property values. startLabel, _, err := rel.Start.MatchProps(s.matchProvider) if err != nil { - return fmt.Errorf("invalid start node: %s", err) + return fmt.Errorf("invalid start node: %w", err) } // Verify that the end node has defined match property values. endLabel, _, err := rel.End.MatchProps(s.matchProvider) if err != nil { - return fmt.Errorf("invalid end node: %s", err) + return fmt.Errorf("invalid end node: %w", err) } // Determine the relationship's batch key. batchKey := createRelBatchKey(rel.Type, startLabel, endLabel) - if _, exists := s.rels[batchKey]; !exists { - s.rels[batchKey] = []*Relationship{} - } - // Add the relationship to the sub s.rels[batchKey] = append(s.rels[batchKey], rel) @@ -105,7 +96,7 @@ func (s *BatchSubgraph) RelCount() int { } func (s *BatchSubgraph) nodeKeys() []string { - keys := []string{} + keys := make([]string, 0, len(s.nodes)) for l := range s.nodes { keys = append(keys, l) } @@ -113,7 +104,7 @@ func (s *BatchSubgraph) nodeKeys() []string { } func (s *BatchSubgraph) relKeys() []string { - keys := []string{} + keys := make([]string, 0, len(s.rels)) for t := range s.rels { keys = append(keys, t) } @@ -121,7 +112,7 @@ func (s *BatchSubgraph) relKeys() []string { } func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) { - batches := []NodeBatch{} + batches := make([]NodeBatch, 0, len(s.nodeKeys())) for _, nodeKey := range s.nodeKeys() { matchLabel, labels, err := deserializeNodeBatchKey(nodeKey) @@ -146,7 +137,7 @@ func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) { } func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) { - batches := []RelBatch{} + batches := make([]RelBatch, 0, len(s.relKeys())) for _, relKey := range s.relKeys() { rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey) @@ -181,9 +172,8 @@ func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) { // Helpers -func createNodeBatchKey(matchLabel string, labels []string) string { - sort.Strings(labels) - serializedLabels := strings.Join(labels, ",") +func createNodeBatchKey(matchLabel string, sortedLabels []string) string { + serializedLabels := strings.Join(sortedLabels, ",") return fmt.Sprintf("%s:%s", matchLabel, serializedLabels) } @@ -245,7 +235,7 @@ func MergeSubgraph( ) } if nodeResultSummary != nil { - resultSummaries = append(resultSummaries, *nodeResultSummary) + resultSummaries = append(resultSummaries, nodeResultSummary) } } @@ -261,7 +251,7 @@ func MergeSubgraph( ) } if relResultSummary != nil { - resultSummaries = append(resultSummaries, *relResultSummary) + resultSummaries = append(resultSummaries, relResultSummary) } } @@ -284,11 +274,11 @@ func MergeNodes( ctx context.Context, tx neo4j.ManagedTransaction, batch NodeBatch, -) (*neo4j.ResultSummary, error) { +) (neo4j.ResultSummary, error) { cypherLabels := ToCypherLabels(batch.Labels) cypherProps := ToCypherProps(batch.MatchKeys, "node.") - serializedNodes := []*SerializedNode{} + serializedNodes := make([]*SerializedNode, 0, len(batch.Nodes)) for _, node := range batch.Nodes { serializedNodes = append(serializedNodes, node.Serialize()) } @@ -316,21 +306,21 @@ func MergeNodes( return nil, err } - return &summary, nil + return summary, nil } func MergeRels( ctx context.Context, tx neo4j.ManagedTransaction, batch RelBatch, -) (*neo4j.ResultSummary, error) { +) (neo4j.ResultSummary, error) { cypherType := ToCypherLabel(batch.Type) startCypherLabel := ToCypherLabel(batch.StartLabel) endCypherLabel := ToCypherLabel(batch.EndLabel) startCypherProps := ToCypherProps(batch.StartMatchKeys, "rel.start.") endCypherProps := ToCypherProps(batch.EndMatchKeys, "rel.end.") - serializedRels := []*SerializedRel{} + serializedRels := make([]*SerializedRel, 0, len(batch.Rels)) for _, rel := range batch.Rels { serializedRels = append(serializedRels, rel.Serialize()) } @@ -363,5 +353,5 @@ func MergeRels( return nil, err } - return &summary, nil + return summary, nil } diff --git a/batchMerge_test.go b/batchMerge_test.go index 489d3ef..165dac7 100644 --- a/batchMerge_test.go +++ b/batchMerge_test.go @@ -7,20 +7,19 @@ import ( func TestNodeBatchKey(t *testing.T) { matchLabel := "Event" - labels := []string{"Event", "AddressableEvent"} - - // labels should be batched by key generator + sortedLabels := []string{"AddressableEvent", "Event"} expectedKey := "Event:AddressableEvent,Event" // Test Serialization - batchKey := createNodeBatchKey(matchLabel, labels) + // labels are expected to be pre-sorted + batchKey := createNodeBatchKey(matchLabel, sortedLabels) assert.Equal(t, expectedKey, batchKey) // Test Deserialization returnedMatchLabel, returnedLabels, err := deserializeNodeBatchKey(batchKey) assert.NoError(t, err) assert.Equal(t, matchLabel, returnedMatchLabel) - assert.ElementsMatch(t, labels, returnedLabels) + assert.ElementsMatch(t, sortedLabels, returnedLabels) } func TestRelBatchKey(t *testing.T) { diff --git a/boltdb.go b/boltdb.go index 9b29f29..807312d 100644 --- a/boltdb.go +++ b/boltdb.go @@ -4,32 +4,11 @@ import ( "github.com/boltdb/bolt" ) -// Interface +const BucketName string = "events" -type BoltDB interface { - Setup() error - BatchCheckEventsExist(eventIDs []string) map[string]bool - BatchWriteEvents(events []EventBlob) error -} - -func NewKVDB(boltdb *bolt.DB) BoltDB { - return &boltDB{db: boltdb} -} - -type boltDB struct { - db *bolt.DB -} - -func (b *boltDB) Setup() error { - return SetupBoltDB(b.db) -} - -func (b *boltDB) BatchCheckEventsExist(eventIDs []string) map[string]bool { - return BatchCheckEventsExist(b.db, eventIDs) -} - -func (b *boltDB) BatchWriteEvents(events []EventBlob) error { - return BatchWriteEvents(b.db, events) +type EventBlob struct { + ID string + JSON string } func SetupBoltDB(boltdb *bolt.DB) error { @@ -39,19 +18,10 @@ func SetupBoltDB(boltdb *bolt.DB) error { }) } -// Functions - -const BucketName string = "events" - -type EventBlob struct { - ID string - JSON string -} - func BatchCheckEventsExist(boltdb *bolt.DB, eventIDs []string) map[string]bool { existsMap := make(map[string]bool) - boltdb.View(func(tx *bolt.Tx) error { + _ = boltdb.View(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(BucketName)) if bucket == nil { return nil diff --git a/cypher.go b/cypher.go index 271db62..fcdca76 100644 --- a/cypher.go +++ b/cypher.go @@ -30,7 +30,8 @@ func ToCypherProps(keys []string, prefix string) string { if prefix == "" { prefix = "$" } - cypherPropsParts := []string{} + + var cypherPropsParts []string for _, key := range keys { cypherPropsParts = append( cypherPropsParts, fmt.Sprintf("%s: %s%s", key, prefix, key)) diff --git a/expanders.go b/expanders.go deleted file mode 100644 index 66a691a..0000000 --- a/expanders.go +++ /dev/null @@ -1,82 +0,0 @@ -package heartwood - -import ( - roots "git.wisehodl.dev/jay/go-roots/events" -) - -type Expander func(e roots.Event, s *EventSubgraph) -type ExpanderPipeline []Expander - -func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline { - return ExpanderPipeline(expanders) -} - -func DefaultExpanders() []Expander { - return []Expander{ - ExpandTaggedEvents, - ExpandTaggedUsers, - } -} - -// Default Expander Functions - -func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) { - tagNodes := s.NodesByLabel("Tag") - for _, tag := range e.Tags { - if !isValidTag(tag) { - continue - } - name := tag[0] - value := tag[1] - - if name != "e" || !roots.Hex64Pattern.MatchString(value) { - continue - } - - tagNode := findTagNode(tagNodes, name, value) - if tagNode == nil { - continue - } - - referencedEvent := NewEventNode(value) - - s.AddNode(referencedEvent) - s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil)) - } -} - -func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) { - tagNodes := s.NodesByLabel("Tag") - for _, tag := range e.Tags { - if !isValidTag(tag) { - continue - } - name := tag[0] - value := tag[1] - - if name != "p" || !roots.Hex64Pattern.MatchString(value) { - continue - } - - tagNode := findTagNode(tagNodes, name, value) - if tagNode == nil { - continue - } - - referencedEvent := NewUserNode(value) - - s.AddNode(referencedEvent) - s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil)) - } -} - -// Helpers - -func findTagNode(nodes []*Node, name, value string) *Node { - for _, node := range nodes { - if node.Props["name"] == name && node.Props["value"] == value { - return node - } - } - return nil -} diff --git a/filters/filters.go b/filters/filters.go index 0083c90..ad731b3 100644 --- a/filters/filters.go +++ b/filters/filters.go @@ -252,7 +252,7 @@ func UnmarshalGraphJSON(data []byte, f *GraphFilter) error { } func marshalGraphArray(filters []GraphFilter) ([]json.RawMessage, error) { - result := []json.RawMessage{} + result := make([]json.RawMessage, 0, len(filters)) for _, f := range filters { b, err := MarshalGraphJSON(f) if err != nil { @@ -268,7 +268,7 @@ func unmarshalGraphArray(raws json.RawMessage) ([]GraphFilter, error) { if err := json.Unmarshal(raws, &rawArray); err != nil { return nil, err } - var result []GraphFilter + result := make([]GraphFilter, 0, len(rawArray)) for _, raw := range rawArray { var f GraphFilter if err := UnmarshalGraphJSON(raw, &f); err != nil { diff --git a/graph.go b/graph.go index d167d9e..bac7275 100644 --- a/graph.go +++ b/graph.go @@ -2,7 +2,6 @@ package heartwood import ( "fmt" - "sort" ) // ======================================== @@ -33,7 +32,7 @@ type SimpleMatchKeys struct { } func (p *SimpleMatchKeys) GetLabels() []string { - labels := []string{} + labels := make([]string, 0, len(p.Keys)) for l := range p.Keys { labels = append(labels, l) } @@ -43,9 +42,8 @@ func (p *SimpleMatchKeys) GetLabels() []string { func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) { if keys, exists := p.Keys[label]; exists { return keys, exists - } else { - return nil, exists } + return nil, false } // ======================================== @@ -56,7 +54,7 @@ func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) { // properties. type Node struct { // Set of labels on the node. - Labels Set[string] + Labels *StringSet // Mapping of properties on the node. Props Properties } @@ -67,7 +65,7 @@ func NewNode(label string, props Properties) *Node { props = make(Properties) } return &Node{ - Labels: NewSet(label), + Labels: NewStringSet(label), Props: props, } } @@ -79,8 +77,7 @@ func (n *Node) MatchProps( // Iterate over each label on the node, checking whether each has match // keys associated with it. - labels := n.Labels.ToArray() - sort.Strings(labels) + labels := n.Labels.AsSortedArray() for _, label := range labels { if keys, exists := matchProvider.GetKeys(label); exists { props := make(Properties) diff --git a/graph_test.go b/graph_test.go index 92409c6..1c68e7a 100644 --- a/graph_test.go +++ b/graph_test.go @@ -71,7 +71,7 @@ func TestMatchProps(t *testing.T) { { name: "multiple labels, one matches", node: &Node{ - Labels: NewSet("Event", "Unknown"), + Labels: NewStringSet("Event", "Unknown"), Props: Properties{ "id": "abc123", }, diff --git a/neo4j.go b/neo4j.go index be19d42..88619d9 100644 --- a/neo4j.go +++ b/neo4j.go @@ -5,26 +5,6 @@ import ( "github.com/neo4j/neo4j-go-driver/v6/neo4j" ) -// Interface - -type GraphDB interface { - MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error) -} - -func NewGraphDriver(driver neo4j.Driver) GraphDB { - return &graphdb{driver: driver} -} - -type graphdb struct { - driver neo4j.Driver -} - -func (n *graphdb) MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error) { - return MergeSubgraph(ctx, n.driver, subgraph) -} - -// Functions - func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver, error) { driver, err := neo4j.NewDriver( uri, @@ -40,3 +20,39 @@ func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver return driver, nil } + +// SetNeo4jSchema ensures that the necessary indexes and constraints exist in +// the database +func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error { + schemaQueries := []string{ + `CREATE CONSTRAINT user_pubkey IF NOT EXISTS + FOR (n:User) REQUIRE n.pubkey IS UNIQUE`, + + `CREATE INDEX user_pubkey IF NOT EXISTS + FOR (n:User) ON (n.pubkey)`, + + `CREATE INDEX event_id IF NOT EXISTS + FOR (n:Event) ON (n.id)`, + + `CREATE INDEX event_kind IF NOT EXISTS + FOR (n:Event) ON (n.kind)`, + + `CREATE INDEX tag_name_value IF NOT EXISTS + FOR (n:Tag) ON (n.name, n.value)`, + } + + // Create indexes and constraints + for _, query := range schemaQueries { + _, err := neo4j.ExecuteQuery(ctx, driver, + query, + nil, + neo4j.EagerResultTransformer, + neo4j.ExecuteQueryWithDatabase("neo4j")) + + if err != nil { + return err + } + } + + return nil +} diff --git a/schema.go b/schema.go index 39017b3..54d6845 100644 --- a/schema.go +++ b/schema.go @@ -1,9 +1,7 @@ package heartwood import ( - "context" "fmt" - "github.com/neo4j/neo4j-go-driver/v6/neo4j" ) // ======================================== @@ -80,7 +78,7 @@ func validateNodeLabel(node *Node, role string, expectedLabel string) { if !node.Labels.Contains(expectedLabel) { panic(fmt.Errorf( "expected %s node to have label %q. got %v", - role, expectedLabel, node.Labels.ToArray(), + role, expectedLabel, node.Labels.AsSortedArray(), )) } } @@ -98,43 +96,3 @@ func NewRelationshipWithValidation( return NewRelationship(rtype, start, end, props) } - -// ======================================== -// Schema Indexes and Constraints -// ======================================== - -// SetNeo4jSchema ensures that the necessary indexes and constraints exist in -// the database -func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error { - schemaQueries := []string{ - `CREATE CONSTRAINT user_pubkey IF NOT EXISTS - FOR (n:User) REQUIRE n.pubkey IS UNIQUE`, - - `CREATE INDEX user_pubkey IF NOT EXISTS - FOR (n:User) ON (n.pubkey)`, - - `CREATE INDEX event_id IF NOT EXISTS - FOR (n:Event) ON (n.id)`, - - `CREATE INDEX event_kind IF NOT EXISTS - FOR (n:Event) ON (n.kind)`, - - `CREATE INDEX tag_name_value IF NOT EXISTS - FOR (n:Tag) ON (n.name, n.value)`, - } - - // Create indexes and constraints - for _, query := range schemaQueries { - _, err := neo4j.ExecuteQuery(ctx, driver, - query, - nil, - neo4j.EagerResultTransformer, - neo4j.ExecuteQueryWithDatabase("neo4j")) - - if err != nil { - return err - } - } - - return nil -} diff --git a/schema_test.go b/schema_test.go index 0f19065..f5f4f61 100644 --- a/schema_test.go +++ b/schema_test.go @@ -37,8 +37,8 @@ func TestNewRelationshipWithValidation(t *testing.T) { } rel := NewSignedRel(tc.start, tc.end, nil) assert.Equal(t, "SIGNED", rel.Type) - assert.Contains(t, rel.Start.Labels.ToArray(), "User") - assert.Contains(t, rel.End.Labels.ToArray(), "Event") + assert.Contains(t, rel.Start.Labels.AsSortedArray(), "User") + assert.Contains(t, rel.End.Labels.AsSortedArray(), "Event") }) } } diff --git a/set.go b/set.go index 72b0814..b3851cf 100644 --- a/set.go +++ b/set.go @@ -1,14 +1,20 @@ package heartwood +import ( + "sort" +) + // Sets -type Set[T comparable] struct { - inner map[T]struct{} +type StringSet struct { + inner map[string]struct{} + sorted []string } -func NewSet[T comparable](items ...T) Set[T] { - set := Set[T]{ - inner: make(map[T]struct{}), +func NewStringSet(items ...string) *StringSet { + set := &StringSet{ + inner: make(map[string]struct{}), + sorted: []string{}, } for _, i := range items { set.Add(i) @@ -16,20 +22,26 @@ func NewSet[T comparable](items ...T) Set[T] { return set } -func (s Set[T]) Add(item T) { - s.inner[item] = struct{}{} +func (s *StringSet) Add(item string) { + if _, exists := s.inner[item]; !exists { + s.inner[item] = struct{}{} + s.rebuildSorted() + } } -func (s Set[T]) Remove(item T) { - delete(s.inner, item) +func (s *StringSet) Remove(item string) { + if _, exists := s.inner[item]; exists { + delete(s.inner, item) + s.rebuildSorted() + } } -func (s Set[T]) Contains(item T) bool { +func (s *StringSet) Contains(item string) bool { _, exists := s.inner[item] return exists } -func (s Set[T]) Equal(other Set[T]) bool { +func (s *StringSet) Equal(other StringSet) bool { if len(s.inner) != len(other.inner) { return false } @@ -41,14 +53,18 @@ func (s Set[T]) Equal(other Set[T]) bool { return true } -func (s Set[T]) Length() int { +func (s *StringSet) Length() int { return len(s.inner) } -func (s Set[T]) ToArray() []T { - array := []T{} - for i := range s.inner { - array = append(array, i) - } - return array +func (s *StringSet) AsSortedArray() []string { + return s.sorted +} + +func (s *StringSet) rebuildSorted() { + s.sorted = make([]string, 0, len(s.inner)) + for item := range s.inner { + s.sorted = append(s.sorted, item) + } + sort.Strings(s.sorted) } diff --git a/subgraph.go b/subgraph.go index 177ace5..4f0577f 100644 --- a/subgraph.go +++ b/subgraph.go @@ -4,7 +4,7 @@ import ( roots "git.wisehodl.dev/jay/go-roots/events" ) -// Event subgraph struct +// Types type EventSubgraph struct { nodes []*Node @@ -35,7 +35,7 @@ func (s *EventSubgraph) Rels() []*Relationship { } func (s *EventSubgraph) NodesByLabel(label string) []*Node { - nodes := []*Node{} + var nodes []*Node for _, node := range s.nodes { if node.Labels.Contains(label) { nodes = append(nodes, node) @@ -58,7 +58,7 @@ func isValidTag(t roots.Tag) bool { return true } -// Event to subgraph conversion +// Event to subgraph pipeline func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph { s := NewEventSubgraph() @@ -89,6 +89,8 @@ func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph { return s } +// Core pipeline functions + func newEventNode(eventID string, createdAt int, kind int, content string) *Node { eventNode := NewEventNode(eventID) eventNode.Props["created_at"] = createdAt @@ -106,7 +108,7 @@ func newSignedRel(user, event *Node) *Relationship { } func newTagNodes(tags []roots.Tag) []*Node { - nodes := []*Node{} + nodes := make([]*Node, 0, len(tags)) for _, tag := range tags { if !isValidTag(tag) { continue @@ -117,9 +119,84 @@ func newTagNodes(tags []roots.Tag) []*Node { } func newTagRels(event *Node, tags []*Node) []*Relationship { - rels := []*Relationship{} + rels := make([]*Relationship, 0, len(tags)) for _, tag := range tags { rels = append(rels, NewTaggedRel(event, tag, nil)) } return rels } + +// Expander Pipeline + +type Expander func(e roots.Event, s *EventSubgraph) +type ExpanderPipeline []Expander + +func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline { + return ExpanderPipeline(expanders) +} + +func DefaultExpanders() []Expander { + return []Expander{ + ExpandTaggedEvents, + ExpandTaggedUsers, + } +} + +func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) { + tagNodes := s.NodesByLabel("Tag") + for _, tag := range e.Tags { + if !isValidTag(tag) { + continue + } + name := tag[0] + value := tag[1] + + if name != "e" || !roots.Hex64Pattern.MatchString(value) { + continue + } + + tagNode := findTagNode(tagNodes, name, value) + if tagNode == nil { + continue + } + + referencedEvent := NewEventNode(value) + + s.AddNode(referencedEvent) + s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil)) + } +} + +func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) { + tagNodes := s.NodesByLabel("Tag") + for _, tag := range e.Tags { + if !isValidTag(tag) { + continue + } + name := tag[0] + value := tag[1] + + if name != "p" || !roots.Hex64Pattern.MatchString(value) { + continue + } + + tagNode := findTagNode(tagNodes, name, value) + if tagNode == nil { + continue + } + + referencedEvent := NewUserNode(value) + + s.AddNode(referencedEvent) + s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil)) + } +} + +func findTagNode(nodes []*Node, name, value string) *Node { + for _, node := range nodes { + if node.Props["name"] == name && node.Props["value"] == value { + return node + } + } + return nil +} diff --git a/subgraph_test.go b/subgraph_test.go index 7cf25af..8c12637 100644 --- a/subgraph_test.go +++ b/subgraph_test.go @@ -172,7 +172,7 @@ func nodesEqual(expected, got *Node) error { } // Compare label values - for _, label := range expected.Labels.ToArray() { + for _, label := range expected.Labels.AsSortedArray() { if !got.Labels.Contains(label) { return fmt.Errorf("missing label %q", label) } diff --git a/write.go b/write.go index e8f53bf..ca8e8ac 100644 --- a/write.go +++ b/write.go @@ -5,14 +5,15 @@ import ( "encoding/json" "fmt" roots "git.wisehodl.dev/jay/go-roots/events" + "github.com/boltdb/bolt" "github.com/neo4j/neo4j-go-driver/v6/neo4j" "sync" "time" ) type WriteOptions struct { - Expanders ExpanderPipeline - KVReadBatchSize int + Expanders ExpanderPipeline + BoltReadBatchSize int } type EventFollower struct { @@ -39,7 +40,7 @@ type WriteReport struct { func WriteEvents( events []string, - graphdb GraphDB, boltdb BoltDB, + driver neo4j.Driver, boltdb *bolt.DB, opts *WriteOptions, ) (WriteReport, error) { start := time.Now() @@ -50,7 +51,7 @@ func WriteEvents( setDefaultWriteOptions(opts) - err := boltdb.Setup() + err := SetupBoltDB(boltdb) if err != nil { return WriteReport{}, fmt.Errorf("error setting up bolt db: %w", err) } @@ -58,80 +59,55 @@ func WriteEvents( var wg sync.WaitGroup // Create Event Followers - jsonChan := make(chan string, 10) - eventChan := make(chan EventFollower, 10) + jsonChan := make(chan string) + eventChan := make(chan EventFollower) wg.Add(1) - go func() { - defer wg.Done() - createEventFollowers(jsonChan, eventChan) - }() + go createEventFollowers(&wg, jsonChan, eventChan) // Parse Event JSON - parsedChan := make(chan EventFollower, 10) - invalidChan := make(chan EventFollower, 10) + parsedChan := make(chan EventFollower) + invalidChan := make(chan EventFollower) wg.Add(1) - go func() { - defer wg.Done() - parseEventJSON(eventChan, parsedChan, invalidChan) - }() + go parseEventJSON(&wg, eventChan, parsedChan, invalidChan) // Collect Invalid Events collectedInvalidChan := make(chan []EventFollower) wg.Add(1) - go func() { - defer wg.Done() - collectEvents(invalidChan, collectedInvalidChan) - }() + go collectEvents(&wg, invalidChan, collectedInvalidChan) // Enforce Policy Rules - queuedChan := make(chan EventFollower, 10) - skippedChan := make(chan EventFollower, 10) + queuedChan := make(chan EventFollower) + skippedChan := make(chan EventFollower) wg.Add(1) - go func() { - defer wg.Done() - enforcePolicyRules( - graphdb, boltdb, - opts.KVReadBatchSize, - parsedChan, queuedChan, skippedChan) - }() + go enforcePolicyRules(&wg, driver, boltdb, opts.BoltReadBatchSize, + parsedChan, queuedChan, skippedChan) // Collect Skipped Events collectedSkippedChan := make(chan []EventFollower) wg.Add(1) - go func() { - defer wg.Done() - collectEvents(skippedChan, collectedSkippedChan) - }() + go collectEvents(&wg, skippedChan, collectedSkippedChan) // Convert Events To Subgraphs - convertedChan := make(chan EventFollower, 10) + convertedChan := make(chan EventFollower) wg.Add(1) - go func() { - defer wg.Done() - convertEventsToSubgraphs(opts.Expanders, queuedChan, convertedChan) - }() + go convertEventsToSubgraphs(&wg, opts.Expanders, queuedChan, convertedChan) // Write Events To Databases writeResultChan := make(chan WriteResult) wg.Add(1) - go func() { - defer wg.Done() - writeEventsToDatabases( - graphdb, boltdb, - convertedChan, writeResultChan) - }() + go writeEventsToDatabases(&wg, driver, boltdb, convertedChan, writeResultChan) // Send event jsons into pipeline go func() { - for _, json := range events { - jsonChan <- json + for _, raw := range events { + jsonChan <- raw } close(jsonChan) }() @@ -158,19 +134,21 @@ func setDefaultWriteOptions(opts *WriteOptions) { if opts.Expanders == nil { opts.Expanders = NewExpanderPipeline(DefaultExpanders()...) } - if opts.KVReadBatchSize == 0 { - opts.KVReadBatchSize = 100 + if opts.BoltReadBatchSize == 0 { + opts.BoltReadBatchSize = 100 } } -func createEventFollowers(jsonChan chan string, eventChan chan EventFollower) { +func createEventFollowers(wg *sync.WaitGroup, jsonChan chan string, eventChan chan EventFollower) { + defer wg.Done() for json := range jsonChan { eventChan <- EventFollower{JSON: json} } close(eventChan) } -func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) { +func parseEventJSON(wg *sync.WaitGroup, inChan, parsedChan, invalidChan chan EventFollower) { + defer wg.Done() for follower := range inChan { var event roots.Event jsonBytes := []byte(follower.JSON) @@ -191,11 +169,13 @@ func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) { } func enforcePolicyRules( - graphdb GraphDB, boltdb BoltDB, + wg *sync.WaitGroup, + driver neo4j.Driver, boltdb *bolt.DB, batchSize int, inChan, queuedChan, skippedChan chan EventFollower, ) { - batch := []EventFollower{} + defer wg.Done() + var batch []EventFollower for follower := range inChan { batch = append(batch, follower) @@ -215,17 +195,17 @@ func enforcePolicyRules( } func processPolicyRulesBatch( - boltdb BoltDB, + boltdb *bolt.DB, batch []EventFollower, queuedChan, skippedChan chan EventFollower, ) { - eventIDs := []string{} + eventIDs := make([]string, 0, len(batch)) for _, follower := range batch { eventIDs = append(eventIDs, follower.ID) } - existsMap := boltdb.BatchCheckEventsExist(eventIDs) + existsMap := BatchCheckEventsExist(boltdb, eventIDs) for _, follower := range batch { if existsMap[follower.ID] { @@ -237,9 +217,10 @@ func processPolicyRulesBatch( } func convertEventsToSubgraphs( - expanders ExpanderPipeline, + wg *sync.WaitGroup, expanders ExpanderPipeline, inChan, convertedChan chan EventFollower, ) { + defer wg.Done() for follower := range inChan { subgraph := EventToSubgraph(follower.Event, expanders) follower.Subgraph = subgraph @@ -249,93 +230,66 @@ func convertEventsToSubgraphs( } func writeEventsToDatabases( - graphdb GraphDB, boltdb BoltDB, + wg *sync.WaitGroup, + driver neo4j.Driver, boltdb *bolt.DB, inChan chan EventFollower, resultChan chan WriteResult, ) { - var wg sync.WaitGroup + defer wg.Done() + var localWg sync.WaitGroup - kvEventChan := make(chan EventFollower, 10) - graphEventChan := make(chan EventFollower, 10) + boltEventChan := make(chan EventFollower) + graphEventChan := make(chan EventFollower) - kvWriteDone := make(chan struct{}) - - kvErrorChan := make(chan error) + boltErrorChan := make(chan error) graphResultChan := make(chan WriteResult) - wg.Add(2) - go func() { - defer wg.Done() - writeEventsToKVStore( - boltdb, - kvEventChan, kvWriteDone, kvErrorChan) - }() - go func() { - defer wg.Done() - writeEventsToGraphDriver( - graphdb, - graphEventChan, kvWriteDone, graphResultChan) - }() + localWg.Add(2) + go writeEventsToBoltDB(&localWg, boltdb, boltEventChan, boltErrorChan) + go writeEventsToGraphDB(&localWg, driver, graphEventChan, boltErrorChan, graphResultChan) // Fan out events to both writers for follower := range inChan { - kvEventChan <- follower + boltEventChan <- follower graphEventChan <- follower } - close(kvEventChan) + close(boltEventChan) close(graphEventChan) - wg.Wait() + localWg.Wait() - kvError := <-kvErrorChan graphResult := <-graphResultChan - - var finalErr error - if kvError != nil && graphResult.Error != nil { - finalErr = fmt.Errorf("kvstore: %w; graphstore: %v", kvError, graphResult.Error) - } else if kvError != nil { - finalErr = fmt.Errorf("kvstore: %w", kvError) - } else if graphResult.Error != nil { - finalErr = fmt.Errorf("graphstore: %w", graphResult.Error) - } - - resultChan <- WriteResult{ - ResultSummaries: graphResult.ResultSummaries, - Error: finalErr, - } + resultChan <- graphResult } -func writeEventsToKVStore( - boltdb BoltDB, +func writeEventsToBoltDB( + wg *sync.WaitGroup, + boltdb *bolt.DB, inChan chan EventFollower, - done chan struct{}, - resultChan chan error, + errorChan chan error, ) { - events := []EventBlob{} + defer wg.Done() + var events []EventBlob for follower := range inChan { events = append(events, EventBlob{ID: follower.ID, JSON: follower.JSON}) } - err := boltdb.BatchWriteEvents(events) - if err != nil { - close(done) - } else { - done <- struct{}{} - close(done) - } + err := BatchWriteEvents(boltdb, events) - resultChan <- err - close(resultChan) + errorChan <- err + close(errorChan) } -func writeEventsToGraphDriver( - graphdb GraphDB, +func writeEventsToGraphDB( + wg *sync.WaitGroup, + driver neo4j.Driver, inChan chan EventFollower, - start chan struct{}, + boltErrorChan chan error, resultChan chan WriteResult, ) { + defer wg.Done() matchKeys := NewSimpleMatchKeys() batch := NewBatchSubgraph(matchKeys) @@ -348,14 +302,17 @@ func writeEventsToGraphDriver( } } - _, ok := <-start - if !ok { - resultChan <- WriteResult{Error: fmt.Errorf("kv write failed, aborting graph write")} + boltErr := <-boltErrorChan + if boltErr != nil { + resultChan <- WriteResult{ + Error: fmt.Errorf( + "boltdb write failed, aborting graph write: %w", boltErr, + )} close(resultChan) return } - summaries, err := graphdb.MergeSubgraph(context.Background(), batch) + summaries, err := MergeSubgraph(context.Background(), driver, batch) resultChan <- WriteResult{ ResultSummaries: summaries, Error: err, @@ -363,8 +320,9 @@ func writeEventsToGraphDriver( close(resultChan) } -func collectEvents(inChan chan EventFollower, resultChan chan []EventFollower) { - collected := []EventFollower{} +func collectEvents(wg *sync.WaitGroup, inChan chan EventFollower, resultChan chan []EventFollower) { + defer wg.Done() + var collected []EventFollower for follower := range inChan { collected = append(collected, follower) }