From f39f8786d9a797b46ccc04ac4201a13db27534eb Mon Sep 17 00:00:00 2001 From: Jay Date: Mon, 2 Mar 2026 14:50:27 -0500 Subject: [PATCH] Migrated from neostr-brainstorm. Added unit tests. --- batch.go | 198 ++++++++++++++++++++++++++++ c2p | 1 + cypher.go | 39 ++++++ cypher_test.go | 43 ++++++ go.mod | 14 ++ go.sum | 12 ++ graph.go | 352 +++++++++++++++++++++++++++++++++++++++++++++++++ graph_test.go | 172 ++++++++++++++++++++++++ neo4j.go | 20 +++ schema.go | 154 ++++++++++++++++++++++ schema_test.go | 45 +++++++ types.go | 0 util.go | 50 +++++++ 13 files changed, 1100 insertions(+) create mode 100644 batch.go create mode 100755 c2p create mode 100644 cypher.go create mode 100644 cypher_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 graph.go create mode 100644 graph_test.go create mode 100644 neo4j.go create mode 100644 schema.go create mode 100644 schema_test.go create mode 100644 types.go create mode 100644 util.go diff --git a/batch.go b/batch.go new file mode 100644 index 0000000..b084c5e --- /dev/null +++ b/batch.go @@ -0,0 +1,198 @@ +package heartwood + +import ( + "context" + "fmt" + "github.com/neo4j/neo4j-go-driver/v6/neo4j" +) + +func MergeSubgraph( + ctx context.Context, + driver neo4j.Driver, + subgraph *StructuredSubgraph, +) ([]neo4j.ResultSummary, error) { + // Validate subgraph + for _, nodeKey := range subgraph.NodeKeys() { + matchLabel, _, err := DeserializeNodeKey(nodeKey) + if err != nil { + return nil, err + } + + _, exists := subgraph.matchProvider.GetKeys(matchLabel) + if !exists { + return nil, fmt.Errorf("unknown match label: %s", matchLabel) + } + } + + for _, relKey := range subgraph.RelKeys() { + _, startLabel, endLabel, err := DeserializeRelKey(relKey) + if err != nil { + return nil, err + } + + _, exists := subgraph.matchProvider.GetKeys(startLabel) + if !exists { + return nil, fmt.Errorf("unknown match label: %s", startLabel) + } + + _, exists = subgraph.matchProvider.GetKeys(endLabel) + if !exists { + return nil, fmt.Errorf("unknown match label: %s", endLabel) + } + } + + // Merge subgraph + session := driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: "neo4j"}) + defer session.Close(ctx) + + resultSummariesAny, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (any, error) { + var resultSummaries []neo4j.ResultSummary + + for _, nodeKey := range subgraph.NodeKeys() { + matchLabel, labels, _ := DeserializeNodeKey(nodeKey) + nodeResultSummary, err := MergeNodes( + ctx, tx, + matchLabel, + labels, + subgraph.matchProvider, + subgraph.GetNodes(nodeKey), + ) + if err != nil { + return nil, fmt.Errorf("failed to merge nodes for key %s: %w", nodeKey, err) + } + if nodeResultSummary != nil { + resultSummaries = append(resultSummaries, *nodeResultSummary) + } + } + + for _, relKey := range subgraph.RelKeys() { + rtype, startLabel, endLabel, _ := DeserializeRelKey(relKey) + relResultSummary, err := MergeRels( + ctx, tx, + rtype, + startLabel, + endLabel, + subgraph.matchProvider, + subgraph.GetRels(relKey), + ) + if err != nil { + return nil, fmt.Errorf("failed to merge relationships for key %s: %w", relKey, err) + } + if relResultSummary != nil { + resultSummaries = append(resultSummaries, *relResultSummary) + } + } + + return resultSummaries, nil + }) + + if err != nil { + return nil, fmt.Errorf("subgraph merge transaction failed: %w", err) + } + + resultSummaries, ok := resultSummariesAny.([]neo4j.ResultSummary) + if !ok { + return nil, fmt.Errorf("unexpected type returned from ExecuteWrite: got %T", resultSummariesAny) + } + + return resultSummaries, nil +} + +func MergeNodes( + ctx context.Context, + tx neo4j.ManagedTransaction, + matchLabel string, + nodeLabels []string, + matchProvider MatchKeysProvider, + nodes []*Node, +) (*neo4j.ResultSummary, error) { + cypherLabels := ToCypherLabels(nodeLabels) + + matchKeys, _ := matchProvider.GetKeys(matchLabel) + cypherProps := ToCypherProps(matchKeys, "node.") + + serializedNodes := []*SerializedNode{} + for _, node := range nodes { + serializedNodes = append(serializedNodes, node.Serialize()) + } + + query := fmt.Sprintf(` + UNWIND $nodes as node + + MERGE (n%s { %s }) + SET n += node + `, + cypherLabels, cypherProps, + ) + + result, err := tx.Run(ctx, + query, + map[string]any{ + "nodes": serializedNodes, + }) + if err != nil { + return nil, err + } + + summary, err := result.Consume(ctx) + if err != nil { + return nil, err + } + + return &summary, nil +} + +func MergeRels( + ctx context.Context, + tx neo4j.ManagedTransaction, + rtype string, + startLabel string, + endLabel string, + matchProvider MatchKeysProvider, + rels []*Relationship, +) (*neo4j.ResultSummary, error) { + cypherType := ToCypherLabel(rtype) + startCypherLabel := ToCypherLabel(startLabel) + endCypherLabel := ToCypherLabel(endLabel) + + matchKeys, _ := matchProvider.GetKeys(startLabel) + startCypherProps := ToCypherProps(matchKeys, "rel.start.") + + matchKeys, _ = matchProvider.GetKeys(endLabel) + endCypherProps := ToCypherProps(matchKeys, "rel.end.") + + serializedRels := []*SerializedRel{} + for _, rel := range rels { + serializedRels = append(serializedRels, rel.Serialize()) + } + + query := fmt.Sprintf(` + UNWIND $rels as rel + + MATCH (start%s { %s }) + MATCH (end%s { %s }) + + MERGE (start)-[r%s]->(end) + SET r += rel.props + `, + startCypherLabel, startCypherProps, + endCypherLabel, endCypherProps, + cypherType, + ) + + result, err := tx.Run(ctx, + query, + map[string]any{ + "rels": serializedRels, + }) + if err != nil { + return nil, err + } + + summary, err := result.Consume(ctx) + if err != nil { + return nil, err + } + + return &summary, nil +} diff --git a/c2p b/c2p new file mode 100755 index 0000000..87189c3 --- /dev/null +++ b/c2p @@ -0,0 +1 @@ +code2prompt -c -e go.sum diff --git a/cypher.go b/cypher.go new file mode 100644 index 0000000..271db62 --- /dev/null +++ b/cypher.go @@ -0,0 +1,39 @@ +package heartwood + +import ( + "fmt" + "strings" +) + +// ======================================== +// Cypher Formatting Functions +// ======================================== + +// ToCypherLabel converts a node label or relationship type into its Cypher +// format. +func ToCypherLabel(label string) string { + return fmt.Sprintf(":`%s`", label) +} + +// ToCypherLabels converts a list of node labels into its Cypher format. +func ToCypherLabels(labels []string) string { + var cypherLabels []string + + for _, label := range labels { + cypherLabels = append(cypherLabels, ToCypherLabel(label)) + } + + return strings.Join(cypherLabels, "") +} + +func ToCypherProps(keys []string, prefix string) string { + if prefix == "" { + prefix = "$" + } + cypherPropsParts := []string{} + for _, key := range keys { + cypherPropsParts = append( + cypherPropsParts, fmt.Sprintf("%s: %s%s", key, prefix, key)) + } + return strings.Join(cypherPropsParts, ", ") +} diff --git a/cypher_test.go b/cypher_test.go new file mode 100644 index 0000000..a5402ea --- /dev/null +++ b/cypher_test.go @@ -0,0 +1,43 @@ +package heartwood + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestToCypherLabel(t *testing.T) { + assert.Equal(t, ":`Event`", ToCypherLabel("Event")) +} + +func TestToCypherLabels(t *testing.T) { + assert.Equal(t, ":`Event`:`ReplaceableEvent`", + ToCypherLabels([]string{"Event", "ReplaceableEvent"})) +} + +func TestToCypherProps(t *testing.T) { + cases := []struct { + name string + keys []string + prefix string + expected string + }{ + { + name: "default prefix", + keys: []string{"id", "name"}, + prefix: "", + expected: "id: $id, name: $name", + }, + { + name: "set prefix", + keys: []string{"id", "name"}, + prefix: "entity.", + expected: "id: entity.id, name: entity.name", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, ToCypherProps(tc.keys, tc.prefix)) + }) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..dfb0ce8 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module git.wisehodl.dev/jay/go-heartwood + +go 1.24 + +require ( + github.com/neo4j/neo4j-go-driver/v6 v6.0.0 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d95611d --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/neo4j/neo4j-go-driver/v6 v6.0.0 h1:xVAi6YLOfzXUx+1Lc/F2dUhpbN76BfKleZbAlnDFRiA= +github.com/neo4j/neo4j-go-driver/v6 v6.0.0/go.mod h1:hzSTfNfM31p1uRSzL1F/BAYOgaiTarE6OAQBajfsm+I= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/graph.go b/graph.go new file mode 100644 index 0000000..ef32432 --- /dev/null +++ b/graph.go @@ -0,0 +1,352 @@ +// This module defines types and functions for working with Neo4j graph +// entities. + +package heartwood + +import ( + "fmt" + "sort" + "strings" +) + +// ======================================== +// Types +// ======================================== + +// Properties represents a map of node or relationship props. +type Properties map[string]any + +// ======================================== +// Match Key Provider +// ======================================== + +// MatchKeysProvider defines methods for querying a mapping of node labels and +// the property keys used to match nodes with them. +type MatchKeysProvider interface { + // GetLabels returns the array of node labels in the mapping. + GetLabels() []string + + // GetKeys returns the node property keys used to match nodes with the + // given label and a boolean indicating the success of the lookup. + GetKeys(label string) ([]string, bool) +} + +// MatchKeys is a simple implementation of the MatchKeysProvider interface. +type MatchKeys struct { + keys map[string][]string +} + +func (p *MatchKeys) GetLabels() []string { + labels := []string{} + for l := range p.keys { + labels = append(labels, l) + } + return labels +} + +func (p *MatchKeys) GetKeys(label string) ([]string, bool) { + if keys, exists := p.keys[label]; exists { + return keys, exists + } else { + return nil, exists + } +} + +// ======================================== +// Nodes +// ======================================== + +// Node represents a Neo4j node entity, encapsulating its labels and +// properties. +type Node struct { + // Set of labels on the node. + Labels Set[string] + // Mapping of properties on the node. + Props Properties +} + +// NewNode creates a new node with the given label and properties. +func NewNode(label string, props Properties) *Node { + if props == nil { + props = make(Properties) + } + return &Node{ + Labels: NewSet(label), + Props: props, + } +} + +// MatchProps returns the node label and the property values to match it in the +// database. +func (n *Node) MatchProps( + matchProvider MatchKeysProvider) (string, Properties, error) { + + // Iterate over each label on the node, checking whether each has match + // keys associated with it. + labels := n.Labels.ToArray() + sort.Strings(labels) + for _, label := range labels { + if keys, exists := matchProvider.GetKeys(label); exists { + props := make(Properties) + + // Get the property values associated with each match key. + for _, key := range keys { + if value, exists := n.Props[key]; exists { + props[key] = value + } else { + + // If any match property values are missing, return an + // error. + return label, nil, + fmt.Errorf( + "missing property %s for label %s", key, label) + } + } + + // Return the label and match properties + return label, props, nil + } + } + + // If none of the node labels have defined match keys, return an error. + return "", nil, fmt.Errorf("no recognized label found in %v", n.Labels) +} + +type SerializedNode = Properties + +func (n *Node) Serialize() *SerializedNode { + return &n.Props +} + +// ======================================== +// Relationships +// ======================================== + +// Relationship represents a Neo4j relationship between two nodes, including +// its type and properties. +type Relationship struct { + // The relationship type. + Type string + // The start node for the relationship. + Start *Node + // The end node for the relationship. + End *Node + // Mapping of properties on the relationship + Props Properties +} + +// NewRelationship creates a new relationship with the given type, start node, +// end node, and properties +func NewRelationship( + rtype string, start *Node, end *Node, props Properties) *Relationship { + + if props == nil { + props = make(Properties) + } + return &Relationship{ + Type: rtype, + Start: start, + End: end, + Props: props, + } +} + +type SerializedRel = map[string]Properties + +func (r *Relationship) Serialize() *SerializedRel { + srel := make(map[string]Properties) + srel["props"] = r.Props + srel["start"] = r.Start.Props + srel["end"] = r.End.Props + return &srel +} + +// ======================================== +// Simple Subgraph +// ======================================== + +// Subgraph represents a simple collection of nodes and relationships. +type Subgraph struct { + // The nodes in the subgraph. + nodes []*Node + // The relationships in the subgraph. + rels []*Relationship +} + +// NewSubgraph creates an empty subgraph. +func NewSubgraph() *Subgraph { + return &Subgraph{ + nodes: []*Node{}, + rels: []*Relationship{}, + } +} + +// AddNode adds a node to the subgraph +func (s *Subgraph) AddNode(node *Node) { + s.nodes = append(s.nodes, node) +} + +// AddRel adds a relationship to the subgraph. +func (s *Subgraph) AddRel(rel *Relationship) { + s.rels = append(s.rels, rel) +} + +// ======================================== +// Structured Subgraph +// ======================================== + +// StructuredSubgraph is a structured collection of nodes and relationships for +// the purpose of conducting batch operations. +type StructuredSubgraph struct { + // A map of grouped nodes, sorted by their label combinations. + nodes map[string][]*Node + // A map of grouped relationships, sorted by their type and related node + // labels. + rels map[string][]*Relationship + // Provides node property keys used to match nodes with given labels in the + // database. + matchProvider MatchKeysProvider +} + +// NewStructuredSubgraph creates an empty structured subgraph with the given +// match keys provider. +func NewStructuredSubgraph(matchProvider MatchKeysProvider) *StructuredSubgraph { + return &StructuredSubgraph{ + nodes: make(map[string][]*Node), + rels: make(map[string][]*Relationship), + matchProvider: matchProvider, + } +} + +// AddNode sorts a node into the subgraph. +func (s *StructuredSubgraph) AddNode(node *Node) error { + + // Verify that the node has defined match property values. + matchLabel, _, err := node.MatchProps(s.matchProvider) + if err != nil { + return fmt.Errorf("invalid node: %s", err) + } + + // Determine the node's sort key. + sortKey := createNodeSortKey(matchLabel, node.Labels.ToArray()) + + if _, exists := s.nodes[sortKey]; !exists { + s.nodes[sortKey] = []*Node{} + } + + // Add the node to the subgraph. + s.nodes[sortKey] = append(s.nodes[sortKey], node) + + return nil +} + +// AddRel sorts a relationship into the subgraph. +func (s *StructuredSubgraph) AddRel(rel *Relationship) error { + + // Verify that the start node has defined match property values. + startLabel, _, err := rel.Start.MatchProps(s.matchProvider) + if err != nil { + return fmt.Errorf("invalid start node: %s", err) + } + + // Verify that the end node has defined match property values. + endLabel, _, err := rel.End.MatchProps(s.matchProvider) + if err != nil { + return fmt.Errorf("invalid end node: %s", err) + } + + // Determine the relationship's sort key. + sortKey := createRelSortKey(rel.Type, startLabel, endLabel) + + if _, exists := s.rels[sortKey]; !exists { + s.rels[sortKey] = []*Relationship{} + } + + // Add the relationship to the subgraph. + s.rels[sortKey] = append(s.rels[sortKey], rel) + + return nil +} + +// GetNodes returns the nodes grouped under the given sort key. +func (s *StructuredSubgraph) GetNodes(nodeKey string) []*Node { + return s.nodes[nodeKey] +} + +// GetRels returns the rels grouped under the given sort key. +func (s *StructuredSubgraph) GetRels(relKey string) []*Relationship { + return s.rels[relKey] +} + +// NodeCount returns the number of nodes in the subgraph. +func (s *StructuredSubgraph) NodeCount() int { + count := 0 + for l := range s.nodes { + count += len(s.nodes[l]) + } + return count +} + +// RelCount returns the number of relationships in the subgraph. +func (s *StructuredSubgraph) RelCount() int { + count := 0 + for t := range s.rels { + count += len(s.rels[t]) + } + return count +} + +// NodeKeys returns the list of node sort keys in the subgraph. +func (s *StructuredSubgraph) NodeKeys() []string { + keys := []string{} + for l := range s.nodes { + keys = append(keys, l) + } + return keys +} + +// RelKeys returns the list of relationship sort keys in the subgraph. +func (s *StructuredSubgraph) RelKeys() []string { + keys := []string{} + for t := range s.rels { + keys = append(keys, t) + } + return keys +} + +// createNodeSortKey returns the serialized node labels for sorting. +func createNodeSortKey(matchLabel string, labels []string) string { + sort.Strings(labels) + serializedLabels := strings.Join(labels, ",") + return fmt.Sprintf("%s:%s", matchLabel, serializedLabels) +} + +// createRelSortKey returns the serialized relationship type and start/end node +// labels for sorting. +func createRelSortKey( + rtype string, startLabel string, endLabel string) string { + return strings.Join([]string{rtype, startLabel, endLabel}, ",") +} + +// DeserializeNodeKey returns the list of node labels from the serialized sort +// key. +func DeserializeNodeKey(sortKey string) (string, []string, error) { + parts := strings.Split(sortKey, ":") + if len(parts) != 2 { + return "", nil, fmt.Errorf("invalid node sort key: %s", sortKey) + } + matchLabel, serializedLabels := parts[0], parts[1] + labels := strings.Split(serializedLabels, ",") + return matchLabel, labels, nil +} + +// DeserializeRelKey returns the relationship type, start node label, and end +// node label from the serialized sort key. Panics if the sort key is invalid. +func DeserializeRelKey(sortKey string) (string, string, string, error) { + parts := strings.Split(sortKey, ",") + if len(parts) != 3 { + return "", "", "", fmt.Errorf("invalid relationship sort key: %s", sortKey) + } + rtype, startLabel, endLabel := parts[0], parts[1], parts[2] + return rtype, startLabel, endLabel, nil +} diff --git a/graph_test.go b/graph_test.go new file mode 100644 index 0000000..b2f33ec --- /dev/null +++ b/graph_test.go @@ -0,0 +1,172 @@ +package heartwood + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestMatchKeys(t *testing.T) { + matchKeys := &MatchKeys{ + keys: map[string][]string{ + "User": {"pubkey"}, + "Event": {"id"}, + "Tag": {"name", "value"}, + }, + } + + t.Run("get labels", func(t *testing.T) { + expectedLabels := []string{"Event", "Tag", "User"} + labels := matchKeys.GetLabels() + assert.ElementsMatch(t, expectedLabels, labels) + }) + + t.Run("get keys", func(t *testing.T) { + expectedKeys := []string{"id"} + keys, exists := matchKeys.GetKeys("Event") + assert.True(t, exists) + assert.ElementsMatch(t, expectedKeys, keys) + }) + + t.Run("unknown key", func(t *testing.T) { + keys, exists := matchKeys.GetKeys("Unknown") + assert.False(t, exists) + assert.Nil(t, keys) + }) +} + +func TestNodeSortKey(t *testing.T) { + matchLabel := "Event" + labels := []string{"Event", "AddressableEvent"} + + // labels should be sorted by key generator + expectedKey := "Event:AddressableEvent,Event" + + // Test Serialization + sortKey := createNodeSortKey(matchLabel, labels) + assert.Equal(t, expectedKey, sortKey) + + // Test Deserialization + returnedMatchLabel, returnedLabels, err := DeserializeNodeKey(sortKey) + assert.NoError(t, err) + assert.Equal(t, matchLabel, returnedMatchLabel) + assert.ElementsMatch(t, labels, returnedLabels) +} + +func TestRelSortKey(t *testing.T) { + rtype, startLabel, endLabel := "SIGNED", "User", "Event" + expectedKey := "SIGNED,User,Event" + + // Test Serialization + sortKey := createRelSortKey(rtype, startLabel, endLabel) + assert.Equal(t, expectedKey, sortKey) + + // Test Deserialization + returnedRtype, returnedStartLabel, returnedEndLabel, err := DeserializeRelKey(sortKey) + assert.NoError(t, err) + assert.Equal(t, rtype, returnedRtype) + assert.Equal(t, startLabel, returnedStartLabel) + assert.Equal(t, endLabel, returnedEndLabel) +} + +func TestMatchProps(t *testing.T) { + matchKeys := &MatchKeys{ + keys: map[string][]string{ + "User": {"pubkey"}, + "Event": {"id"}, + }, + } + + cases := []struct { + name string + node *Node + wantMatchLabel string + wantMatchProps Properties + wantErr bool + wantErrText string + }{ + { + name: "matching label, all props present", + node: NewEventNode("abc123"), + wantMatchLabel: "Event", + wantMatchProps: Properties{"id": "abc123"}, + }, + { + name: "matching label, required prop missing", + node: NewNode("Event", Properties{}), + wantErr: true, + wantErrText: "missing property", + }, + { + name: "no recognized label", + node: NewNode("Tag", Properties{"name": "e", "value": "abc"}), + wantErr: true, + wantErrText: "no recognized label", + }, + { + name: "multiple labels, one matches", + node: &Node{ + Labels: NewSet("Event", "Unknown"), + Props: Properties{ + "id": "abc123", + }, + }, + wantMatchLabel: "Event", + wantMatchProps: Properties{"id": "abc123"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + matchLabel, props, err := tc.node.MatchProps(matchKeys) + if tc.wantErr { + assert.Error(t, err) + if tc.wantErrText != "" { + assert.ErrorContains(t, err, tc.wantErrText) + } + return + } + assert.NoError(t, err) + assert.Equal(t, tc.wantMatchLabel, matchLabel) + assert.Equal(t, tc.wantMatchProps, props) + }) + } +} + +func TestStructuredSubgraphAddNode(t *testing.T) { + matchKeys := NewMatchKeys() + subgraph := NewStructuredSubgraph(matchKeys) + node := NewEventNode("abc123") + + err := subgraph.AddNode(node) + + assert.NoError(t, err) + assert.Equal(t, 1, subgraph.NodeCount()) + assert.Equal(t, []*Node{node}, subgraph.GetNodes("Event:Event")) +} + +func TestStructuredSubgraphAddNodeInvalid(t *testing.T) { + matchKeys := NewMatchKeys() + subgraph := NewStructuredSubgraph(matchKeys) + node := NewNode("Event", Properties{}) + + err := subgraph.AddNode(node) + + assert.ErrorContains(t, err, "invalid node: missing property id") + assert.Equal(t, 0, subgraph.NodeCount()) +} + +func TestStructuredSubgraphAddRel(t *testing.T) { + matchKeys := NewMatchKeys() + subgraph := NewStructuredSubgraph(matchKeys) + + userNode := NewUserNode("pubkey1") + eventNode := NewEventNode("abc123") + rel, _ := NewSignedRel(userNode, eventNode, nil) + + err := subgraph.AddRel(rel) + + assert.NoError(t, err) + assert.Equal(t, 1, subgraph.RelCount()) + assert.Equal(t, []*Relationship{rel}, subgraph.GetRels("SIGNED,User,Event")) + +} diff --git a/neo4j.go b/neo4j.go new file mode 100644 index 0000000..7f64381 --- /dev/null +++ b/neo4j.go @@ -0,0 +1,20 @@ +package heartwood + +import ( + "context" + "github.com/neo4j/neo4j-go-driver/v6/neo4j" +) + +// ConnectNeo4j creates a new Neo4j driver and verifies its connectivity. +func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver, error) { + driver, err := neo4j.NewDriver( + uri, + neo4j.BasicAuth(user, password, "")) + + err = driver.VerifyConnectivity(ctx) + if err != nil { + return driver, err + } + + return driver, nil +} diff --git a/schema.go b/schema.go new file mode 100644 index 0000000..c821dc8 --- /dev/null +++ b/schema.go @@ -0,0 +1,154 @@ +// This module provides methods for creating nodes and relationships according +// to a defined schema. + +package heartwood + +import ( + "context" + "fmt" + "github.com/neo4j/neo4j-go-driver/v6/neo4j" +) + +// ======================================== +// Schema Match Keys +// ======================================== + +func NewMatchKeys() *MatchKeys { + return &MatchKeys{ + keys: map[string][]string{ + "User": {"pubkey"}, + "Relay": {"url"}, + "Event": {"id"}, + "Tag": {"name", "value"}, + }, + } +} + +// ======================================== +// Node Constructors +// ======================================== + +func NewUserNode(pubkey string) *Node { + return NewNode("User", Properties{"pubkey": pubkey}) +} + +func NewRelayNode(url string) *Node { + return NewNode("Relay", Properties{"url": url}) +} + +func NewEventNode(id string) *Node { + return NewNode("Event", Properties{"id": id}) +} + +func NewTagNode(name string, value string, rest []string) *Node { + return NewNode("Tag", Properties{ + "name": name, + "value": value, + "rest": rest}) +} + +// ======================================== +// Relationship Constructors +// ======================================== + +func NewSignedRel( + start *Node, end *Node, props Properties) (*Relationship, error) { + return NewRelationshipWithValidation( + "SIGNED", "User", "Event", start, end, props) + +} + +func NewTaggedRel( + start *Node, end *Node, props Properties) (*Relationship, error) { + return NewRelationshipWithValidation( + "TAGGED", "Event", "Tag", start, end, props) +} + +func NewReferencesEventRel( + start *Node, end *Node, props Properties) (*Relationship, error) { + return NewRelationshipWithValidation( + "REFERENCES", "Event", "Event", start, end, props) +} + +func NewReferencesUserRel( + start *Node, end *Node, props Properties) (*Relationship, error) { + return NewRelationshipWithValidation( + "REFERENCES", "Event", "User", start, end, props) +} + +// ======================================== +// Relationship Constructor Helpers +// ======================================== + +func validateNodeLabel(node *Node, role string, expectedLabel string) error { + if !node.Labels.Contains(expectedLabel) { + return fmt.Errorf( + "expected %s node to have label '%s'. got %v", + role, expectedLabel, node.Labels.ToArray(), + ) + } + + return nil +} + +func NewRelationshipWithValidation( + rtype string, + startLabel string, + endLabel string, + start *Node, + end *Node, + props Properties) (*Relationship, error) { + var err error + + err = validateNodeLabel(start, "start", startLabel) + if err != nil { + return nil, err + } + + err = validateNodeLabel(end, "end", endLabel) + if err != nil { + return nil, err + } + + return NewRelationship(rtype, start, end, props), nil +} + +// ======================================== +// Schema Indexes and Constaints +// ======================================== + +// SetNeo4jSchema ensures that the necessary indexes and constraints exist in +// the database +func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error { + schemaQueries := []string{ + `CREATE CONSTRAINT user_pubkey IF NOT EXISTS + FOR (n:User) REQUIRE n.pubkey IS UNIQUE`, + + `CREATE INDEX user_pubkey IF NOT EXISTS + FOR (n:User) ON (n.pubkey)`, + + `CREATE INDEX event_id IF NOT EXISTS + FOR (n:Event) ON (n.id)`, + + `CREATE INDEX event_kind IF NOT EXISTS + FOR (n:Event) ON (n.kind)`, + + `CREATE INDEX tag_name_value IF NOT EXISTS + FOR (n:Tag) ON (n.name, n.value)`, + } + + // Create indexes and constraints + for _, query := range schemaQueries { + _, err := neo4j.ExecuteQuery(ctx, driver, + query, + nil, + neo4j.EagerResultTransformer, + neo4j.ExecuteQueryWithDatabase("neo4j")) + + if err != nil { + return err + } + } + + return nil +} diff --git a/schema_test.go b/schema_test.go new file mode 100644 index 0000000..025fbce --- /dev/null +++ b/schema_test.go @@ -0,0 +1,45 @@ +package heartwood + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewRelationshipWithValidation(t *testing.T) { + cases := []struct { + name string + start *Node + end *Node + wantErr bool + wantErrText string + }{ + { + name: "valid start and end nodes", + start: NewUserNode("pubkey1"), + end: NewEventNode("abc123"), + }, + { + name: "mismatched start node label", + start: NewEventNode("abc123"), + end: NewEventNode("abc123"), + wantErr: true, + wantErrText: "expected start node to have label 'User'", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + rel, err := NewSignedRel(tc.start, tc.end, nil) + if tc.wantErr { + assert.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrText) + assert.Nil(t, rel) + return + } + assert.NoError(t, err) + assert.Equal(t, "SIGNED", rel.Type) + assert.Contains(t, rel.Start.Labels.ToArray(), "User") + assert.Contains(t, rel.End.Labels.ToArray(), "Event") + }) + } +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..e69de29 diff --git a/util.go b/util.go new file mode 100644 index 0000000..6a944f0 --- /dev/null +++ b/util.go @@ -0,0 +1,50 @@ +package heartwood + +// Sets + +type Set[T comparable] struct { + inner map[T]struct{} +} + +func NewSet[T comparable](items ...T) Set[T] { + set := Set[T]{ + inner: make(map[T]struct{}), + } + for _, i := range items { + set.Add(i) + } + return set +} + +func (s Set[T]) Add(item T) { + s.inner[item] = struct{}{} +} + +func (s Set[T]) Remove(item T) { + delete(s.inner, item) +} + +func (s Set[T]) Contains(item T) bool { + _, exists := s.inner[item] + return exists +} + +func (s Set[T]) ToArray() []T { + array := []T{} + for i := range s.inner { + array = append(array, i) + } + return array +} + +// Operations + +func Flatten[K comparable, V comparable](mapping map[K][]V) []V { + var values []V + for _, array := range mapping { + for _, v := range array { + values = append(values, v) + } + } + return values +}