Move StructuredSubgraph alongside batch merge function.
`graphstore` package owns entire batch merge operation. Added tests for node and rel batching.
This commit is contained in:
@@ -6,41 +6,227 @@ import (
|
||||
"git.wisehodl.dev/jay/go-heartwood/cypher"
|
||||
"git.wisehodl.dev/jay/go-heartwood/graph"
|
||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func MergeSubgraph(
|
||||
ctx context.Context,
|
||||
driver neo4j.Driver,
|
||||
subgraph *graph.StructuredSubgraph,
|
||||
) ([]neo4j.ResultSummary, error) {
|
||||
// Validate subgraph
|
||||
for _, nodeKey := range subgraph.NodeKeys() {
|
||||
matchLabel, _, err := graph.DeserializeNodeBatchKey(nodeKey)
|
||||
// Structs
|
||||
|
||||
type NodeBatch struct {
|
||||
MatchLabel string
|
||||
Labels []string
|
||||
MatchKeys []string
|
||||
Nodes []*graph.Node
|
||||
}
|
||||
|
||||
type RelBatch struct {
|
||||
Type string
|
||||
StartLabel string
|
||||
StartMatchKeys []string
|
||||
EndLabel string
|
||||
EndMatchKeys []string
|
||||
Rels []*graph.Relationship
|
||||
}
|
||||
|
||||
type BatchSubgraph struct {
|
||||
nodes map[string][]*graph.Node
|
||||
rels map[string][]*graph.Relationship
|
||||
matchProvider graph.MatchKeysProvider
|
||||
}
|
||||
|
||||
func NewBatchSubgraph(matchProvider graph.MatchKeysProvider) *BatchSubgraph {
|
||||
return &BatchSubgraph{
|
||||
nodes: make(map[string][]*graph.Node),
|
||||
rels: make(map[string][]*graph.Relationship),
|
||||
matchProvider: matchProvider,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) AddNode(node *graph.Node) error {
|
||||
|
||||
// Verify that the node has defined match property values.
|
||||
matchLabel, _, err := node.MatchProps(s.matchProvider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid node: %s", err)
|
||||
}
|
||||
|
||||
// Determine the node's batch key.
|
||||
batchKey := createNodeBatchKey(matchLabel, node.Labels.ToArray())
|
||||
|
||||
if _, exists := s.nodes[batchKey]; !exists {
|
||||
s.nodes[batchKey] = []*graph.Node{}
|
||||
}
|
||||
|
||||
// Add the node to the subgraph.
|
||||
s.nodes[batchKey] = append(s.nodes[batchKey], node)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) AddRel(rel *graph.Relationship) error {
|
||||
|
||||
// Verify that the start node has defined match property values.
|
||||
startLabel, _, err := rel.Start.MatchProps(s.matchProvider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid start node: %s", err)
|
||||
}
|
||||
|
||||
// Verify that the end node has defined match property values.
|
||||
endLabel, _, err := rel.End.MatchProps(s.matchProvider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid end node: %s", err)
|
||||
}
|
||||
|
||||
// Determine the relationship's batch key.
|
||||
batchKey := createRelBatchKey(rel.Type, startLabel, endLabel)
|
||||
|
||||
if _, exists := s.rels[batchKey]; !exists {
|
||||
s.rels[batchKey] = []*graph.Relationship{}
|
||||
}
|
||||
|
||||
// Add the relationship to the subgraph.
|
||||
s.rels[batchKey] = append(s.rels[batchKey], rel)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) NodeCount() int {
|
||||
count := 0
|
||||
for l := range s.nodes {
|
||||
count += len(s.nodes[l])
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) RelCount() int {
|
||||
count := 0
|
||||
for t := range s.rels {
|
||||
count += len(s.rels[t])
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) nodeKeys() []string {
|
||||
keys := []string{}
|
||||
for l := range s.nodes {
|
||||
keys = append(keys, l)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) relKeys() []string {
|
||||
keys := []string{}
|
||||
for t := range s.rels {
|
||||
keys = append(keys, t)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
|
||||
batches := []NodeBatch{}
|
||||
|
||||
for _, nodeKey := range s.nodeKeys() {
|
||||
matchLabel, labels, err := deserializeNodeBatchKey(nodeKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, exists := subgraph.MatchProvider().GetKeys(matchLabel)
|
||||
matchKeys, exists := s.matchProvider.GetKeys(matchLabel)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unknown match label: %s", matchLabel)
|
||||
}
|
||||
nodes := s.nodes[nodeKey]
|
||||
batch := NodeBatch{
|
||||
MatchLabel: matchLabel,
|
||||
Labels: labels,
|
||||
MatchKeys: matchKeys,
|
||||
Nodes: nodes,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
|
||||
for _, relKey := range subgraph.RelKeys() {
|
||||
_, startLabel, endLabel, err := graph.DeserializeRelBatchKey(relKey)
|
||||
return batches, nil
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
|
||||
batches := []RelBatch{}
|
||||
|
||||
for _, relKey := range s.relKeys() {
|
||||
rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, exists := subgraph.MatchProvider().GetKeys(startLabel)
|
||||
startMatchKeys, exists := s.matchProvider.GetKeys(startLabel)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unknown match label: %s", startLabel)
|
||||
}
|
||||
|
||||
_, exists = subgraph.MatchProvider().GetKeys(endLabel)
|
||||
endMatchKeys, exists := s.matchProvider.GetKeys(endLabel)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unknown match label: %s", endLabel)
|
||||
}
|
||||
|
||||
rels := s.rels[relKey]
|
||||
batch := RelBatch{
|
||||
Type: rtype,
|
||||
StartLabel: startLabel,
|
||||
StartMatchKeys: startMatchKeys,
|
||||
EndLabel: endLabel,
|
||||
EndMatchKeys: endMatchKeys,
|
||||
Rels: rels,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
|
||||
return batches, nil
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func createNodeBatchKey(matchLabel string, labels []string) string {
|
||||
sort.Strings(labels)
|
||||
serializedLabels := strings.Join(labels, ",")
|
||||
return fmt.Sprintf("%s:%s", matchLabel, serializedLabels)
|
||||
}
|
||||
|
||||
func createRelBatchKey(
|
||||
rtype string, startLabel string, endLabel string) string {
|
||||
return strings.Join([]string{rtype, startLabel, endLabel}, ",")
|
||||
}
|
||||
|
||||
func deserializeNodeBatchKey(batchKey string) (string, []string, error) {
|
||||
parts := strings.Split(batchKey, ":")
|
||||
if len(parts) != 2 {
|
||||
return "", nil, fmt.Errorf("invalid node batch key: %s", batchKey)
|
||||
}
|
||||
matchLabel, serializedLabels := parts[0], parts[1]
|
||||
labels := strings.Split(serializedLabels, ",")
|
||||
return matchLabel, labels, nil
|
||||
}
|
||||
|
||||
func deserializeRelBatchKey(batchKey string) (string, string, string, error) {
|
||||
parts := strings.Split(batchKey, ",")
|
||||
if len(parts) != 3 {
|
||||
return "", "", "", fmt.Errorf("invalid relationship batch key: %s", batchKey)
|
||||
}
|
||||
rtype, startLabel, endLabel := parts[0], parts[1], parts[2]
|
||||
return rtype, startLabel, endLabel, nil
|
||||
}
|
||||
|
||||
// Merge functions
|
||||
|
||||
func MergeSubgraph(
|
||||
ctx context.Context,
|
||||
driver neo4j.Driver,
|
||||
subgraph *BatchSubgraph,
|
||||
) ([]neo4j.ResultSummary, error) {
|
||||
nodeBatches, err := subgraph.NodeBatches()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relBatches, err := subgraph.RelBatches()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Merge subgraph
|
||||
@@ -50,35 +236,31 @@ func MergeSubgraph(
|
||||
resultSummariesAny, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
var resultSummaries []neo4j.ResultSummary
|
||||
|
||||
for _, nodeKey := range subgraph.NodeKeys() {
|
||||
matchLabel, labels, _ := graph.DeserializeNodeBatchKey(nodeKey)
|
||||
nodeResultSummary, err := MergeNodes(
|
||||
ctx, tx,
|
||||
matchLabel,
|
||||
labels,
|
||||
subgraph.MatchProvider(),
|
||||
subgraph.GetNodes(nodeKey),
|
||||
)
|
||||
for _, nodeBatch := range nodeBatches {
|
||||
nodeResultSummary, err := MergeNodes(ctx, tx, nodeBatch)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to merge nodes for key %s: %w", nodeKey, err)
|
||||
return nil, fmt.Errorf(
|
||||
"failed to merge nodes on label %q (labels %v): %w",
|
||||
nodeBatch.MatchLabel,
|
||||
nodeBatch.Labels,
|
||||
err,
|
||||
)
|
||||
}
|
||||
if nodeResultSummary != nil {
|
||||
resultSummaries = append(resultSummaries, *nodeResultSummary)
|
||||
}
|
||||
}
|
||||
|
||||
for _, relKey := range subgraph.RelKeys() {
|
||||
rtype, startLabel, endLabel, _ := graph.DeserializeRelBatchKey(relKey)
|
||||
relResultSummary, err := MergeRels(
|
||||
ctx, tx,
|
||||
rtype,
|
||||
startLabel,
|
||||
endLabel,
|
||||
subgraph.MatchProvider(),
|
||||
subgraph.GetRels(relKey),
|
||||
)
|
||||
for _, relBatch := range relBatches {
|
||||
relResultSummary, err := MergeRels(ctx, tx, relBatch)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to merge relationships for key %s: %w", relKey, err)
|
||||
return nil, fmt.Errorf(
|
||||
"failed to merge relationships on (%s)-[%s]->(%s): %w",
|
||||
relBatch.StartLabel,
|
||||
relBatch.Type,
|
||||
relBatch.EndLabel,
|
||||
err,
|
||||
)
|
||||
}
|
||||
if relResultSummary != nil {
|
||||
resultSummaries = append(resultSummaries, *relResultSummary)
|
||||
@@ -103,18 +285,13 @@ func MergeSubgraph(
|
||||
func MergeNodes(
|
||||
ctx context.Context,
|
||||
tx neo4j.ManagedTransaction,
|
||||
matchLabel string,
|
||||
nodeLabels []string,
|
||||
matchProvider graph.MatchKeysProvider,
|
||||
nodes []*graph.Node,
|
||||
batch NodeBatch,
|
||||
) (*neo4j.ResultSummary, error) {
|
||||
cypherLabels := cypher.ToCypherLabels(nodeLabels)
|
||||
|
||||
matchKeys, _ := matchProvider.GetKeys(matchLabel)
|
||||
cypherProps := cypher.ToCypherProps(matchKeys, "node.")
|
||||
cypherLabels := cypher.ToCypherLabels(batch.Labels)
|
||||
cypherProps := cypher.ToCypherProps(batch.MatchKeys, "node.")
|
||||
|
||||
serializedNodes := []*graph.SerializedNode{}
|
||||
for _, node := range nodes {
|
||||
for _, node := range batch.Nodes {
|
||||
serializedNodes = append(serializedNodes, node.Serialize())
|
||||
}
|
||||
|
||||
@@ -147,24 +324,16 @@ func MergeNodes(
|
||||
func MergeRels(
|
||||
ctx context.Context,
|
||||
tx neo4j.ManagedTransaction,
|
||||
rtype string,
|
||||
startLabel string,
|
||||
endLabel string,
|
||||
matchProvider graph.MatchKeysProvider,
|
||||
rels []*graph.Relationship,
|
||||
batch RelBatch,
|
||||
) (*neo4j.ResultSummary, error) {
|
||||
cypherType := cypher.ToCypherLabel(rtype)
|
||||
startCypherLabel := cypher.ToCypherLabel(startLabel)
|
||||
endCypherLabel := cypher.ToCypherLabel(endLabel)
|
||||
|
||||
matchKeys, _ := matchProvider.GetKeys(startLabel)
|
||||
startCypherProps := cypher.ToCypherProps(matchKeys, "rel.start.")
|
||||
|
||||
matchKeys, _ = matchProvider.GetKeys(endLabel)
|
||||
endCypherProps := cypher.ToCypherProps(matchKeys, "rel.end.")
|
||||
cypherType := cypher.ToCypherLabel(batch.Type)
|
||||
startCypherLabel := cypher.ToCypherLabel(batch.StartLabel)
|
||||
endCypherLabel := cypher.ToCypherLabel(batch.EndLabel)
|
||||
startCypherProps := cypher.ToCypherProps(batch.StartMatchKeys, "rel.start.")
|
||||
endCypherProps := cypher.ToCypherProps(batch.EndMatchKeys, "rel.end.")
|
||||
|
||||
serializedRels := []*graph.SerializedRel{}
|
||||
for _, rel := range rels {
|
||||
for _, rel := range batch.Rels {
|
||||
serializedRels = append(serializedRels, rel.Serialize())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user