From e43ed4af5503151aa91f70aa94fc7ee1b33dd7a9 Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 3 Mar 2026 17:24:39 -0500 Subject: [PATCH] Move StructuredSubgraph alongside batch merge function. `graphstore` package owns entire batch merge operation. Added tests for node and rel batching. --- graph/graph.go | 165 ---------------------- graph/graph_test.go | 73 ---------- graphstore/batch.go | 287 +++++++++++++++++++++++++++++++-------- graphstore/batch_test.go | 115 ++++++++++++++++ 4 files changed, 343 insertions(+), 297 deletions(-) create mode 100644 graphstore/batch_test.go diff --git a/graph/graph.go b/graph/graph.go index a150e03..030a19d 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -6,7 +6,6 @@ package graph import ( "fmt" "sort" - "strings" ) // ======================================== @@ -160,167 +159,3 @@ func (r *Relationship) Serialize() *SerializedRel { srel["end"] = r.End.Props return &srel } - -// ======================================== -// Structured Subgraph -// ======================================== - -// StructuredSubgraph is a structured collection of nodes and relationships for -// the purpose of conducting batch operations. -type StructuredSubgraph struct { - // A map of grouped nodes, batched by their label combinations. - nodes map[string][]*Node - // A map of grouped relationships, batched by their type and related node - // labels. - rels map[string][]*Relationship - // Provides node property keys used to match nodes with given labels in the - // database. - matchProvider MatchKeysProvider -} - -// NewStructuredSubgraph creates an empty structured subgraph with the given -// match keys provider. -func NewStructuredSubgraph(matchProvider MatchKeysProvider) *StructuredSubgraph { - return &StructuredSubgraph{ - nodes: make(map[string][]*Node), - rels: make(map[string][]*Relationship), - matchProvider: matchProvider, - } -} - -// AddNode adds a node into the subgraph. -func (s *StructuredSubgraph) AddNode(node *Node) error { - - // Verify that the node has defined match property values. - matchLabel, _, err := node.MatchProps(s.matchProvider) - if err != nil { - return fmt.Errorf("invalid node: %s", err) - } - - // Determine the node's batch key. - batchKey := createNodeBatchKey(matchLabel, node.Labels.ToArray()) - - if _, exists := s.nodes[batchKey]; !exists { - s.nodes[batchKey] = []*Node{} - } - - // Add the node to the subgraph. - s.nodes[batchKey] = append(s.nodes[batchKey], node) - - return nil -} - -// AddRel adds a relationship into the subgraph. -func (s *StructuredSubgraph) AddRel(rel *Relationship) error { - - // Verify that the start node has defined match property values. - startLabel, _, err := rel.Start.MatchProps(s.matchProvider) - if err != nil { - return fmt.Errorf("invalid start node: %s", err) - } - - // Verify that the end node has defined match property values. - endLabel, _, err := rel.End.MatchProps(s.matchProvider) - if err != nil { - return fmt.Errorf("invalid end node: %s", err) - } - - // Determine the relationship's batch key. - batchKey := createRelBatchKey(rel.Type, startLabel, endLabel) - - if _, exists := s.rels[batchKey]; !exists { - s.rels[batchKey] = []*Relationship{} - } - - // Add the relationship to the subgraph. - s.rels[batchKey] = append(s.rels[batchKey], rel) - - return nil -} - -// GetNodes returns the nodes grouped under the given batch key. -func (s *StructuredSubgraph) GetNodes(nodeKey string) []*Node { - return s.nodes[nodeKey] -} - -// GetRels returns the rels grouped under the given batch key. -func (s *StructuredSubgraph) GetRels(relKey string) []*Relationship { - return s.rels[relKey] -} - -func (s *StructuredSubgraph) MatchProvider() MatchKeysProvider { - return s.matchProvider -} - -// NodeCount returns the number of nodes in the subgraph. -func (s *StructuredSubgraph) NodeCount() int { - count := 0 - for l := range s.nodes { - count += len(s.nodes[l]) - } - return count -} - -// RelCount returns the number of relationships in the subgraph. -func (s *StructuredSubgraph) RelCount() int { - count := 0 - for t := range s.rels { - count += len(s.rels[t]) - } - return count -} - -// NodeKeys returns the list of node batch keys in the subgraph. -func (s *StructuredSubgraph) NodeKeys() []string { - keys := []string{} - for l := range s.nodes { - keys = append(keys, l) - } - return keys -} - -// RelKeys returns the list of relationship batch keys in the subgraph. -func (s *StructuredSubgraph) RelKeys() []string { - keys := []string{} - for t := range s.rels { - keys = append(keys, t) - } - return keys -} - -// createNodeBatchKey returns the serialized node labels for batching. -func createNodeBatchKey(matchLabel string, labels []string) string { - sort.Strings(labels) - serializedLabels := strings.Join(labels, ",") - return fmt.Sprintf("%s:%s", matchLabel, serializedLabels) -} - -// createRelBatchKey returns the serialized relationship type and start/end node -// labels for batching. -func createRelBatchKey( - rtype string, startLabel string, endLabel string) string { - return strings.Join([]string{rtype, startLabel, endLabel}, ",") -} - -// DeserializeNodeBatchKey returns the list of node labels from the serialized batch -// key. -func DeserializeNodeBatchKey(batchKey string) (string, []string, error) { - parts := strings.Split(batchKey, ":") - if len(parts) != 2 { - return "", nil, fmt.Errorf("invalid node batch key: %s", batchKey) - } - matchLabel, serializedLabels := parts[0], parts[1] - labels := strings.Split(serializedLabels, ",") - return matchLabel, labels, nil -} - -// DeserializeRelBatchKey returns the relationship type, start node label, and end -// node label from the serialized batch key. Panics if the batch key is invalid. -func DeserializeRelBatchKey(batchKey string) (string, string, string, error) { - parts := strings.Split(batchKey, ",") - if len(parts) != 3 { - return "", "", "", fmt.Errorf("invalid relationship batch key: %s", batchKey) - } - rtype, startLabel, endLabel := parts[0], parts[1], parts[2] - return rtype, startLabel, endLabel, nil -} diff --git a/graph/graph_test.go b/graph/graph_test.go index f72d869..19aac44 100644 --- a/graph/graph_test.go +++ b/graph/graph_test.go @@ -34,40 +34,6 @@ func TestMatchKeys(t *testing.T) { }) } -func TestNodeBatchKey(t *testing.T) { - matchLabel := "Event" - labels := []string{"Event", "AddressableEvent"} - - // labels should be batched by key generator - expectedKey := "Event:AddressableEvent,Event" - - // Test Serialization - batchKey := createNodeBatchKey(matchLabel, labels) - assert.Equal(t, expectedKey, batchKey) - - // Test Deserialization - returnedMatchLabel, returnedLabels, err := DeserializeNodeBatchKey(batchKey) - assert.NoError(t, err) - assert.Equal(t, matchLabel, returnedMatchLabel) - assert.ElementsMatch(t, labels, returnedLabels) -} - -func TestRelBatchKey(t *testing.T) { - rtype, startLabel, endLabel := "SIGNED", "User", "Event" - expectedKey := "SIGNED,User,Event" - - // Test Serialization - batchKey := createRelBatchKey(rtype, startLabel, endLabel) - assert.Equal(t, expectedKey, batchKey) - - // Test Deserialization - returnedRtype, returnedStartLabel, returnedEndLabel, err := DeserializeRelBatchKey(batchKey) - assert.NoError(t, err) - assert.Equal(t, rtype, returnedRtype) - assert.Equal(t, startLabel, returnedStartLabel) - assert.Equal(t, endLabel, returnedEndLabel) -} - func TestMatchProps(t *testing.T) { matchKeys := &SimpleMatchKeys{ Keys: map[string][]string{ @@ -131,42 +97,3 @@ func TestMatchProps(t *testing.T) { }) } } - -func TestStructuredSubgraphAddNode(t *testing.T) { - matchKeys := NewSimpleMatchKeys() - subgraph := NewStructuredSubgraph(matchKeys) - node := NewEventNode("abc123") - - err := subgraph.AddNode(node) - - assert.NoError(t, err) - assert.Equal(t, 1, subgraph.NodeCount()) - assert.Equal(t, []*Node{node}, subgraph.GetNodes("Event:Event")) -} - -func TestStructuredSubgraphAddNodeInvalid(t *testing.T) { - matchKeys := NewSimpleMatchKeys() - subgraph := NewStructuredSubgraph(matchKeys) - node := NewNode("Event", Properties{}) - - err := subgraph.AddNode(node) - - assert.ErrorContains(t, err, "invalid node: missing property id") - assert.Equal(t, 0, subgraph.NodeCount()) -} - -func TestStructuredSubgraphAddRel(t *testing.T) { - matchKeys := NewSimpleMatchKeys() - subgraph := NewStructuredSubgraph(matchKeys) - - userNode := NewUserNode("pubkey1") - eventNode := NewEventNode("abc123") - rel := NewSignedRel(userNode, eventNode, nil) - - err := subgraph.AddRel(rel) - - assert.NoError(t, err) - assert.Equal(t, 1, subgraph.RelCount()) - assert.Equal(t, []*Relationship{rel}, subgraph.GetRels("SIGNED,User,Event")) - -} diff --git a/graphstore/batch.go b/graphstore/batch.go index 04f1e80..b2ea2fa 100644 --- a/graphstore/batch.go +++ b/graphstore/batch.go @@ -6,41 +6,227 @@ import ( "git.wisehodl.dev/jay/go-heartwood/cypher" "git.wisehodl.dev/jay/go-heartwood/graph" "github.com/neo4j/neo4j-go-driver/v6/neo4j" + "sort" + "strings" ) -func MergeSubgraph( - ctx context.Context, - driver neo4j.Driver, - subgraph *graph.StructuredSubgraph, -) ([]neo4j.ResultSummary, error) { - // Validate subgraph - for _, nodeKey := range subgraph.NodeKeys() { - matchLabel, _, err := graph.DeserializeNodeBatchKey(nodeKey) +// Structs + +type NodeBatch struct { + MatchLabel string + Labels []string + MatchKeys []string + Nodes []*graph.Node +} + +type RelBatch struct { + Type string + StartLabel string + StartMatchKeys []string + EndLabel string + EndMatchKeys []string + Rels []*graph.Relationship +} + +type BatchSubgraph struct { + nodes map[string][]*graph.Node + rels map[string][]*graph.Relationship + matchProvider graph.MatchKeysProvider +} + +func NewBatchSubgraph(matchProvider graph.MatchKeysProvider) *BatchSubgraph { + return &BatchSubgraph{ + nodes: make(map[string][]*graph.Node), + rels: make(map[string][]*graph.Relationship), + matchProvider: matchProvider, + } +} + +func (s *BatchSubgraph) AddNode(node *graph.Node) error { + + // Verify that the node has defined match property values. + matchLabel, _, err := node.MatchProps(s.matchProvider) + if err != nil { + return fmt.Errorf("invalid node: %s", err) + } + + // Determine the node's batch key. + batchKey := createNodeBatchKey(matchLabel, node.Labels.ToArray()) + + if _, exists := s.nodes[batchKey]; !exists { + s.nodes[batchKey] = []*graph.Node{} + } + + // Add the node to the subgraph. + s.nodes[batchKey] = append(s.nodes[batchKey], node) + + return nil +} + +func (s *BatchSubgraph) AddRel(rel *graph.Relationship) error { + + // Verify that the start node has defined match property values. + startLabel, _, err := rel.Start.MatchProps(s.matchProvider) + if err != nil { + return fmt.Errorf("invalid start node: %s", err) + } + + // Verify that the end node has defined match property values. + endLabel, _, err := rel.End.MatchProps(s.matchProvider) + if err != nil { + return fmt.Errorf("invalid end node: %s", err) + } + + // Determine the relationship's batch key. + batchKey := createRelBatchKey(rel.Type, startLabel, endLabel) + + if _, exists := s.rels[batchKey]; !exists { + s.rels[batchKey] = []*graph.Relationship{} + } + + // Add the relationship to the subgraph. + s.rels[batchKey] = append(s.rels[batchKey], rel) + + return nil +} + +func (s *BatchSubgraph) NodeCount() int { + count := 0 + for l := range s.nodes { + count += len(s.nodes[l]) + } + return count +} + +func (s *BatchSubgraph) RelCount() int { + count := 0 + for t := range s.rels { + count += len(s.rels[t]) + } + return count +} + +func (s *BatchSubgraph) nodeKeys() []string { + keys := []string{} + for l := range s.nodes { + keys = append(keys, l) + } + return keys +} + +func (s *BatchSubgraph) relKeys() []string { + keys := []string{} + for t := range s.rels { + keys = append(keys, t) + } + return keys +} + +func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) { + batches := []NodeBatch{} + + for _, nodeKey := range s.nodeKeys() { + matchLabel, labels, err := deserializeNodeBatchKey(nodeKey) if err != nil { return nil, err } - - _, exists := subgraph.MatchProvider().GetKeys(matchLabel) + matchKeys, exists := s.matchProvider.GetKeys(matchLabel) if !exists { return nil, fmt.Errorf("unknown match label: %s", matchLabel) } + nodes := s.nodes[nodeKey] + batch := NodeBatch{ + MatchLabel: matchLabel, + Labels: labels, + MatchKeys: matchKeys, + Nodes: nodes, + } + batches = append(batches, batch) } - for _, relKey := range subgraph.RelKeys() { - _, startLabel, endLabel, err := graph.DeserializeRelBatchKey(relKey) + return batches, nil +} + +func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) { + batches := []RelBatch{} + + for _, relKey := range s.relKeys() { + rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey) if err != nil { return nil, err } - _, exists := subgraph.MatchProvider().GetKeys(startLabel) + startMatchKeys, exists := s.matchProvider.GetKeys(startLabel) if !exists { return nil, fmt.Errorf("unknown match label: %s", startLabel) } - _, exists = subgraph.MatchProvider().GetKeys(endLabel) + endMatchKeys, exists := s.matchProvider.GetKeys(endLabel) if !exists { return nil, fmt.Errorf("unknown match label: %s", endLabel) } + + rels := s.rels[relKey] + batch := RelBatch{ + Type: rtype, + StartLabel: startLabel, + StartMatchKeys: startMatchKeys, + EndLabel: endLabel, + EndMatchKeys: endMatchKeys, + Rels: rels, + } + batches = append(batches, batch) + } + + return batches, nil +} + +// Helpers + +func createNodeBatchKey(matchLabel string, labels []string) string { + sort.Strings(labels) + serializedLabels := strings.Join(labels, ",") + return fmt.Sprintf("%s:%s", matchLabel, serializedLabels) +} + +func createRelBatchKey( + rtype string, startLabel string, endLabel string) string { + return strings.Join([]string{rtype, startLabel, endLabel}, ",") +} + +func deserializeNodeBatchKey(batchKey string) (string, []string, error) { + parts := strings.Split(batchKey, ":") + if len(parts) != 2 { + return "", nil, fmt.Errorf("invalid node batch key: %s", batchKey) + } + matchLabel, serializedLabels := parts[0], parts[1] + labels := strings.Split(serializedLabels, ",") + return matchLabel, labels, nil +} + +func deserializeRelBatchKey(batchKey string) (string, string, string, error) { + parts := strings.Split(batchKey, ",") + if len(parts) != 3 { + return "", "", "", fmt.Errorf("invalid relationship batch key: %s", batchKey) + } + rtype, startLabel, endLabel := parts[0], parts[1], parts[2] + return rtype, startLabel, endLabel, nil +} + +// Merge functions + +func MergeSubgraph( + ctx context.Context, + driver neo4j.Driver, + subgraph *BatchSubgraph, +) ([]neo4j.ResultSummary, error) { + nodeBatches, err := subgraph.NodeBatches() + if err != nil { + return nil, err + } + relBatches, err := subgraph.RelBatches() + if err != nil { + return nil, err } // Merge subgraph @@ -50,35 +236,31 @@ func MergeSubgraph( resultSummariesAny, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (any, error) { var resultSummaries []neo4j.ResultSummary - for _, nodeKey := range subgraph.NodeKeys() { - matchLabel, labels, _ := graph.DeserializeNodeBatchKey(nodeKey) - nodeResultSummary, err := MergeNodes( - ctx, tx, - matchLabel, - labels, - subgraph.MatchProvider(), - subgraph.GetNodes(nodeKey), - ) + for _, nodeBatch := range nodeBatches { + nodeResultSummary, err := MergeNodes(ctx, tx, nodeBatch) if err != nil { - return nil, fmt.Errorf("failed to merge nodes for key %s: %w", nodeKey, err) + return nil, fmt.Errorf( + "failed to merge nodes on label %q (labels %v): %w", + nodeBatch.MatchLabel, + nodeBatch.Labels, + err, + ) } if nodeResultSummary != nil { resultSummaries = append(resultSummaries, *nodeResultSummary) } } - for _, relKey := range subgraph.RelKeys() { - rtype, startLabel, endLabel, _ := graph.DeserializeRelBatchKey(relKey) - relResultSummary, err := MergeRels( - ctx, tx, - rtype, - startLabel, - endLabel, - subgraph.MatchProvider(), - subgraph.GetRels(relKey), - ) + for _, relBatch := range relBatches { + relResultSummary, err := MergeRels(ctx, tx, relBatch) if err != nil { - return nil, fmt.Errorf("failed to merge relationships for key %s: %w", relKey, err) + return nil, fmt.Errorf( + "failed to merge relationships on (%s)-[%s]->(%s): %w", + relBatch.StartLabel, + relBatch.Type, + relBatch.EndLabel, + err, + ) } if relResultSummary != nil { resultSummaries = append(resultSummaries, *relResultSummary) @@ -103,18 +285,13 @@ func MergeSubgraph( func MergeNodes( ctx context.Context, tx neo4j.ManagedTransaction, - matchLabel string, - nodeLabels []string, - matchProvider graph.MatchKeysProvider, - nodes []*graph.Node, + batch NodeBatch, ) (*neo4j.ResultSummary, error) { - cypherLabels := cypher.ToCypherLabels(nodeLabels) - - matchKeys, _ := matchProvider.GetKeys(matchLabel) - cypherProps := cypher.ToCypherProps(matchKeys, "node.") + cypherLabels := cypher.ToCypherLabels(batch.Labels) + cypherProps := cypher.ToCypherProps(batch.MatchKeys, "node.") serializedNodes := []*graph.SerializedNode{} - for _, node := range nodes { + for _, node := range batch.Nodes { serializedNodes = append(serializedNodes, node.Serialize()) } @@ -147,24 +324,16 @@ func MergeNodes( func MergeRels( ctx context.Context, tx neo4j.ManagedTransaction, - rtype string, - startLabel string, - endLabel string, - matchProvider graph.MatchKeysProvider, - rels []*graph.Relationship, + batch RelBatch, ) (*neo4j.ResultSummary, error) { - cypherType := cypher.ToCypherLabel(rtype) - startCypherLabel := cypher.ToCypherLabel(startLabel) - endCypherLabel := cypher.ToCypherLabel(endLabel) - - matchKeys, _ := matchProvider.GetKeys(startLabel) - startCypherProps := cypher.ToCypherProps(matchKeys, "rel.start.") - - matchKeys, _ = matchProvider.GetKeys(endLabel) - endCypherProps := cypher.ToCypherProps(matchKeys, "rel.end.") + cypherType := cypher.ToCypherLabel(batch.Type) + startCypherLabel := cypher.ToCypherLabel(batch.StartLabel) + endCypherLabel := cypher.ToCypherLabel(batch.EndLabel) + startCypherProps := cypher.ToCypherProps(batch.StartMatchKeys, "rel.start.") + endCypherProps := cypher.ToCypherProps(batch.EndMatchKeys, "rel.end.") serializedRels := []*graph.SerializedRel{} - for _, rel := range rels { + for _, rel := range batch.Rels { serializedRels = append(serializedRels, rel.Serialize()) } diff --git a/graphstore/batch_test.go b/graphstore/batch_test.go new file mode 100644 index 0000000..9622c13 --- /dev/null +++ b/graphstore/batch_test.go @@ -0,0 +1,115 @@ +package graphstore + +import ( + "git.wisehodl.dev/jay/go-heartwood/graph" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNodeBatchKey(t *testing.T) { + matchLabel := "Event" + labels := []string{"Event", "AddressableEvent"} + + // labels should be batched by key generator + expectedKey := "Event:AddressableEvent,Event" + + // Test Serialization + batchKey := createNodeBatchKey(matchLabel, labels) + assert.Equal(t, expectedKey, batchKey) + + // Test Deserialization + returnedMatchLabel, returnedLabels, err := deserializeNodeBatchKey(batchKey) + assert.NoError(t, err) + assert.Equal(t, matchLabel, returnedMatchLabel) + assert.ElementsMatch(t, labels, returnedLabels) +} + +func TestRelBatchKey(t *testing.T) { + rtype, startLabel, endLabel := "SIGNED", "User", "Event" + expectedKey := "SIGNED,User,Event" + + // Test Serialization + batchKey := createRelBatchKey(rtype, startLabel, endLabel) + assert.Equal(t, expectedKey, batchKey) + + // Test Deserialization + returnedRtype, returnedStartLabel, returnedEndLabel, err := deserializeRelBatchKey(batchKey) + assert.NoError(t, err) + assert.Equal(t, rtype, returnedRtype) + assert.Equal(t, startLabel, returnedStartLabel) + assert.Equal(t, endLabel, returnedEndLabel) +} + +func TestBatchSubgraphAddNode(t *testing.T) { + matchKeys := graph.NewSimpleMatchKeys() + subgraph := NewBatchSubgraph(matchKeys) + node := graph.NewEventNode("abc123") + + err := subgraph.AddNode(node) + + assert.NoError(t, err) + assert.Equal(t, 1, subgraph.NodeCount()) + assert.Equal(t, []*graph.Node{node}, subgraph.nodes["Event:Event"]) +} + +func TestBatchSubgraphAddNodeInvalid(t *testing.T) { + matchKeys := graph.NewSimpleMatchKeys() + subgraph := NewBatchSubgraph(matchKeys) + node := graph.NewNode("Event", graph.Properties{}) + + err := subgraph.AddNode(node) + + assert.ErrorContains(t, err, "invalid node: missing property id") + assert.Equal(t, 0, subgraph.NodeCount()) +} + +func TestBatchSubgraphAddRel(t *testing.T) { + matchKeys := graph.NewSimpleMatchKeys() + subgraph := NewBatchSubgraph(matchKeys) + + userNode := graph.NewUserNode("pubkey1") + eventNode := graph.NewEventNode("abc123") + rel := graph.NewSignedRel(userNode, eventNode, nil) + + err := subgraph.AddRel(rel) + + assert.NoError(t, err) + assert.Equal(t, 1, subgraph.RelCount()) + assert.Equal(t, []*graph.Relationship{rel}, subgraph.rels["SIGNED,User,Event"]) +} + +func TestNodeBatches(t *testing.T) { + matchKeys := graph.NewSimpleMatchKeys() + subgraph := NewBatchSubgraph(matchKeys) + node := graph.NewEventNode("abc123") + subgraph.AddNode(node) + + batches, err := subgraph.NodeBatches() + + assert.NoError(t, err) + assert.Len(t, batches, 1) + assert.Equal(t, "Event", batches[0].MatchLabel) + assert.ElementsMatch(t, []string{"Event"}, batches[0].Labels) + assert.ElementsMatch(t, []string{"id"}, batches[0].MatchKeys) + assert.Equal(t, []*graph.Node{node}, batches[0].Nodes) +} + +func TestRelBatches(t *testing.T) { + matchKeys := graph.NewSimpleMatchKeys() + subgraph := NewBatchSubgraph(matchKeys) + userNode := graph.NewUserNode("pubkey1") + eventNode := graph.NewEventNode("abc123") + rel := graph.NewSignedRel(userNode, eventNode, nil) + subgraph.AddRel(rel) + + batches, err := subgraph.RelBatches() + + assert.NoError(t, err) + assert.Len(t, batches, 1) + assert.Equal(t, "SIGNED", batches[0].Type) + assert.Equal(t, "User", batches[0].StartLabel) + assert.ElementsMatch(t, []string{"pubkey"}, batches[0].StartMatchKeys) + assert.Equal(t, "Event", batches[0].EndLabel) + assert.ElementsMatch(t, []string{"id"}, batches[0].EndMatchKeys) + assert.Equal(t, []*graph.Relationship{rel}, batches[0].Rels) +}