Migrated from neostr-brainstorm.

Added unit tests.
This commit is contained in:
Jay
2026-03-02 14:50:27 -05:00
commit f39f8786d9
13 changed files with 1100 additions and 0 deletions

198
batch.go Normal file
View File

@@ -0,0 +1,198 @@
package heartwood
import (
"context"
"fmt"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
func MergeSubgraph(
ctx context.Context,
driver neo4j.Driver,
subgraph *StructuredSubgraph,
) ([]neo4j.ResultSummary, error) {
// Validate subgraph
for _, nodeKey := range subgraph.NodeKeys() {
matchLabel, _, err := DeserializeNodeKey(nodeKey)
if err != nil {
return nil, err
}
_, exists := subgraph.matchProvider.GetKeys(matchLabel)
if !exists {
return nil, fmt.Errorf("unknown match label: %s", matchLabel)
}
}
for _, relKey := range subgraph.RelKeys() {
_, startLabel, endLabel, err := DeserializeRelKey(relKey)
if err != nil {
return nil, err
}
_, exists := subgraph.matchProvider.GetKeys(startLabel)
if !exists {
return nil, fmt.Errorf("unknown match label: %s", startLabel)
}
_, exists = subgraph.matchProvider.GetKeys(endLabel)
if !exists {
return nil, fmt.Errorf("unknown match label: %s", endLabel)
}
}
// Merge subgraph
session := driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: "neo4j"})
defer session.Close(ctx)
resultSummariesAny, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (any, error) {
var resultSummaries []neo4j.ResultSummary
for _, nodeKey := range subgraph.NodeKeys() {
matchLabel, labels, _ := DeserializeNodeKey(nodeKey)
nodeResultSummary, err := MergeNodes(
ctx, tx,
matchLabel,
labels,
subgraph.matchProvider,
subgraph.GetNodes(nodeKey),
)
if err != nil {
return nil, fmt.Errorf("failed to merge nodes for key %s: %w", nodeKey, err)
}
if nodeResultSummary != nil {
resultSummaries = append(resultSummaries, *nodeResultSummary)
}
}
for _, relKey := range subgraph.RelKeys() {
rtype, startLabel, endLabel, _ := DeserializeRelKey(relKey)
relResultSummary, err := MergeRels(
ctx, tx,
rtype,
startLabel,
endLabel,
subgraph.matchProvider,
subgraph.GetRels(relKey),
)
if err != nil {
return nil, fmt.Errorf("failed to merge relationships for key %s: %w", relKey, err)
}
if relResultSummary != nil {
resultSummaries = append(resultSummaries, *relResultSummary)
}
}
return resultSummaries, nil
})
if err != nil {
return nil, fmt.Errorf("subgraph merge transaction failed: %w", err)
}
resultSummaries, ok := resultSummariesAny.([]neo4j.ResultSummary)
if !ok {
return nil, fmt.Errorf("unexpected type returned from ExecuteWrite: got %T", resultSummariesAny)
}
return resultSummaries, nil
}
func MergeNodes(
ctx context.Context,
tx neo4j.ManagedTransaction,
matchLabel string,
nodeLabels []string,
matchProvider MatchKeysProvider,
nodes []*Node,
) (*neo4j.ResultSummary, error) {
cypherLabels := ToCypherLabels(nodeLabels)
matchKeys, _ := matchProvider.GetKeys(matchLabel)
cypherProps := ToCypherProps(matchKeys, "node.")
serializedNodes := []*SerializedNode{}
for _, node := range nodes {
serializedNodes = append(serializedNodes, node.Serialize())
}
query := fmt.Sprintf(`
UNWIND $nodes as node
MERGE (n%s { %s })
SET n += node
`,
cypherLabels, cypherProps,
)
result, err := tx.Run(ctx,
query,
map[string]any{
"nodes": serializedNodes,
})
if err != nil {
return nil, err
}
summary, err := result.Consume(ctx)
if err != nil {
return nil, err
}
return &summary, nil
}
func MergeRels(
ctx context.Context,
tx neo4j.ManagedTransaction,
rtype string,
startLabel string,
endLabel string,
matchProvider MatchKeysProvider,
rels []*Relationship,
) (*neo4j.ResultSummary, error) {
cypherType := ToCypherLabel(rtype)
startCypherLabel := ToCypherLabel(startLabel)
endCypherLabel := ToCypherLabel(endLabel)
matchKeys, _ := matchProvider.GetKeys(startLabel)
startCypherProps := ToCypherProps(matchKeys, "rel.start.")
matchKeys, _ = matchProvider.GetKeys(endLabel)
endCypherProps := ToCypherProps(matchKeys, "rel.end.")
serializedRels := []*SerializedRel{}
for _, rel := range rels {
serializedRels = append(serializedRels, rel.Serialize())
}
query := fmt.Sprintf(`
UNWIND $rels as rel
MATCH (start%s { %s })
MATCH (end%s { %s })
MERGE (start)-[r%s]->(end)
SET r += rel.props
`,
startCypherLabel, startCypherProps,
endCypherLabel, endCypherProps,
cypherType,
)
result, err := tx.Run(ctx,
query,
map[string]any{
"rels": serializedRels,
})
if err != nil {
return nil, err
}
summary, err := result.Consume(ctx)
if err != nil {
return nil, err
}
return &summary, nil
}

1
c2p Executable file
View File

@@ -0,0 +1 @@
code2prompt -c -e go.sum

39
cypher.go Normal file
View File

@@ -0,0 +1,39 @@
package heartwood
import (
"fmt"
"strings"
)
// ========================================
// Cypher Formatting Functions
// ========================================
// ToCypherLabel converts a node label or relationship type into its Cypher
// format.
func ToCypherLabel(label string) string {
return fmt.Sprintf(":`%s`", label)
}
// ToCypherLabels converts a list of node labels into its Cypher format.
func ToCypherLabels(labels []string) string {
var cypherLabels []string
for _, label := range labels {
cypherLabels = append(cypherLabels, ToCypherLabel(label))
}
return strings.Join(cypherLabels, "")
}
func ToCypherProps(keys []string, prefix string) string {
if prefix == "" {
prefix = "$"
}
cypherPropsParts := []string{}
for _, key := range keys {
cypherPropsParts = append(
cypherPropsParts, fmt.Sprintf("%s: %s%s", key, prefix, key))
}
return strings.Join(cypherPropsParts, ", ")
}

43
cypher_test.go Normal file
View File

@@ -0,0 +1,43 @@
package heartwood
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestToCypherLabel(t *testing.T) {
assert.Equal(t, ":`Event`", ToCypherLabel("Event"))
}
func TestToCypherLabels(t *testing.T) {
assert.Equal(t, ":`Event`:`ReplaceableEvent`",
ToCypherLabels([]string{"Event", "ReplaceableEvent"}))
}
func TestToCypherProps(t *testing.T) {
cases := []struct {
name string
keys []string
prefix string
expected string
}{
{
name: "default prefix",
keys: []string{"id", "name"},
prefix: "",
expected: "id: $id, name: $name",
},
{
name: "set prefix",
keys: []string{"id", "name"},
prefix: "entity.",
expected: "id: entity.id, name: entity.name",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, ToCypherProps(tc.keys, tc.prefix))
})
}
}

14
go.mod Normal file
View File

@@ -0,0 +1,14 @@
module git.wisehodl.dev/jay/go-heartwood
go 1.24
require (
github.com/neo4j/neo4j-go-driver/v6 v6.0.0
github.com/stretchr/testify v1.11.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

12
go.sum Normal file
View File

@@ -0,0 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/neo4j/neo4j-go-driver/v6 v6.0.0 h1:xVAi6YLOfzXUx+1Lc/F2dUhpbN76BfKleZbAlnDFRiA=
github.com/neo4j/neo4j-go-driver/v6 v6.0.0/go.mod h1:hzSTfNfM31p1uRSzL1F/BAYOgaiTarE6OAQBajfsm+I=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

352
graph.go Normal file
View File

@@ -0,0 +1,352 @@
// This module defines types and functions for working with Neo4j graph
// entities.
package heartwood
import (
"fmt"
"sort"
"strings"
)
// ========================================
// Types
// ========================================
// Properties represents a map of node or relationship props.
type Properties map[string]any
// ========================================
// Match Key Provider
// ========================================
// MatchKeysProvider defines methods for querying a mapping of node labels and
// the property keys used to match nodes with them.
type MatchKeysProvider interface {
// GetLabels returns the array of node labels in the mapping.
GetLabels() []string
// GetKeys returns the node property keys used to match nodes with the
// given label and a boolean indicating the success of the lookup.
GetKeys(label string) ([]string, bool)
}
// MatchKeys is a simple implementation of the MatchKeysProvider interface.
type MatchKeys struct {
keys map[string][]string
}
func (p *MatchKeys) GetLabels() []string {
labels := []string{}
for l := range p.keys {
labels = append(labels, l)
}
return labels
}
func (p *MatchKeys) GetKeys(label string) ([]string, bool) {
if keys, exists := p.keys[label]; exists {
return keys, exists
} else {
return nil, exists
}
}
// ========================================
// Nodes
// ========================================
// Node represents a Neo4j node entity, encapsulating its labels and
// properties.
type Node struct {
// Set of labels on the node.
Labels Set[string]
// Mapping of properties on the node.
Props Properties
}
// NewNode creates a new node with the given label and properties.
func NewNode(label string, props Properties) *Node {
if props == nil {
props = make(Properties)
}
return &Node{
Labels: NewSet(label),
Props: props,
}
}
// MatchProps returns the node label and the property values to match it in the
// database.
func (n *Node) MatchProps(
matchProvider MatchKeysProvider) (string, Properties, error) {
// Iterate over each label on the node, checking whether each has match
// keys associated with it.
labels := n.Labels.ToArray()
sort.Strings(labels)
for _, label := range labels {
if keys, exists := matchProvider.GetKeys(label); exists {
props := make(Properties)
// Get the property values associated with each match key.
for _, key := range keys {
if value, exists := n.Props[key]; exists {
props[key] = value
} else {
// If any match property values are missing, return an
// error.
return label, nil,
fmt.Errorf(
"missing property %s for label %s", key, label)
}
}
// Return the label and match properties
return label, props, nil
}
}
// If none of the node labels have defined match keys, return an error.
return "", nil, fmt.Errorf("no recognized label found in %v", n.Labels)
}
type SerializedNode = Properties
func (n *Node) Serialize() *SerializedNode {
return &n.Props
}
// ========================================
// Relationships
// ========================================
// Relationship represents a Neo4j relationship between two nodes, including
// its type and properties.
type Relationship struct {
// The relationship type.
Type string
// The start node for the relationship.
Start *Node
// The end node for the relationship.
End *Node
// Mapping of properties on the relationship
Props Properties
}
// NewRelationship creates a new relationship with the given type, start node,
// end node, and properties
func NewRelationship(
rtype string, start *Node, end *Node, props Properties) *Relationship {
if props == nil {
props = make(Properties)
}
return &Relationship{
Type: rtype,
Start: start,
End: end,
Props: props,
}
}
type SerializedRel = map[string]Properties
func (r *Relationship) Serialize() *SerializedRel {
srel := make(map[string]Properties)
srel["props"] = r.Props
srel["start"] = r.Start.Props
srel["end"] = r.End.Props
return &srel
}
// ========================================
// Simple Subgraph
// ========================================
// Subgraph represents a simple collection of nodes and relationships.
type Subgraph struct {
// The nodes in the subgraph.
nodes []*Node
// The relationships in the subgraph.
rels []*Relationship
}
// NewSubgraph creates an empty subgraph.
func NewSubgraph() *Subgraph {
return &Subgraph{
nodes: []*Node{},
rels: []*Relationship{},
}
}
// AddNode adds a node to the subgraph
func (s *Subgraph) AddNode(node *Node) {
s.nodes = append(s.nodes, node)
}
// AddRel adds a relationship to the subgraph.
func (s *Subgraph) AddRel(rel *Relationship) {
s.rels = append(s.rels, rel)
}
// ========================================
// Structured Subgraph
// ========================================
// StructuredSubgraph is a structured collection of nodes and relationships for
// the purpose of conducting batch operations.
type StructuredSubgraph struct {
// A map of grouped nodes, sorted by their label combinations.
nodes map[string][]*Node
// A map of grouped relationships, sorted by their type and related node
// labels.
rels map[string][]*Relationship
// Provides node property keys used to match nodes with given labels in the
// database.
matchProvider MatchKeysProvider
}
// NewStructuredSubgraph creates an empty structured subgraph with the given
// match keys provider.
func NewStructuredSubgraph(matchProvider MatchKeysProvider) *StructuredSubgraph {
return &StructuredSubgraph{
nodes: make(map[string][]*Node),
rels: make(map[string][]*Relationship),
matchProvider: matchProvider,
}
}
// AddNode sorts a node into the subgraph.
func (s *StructuredSubgraph) AddNode(node *Node) error {
// Verify that the node has defined match property values.
matchLabel, _, err := node.MatchProps(s.matchProvider)
if err != nil {
return fmt.Errorf("invalid node: %s", err)
}
// Determine the node's sort key.
sortKey := createNodeSortKey(matchLabel, node.Labels.ToArray())
if _, exists := s.nodes[sortKey]; !exists {
s.nodes[sortKey] = []*Node{}
}
// Add the node to the subgraph.
s.nodes[sortKey] = append(s.nodes[sortKey], node)
return nil
}
// AddRel sorts a relationship into the subgraph.
func (s *StructuredSubgraph) AddRel(rel *Relationship) error {
// Verify that the start node has defined match property values.
startLabel, _, err := rel.Start.MatchProps(s.matchProvider)
if err != nil {
return fmt.Errorf("invalid start node: %s", err)
}
// Verify that the end node has defined match property values.
endLabel, _, err := rel.End.MatchProps(s.matchProvider)
if err != nil {
return fmt.Errorf("invalid end node: %s", err)
}
// Determine the relationship's sort key.
sortKey := createRelSortKey(rel.Type, startLabel, endLabel)
if _, exists := s.rels[sortKey]; !exists {
s.rels[sortKey] = []*Relationship{}
}
// Add the relationship to the subgraph.
s.rels[sortKey] = append(s.rels[sortKey], rel)
return nil
}
// GetNodes returns the nodes grouped under the given sort key.
func (s *StructuredSubgraph) GetNodes(nodeKey string) []*Node {
return s.nodes[nodeKey]
}
// GetRels returns the rels grouped under the given sort key.
func (s *StructuredSubgraph) GetRels(relKey string) []*Relationship {
return s.rels[relKey]
}
// NodeCount returns the number of nodes in the subgraph.
func (s *StructuredSubgraph) NodeCount() int {
count := 0
for l := range s.nodes {
count += len(s.nodes[l])
}
return count
}
// RelCount returns the number of relationships in the subgraph.
func (s *StructuredSubgraph) RelCount() int {
count := 0
for t := range s.rels {
count += len(s.rels[t])
}
return count
}
// NodeKeys returns the list of node sort keys in the subgraph.
func (s *StructuredSubgraph) NodeKeys() []string {
keys := []string{}
for l := range s.nodes {
keys = append(keys, l)
}
return keys
}
// RelKeys returns the list of relationship sort keys in the subgraph.
func (s *StructuredSubgraph) RelKeys() []string {
keys := []string{}
for t := range s.rels {
keys = append(keys, t)
}
return keys
}
// createNodeSortKey returns the serialized node labels for sorting.
func createNodeSortKey(matchLabel string, labels []string) string {
sort.Strings(labels)
serializedLabels := strings.Join(labels, ",")
return fmt.Sprintf("%s:%s", matchLabel, serializedLabels)
}
// createRelSortKey returns the serialized relationship type and start/end node
// labels for sorting.
func createRelSortKey(
rtype string, startLabel string, endLabel string) string {
return strings.Join([]string{rtype, startLabel, endLabel}, ",")
}
// DeserializeNodeKey returns the list of node labels from the serialized sort
// key.
func DeserializeNodeKey(sortKey string) (string, []string, error) {
parts := strings.Split(sortKey, ":")
if len(parts) != 2 {
return "", nil, fmt.Errorf("invalid node sort key: %s", sortKey)
}
matchLabel, serializedLabels := parts[0], parts[1]
labels := strings.Split(serializedLabels, ",")
return matchLabel, labels, nil
}
// DeserializeRelKey returns the relationship type, start node label, and end
// node label from the serialized sort key. Panics if the sort key is invalid.
func DeserializeRelKey(sortKey string) (string, string, string, error) {
parts := strings.Split(sortKey, ",")
if len(parts) != 3 {
return "", "", "", fmt.Errorf("invalid relationship sort key: %s", sortKey)
}
rtype, startLabel, endLabel := parts[0], parts[1], parts[2]
return rtype, startLabel, endLabel, nil
}

172
graph_test.go Normal file
View File

@@ -0,0 +1,172 @@
package heartwood
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestMatchKeys(t *testing.T) {
matchKeys := &MatchKeys{
keys: map[string][]string{
"User": {"pubkey"},
"Event": {"id"},
"Tag": {"name", "value"},
},
}
t.Run("get labels", func(t *testing.T) {
expectedLabels := []string{"Event", "Tag", "User"}
labels := matchKeys.GetLabels()
assert.ElementsMatch(t, expectedLabels, labels)
})
t.Run("get keys", func(t *testing.T) {
expectedKeys := []string{"id"}
keys, exists := matchKeys.GetKeys("Event")
assert.True(t, exists)
assert.ElementsMatch(t, expectedKeys, keys)
})
t.Run("unknown key", func(t *testing.T) {
keys, exists := matchKeys.GetKeys("Unknown")
assert.False(t, exists)
assert.Nil(t, keys)
})
}
func TestNodeSortKey(t *testing.T) {
matchLabel := "Event"
labels := []string{"Event", "AddressableEvent"}
// labels should be sorted by key generator
expectedKey := "Event:AddressableEvent,Event"
// Test Serialization
sortKey := createNodeSortKey(matchLabel, labels)
assert.Equal(t, expectedKey, sortKey)
// Test Deserialization
returnedMatchLabel, returnedLabels, err := DeserializeNodeKey(sortKey)
assert.NoError(t, err)
assert.Equal(t, matchLabel, returnedMatchLabel)
assert.ElementsMatch(t, labels, returnedLabels)
}
func TestRelSortKey(t *testing.T) {
rtype, startLabel, endLabel := "SIGNED", "User", "Event"
expectedKey := "SIGNED,User,Event"
// Test Serialization
sortKey := createRelSortKey(rtype, startLabel, endLabel)
assert.Equal(t, expectedKey, sortKey)
// Test Deserialization
returnedRtype, returnedStartLabel, returnedEndLabel, err := DeserializeRelKey(sortKey)
assert.NoError(t, err)
assert.Equal(t, rtype, returnedRtype)
assert.Equal(t, startLabel, returnedStartLabel)
assert.Equal(t, endLabel, returnedEndLabel)
}
func TestMatchProps(t *testing.T) {
matchKeys := &MatchKeys{
keys: map[string][]string{
"User": {"pubkey"},
"Event": {"id"},
},
}
cases := []struct {
name string
node *Node
wantMatchLabel string
wantMatchProps Properties
wantErr bool
wantErrText string
}{
{
name: "matching label, all props present",
node: NewEventNode("abc123"),
wantMatchLabel: "Event",
wantMatchProps: Properties{"id": "abc123"},
},
{
name: "matching label, required prop missing",
node: NewNode("Event", Properties{}),
wantErr: true,
wantErrText: "missing property",
},
{
name: "no recognized label",
node: NewNode("Tag", Properties{"name": "e", "value": "abc"}),
wantErr: true,
wantErrText: "no recognized label",
},
{
name: "multiple labels, one matches",
node: &Node{
Labels: NewSet("Event", "Unknown"),
Props: Properties{
"id": "abc123",
},
},
wantMatchLabel: "Event",
wantMatchProps: Properties{"id": "abc123"},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
matchLabel, props, err := tc.node.MatchProps(matchKeys)
if tc.wantErr {
assert.Error(t, err)
if tc.wantErrText != "" {
assert.ErrorContains(t, err, tc.wantErrText)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tc.wantMatchLabel, matchLabel)
assert.Equal(t, tc.wantMatchProps, props)
})
}
}
func TestStructuredSubgraphAddNode(t *testing.T) {
matchKeys := NewMatchKeys()
subgraph := NewStructuredSubgraph(matchKeys)
node := NewEventNode("abc123")
err := subgraph.AddNode(node)
assert.NoError(t, err)
assert.Equal(t, 1, subgraph.NodeCount())
assert.Equal(t, []*Node{node}, subgraph.GetNodes("Event:Event"))
}
func TestStructuredSubgraphAddNodeInvalid(t *testing.T) {
matchKeys := NewMatchKeys()
subgraph := NewStructuredSubgraph(matchKeys)
node := NewNode("Event", Properties{})
err := subgraph.AddNode(node)
assert.ErrorContains(t, err, "invalid node: missing property id")
assert.Equal(t, 0, subgraph.NodeCount())
}
func TestStructuredSubgraphAddRel(t *testing.T) {
matchKeys := NewMatchKeys()
subgraph := NewStructuredSubgraph(matchKeys)
userNode := NewUserNode("pubkey1")
eventNode := NewEventNode("abc123")
rel, _ := NewSignedRel(userNode, eventNode, nil)
err := subgraph.AddRel(rel)
assert.NoError(t, err)
assert.Equal(t, 1, subgraph.RelCount())
assert.Equal(t, []*Relationship{rel}, subgraph.GetRels("SIGNED,User,Event"))
}

20
neo4j.go Normal file
View File

@@ -0,0 +1,20 @@
package heartwood
import (
"context"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
// ConnectNeo4j creates a new Neo4j driver and verifies its connectivity.
func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver, error) {
driver, err := neo4j.NewDriver(
uri,
neo4j.BasicAuth(user, password, ""))
err = driver.VerifyConnectivity(ctx)
if err != nil {
return driver, err
}
return driver, nil
}

154
schema.go Normal file
View File

@@ -0,0 +1,154 @@
// This module provides methods for creating nodes and relationships according
// to a defined schema.
package heartwood
import (
"context"
"fmt"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
// ========================================
// Schema Match Keys
// ========================================
func NewMatchKeys() *MatchKeys {
return &MatchKeys{
keys: map[string][]string{
"User": {"pubkey"},
"Relay": {"url"},
"Event": {"id"},
"Tag": {"name", "value"},
},
}
}
// ========================================
// Node Constructors
// ========================================
func NewUserNode(pubkey string) *Node {
return NewNode("User", Properties{"pubkey": pubkey})
}
func NewRelayNode(url string) *Node {
return NewNode("Relay", Properties{"url": url})
}
func NewEventNode(id string) *Node {
return NewNode("Event", Properties{"id": id})
}
func NewTagNode(name string, value string, rest []string) *Node {
return NewNode("Tag", Properties{
"name": name,
"value": value,
"rest": rest})
}
// ========================================
// Relationship Constructors
// ========================================
func NewSignedRel(
start *Node, end *Node, props Properties) (*Relationship, error) {
return NewRelationshipWithValidation(
"SIGNED", "User", "Event", start, end, props)
}
func NewTaggedRel(
start *Node, end *Node, props Properties) (*Relationship, error) {
return NewRelationshipWithValidation(
"TAGGED", "Event", "Tag", start, end, props)
}
func NewReferencesEventRel(
start *Node, end *Node, props Properties) (*Relationship, error) {
return NewRelationshipWithValidation(
"REFERENCES", "Event", "Event", start, end, props)
}
func NewReferencesUserRel(
start *Node, end *Node, props Properties) (*Relationship, error) {
return NewRelationshipWithValidation(
"REFERENCES", "Event", "User", start, end, props)
}
// ========================================
// Relationship Constructor Helpers
// ========================================
func validateNodeLabel(node *Node, role string, expectedLabel string) error {
if !node.Labels.Contains(expectedLabel) {
return fmt.Errorf(
"expected %s node to have label '%s'. got %v",
role, expectedLabel, node.Labels.ToArray(),
)
}
return nil
}
func NewRelationshipWithValidation(
rtype string,
startLabel string,
endLabel string,
start *Node,
end *Node,
props Properties) (*Relationship, error) {
var err error
err = validateNodeLabel(start, "start", startLabel)
if err != nil {
return nil, err
}
err = validateNodeLabel(end, "end", endLabel)
if err != nil {
return nil, err
}
return NewRelationship(rtype, start, end, props), nil
}
// ========================================
// Schema Indexes and Constaints
// ========================================
// SetNeo4jSchema ensures that the necessary indexes and constraints exist in
// the database
func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error {
schemaQueries := []string{
`CREATE CONSTRAINT user_pubkey IF NOT EXISTS
FOR (n:User) REQUIRE n.pubkey IS UNIQUE`,
`CREATE INDEX user_pubkey IF NOT EXISTS
FOR (n:User) ON (n.pubkey)`,
`CREATE INDEX event_id IF NOT EXISTS
FOR (n:Event) ON (n.id)`,
`CREATE INDEX event_kind IF NOT EXISTS
FOR (n:Event) ON (n.kind)`,
`CREATE INDEX tag_name_value IF NOT EXISTS
FOR (n:Tag) ON (n.name, n.value)`,
}
// Create indexes and constraints
for _, query := range schemaQueries {
_, err := neo4j.ExecuteQuery(ctx, driver,
query,
nil,
neo4j.EagerResultTransformer,
neo4j.ExecuteQueryWithDatabase("neo4j"))
if err != nil {
return err
}
}
return nil
}

45
schema_test.go Normal file
View File

@@ -0,0 +1,45 @@
package heartwood
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewRelationshipWithValidation(t *testing.T) {
cases := []struct {
name string
start *Node
end *Node
wantErr bool
wantErrText string
}{
{
name: "valid start and end nodes",
start: NewUserNode("pubkey1"),
end: NewEventNode("abc123"),
},
{
name: "mismatched start node label",
start: NewEventNode("abc123"),
end: NewEventNode("abc123"),
wantErr: true,
wantErrText: "expected start node to have label 'User'",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
rel, err := NewSignedRel(tc.start, tc.end, nil)
if tc.wantErr {
assert.Error(t, err)
assert.ErrorContains(t, err, tc.wantErrText)
assert.Nil(t, rel)
return
}
assert.NoError(t, err)
assert.Equal(t, "SIGNED", rel.Type)
assert.Contains(t, rel.Start.Labels.ToArray(), "User")
assert.Contains(t, rel.End.Labels.ToArray(), "Event")
})
}
}

0
types.go Normal file
View File

50
util.go Normal file
View File

@@ -0,0 +1,50 @@
package heartwood
// Sets
type Set[T comparable] struct {
inner map[T]struct{}
}
func NewSet[T comparable](items ...T) Set[T] {
set := Set[T]{
inner: make(map[T]struct{}),
}
for _, i := range items {
set.Add(i)
}
return set
}
func (s Set[T]) Add(item T) {
s.inner[item] = struct{}{}
}
func (s Set[T]) Remove(item T) {
delete(s.inner, item)
}
func (s Set[T]) Contains(item T) bool {
_, exists := s.inner[item]
return exists
}
func (s Set[T]) ToArray() []T {
array := []T{}
for i := range s.inner {
array = append(array, i)
}
return array
}
// Operations
func Flatten[K comparable, V comparable](mapping map[K][]V) []V {
var values []V
for _, array := range mapping {
for _, v := range array {
values = append(values, v)
}
}
return values
}