358 lines
8.2 KiB
Go
358 lines
8.2 KiB
Go
package heartwood
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
|
"strings"
|
|
)
|
|
|
|
// Structs
|
|
|
|
type NodeBatch struct {
|
|
MatchLabel string
|
|
Labels []string
|
|
MatchKeys []string
|
|
Nodes []*Node
|
|
}
|
|
|
|
type RelBatch struct {
|
|
Type string
|
|
StartLabel string
|
|
StartMatchKeys []string
|
|
EndLabel string
|
|
EndMatchKeys []string
|
|
Rels []*Relationship
|
|
}
|
|
|
|
type BatchSubgraph struct {
|
|
nodes map[string][]*Node
|
|
rels map[string][]*Relationship
|
|
matchProvider MatchKeysProvider
|
|
}
|
|
|
|
func NewBatchSubgraph(matchProvider MatchKeysProvider) *BatchSubgraph {
|
|
return &BatchSubgraph{
|
|
nodes: make(map[string][]*Node),
|
|
rels: make(map[string][]*Relationship),
|
|
matchProvider: matchProvider,
|
|
}
|
|
}
|
|
|
|
func (s *BatchSubgraph) AddNode(node *Node) error {
|
|
|
|
// Verify that the node has defined match property values.
|
|
matchLabel, _, err := node.MatchProps(s.matchProvider)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid node: %w", err)
|
|
}
|
|
|
|
// Determine the node's batch key.
|
|
batchKey := createNodeBatchKey(matchLabel, node.Labels.AsSortedArray())
|
|
|
|
// Add the node to the sub
|
|
s.nodes[batchKey] = append(s.nodes[batchKey], node)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BatchSubgraph) AddRel(rel *Relationship) error {
|
|
|
|
// Verify that the start node has defined match property values.
|
|
startLabel, _, err := rel.Start.MatchProps(s.matchProvider)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid start node: %w", err)
|
|
}
|
|
|
|
// Verify that the end node has defined match property values.
|
|
endLabel, _, err := rel.End.MatchProps(s.matchProvider)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid end node: %w", err)
|
|
}
|
|
|
|
// Determine the relationship's batch key.
|
|
batchKey := createRelBatchKey(rel.Type, startLabel, endLabel)
|
|
|
|
// Add the relationship to the sub
|
|
s.rels[batchKey] = append(s.rels[batchKey], rel)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BatchSubgraph) NodeCount() int {
|
|
count := 0
|
|
for l := range s.nodes {
|
|
count += len(s.nodes[l])
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (s *BatchSubgraph) RelCount() int {
|
|
count := 0
|
|
for t := range s.rels {
|
|
count += len(s.rels[t])
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (s *BatchSubgraph) nodeKeys() []string {
|
|
keys := make([]string, 0, len(s.nodes))
|
|
for l := range s.nodes {
|
|
keys = append(keys, l)
|
|
}
|
|
return keys
|
|
}
|
|
|
|
func (s *BatchSubgraph) relKeys() []string {
|
|
keys := make([]string, 0, len(s.rels))
|
|
for t := range s.rels {
|
|
keys = append(keys, t)
|
|
}
|
|
return keys
|
|
}
|
|
|
|
func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
|
|
batches := make([]NodeBatch, 0, len(s.nodeKeys()))
|
|
|
|
for _, nodeKey := range s.nodeKeys() {
|
|
matchLabel, labels, err := deserializeNodeBatchKey(nodeKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
matchKeys, exists := s.matchProvider.GetKeys(matchLabel)
|
|
if !exists {
|
|
return nil, fmt.Errorf("unknown match label: %s", matchLabel)
|
|
}
|
|
nodes := s.nodes[nodeKey]
|
|
batch := NodeBatch{
|
|
MatchLabel: matchLabel,
|
|
Labels: labels,
|
|
MatchKeys: matchKeys,
|
|
Nodes: nodes,
|
|
}
|
|
batches = append(batches, batch)
|
|
}
|
|
|
|
return batches, nil
|
|
}
|
|
|
|
func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
|
|
batches := make([]RelBatch, 0, len(s.relKeys()))
|
|
|
|
for _, relKey := range s.relKeys() {
|
|
rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
startMatchKeys, exists := s.matchProvider.GetKeys(startLabel)
|
|
if !exists {
|
|
return nil, fmt.Errorf("unknown match label: %s", startLabel)
|
|
}
|
|
|
|
endMatchKeys, exists := s.matchProvider.GetKeys(endLabel)
|
|
if !exists {
|
|
return nil, fmt.Errorf("unknown match label: %s", endLabel)
|
|
}
|
|
|
|
rels := s.rels[relKey]
|
|
batch := RelBatch{
|
|
Type: rtype,
|
|
StartLabel: startLabel,
|
|
StartMatchKeys: startMatchKeys,
|
|
EndLabel: endLabel,
|
|
EndMatchKeys: endMatchKeys,
|
|
Rels: rels,
|
|
}
|
|
batches = append(batches, batch)
|
|
}
|
|
|
|
return batches, nil
|
|
}
|
|
|
|
// Helpers
|
|
|
|
func createNodeBatchKey(matchLabel string, sortedLabels []string) string {
|
|
serializedLabels := strings.Join(sortedLabels, ",")
|
|
return fmt.Sprintf("%s:%s", matchLabel, serializedLabels)
|
|
}
|
|
|
|
func createRelBatchKey(
|
|
rtype string, startLabel string, endLabel string) string {
|
|
return strings.Join([]string{rtype, startLabel, endLabel}, ",")
|
|
}
|
|
|
|
func deserializeNodeBatchKey(batchKey string) (string, []string, error) {
|
|
parts := strings.Split(batchKey, ":")
|
|
if len(parts) != 2 {
|
|
return "", nil, fmt.Errorf("invalid node batch key: %s", batchKey)
|
|
}
|
|
matchLabel, serializedLabels := parts[0], parts[1]
|
|
labels := strings.Split(serializedLabels, ",")
|
|
return matchLabel, labels, nil
|
|
}
|
|
|
|
func deserializeRelBatchKey(batchKey string) (string, string, string, error) {
|
|
parts := strings.Split(batchKey, ",")
|
|
if len(parts) != 3 {
|
|
return "", "", "", fmt.Errorf("invalid relationship batch key: %s", batchKey)
|
|
}
|
|
rtype, startLabel, endLabel := parts[0], parts[1], parts[2]
|
|
return rtype, startLabel, endLabel, nil
|
|
}
|
|
|
|
// Merge functions
|
|
|
|
func MergeSubgraph(
|
|
ctx context.Context,
|
|
driver neo4j.Driver,
|
|
subgraph *BatchSubgraph,
|
|
) ([]neo4j.ResultSummary, error) {
|
|
nodeBatches, err := subgraph.NodeBatches()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
relBatches, err := subgraph.RelBatches()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Merge subgraph
|
|
session := driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: "neo4j"})
|
|
defer session.Close(ctx)
|
|
|
|
resultSummariesAny, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (any, error) {
|
|
var resultSummaries []neo4j.ResultSummary
|
|
|
|
for _, nodeBatch := range nodeBatches {
|
|
nodeResultSummary, err := MergeNodes(ctx, tx, nodeBatch)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"failed to merge nodes on label %q (labels %v): %w",
|
|
nodeBatch.MatchLabel,
|
|
nodeBatch.Labels,
|
|
err,
|
|
)
|
|
}
|
|
if nodeResultSummary != nil {
|
|
resultSummaries = append(resultSummaries, nodeResultSummary)
|
|
}
|
|
}
|
|
|
|
for _, relBatch := range relBatches {
|
|
relResultSummary, err := MergeRels(ctx, tx, relBatch)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"failed to merge relationships on (%s)-[%s]->(%s): %w",
|
|
relBatch.StartLabel,
|
|
relBatch.Type,
|
|
relBatch.EndLabel,
|
|
err,
|
|
)
|
|
}
|
|
if relResultSummary != nil {
|
|
resultSummaries = append(resultSummaries, relResultSummary)
|
|
}
|
|
}
|
|
|
|
return resultSummaries, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("subgraph merge transaction failed: %w", err)
|
|
}
|
|
|
|
resultSummaries, ok := resultSummariesAny.([]neo4j.ResultSummary)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected type returned from ExecuteWrite: got %T", resultSummariesAny)
|
|
}
|
|
|
|
return resultSummaries, nil
|
|
}
|
|
|
|
func MergeNodes(
|
|
ctx context.Context,
|
|
tx neo4j.ManagedTransaction,
|
|
batch NodeBatch,
|
|
) (neo4j.ResultSummary, error) {
|
|
cypherLabels := ToCypherLabels(batch.Labels)
|
|
cypherProps := ToCypherProps(batch.MatchKeys, "node.")
|
|
|
|
serializedNodes := make([]*SerializedNode, 0, len(batch.Nodes))
|
|
for _, node := range batch.Nodes {
|
|
serializedNodes = append(serializedNodes, node.Serialize())
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
UNWIND $nodes as node
|
|
|
|
MERGE (n%s { %s })
|
|
SET n += node
|
|
`,
|
|
cypherLabels, cypherProps,
|
|
)
|
|
|
|
result, err := tx.Run(ctx,
|
|
query,
|
|
map[string]any{
|
|
"nodes": serializedNodes,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
summary, err := result.Consume(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return summary, nil
|
|
}
|
|
|
|
func MergeRels(
|
|
ctx context.Context,
|
|
tx neo4j.ManagedTransaction,
|
|
batch RelBatch,
|
|
) (neo4j.ResultSummary, error) {
|
|
cypherType := ToCypherLabel(batch.Type)
|
|
startCypherLabel := ToCypherLabel(batch.StartLabel)
|
|
endCypherLabel := ToCypherLabel(batch.EndLabel)
|
|
startCypherProps := ToCypherProps(batch.StartMatchKeys, "rel.start.")
|
|
endCypherProps := ToCypherProps(batch.EndMatchKeys, "rel.end.")
|
|
|
|
serializedRels := make([]*SerializedRel, 0, len(batch.Rels))
|
|
for _, rel := range batch.Rels {
|
|
serializedRels = append(serializedRels, rel.Serialize())
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
UNWIND $rels as rel
|
|
|
|
MATCH (start%s { %s })
|
|
MATCH (end%s { %s })
|
|
|
|
MERGE (start)-[r%s]->(end)
|
|
SET r += rel.props
|
|
`,
|
|
startCypherLabel, startCypherProps,
|
|
endCypherLabel, endCypherProps,
|
|
cypherType,
|
|
)
|
|
|
|
result, err := tx.Run(ctx,
|
|
query,
|
|
map[string]any{
|
|
"rels": serializedRels,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
summary, err := result.Consume(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return summary, nil
|
|
}
|