Variety of refactors and optimizations.
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -45,15 +44,11 @@ func (s *BatchSubgraph) AddNode(node *Node) error {
|
||||
// Verify that the node has defined match property values.
|
||||
matchLabel, _, err := node.MatchProps(s.matchProvider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid node: %s", err)
|
||||
return fmt.Errorf("invalid node: %w", err)
|
||||
}
|
||||
|
||||
// Determine the node's batch key.
|
||||
batchKey := createNodeBatchKey(matchLabel, node.Labels.ToArray())
|
||||
|
||||
if _, exists := s.nodes[batchKey]; !exists {
|
||||
s.nodes[batchKey] = []*Node{}
|
||||
}
|
||||
batchKey := createNodeBatchKey(matchLabel, node.Labels.AsSortedArray())
|
||||
|
||||
// Add the node to the sub
|
||||
s.nodes[batchKey] = append(s.nodes[batchKey], node)
|
||||
@@ -66,22 +61,18 @@ func (s *BatchSubgraph) AddRel(rel *Relationship) error {
|
||||
// Verify that the start node has defined match property values.
|
||||
startLabel, _, err := rel.Start.MatchProps(s.matchProvider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid start node: %s", err)
|
||||
return fmt.Errorf("invalid start node: %w", err)
|
||||
}
|
||||
|
||||
// Verify that the end node has defined match property values.
|
||||
endLabel, _, err := rel.End.MatchProps(s.matchProvider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid end node: %s", err)
|
||||
return fmt.Errorf("invalid end node: %w", err)
|
||||
}
|
||||
|
||||
// Determine the relationship's batch key.
|
||||
batchKey := createRelBatchKey(rel.Type, startLabel, endLabel)
|
||||
|
||||
if _, exists := s.rels[batchKey]; !exists {
|
||||
s.rels[batchKey] = []*Relationship{}
|
||||
}
|
||||
|
||||
// Add the relationship to the sub
|
||||
s.rels[batchKey] = append(s.rels[batchKey], rel)
|
||||
|
||||
@@ -105,7 +96,7 @@ func (s *BatchSubgraph) RelCount() int {
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) nodeKeys() []string {
|
||||
keys := []string{}
|
||||
keys := make([]string, 0, len(s.nodes))
|
||||
for l := range s.nodes {
|
||||
keys = append(keys, l)
|
||||
}
|
||||
@@ -113,7 +104,7 @@ func (s *BatchSubgraph) nodeKeys() []string {
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) relKeys() []string {
|
||||
keys := []string{}
|
||||
keys := make([]string, 0, len(s.rels))
|
||||
for t := range s.rels {
|
||||
keys = append(keys, t)
|
||||
}
|
||||
@@ -121,7 +112,7 @@ func (s *BatchSubgraph) relKeys() []string {
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
|
||||
batches := []NodeBatch{}
|
||||
batches := make([]NodeBatch, 0, len(s.nodeKeys()))
|
||||
|
||||
for _, nodeKey := range s.nodeKeys() {
|
||||
matchLabel, labels, err := deserializeNodeBatchKey(nodeKey)
|
||||
@@ -146,7 +137,7 @@ func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
|
||||
}
|
||||
|
||||
func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
|
||||
batches := []RelBatch{}
|
||||
batches := make([]RelBatch, 0, len(s.relKeys()))
|
||||
|
||||
for _, relKey := range s.relKeys() {
|
||||
rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey)
|
||||
@@ -181,9 +172,8 @@ func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
|
||||
|
||||
// Helpers
|
||||
|
||||
func createNodeBatchKey(matchLabel string, labels []string) string {
|
||||
sort.Strings(labels)
|
||||
serializedLabels := strings.Join(labels, ",")
|
||||
func createNodeBatchKey(matchLabel string, sortedLabels []string) string {
|
||||
serializedLabels := strings.Join(sortedLabels, ",")
|
||||
return fmt.Sprintf("%s:%s", matchLabel, serializedLabels)
|
||||
}
|
||||
|
||||
@@ -245,7 +235,7 @@ func MergeSubgraph(
|
||||
)
|
||||
}
|
||||
if nodeResultSummary != nil {
|
||||
resultSummaries = append(resultSummaries, *nodeResultSummary)
|
||||
resultSummaries = append(resultSummaries, nodeResultSummary)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,7 +251,7 @@ func MergeSubgraph(
|
||||
)
|
||||
}
|
||||
if relResultSummary != nil {
|
||||
resultSummaries = append(resultSummaries, *relResultSummary)
|
||||
resultSummaries = append(resultSummaries, relResultSummary)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,11 +274,11 @@ func MergeNodes(
|
||||
ctx context.Context,
|
||||
tx neo4j.ManagedTransaction,
|
||||
batch NodeBatch,
|
||||
) (*neo4j.ResultSummary, error) {
|
||||
) (neo4j.ResultSummary, error) {
|
||||
cypherLabels := ToCypherLabels(batch.Labels)
|
||||
cypherProps := ToCypherProps(batch.MatchKeys, "node.")
|
||||
|
||||
serializedNodes := []*SerializedNode{}
|
||||
serializedNodes := make([]*SerializedNode, 0, len(batch.Nodes))
|
||||
for _, node := range batch.Nodes {
|
||||
serializedNodes = append(serializedNodes, node.Serialize())
|
||||
}
|
||||
@@ -316,21 +306,21 @@ func MergeNodes(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &summary, nil
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func MergeRels(
|
||||
ctx context.Context,
|
||||
tx neo4j.ManagedTransaction,
|
||||
batch RelBatch,
|
||||
) (*neo4j.ResultSummary, error) {
|
||||
) (neo4j.ResultSummary, error) {
|
||||
cypherType := ToCypherLabel(batch.Type)
|
||||
startCypherLabel := ToCypherLabel(batch.StartLabel)
|
||||
endCypherLabel := ToCypherLabel(batch.EndLabel)
|
||||
startCypherProps := ToCypherProps(batch.StartMatchKeys, "rel.start.")
|
||||
endCypherProps := ToCypherProps(batch.EndMatchKeys, "rel.end.")
|
||||
|
||||
serializedRels := []*SerializedRel{}
|
||||
serializedRels := make([]*SerializedRel, 0, len(batch.Rels))
|
||||
for _, rel := range batch.Rels {
|
||||
serializedRels = append(serializedRels, rel.Serialize())
|
||||
}
|
||||
@@ -363,5 +353,5 @@ func MergeRels(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &summary, nil
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
@@ -7,20 +7,19 @@ import (
|
||||
|
||||
func TestNodeBatchKey(t *testing.T) {
|
||||
matchLabel := "Event"
|
||||
labels := []string{"Event", "AddressableEvent"}
|
||||
|
||||
// labels should be batched by key generator
|
||||
sortedLabels := []string{"AddressableEvent", "Event"}
|
||||
expectedKey := "Event:AddressableEvent,Event"
|
||||
|
||||
// Test Serialization
|
||||
batchKey := createNodeBatchKey(matchLabel, labels)
|
||||
// labels are expected to be pre-sorted
|
||||
batchKey := createNodeBatchKey(matchLabel, sortedLabels)
|
||||
assert.Equal(t, expectedKey, batchKey)
|
||||
|
||||
// Test Deserialization
|
||||
returnedMatchLabel, returnedLabels, err := deserializeNodeBatchKey(batchKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, matchLabel, returnedMatchLabel)
|
||||
assert.ElementsMatch(t, labels, returnedLabels)
|
||||
assert.ElementsMatch(t, sortedLabels, returnedLabels)
|
||||
}
|
||||
|
||||
func TestRelBatchKey(t *testing.T) {
|
||||
|
||||
40
boltdb.go
40
boltdb.go
@@ -4,32 +4,11 @@ import (
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
// Interface
|
||||
const BucketName string = "events"
|
||||
|
||||
type BoltDB interface {
|
||||
Setup() error
|
||||
BatchCheckEventsExist(eventIDs []string) map[string]bool
|
||||
BatchWriteEvents(events []EventBlob) error
|
||||
}
|
||||
|
||||
func NewKVDB(boltdb *bolt.DB) BoltDB {
|
||||
return &boltDB{db: boltdb}
|
||||
}
|
||||
|
||||
type boltDB struct {
|
||||
db *bolt.DB
|
||||
}
|
||||
|
||||
func (b *boltDB) Setup() error {
|
||||
return SetupBoltDB(b.db)
|
||||
}
|
||||
|
||||
func (b *boltDB) BatchCheckEventsExist(eventIDs []string) map[string]bool {
|
||||
return BatchCheckEventsExist(b.db, eventIDs)
|
||||
}
|
||||
|
||||
func (b *boltDB) BatchWriteEvents(events []EventBlob) error {
|
||||
return BatchWriteEvents(b.db, events)
|
||||
type EventBlob struct {
|
||||
ID string
|
||||
JSON string
|
||||
}
|
||||
|
||||
func SetupBoltDB(boltdb *bolt.DB) error {
|
||||
@@ -39,19 +18,10 @@ func SetupBoltDB(boltdb *bolt.DB) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Functions
|
||||
|
||||
const BucketName string = "events"
|
||||
|
||||
type EventBlob struct {
|
||||
ID string
|
||||
JSON string
|
||||
}
|
||||
|
||||
func BatchCheckEventsExist(boltdb *bolt.DB, eventIDs []string) map[string]bool {
|
||||
existsMap := make(map[string]bool)
|
||||
|
||||
boltdb.View(func(tx *bolt.Tx) error {
|
||||
_ = boltdb.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(BucketName))
|
||||
if bucket == nil {
|
||||
return nil
|
||||
|
||||
@@ -30,7 +30,8 @@ func ToCypherProps(keys []string, prefix string) string {
|
||||
if prefix == "" {
|
||||
prefix = "$"
|
||||
}
|
||||
cypherPropsParts := []string{}
|
||||
|
||||
var cypherPropsParts []string
|
||||
for _, key := range keys {
|
||||
cypherPropsParts = append(
|
||||
cypherPropsParts, fmt.Sprintf("%s: %s%s", key, prefix, key))
|
||||
|
||||
82
expanders.go
82
expanders.go
@@ -1,82 +0,0 @@
|
||||
package heartwood
|
||||
|
||||
import (
|
||||
roots "git.wisehodl.dev/jay/go-roots/events"
|
||||
)
|
||||
|
||||
type Expander func(e roots.Event, s *EventSubgraph)
|
||||
type ExpanderPipeline []Expander
|
||||
|
||||
func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline {
|
||||
return ExpanderPipeline(expanders)
|
||||
}
|
||||
|
||||
func DefaultExpanders() []Expander {
|
||||
return []Expander{
|
||||
ExpandTaggedEvents,
|
||||
ExpandTaggedUsers,
|
||||
}
|
||||
}
|
||||
|
||||
// Default Expander Functions
|
||||
|
||||
func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) {
|
||||
tagNodes := s.NodesByLabel("Tag")
|
||||
for _, tag := range e.Tags {
|
||||
if !isValidTag(tag) {
|
||||
continue
|
||||
}
|
||||
name := tag[0]
|
||||
value := tag[1]
|
||||
|
||||
if name != "e" || !roots.Hex64Pattern.MatchString(value) {
|
||||
continue
|
||||
}
|
||||
|
||||
tagNode := findTagNode(tagNodes, name, value)
|
||||
if tagNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
referencedEvent := NewEventNode(value)
|
||||
|
||||
s.AddNode(referencedEvent)
|
||||
s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil))
|
||||
}
|
||||
}
|
||||
|
||||
func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) {
|
||||
tagNodes := s.NodesByLabel("Tag")
|
||||
for _, tag := range e.Tags {
|
||||
if !isValidTag(tag) {
|
||||
continue
|
||||
}
|
||||
name := tag[0]
|
||||
value := tag[1]
|
||||
|
||||
if name != "p" || !roots.Hex64Pattern.MatchString(value) {
|
||||
continue
|
||||
}
|
||||
|
||||
tagNode := findTagNode(tagNodes, name, value)
|
||||
if tagNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
referencedEvent := NewUserNode(value)
|
||||
|
||||
s.AddNode(referencedEvent)
|
||||
s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func findTagNode(nodes []*Node, name, value string) *Node {
|
||||
for _, node := range nodes {
|
||||
if node.Props["name"] == name && node.Props["value"] == value {
|
||||
return node
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -252,7 +252,7 @@ func UnmarshalGraphJSON(data []byte, f *GraphFilter) error {
|
||||
}
|
||||
|
||||
func marshalGraphArray(filters []GraphFilter) ([]json.RawMessage, error) {
|
||||
result := []json.RawMessage{}
|
||||
result := make([]json.RawMessage, 0, len(filters))
|
||||
for _, f := range filters {
|
||||
b, err := MarshalGraphJSON(f)
|
||||
if err != nil {
|
||||
@@ -268,7 +268,7 @@ func unmarshalGraphArray(raws json.RawMessage) ([]GraphFilter, error) {
|
||||
if err := json.Unmarshal(raws, &rawArray); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []GraphFilter
|
||||
result := make([]GraphFilter, 0, len(rawArray))
|
||||
for _, raw := range rawArray {
|
||||
var f GraphFilter
|
||||
if err := UnmarshalGraphJSON(raw, &f); err != nil {
|
||||
|
||||
13
graph.go
13
graph.go
@@ -2,7 +2,6 @@ package heartwood
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
@@ -33,7 +32,7 @@ type SimpleMatchKeys struct {
|
||||
}
|
||||
|
||||
func (p *SimpleMatchKeys) GetLabels() []string {
|
||||
labels := []string{}
|
||||
labels := make([]string, 0, len(p.Keys))
|
||||
for l := range p.Keys {
|
||||
labels = append(labels, l)
|
||||
}
|
||||
@@ -43,9 +42,8 @@ func (p *SimpleMatchKeys) GetLabels() []string {
|
||||
func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) {
|
||||
if keys, exists := p.Keys[label]; exists {
|
||||
return keys, exists
|
||||
} else {
|
||||
return nil, exists
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ========================================
|
||||
@@ -56,7 +54,7 @@ func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) {
|
||||
// properties.
|
||||
type Node struct {
|
||||
// Set of labels on the node.
|
||||
Labels Set[string]
|
||||
Labels *StringSet
|
||||
// Mapping of properties on the node.
|
||||
Props Properties
|
||||
}
|
||||
@@ -67,7 +65,7 @@ func NewNode(label string, props Properties) *Node {
|
||||
props = make(Properties)
|
||||
}
|
||||
return &Node{
|
||||
Labels: NewSet(label),
|
||||
Labels: NewStringSet(label),
|
||||
Props: props,
|
||||
}
|
||||
}
|
||||
@@ -79,8 +77,7 @@ func (n *Node) MatchProps(
|
||||
|
||||
// Iterate over each label on the node, checking whether each has match
|
||||
// keys associated with it.
|
||||
labels := n.Labels.ToArray()
|
||||
sort.Strings(labels)
|
||||
labels := n.Labels.AsSortedArray()
|
||||
for _, label := range labels {
|
||||
if keys, exists := matchProvider.GetKeys(label); exists {
|
||||
props := make(Properties)
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestMatchProps(t *testing.T) {
|
||||
{
|
||||
name: "multiple labels, one matches",
|
||||
node: &Node{
|
||||
Labels: NewSet("Event", "Unknown"),
|
||||
Labels: NewStringSet("Event", "Unknown"),
|
||||
Props: Properties{
|
||||
"id": "abc123",
|
||||
},
|
||||
|
||||
56
neo4j.go
56
neo4j.go
@@ -5,26 +5,6 @@ import (
|
||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||
)
|
||||
|
||||
// Interface
|
||||
|
||||
type GraphDB interface {
|
||||
MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error)
|
||||
}
|
||||
|
||||
func NewGraphDriver(driver neo4j.Driver) GraphDB {
|
||||
return &graphdb{driver: driver}
|
||||
}
|
||||
|
||||
type graphdb struct {
|
||||
driver neo4j.Driver
|
||||
}
|
||||
|
||||
func (n *graphdb) MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error) {
|
||||
return MergeSubgraph(ctx, n.driver, subgraph)
|
||||
}
|
||||
|
||||
// Functions
|
||||
|
||||
func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver, error) {
|
||||
driver, err := neo4j.NewDriver(
|
||||
uri,
|
||||
@@ -40,3 +20,39 @@ func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver
|
||||
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
// SetNeo4jSchema ensures that the necessary indexes and constraints exist in
|
||||
// the database
|
||||
func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error {
|
||||
schemaQueries := []string{
|
||||
`CREATE CONSTRAINT user_pubkey IF NOT EXISTS
|
||||
FOR (n:User) REQUIRE n.pubkey IS UNIQUE`,
|
||||
|
||||
`CREATE INDEX user_pubkey IF NOT EXISTS
|
||||
FOR (n:User) ON (n.pubkey)`,
|
||||
|
||||
`CREATE INDEX event_id IF NOT EXISTS
|
||||
FOR (n:Event) ON (n.id)`,
|
||||
|
||||
`CREATE INDEX event_kind IF NOT EXISTS
|
||||
FOR (n:Event) ON (n.kind)`,
|
||||
|
||||
`CREATE INDEX tag_name_value IF NOT EXISTS
|
||||
FOR (n:Tag) ON (n.name, n.value)`,
|
||||
}
|
||||
|
||||
// Create indexes and constraints
|
||||
for _, query := range schemaQueries {
|
||||
_, err := neo4j.ExecuteQuery(ctx, driver,
|
||||
query,
|
||||
nil,
|
||||
neo4j.EagerResultTransformer,
|
||||
neo4j.ExecuteQueryWithDatabase("neo4j"))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
44
schema.go
44
schema.go
@@ -1,9 +1,7 @@
|
||||
package heartwood
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
@@ -80,7 +78,7 @@ func validateNodeLabel(node *Node, role string, expectedLabel string) {
|
||||
if !node.Labels.Contains(expectedLabel) {
|
||||
panic(fmt.Errorf(
|
||||
"expected %s node to have label %q. got %v",
|
||||
role, expectedLabel, node.Labels.ToArray(),
|
||||
role, expectedLabel, node.Labels.AsSortedArray(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -98,43 +96,3 @@ func NewRelationshipWithValidation(
|
||||
|
||||
return NewRelationship(rtype, start, end, props)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Schema Indexes and Constraints
|
||||
// ========================================
|
||||
|
||||
// SetNeo4jSchema ensures that the necessary indexes and constraints exist in
|
||||
// the database
|
||||
func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error {
|
||||
schemaQueries := []string{
|
||||
`CREATE CONSTRAINT user_pubkey IF NOT EXISTS
|
||||
FOR (n:User) REQUIRE n.pubkey IS UNIQUE`,
|
||||
|
||||
`CREATE INDEX user_pubkey IF NOT EXISTS
|
||||
FOR (n:User) ON (n.pubkey)`,
|
||||
|
||||
`CREATE INDEX event_id IF NOT EXISTS
|
||||
FOR (n:Event) ON (n.id)`,
|
||||
|
||||
`CREATE INDEX event_kind IF NOT EXISTS
|
||||
FOR (n:Event) ON (n.kind)`,
|
||||
|
||||
`CREATE INDEX tag_name_value IF NOT EXISTS
|
||||
FOR (n:Tag) ON (n.name, n.value)`,
|
||||
}
|
||||
|
||||
// Create indexes and constraints
|
||||
for _, query := range schemaQueries {
|
||||
_, err := neo4j.ExecuteQuery(ctx, driver,
|
||||
query,
|
||||
nil,
|
||||
neo4j.EagerResultTransformer,
|
||||
neo4j.ExecuteQueryWithDatabase("neo4j"))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,8 +37,8 @@ func TestNewRelationshipWithValidation(t *testing.T) {
|
||||
}
|
||||
rel := NewSignedRel(tc.start, tc.end, nil)
|
||||
assert.Equal(t, "SIGNED", rel.Type)
|
||||
assert.Contains(t, rel.Start.Labels.ToArray(), "User")
|
||||
assert.Contains(t, rel.End.Labels.ToArray(), "Event")
|
||||
assert.Contains(t, rel.Start.Labels.AsSortedArray(), "User")
|
||||
assert.Contains(t, rel.End.Labels.AsSortedArray(), "Event")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
52
set.go
52
set.go
@@ -1,14 +1,20 @@
|
||||
package heartwood
|
||||
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Sets
|
||||
|
||||
type Set[T comparable] struct {
|
||||
inner map[T]struct{}
|
||||
type StringSet struct {
|
||||
inner map[string]struct{}
|
||||
sorted []string
|
||||
}
|
||||
|
||||
func NewSet[T comparable](items ...T) Set[T] {
|
||||
set := Set[T]{
|
||||
inner: make(map[T]struct{}),
|
||||
func NewStringSet(items ...string) *StringSet {
|
||||
set := &StringSet{
|
||||
inner: make(map[string]struct{}),
|
||||
sorted: []string{},
|
||||
}
|
||||
for _, i := range items {
|
||||
set.Add(i)
|
||||
@@ -16,20 +22,26 @@ func NewSet[T comparable](items ...T) Set[T] {
|
||||
return set
|
||||
}
|
||||
|
||||
func (s Set[T]) Add(item T) {
|
||||
s.inner[item] = struct{}{}
|
||||
func (s *StringSet) Add(item string) {
|
||||
if _, exists := s.inner[item]; !exists {
|
||||
s.inner[item] = struct{}{}
|
||||
s.rebuildSorted()
|
||||
}
|
||||
}
|
||||
|
||||
func (s Set[T]) Remove(item T) {
|
||||
delete(s.inner, item)
|
||||
func (s *StringSet) Remove(item string) {
|
||||
if _, exists := s.inner[item]; exists {
|
||||
delete(s.inner, item)
|
||||
s.rebuildSorted()
|
||||
}
|
||||
}
|
||||
|
||||
func (s Set[T]) Contains(item T) bool {
|
||||
func (s *StringSet) Contains(item string) bool {
|
||||
_, exists := s.inner[item]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (s Set[T]) Equal(other Set[T]) bool {
|
||||
func (s *StringSet) Equal(other StringSet) bool {
|
||||
if len(s.inner) != len(other.inner) {
|
||||
return false
|
||||
}
|
||||
@@ -41,14 +53,18 @@ func (s Set[T]) Equal(other Set[T]) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s Set[T]) Length() int {
|
||||
func (s *StringSet) Length() int {
|
||||
return len(s.inner)
|
||||
}
|
||||
|
||||
func (s Set[T]) ToArray() []T {
|
||||
array := []T{}
|
||||
for i := range s.inner {
|
||||
array = append(array, i)
|
||||
}
|
||||
return array
|
||||
func (s *StringSet) AsSortedArray() []string {
|
||||
return s.sorted
|
||||
}
|
||||
|
||||
func (s *StringSet) rebuildSorted() {
|
||||
s.sorted = make([]string, 0, len(s.inner))
|
||||
for item := range s.inner {
|
||||
s.sorted = append(s.sorted, item)
|
||||
}
|
||||
sort.Strings(s.sorted)
|
||||
}
|
||||
|
||||
87
subgraph.go
87
subgraph.go
@@ -4,7 +4,7 @@ import (
|
||||
roots "git.wisehodl.dev/jay/go-roots/events"
|
||||
)
|
||||
|
||||
// Event subgraph struct
|
||||
// Types
|
||||
|
||||
type EventSubgraph struct {
|
||||
nodes []*Node
|
||||
@@ -35,7 +35,7 @@ func (s *EventSubgraph) Rels() []*Relationship {
|
||||
}
|
||||
|
||||
func (s *EventSubgraph) NodesByLabel(label string) []*Node {
|
||||
nodes := []*Node{}
|
||||
var nodes []*Node
|
||||
for _, node := range s.nodes {
|
||||
if node.Labels.Contains(label) {
|
||||
nodes = append(nodes, node)
|
||||
@@ -58,7 +58,7 @@ func isValidTag(t roots.Tag) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Event to subgraph conversion
|
||||
// Event to subgraph pipeline
|
||||
|
||||
func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph {
|
||||
s := NewEventSubgraph()
|
||||
@@ -89,6 +89,8 @@ func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph {
|
||||
return s
|
||||
}
|
||||
|
||||
// Core pipeline functions
|
||||
|
||||
func newEventNode(eventID string, createdAt int, kind int, content string) *Node {
|
||||
eventNode := NewEventNode(eventID)
|
||||
eventNode.Props["created_at"] = createdAt
|
||||
@@ -106,7 +108,7 @@ func newSignedRel(user, event *Node) *Relationship {
|
||||
}
|
||||
|
||||
func newTagNodes(tags []roots.Tag) []*Node {
|
||||
nodes := []*Node{}
|
||||
nodes := make([]*Node, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
if !isValidTag(tag) {
|
||||
continue
|
||||
@@ -117,9 +119,84 @@ func newTagNodes(tags []roots.Tag) []*Node {
|
||||
}
|
||||
|
||||
func newTagRels(event *Node, tags []*Node) []*Relationship {
|
||||
rels := []*Relationship{}
|
||||
rels := make([]*Relationship, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
rels = append(rels, NewTaggedRel(event, tag, nil))
|
||||
}
|
||||
return rels
|
||||
}
|
||||
|
||||
// Expander Pipeline
|
||||
|
||||
type Expander func(e roots.Event, s *EventSubgraph)
|
||||
type ExpanderPipeline []Expander
|
||||
|
||||
func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline {
|
||||
return ExpanderPipeline(expanders)
|
||||
}
|
||||
|
||||
func DefaultExpanders() []Expander {
|
||||
return []Expander{
|
||||
ExpandTaggedEvents,
|
||||
ExpandTaggedUsers,
|
||||
}
|
||||
}
|
||||
|
||||
func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) {
|
||||
tagNodes := s.NodesByLabel("Tag")
|
||||
for _, tag := range e.Tags {
|
||||
if !isValidTag(tag) {
|
||||
continue
|
||||
}
|
||||
name := tag[0]
|
||||
value := tag[1]
|
||||
|
||||
if name != "e" || !roots.Hex64Pattern.MatchString(value) {
|
||||
continue
|
||||
}
|
||||
|
||||
tagNode := findTagNode(tagNodes, name, value)
|
||||
if tagNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
referencedEvent := NewEventNode(value)
|
||||
|
||||
s.AddNode(referencedEvent)
|
||||
s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil))
|
||||
}
|
||||
}
|
||||
|
||||
func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) {
|
||||
tagNodes := s.NodesByLabel("Tag")
|
||||
for _, tag := range e.Tags {
|
||||
if !isValidTag(tag) {
|
||||
continue
|
||||
}
|
||||
name := tag[0]
|
||||
value := tag[1]
|
||||
|
||||
if name != "p" || !roots.Hex64Pattern.MatchString(value) {
|
||||
continue
|
||||
}
|
||||
|
||||
tagNode := findTagNode(tagNodes, name, value)
|
||||
if tagNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
referencedEvent := NewUserNode(value)
|
||||
|
||||
s.AddNode(referencedEvent)
|
||||
s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil))
|
||||
}
|
||||
}
|
||||
|
||||
func findTagNode(nodes []*Node, name, value string) *Node {
|
||||
for _, node := range nodes {
|
||||
if node.Props["name"] == name && node.Props["value"] == value {
|
||||
return node
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ func nodesEqual(expected, got *Node) error {
|
||||
}
|
||||
|
||||
// Compare label values
|
||||
for _, label := range expected.Labels.ToArray() {
|
||||
for _, label := range expected.Labels.AsSortedArray() {
|
||||
if !got.Labels.Contains(label) {
|
||||
return fmt.Errorf("missing label %q", label)
|
||||
}
|
||||
|
||||
192
write.go
192
write.go
@@ -5,14 +5,15 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
roots "git.wisehodl.dev/jay/go-roots/events"
|
||||
"github.com/boltdb/bolt"
|
||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WriteOptions struct {
|
||||
Expanders ExpanderPipeline
|
||||
KVReadBatchSize int
|
||||
Expanders ExpanderPipeline
|
||||
BoltReadBatchSize int
|
||||
}
|
||||
|
||||
type EventFollower struct {
|
||||
@@ -39,7 +40,7 @@ type WriteReport struct {
|
||||
|
||||
func WriteEvents(
|
||||
events []string,
|
||||
graphdb GraphDB, boltdb BoltDB,
|
||||
driver neo4j.Driver, boltdb *bolt.DB,
|
||||
opts *WriteOptions,
|
||||
) (WriteReport, error) {
|
||||
start := time.Now()
|
||||
@@ -50,7 +51,7 @@ func WriteEvents(
|
||||
|
||||
setDefaultWriteOptions(opts)
|
||||
|
||||
err := boltdb.Setup()
|
||||
err := SetupBoltDB(boltdb)
|
||||
if err != nil {
|
||||
return WriteReport{}, fmt.Errorf("error setting up bolt db: %w", err)
|
||||
}
|
||||
@@ -58,80 +59,55 @@ func WriteEvents(
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Create Event Followers
|
||||
jsonChan := make(chan string, 10)
|
||||
eventChan := make(chan EventFollower, 10)
|
||||
jsonChan := make(chan string)
|
||||
eventChan := make(chan EventFollower)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
createEventFollowers(jsonChan, eventChan)
|
||||
}()
|
||||
go createEventFollowers(&wg, jsonChan, eventChan)
|
||||
|
||||
// Parse Event JSON
|
||||
parsedChan := make(chan EventFollower, 10)
|
||||
invalidChan := make(chan EventFollower, 10)
|
||||
parsedChan := make(chan EventFollower)
|
||||
invalidChan := make(chan EventFollower)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
parseEventJSON(eventChan, parsedChan, invalidChan)
|
||||
}()
|
||||
go parseEventJSON(&wg, eventChan, parsedChan, invalidChan)
|
||||
|
||||
// Collect Invalid Events
|
||||
collectedInvalidChan := make(chan []EventFollower)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
collectEvents(invalidChan, collectedInvalidChan)
|
||||
}()
|
||||
go collectEvents(&wg, invalidChan, collectedInvalidChan)
|
||||
|
||||
// Enforce Policy Rules
|
||||
queuedChan := make(chan EventFollower, 10)
|
||||
skippedChan := make(chan EventFollower, 10)
|
||||
queuedChan := make(chan EventFollower)
|
||||
skippedChan := make(chan EventFollower)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
enforcePolicyRules(
|
||||
graphdb, boltdb,
|
||||
opts.KVReadBatchSize,
|
||||
parsedChan, queuedChan, skippedChan)
|
||||
}()
|
||||
go enforcePolicyRules(&wg, driver, boltdb, opts.BoltReadBatchSize,
|
||||
parsedChan, queuedChan, skippedChan)
|
||||
|
||||
// Collect Skipped Events
|
||||
collectedSkippedChan := make(chan []EventFollower)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
collectEvents(skippedChan, collectedSkippedChan)
|
||||
}()
|
||||
go collectEvents(&wg, skippedChan, collectedSkippedChan)
|
||||
|
||||
// Convert Events To Subgraphs
|
||||
convertedChan := make(chan EventFollower, 10)
|
||||
convertedChan := make(chan EventFollower)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
convertEventsToSubgraphs(opts.Expanders, queuedChan, convertedChan)
|
||||
}()
|
||||
go convertEventsToSubgraphs(&wg, opts.Expanders, queuedChan, convertedChan)
|
||||
|
||||
// Write Events To Databases
|
||||
writeResultChan := make(chan WriteResult)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
writeEventsToDatabases(
|
||||
graphdb, boltdb,
|
||||
convertedChan, writeResultChan)
|
||||
}()
|
||||
go writeEventsToDatabases(&wg, driver, boltdb, convertedChan, writeResultChan)
|
||||
|
||||
// Send event jsons into pipeline
|
||||
go func() {
|
||||
for _, json := range events {
|
||||
jsonChan <- json
|
||||
for _, raw := range events {
|
||||
jsonChan <- raw
|
||||
}
|
||||
close(jsonChan)
|
||||
}()
|
||||
@@ -158,19 +134,21 @@ func setDefaultWriteOptions(opts *WriteOptions) {
|
||||
if opts.Expanders == nil {
|
||||
opts.Expanders = NewExpanderPipeline(DefaultExpanders()...)
|
||||
}
|
||||
if opts.KVReadBatchSize == 0 {
|
||||
opts.KVReadBatchSize = 100
|
||||
if opts.BoltReadBatchSize == 0 {
|
||||
opts.BoltReadBatchSize = 100
|
||||
}
|
||||
}
|
||||
|
||||
func createEventFollowers(jsonChan chan string, eventChan chan EventFollower) {
|
||||
func createEventFollowers(wg *sync.WaitGroup, jsonChan chan string, eventChan chan EventFollower) {
|
||||
defer wg.Done()
|
||||
for json := range jsonChan {
|
||||
eventChan <- EventFollower{JSON: json}
|
||||
}
|
||||
close(eventChan)
|
||||
}
|
||||
|
||||
func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) {
|
||||
func parseEventJSON(wg *sync.WaitGroup, inChan, parsedChan, invalidChan chan EventFollower) {
|
||||
defer wg.Done()
|
||||
for follower := range inChan {
|
||||
var event roots.Event
|
||||
jsonBytes := []byte(follower.JSON)
|
||||
@@ -191,11 +169,13 @@ func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) {
|
||||
}
|
||||
|
||||
func enforcePolicyRules(
|
||||
graphdb GraphDB, boltdb BoltDB,
|
||||
wg *sync.WaitGroup,
|
||||
driver neo4j.Driver, boltdb *bolt.DB,
|
||||
batchSize int,
|
||||
inChan, queuedChan, skippedChan chan EventFollower,
|
||||
) {
|
||||
batch := []EventFollower{}
|
||||
defer wg.Done()
|
||||
var batch []EventFollower
|
||||
|
||||
for follower := range inChan {
|
||||
batch = append(batch, follower)
|
||||
@@ -215,17 +195,17 @@ func enforcePolicyRules(
|
||||
}
|
||||
|
||||
func processPolicyRulesBatch(
|
||||
boltdb BoltDB,
|
||||
boltdb *bolt.DB,
|
||||
batch []EventFollower,
|
||||
queuedChan, skippedChan chan EventFollower,
|
||||
) {
|
||||
eventIDs := []string{}
|
||||
eventIDs := make([]string, 0, len(batch))
|
||||
|
||||
for _, follower := range batch {
|
||||
eventIDs = append(eventIDs, follower.ID)
|
||||
}
|
||||
|
||||
existsMap := boltdb.BatchCheckEventsExist(eventIDs)
|
||||
existsMap := BatchCheckEventsExist(boltdb, eventIDs)
|
||||
|
||||
for _, follower := range batch {
|
||||
if existsMap[follower.ID] {
|
||||
@@ -237,9 +217,10 @@ func processPolicyRulesBatch(
|
||||
}
|
||||
|
||||
func convertEventsToSubgraphs(
|
||||
expanders ExpanderPipeline,
|
||||
wg *sync.WaitGroup, expanders ExpanderPipeline,
|
||||
inChan, convertedChan chan EventFollower,
|
||||
) {
|
||||
defer wg.Done()
|
||||
for follower := range inChan {
|
||||
subgraph := EventToSubgraph(follower.Event, expanders)
|
||||
follower.Subgraph = subgraph
|
||||
@@ -249,93 +230,66 @@ func convertEventsToSubgraphs(
|
||||
}
|
||||
|
||||
func writeEventsToDatabases(
|
||||
graphdb GraphDB, boltdb BoltDB,
|
||||
wg *sync.WaitGroup,
|
||||
driver neo4j.Driver, boltdb *bolt.DB,
|
||||
inChan chan EventFollower,
|
||||
resultChan chan WriteResult,
|
||||
) {
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Done()
|
||||
var localWg sync.WaitGroup
|
||||
|
||||
kvEventChan := make(chan EventFollower, 10)
|
||||
graphEventChan := make(chan EventFollower, 10)
|
||||
boltEventChan := make(chan EventFollower)
|
||||
graphEventChan := make(chan EventFollower)
|
||||
|
||||
kvWriteDone := make(chan struct{})
|
||||
|
||||
kvErrorChan := make(chan error)
|
||||
boltErrorChan := make(chan error)
|
||||
graphResultChan := make(chan WriteResult)
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
writeEventsToKVStore(
|
||||
boltdb,
|
||||
kvEventChan, kvWriteDone, kvErrorChan)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
writeEventsToGraphDriver(
|
||||
graphdb,
|
||||
graphEventChan, kvWriteDone, graphResultChan)
|
||||
}()
|
||||
localWg.Add(2)
|
||||
go writeEventsToBoltDB(&localWg, boltdb, boltEventChan, boltErrorChan)
|
||||
go writeEventsToGraphDB(&localWg, driver, graphEventChan, boltErrorChan, graphResultChan)
|
||||
|
||||
// Fan out events to both writers
|
||||
for follower := range inChan {
|
||||
kvEventChan <- follower
|
||||
boltEventChan <- follower
|
||||
graphEventChan <- follower
|
||||
}
|
||||
close(kvEventChan)
|
||||
close(boltEventChan)
|
||||
close(graphEventChan)
|
||||
|
||||
wg.Wait()
|
||||
localWg.Wait()
|
||||
|
||||
kvError := <-kvErrorChan
|
||||
graphResult := <-graphResultChan
|
||||
|
||||
var finalErr error
|
||||
if kvError != nil && graphResult.Error != nil {
|
||||
finalErr = fmt.Errorf("kvstore: %w; graphstore: %v", kvError, graphResult.Error)
|
||||
} else if kvError != nil {
|
||||
finalErr = fmt.Errorf("kvstore: %w", kvError)
|
||||
} else if graphResult.Error != nil {
|
||||
finalErr = fmt.Errorf("graphstore: %w", graphResult.Error)
|
||||
}
|
||||
|
||||
resultChan <- WriteResult{
|
||||
ResultSummaries: graphResult.ResultSummaries,
|
||||
Error: finalErr,
|
||||
}
|
||||
resultChan <- graphResult
|
||||
}
|
||||
|
||||
func writeEventsToKVStore(
|
||||
boltdb BoltDB,
|
||||
func writeEventsToBoltDB(
|
||||
wg *sync.WaitGroup,
|
||||
boltdb *bolt.DB,
|
||||
inChan chan EventFollower,
|
||||
done chan struct{},
|
||||
resultChan chan error,
|
||||
errorChan chan error,
|
||||
) {
|
||||
events := []EventBlob{}
|
||||
defer wg.Done()
|
||||
var events []EventBlob
|
||||
|
||||
for follower := range inChan {
|
||||
events = append(events,
|
||||
EventBlob{ID: follower.ID, JSON: follower.JSON})
|
||||
}
|
||||
|
||||
err := boltdb.BatchWriteEvents(events)
|
||||
if err != nil {
|
||||
close(done)
|
||||
} else {
|
||||
done <- struct{}{}
|
||||
close(done)
|
||||
}
|
||||
err := BatchWriteEvents(boltdb, events)
|
||||
|
||||
resultChan <- err
|
||||
close(resultChan)
|
||||
errorChan <- err
|
||||
close(errorChan)
|
||||
}
|
||||
|
||||
func writeEventsToGraphDriver(
|
||||
graphdb GraphDB,
|
||||
func writeEventsToGraphDB(
|
||||
wg *sync.WaitGroup,
|
||||
driver neo4j.Driver,
|
||||
inChan chan EventFollower,
|
||||
start chan struct{},
|
||||
boltErrorChan chan error,
|
||||
resultChan chan WriteResult,
|
||||
) {
|
||||
defer wg.Done()
|
||||
matchKeys := NewSimpleMatchKeys()
|
||||
batch := NewBatchSubgraph(matchKeys)
|
||||
|
||||
@@ -348,14 +302,17 @@ func writeEventsToGraphDriver(
|
||||
}
|
||||
}
|
||||
|
||||
_, ok := <-start
|
||||
if !ok {
|
||||
resultChan <- WriteResult{Error: fmt.Errorf("kv write failed, aborting graph write")}
|
||||
boltErr := <-boltErrorChan
|
||||
if boltErr != nil {
|
||||
resultChan <- WriteResult{
|
||||
Error: fmt.Errorf(
|
||||
"boltdb write failed, aborting graph write: %w", boltErr,
|
||||
)}
|
||||
close(resultChan)
|
||||
return
|
||||
}
|
||||
|
||||
summaries, err := graphdb.MergeSubgraph(context.Background(), batch)
|
||||
summaries, err := MergeSubgraph(context.Background(), driver, batch)
|
||||
resultChan <- WriteResult{
|
||||
ResultSummaries: summaries,
|
||||
Error: err,
|
||||
@@ -363,8 +320,9 @@ func writeEventsToGraphDriver(
|
||||
close(resultChan)
|
||||
}
|
||||
|
||||
func collectEvents(inChan chan EventFollower, resultChan chan []EventFollower) {
|
||||
collected := []EventFollower{}
|
||||
func collectEvents(wg *sync.WaitGroup, inChan chan EventFollower, resultChan chan []EventFollower) {
|
||||
defer wg.Done()
|
||||
var collected []EventFollower
|
||||
for follower := range inChan {
|
||||
collected = append(collected, follower)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user