Variety of refactors and optimizations.

This commit is contained in:
Jay
2026-03-05 00:28:40 -05:00
parent 894eab5405
commit 269e88fe49
15 changed files with 268 additions and 368 deletions

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
"sort"
"strings"
)
@@ -45,15 +44,11 @@ func (s *BatchSubgraph) AddNode(node *Node) error {
// Verify that the node has defined match property values.
matchLabel, _, err := node.MatchProps(s.matchProvider)
if err != nil {
return fmt.Errorf("invalid node: %s", err)
return fmt.Errorf("invalid node: %w", err)
}
// Determine the node's batch key.
batchKey := createNodeBatchKey(matchLabel, node.Labels.ToArray())
if _, exists := s.nodes[batchKey]; !exists {
s.nodes[batchKey] = []*Node{}
}
batchKey := createNodeBatchKey(matchLabel, node.Labels.AsSortedArray())
// Add the node to the sub
s.nodes[batchKey] = append(s.nodes[batchKey], node)
@@ -66,22 +61,18 @@ func (s *BatchSubgraph) AddRel(rel *Relationship) error {
// Verify that the start node has defined match property values.
startLabel, _, err := rel.Start.MatchProps(s.matchProvider)
if err != nil {
return fmt.Errorf("invalid start node: %s", err)
return fmt.Errorf("invalid start node: %w", err)
}
// Verify that the end node has defined match property values.
endLabel, _, err := rel.End.MatchProps(s.matchProvider)
if err != nil {
return fmt.Errorf("invalid end node: %s", err)
return fmt.Errorf("invalid end node: %w", err)
}
// Determine the relationship's batch key.
batchKey := createRelBatchKey(rel.Type, startLabel, endLabel)
if _, exists := s.rels[batchKey]; !exists {
s.rels[batchKey] = []*Relationship{}
}
// Add the relationship to the sub
s.rels[batchKey] = append(s.rels[batchKey], rel)
@@ -105,7 +96,7 @@ func (s *BatchSubgraph) RelCount() int {
}
func (s *BatchSubgraph) nodeKeys() []string {
keys := []string{}
keys := make([]string, 0, len(s.nodes))
for l := range s.nodes {
keys = append(keys, l)
}
@@ -113,7 +104,7 @@ func (s *BatchSubgraph) nodeKeys() []string {
}
func (s *BatchSubgraph) relKeys() []string {
keys := []string{}
keys := make([]string, 0, len(s.rels))
for t := range s.rels {
keys = append(keys, t)
}
@@ -121,7 +112,7 @@ func (s *BatchSubgraph) relKeys() []string {
}
func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
batches := []NodeBatch{}
batches := make([]NodeBatch, 0, len(s.nodeKeys()))
for _, nodeKey := range s.nodeKeys() {
matchLabel, labels, err := deserializeNodeBatchKey(nodeKey)
@@ -146,7 +137,7 @@ func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
}
func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
batches := []RelBatch{}
batches := make([]RelBatch, 0, len(s.relKeys()))
for _, relKey := range s.relKeys() {
rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey)
@@ -181,9 +172,8 @@ func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
// Helpers
func createNodeBatchKey(matchLabel string, labels []string) string {
sort.Strings(labels)
serializedLabels := strings.Join(labels, ",")
func createNodeBatchKey(matchLabel string, sortedLabels []string) string {
serializedLabels := strings.Join(sortedLabels, ",")
return fmt.Sprintf("%s:%s", matchLabel, serializedLabels)
}
@@ -245,7 +235,7 @@ func MergeSubgraph(
)
}
if nodeResultSummary != nil {
resultSummaries = append(resultSummaries, *nodeResultSummary)
resultSummaries = append(resultSummaries, nodeResultSummary)
}
}
@@ -261,7 +251,7 @@ func MergeSubgraph(
)
}
if relResultSummary != nil {
resultSummaries = append(resultSummaries, *relResultSummary)
resultSummaries = append(resultSummaries, relResultSummary)
}
}
@@ -284,11 +274,11 @@ func MergeNodes(
ctx context.Context,
tx neo4j.ManagedTransaction,
batch NodeBatch,
) (*neo4j.ResultSummary, error) {
) (neo4j.ResultSummary, error) {
cypherLabels := ToCypherLabels(batch.Labels)
cypherProps := ToCypherProps(batch.MatchKeys, "node.")
serializedNodes := []*SerializedNode{}
serializedNodes := make([]*SerializedNode, 0, len(batch.Nodes))
for _, node := range batch.Nodes {
serializedNodes = append(serializedNodes, node.Serialize())
}
@@ -316,21 +306,21 @@ func MergeNodes(
return nil, err
}
return &summary, nil
return summary, nil
}
func MergeRels(
ctx context.Context,
tx neo4j.ManagedTransaction,
batch RelBatch,
) (*neo4j.ResultSummary, error) {
) (neo4j.ResultSummary, error) {
cypherType := ToCypherLabel(batch.Type)
startCypherLabel := ToCypherLabel(batch.StartLabel)
endCypherLabel := ToCypherLabel(batch.EndLabel)
startCypherProps := ToCypherProps(batch.StartMatchKeys, "rel.start.")
endCypherProps := ToCypherProps(batch.EndMatchKeys, "rel.end.")
serializedRels := []*SerializedRel{}
serializedRels := make([]*SerializedRel, 0, len(batch.Rels))
for _, rel := range batch.Rels {
serializedRels = append(serializedRels, rel.Serialize())
}
@@ -363,5 +353,5 @@ func MergeRels(
return nil, err
}
return &summary, nil
return summary, nil
}

View File

@@ -7,20 +7,19 @@ import (
func TestNodeBatchKey(t *testing.T) {
matchLabel := "Event"
labels := []string{"Event", "AddressableEvent"}
// labels should be batched by key generator
sortedLabels := []string{"AddressableEvent", "Event"}
expectedKey := "Event:AddressableEvent,Event"
// Test Serialization
batchKey := createNodeBatchKey(matchLabel, labels)
// labels are expected to be pre-sorted
batchKey := createNodeBatchKey(matchLabel, sortedLabels)
assert.Equal(t, expectedKey, batchKey)
// Test Deserialization
returnedMatchLabel, returnedLabels, err := deserializeNodeBatchKey(batchKey)
assert.NoError(t, err)
assert.Equal(t, matchLabel, returnedMatchLabel)
assert.ElementsMatch(t, labels, returnedLabels)
assert.ElementsMatch(t, sortedLabels, returnedLabels)
}
func TestRelBatchKey(t *testing.T) {

View File

@@ -4,32 +4,11 @@ import (
"github.com/boltdb/bolt"
)
// Interface
const BucketName string = "events"
type BoltDB interface {
Setup() error
BatchCheckEventsExist(eventIDs []string) map[string]bool
BatchWriteEvents(events []EventBlob) error
}
func NewKVDB(boltdb *bolt.DB) BoltDB {
return &boltDB{db: boltdb}
}
type boltDB struct {
db *bolt.DB
}
func (b *boltDB) Setup() error {
return SetupBoltDB(b.db)
}
func (b *boltDB) BatchCheckEventsExist(eventIDs []string) map[string]bool {
return BatchCheckEventsExist(b.db, eventIDs)
}
func (b *boltDB) BatchWriteEvents(events []EventBlob) error {
return BatchWriteEvents(b.db, events)
type EventBlob struct {
ID string
JSON string
}
func SetupBoltDB(boltdb *bolt.DB) error {
@@ -39,19 +18,10 @@ func SetupBoltDB(boltdb *bolt.DB) error {
})
}
// Functions
const BucketName string = "events"
type EventBlob struct {
ID string
JSON string
}
func BatchCheckEventsExist(boltdb *bolt.DB, eventIDs []string) map[string]bool {
existsMap := make(map[string]bool)
boltdb.View(func(tx *bolt.Tx) error {
_ = boltdb.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(BucketName))
if bucket == nil {
return nil

View File

@@ -30,7 +30,8 @@ func ToCypherProps(keys []string, prefix string) string {
if prefix == "" {
prefix = "$"
}
cypherPropsParts := []string{}
var cypherPropsParts []string
for _, key := range keys {
cypherPropsParts = append(
cypherPropsParts, fmt.Sprintf("%s: %s%s", key, prefix, key))

View File

@@ -1,82 +0,0 @@
package heartwood
import (
roots "git.wisehodl.dev/jay/go-roots/events"
)
type Expander func(e roots.Event, s *EventSubgraph)
type ExpanderPipeline []Expander
func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline {
return ExpanderPipeline(expanders)
}
func DefaultExpanders() []Expander {
return []Expander{
ExpandTaggedEvents,
ExpandTaggedUsers,
}
}
// Default Expander Functions
func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) {
tagNodes := s.NodesByLabel("Tag")
for _, tag := range e.Tags {
if !isValidTag(tag) {
continue
}
name := tag[0]
value := tag[1]
if name != "e" || !roots.Hex64Pattern.MatchString(value) {
continue
}
tagNode := findTagNode(tagNodes, name, value)
if tagNode == nil {
continue
}
referencedEvent := NewEventNode(value)
s.AddNode(referencedEvent)
s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil))
}
}
func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) {
tagNodes := s.NodesByLabel("Tag")
for _, tag := range e.Tags {
if !isValidTag(tag) {
continue
}
name := tag[0]
value := tag[1]
if name != "p" || !roots.Hex64Pattern.MatchString(value) {
continue
}
tagNode := findTagNode(tagNodes, name, value)
if tagNode == nil {
continue
}
referencedEvent := NewUserNode(value)
s.AddNode(referencedEvent)
s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil))
}
}
// Helpers
func findTagNode(nodes []*Node, name, value string) *Node {
for _, node := range nodes {
if node.Props["name"] == name && node.Props["value"] == value {
return node
}
}
return nil
}

View File

@@ -252,7 +252,7 @@ func UnmarshalGraphJSON(data []byte, f *GraphFilter) error {
}
func marshalGraphArray(filters []GraphFilter) ([]json.RawMessage, error) {
result := []json.RawMessage{}
result := make([]json.RawMessage, 0, len(filters))
for _, f := range filters {
b, err := MarshalGraphJSON(f)
if err != nil {
@@ -268,7 +268,7 @@ func unmarshalGraphArray(raws json.RawMessage) ([]GraphFilter, error) {
if err := json.Unmarshal(raws, &rawArray); err != nil {
return nil, err
}
var result []GraphFilter
result := make([]GraphFilter, 0, len(rawArray))
for _, raw := range rawArray {
var f GraphFilter
if err := UnmarshalGraphJSON(raw, &f); err != nil {

View File

@@ -2,7 +2,6 @@ package heartwood
import (
"fmt"
"sort"
)
// ========================================
@@ -33,7 +32,7 @@ type SimpleMatchKeys struct {
}
func (p *SimpleMatchKeys) GetLabels() []string {
labels := []string{}
labels := make([]string, 0, len(p.Keys))
for l := range p.Keys {
labels = append(labels, l)
}
@@ -43,9 +42,8 @@ func (p *SimpleMatchKeys) GetLabels() []string {
func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) {
if keys, exists := p.Keys[label]; exists {
return keys, exists
} else {
return nil, exists
}
return nil, false
}
// ========================================
@@ -56,7 +54,7 @@ func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) {
// properties.
type Node struct {
// Set of labels on the node.
Labels Set[string]
Labels *StringSet
// Mapping of properties on the node.
Props Properties
}
@@ -67,7 +65,7 @@ func NewNode(label string, props Properties) *Node {
props = make(Properties)
}
return &Node{
Labels: NewSet(label),
Labels: NewStringSet(label),
Props: props,
}
}
@@ -79,8 +77,7 @@ func (n *Node) MatchProps(
// Iterate over each label on the node, checking whether each has match
// keys associated with it.
labels := n.Labels.ToArray()
sort.Strings(labels)
labels := n.Labels.AsSortedArray()
for _, label := range labels {
if keys, exists := matchProvider.GetKeys(label); exists {
props := make(Properties)

View File

@@ -71,7 +71,7 @@ func TestMatchProps(t *testing.T) {
{
name: "multiple labels, one matches",
node: &Node{
Labels: NewSet("Event", "Unknown"),
Labels: NewStringSet("Event", "Unknown"),
Props: Properties{
"id": "abc123",
},

View File

@@ -5,26 +5,6 @@ import (
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
// Interface
type GraphDB interface {
MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error)
}
func NewGraphDriver(driver neo4j.Driver) GraphDB {
return &graphdb{driver: driver}
}
type graphdb struct {
driver neo4j.Driver
}
func (n *graphdb) MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error) {
return MergeSubgraph(ctx, n.driver, subgraph)
}
// Functions
func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver, error) {
driver, err := neo4j.NewDriver(
uri,
@@ -40,3 +20,39 @@ func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver
return driver, nil
}
// SetNeo4jSchema ensures that the necessary indexes and constraints exist in
// the database
func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error {
schemaQueries := []string{
`CREATE CONSTRAINT user_pubkey IF NOT EXISTS
FOR (n:User) REQUIRE n.pubkey IS UNIQUE`,
`CREATE INDEX user_pubkey IF NOT EXISTS
FOR (n:User) ON (n.pubkey)`,
`CREATE INDEX event_id IF NOT EXISTS
FOR (n:Event) ON (n.id)`,
`CREATE INDEX event_kind IF NOT EXISTS
FOR (n:Event) ON (n.kind)`,
`CREATE INDEX tag_name_value IF NOT EXISTS
FOR (n:Tag) ON (n.name, n.value)`,
}
// Create indexes and constraints
for _, query := range schemaQueries {
_, err := neo4j.ExecuteQuery(ctx, driver,
query,
nil,
neo4j.EagerResultTransformer,
neo4j.ExecuteQueryWithDatabase("neo4j"))
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,9 +1,7 @@
package heartwood
import (
"context"
"fmt"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
// ========================================
@@ -80,7 +78,7 @@ func validateNodeLabel(node *Node, role string, expectedLabel string) {
if !node.Labels.Contains(expectedLabel) {
panic(fmt.Errorf(
"expected %s node to have label %q. got %v",
role, expectedLabel, node.Labels.ToArray(),
role, expectedLabel, node.Labels.AsSortedArray(),
))
}
}
@@ -98,43 +96,3 @@ func NewRelationshipWithValidation(
return NewRelationship(rtype, start, end, props)
}
// ========================================
// Schema Indexes and Constraints
// ========================================
// SetNeo4jSchema ensures that the necessary indexes and constraints exist in
// the database
func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error {
schemaQueries := []string{
`CREATE CONSTRAINT user_pubkey IF NOT EXISTS
FOR (n:User) REQUIRE n.pubkey IS UNIQUE`,
`CREATE INDEX user_pubkey IF NOT EXISTS
FOR (n:User) ON (n.pubkey)`,
`CREATE INDEX event_id IF NOT EXISTS
FOR (n:Event) ON (n.id)`,
`CREATE INDEX event_kind IF NOT EXISTS
FOR (n:Event) ON (n.kind)`,
`CREATE INDEX tag_name_value IF NOT EXISTS
FOR (n:Tag) ON (n.name, n.value)`,
}
// Create indexes and constraints
for _, query := range schemaQueries {
_, err := neo4j.ExecuteQuery(ctx, driver,
query,
nil,
neo4j.EagerResultTransformer,
neo4j.ExecuteQueryWithDatabase("neo4j"))
if err != nil {
return err
}
}
return nil
}

View File

@@ -37,8 +37,8 @@ func TestNewRelationshipWithValidation(t *testing.T) {
}
rel := NewSignedRel(tc.start, tc.end, nil)
assert.Equal(t, "SIGNED", rel.Type)
assert.Contains(t, rel.Start.Labels.ToArray(), "User")
assert.Contains(t, rel.End.Labels.ToArray(), "Event")
assert.Contains(t, rel.Start.Labels.AsSortedArray(), "User")
assert.Contains(t, rel.End.Labels.AsSortedArray(), "Event")
})
}
}

52
set.go
View File

@@ -1,14 +1,20 @@
package heartwood
import (
"sort"
)
// Sets
type Set[T comparable] struct {
inner map[T]struct{}
type StringSet struct {
inner map[string]struct{}
sorted []string
}
func NewSet[T comparable](items ...T) Set[T] {
set := Set[T]{
inner: make(map[T]struct{}),
func NewStringSet(items ...string) *StringSet {
set := &StringSet{
inner: make(map[string]struct{}),
sorted: []string{},
}
for _, i := range items {
set.Add(i)
@@ -16,20 +22,26 @@ func NewSet[T comparable](items ...T) Set[T] {
return set
}
func (s Set[T]) Add(item T) {
s.inner[item] = struct{}{}
func (s *StringSet) Add(item string) {
if _, exists := s.inner[item]; !exists {
s.inner[item] = struct{}{}
s.rebuildSorted()
}
}
func (s Set[T]) Remove(item T) {
delete(s.inner, item)
func (s *StringSet) Remove(item string) {
if _, exists := s.inner[item]; exists {
delete(s.inner, item)
s.rebuildSorted()
}
}
func (s Set[T]) Contains(item T) bool {
func (s *StringSet) Contains(item string) bool {
_, exists := s.inner[item]
return exists
}
func (s Set[T]) Equal(other Set[T]) bool {
func (s *StringSet) Equal(other StringSet) bool {
if len(s.inner) != len(other.inner) {
return false
}
@@ -41,14 +53,18 @@ func (s Set[T]) Equal(other Set[T]) bool {
return true
}
func (s Set[T]) Length() int {
func (s *StringSet) Length() int {
return len(s.inner)
}
func (s Set[T]) ToArray() []T {
array := []T{}
for i := range s.inner {
array = append(array, i)
}
return array
func (s *StringSet) AsSortedArray() []string {
return s.sorted
}
func (s *StringSet) rebuildSorted() {
s.sorted = make([]string, 0, len(s.inner))
for item := range s.inner {
s.sorted = append(s.sorted, item)
}
sort.Strings(s.sorted)
}

View File

@@ -4,7 +4,7 @@ import (
roots "git.wisehodl.dev/jay/go-roots/events"
)
// Event subgraph struct
// Types
type EventSubgraph struct {
nodes []*Node
@@ -35,7 +35,7 @@ func (s *EventSubgraph) Rels() []*Relationship {
}
func (s *EventSubgraph) NodesByLabel(label string) []*Node {
nodes := []*Node{}
var nodes []*Node
for _, node := range s.nodes {
if node.Labels.Contains(label) {
nodes = append(nodes, node)
@@ -58,7 +58,7 @@ func isValidTag(t roots.Tag) bool {
return true
}
// Event to subgraph conversion
// Event to subgraph pipeline
func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph {
s := NewEventSubgraph()
@@ -89,6 +89,8 @@ func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph {
return s
}
// Core pipeline functions
func newEventNode(eventID string, createdAt int, kind int, content string) *Node {
eventNode := NewEventNode(eventID)
eventNode.Props["created_at"] = createdAt
@@ -106,7 +108,7 @@ func newSignedRel(user, event *Node) *Relationship {
}
func newTagNodes(tags []roots.Tag) []*Node {
nodes := []*Node{}
nodes := make([]*Node, 0, len(tags))
for _, tag := range tags {
if !isValidTag(tag) {
continue
@@ -117,9 +119,84 @@ func newTagNodes(tags []roots.Tag) []*Node {
}
func newTagRels(event *Node, tags []*Node) []*Relationship {
rels := []*Relationship{}
rels := make([]*Relationship, 0, len(tags))
for _, tag := range tags {
rels = append(rels, NewTaggedRel(event, tag, nil))
}
return rels
}
// Expander Pipeline
type Expander func(e roots.Event, s *EventSubgraph)
type ExpanderPipeline []Expander
func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline {
return ExpanderPipeline(expanders)
}
func DefaultExpanders() []Expander {
return []Expander{
ExpandTaggedEvents,
ExpandTaggedUsers,
}
}
func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) {
tagNodes := s.NodesByLabel("Tag")
for _, tag := range e.Tags {
if !isValidTag(tag) {
continue
}
name := tag[0]
value := tag[1]
if name != "e" || !roots.Hex64Pattern.MatchString(value) {
continue
}
tagNode := findTagNode(tagNodes, name, value)
if tagNode == nil {
continue
}
referencedEvent := NewEventNode(value)
s.AddNode(referencedEvent)
s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil))
}
}
func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) {
tagNodes := s.NodesByLabel("Tag")
for _, tag := range e.Tags {
if !isValidTag(tag) {
continue
}
name := tag[0]
value := tag[1]
if name != "p" || !roots.Hex64Pattern.MatchString(value) {
continue
}
tagNode := findTagNode(tagNodes, name, value)
if tagNode == nil {
continue
}
referencedEvent := NewUserNode(value)
s.AddNode(referencedEvent)
s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil))
}
}
func findTagNode(nodes []*Node, name, value string) *Node {
for _, node := range nodes {
if node.Props["name"] == name && node.Props["value"] == value {
return node
}
}
return nil
}

View File

@@ -172,7 +172,7 @@ func nodesEqual(expected, got *Node) error {
}
// Compare label values
for _, label := range expected.Labels.ToArray() {
for _, label := range expected.Labels.AsSortedArray() {
if !got.Labels.Contains(label) {
return fmt.Errorf("missing label %q", label)
}

192
write.go
View File

@@ -5,14 +5,15 @@ import (
"encoding/json"
"fmt"
roots "git.wisehodl.dev/jay/go-roots/events"
"github.com/boltdb/bolt"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
"sync"
"time"
)
type WriteOptions struct {
Expanders ExpanderPipeline
KVReadBatchSize int
Expanders ExpanderPipeline
BoltReadBatchSize int
}
type EventFollower struct {
@@ -39,7 +40,7 @@ type WriteReport struct {
func WriteEvents(
events []string,
graphdb GraphDB, boltdb BoltDB,
driver neo4j.Driver, boltdb *bolt.DB,
opts *WriteOptions,
) (WriteReport, error) {
start := time.Now()
@@ -50,7 +51,7 @@ func WriteEvents(
setDefaultWriteOptions(opts)
err := boltdb.Setup()
err := SetupBoltDB(boltdb)
if err != nil {
return WriteReport{}, fmt.Errorf("error setting up bolt db: %w", err)
}
@@ -58,80 +59,55 @@ func WriteEvents(
var wg sync.WaitGroup
// Create Event Followers
jsonChan := make(chan string, 10)
eventChan := make(chan EventFollower, 10)
jsonChan := make(chan string)
eventChan := make(chan EventFollower)
wg.Add(1)
go func() {
defer wg.Done()
createEventFollowers(jsonChan, eventChan)
}()
go createEventFollowers(&wg, jsonChan, eventChan)
// Parse Event JSON
parsedChan := make(chan EventFollower, 10)
invalidChan := make(chan EventFollower, 10)
parsedChan := make(chan EventFollower)
invalidChan := make(chan EventFollower)
wg.Add(1)
go func() {
defer wg.Done()
parseEventJSON(eventChan, parsedChan, invalidChan)
}()
go parseEventJSON(&wg, eventChan, parsedChan, invalidChan)
// Collect Invalid Events
collectedInvalidChan := make(chan []EventFollower)
wg.Add(1)
go func() {
defer wg.Done()
collectEvents(invalidChan, collectedInvalidChan)
}()
go collectEvents(&wg, invalidChan, collectedInvalidChan)
// Enforce Policy Rules
queuedChan := make(chan EventFollower, 10)
skippedChan := make(chan EventFollower, 10)
queuedChan := make(chan EventFollower)
skippedChan := make(chan EventFollower)
wg.Add(1)
go func() {
defer wg.Done()
enforcePolicyRules(
graphdb, boltdb,
opts.KVReadBatchSize,
parsedChan, queuedChan, skippedChan)
}()
go enforcePolicyRules(&wg, driver, boltdb, opts.BoltReadBatchSize,
parsedChan, queuedChan, skippedChan)
// Collect Skipped Events
collectedSkippedChan := make(chan []EventFollower)
wg.Add(1)
go func() {
defer wg.Done()
collectEvents(skippedChan, collectedSkippedChan)
}()
go collectEvents(&wg, skippedChan, collectedSkippedChan)
// Convert Events To Subgraphs
convertedChan := make(chan EventFollower, 10)
convertedChan := make(chan EventFollower)
wg.Add(1)
go func() {
defer wg.Done()
convertEventsToSubgraphs(opts.Expanders, queuedChan, convertedChan)
}()
go convertEventsToSubgraphs(&wg, opts.Expanders, queuedChan, convertedChan)
// Write Events To Databases
writeResultChan := make(chan WriteResult)
wg.Add(1)
go func() {
defer wg.Done()
writeEventsToDatabases(
graphdb, boltdb,
convertedChan, writeResultChan)
}()
go writeEventsToDatabases(&wg, driver, boltdb, convertedChan, writeResultChan)
// Send event jsons into pipeline
go func() {
for _, json := range events {
jsonChan <- json
for _, raw := range events {
jsonChan <- raw
}
close(jsonChan)
}()
@@ -158,19 +134,21 @@ func setDefaultWriteOptions(opts *WriteOptions) {
if opts.Expanders == nil {
opts.Expanders = NewExpanderPipeline(DefaultExpanders()...)
}
if opts.KVReadBatchSize == 0 {
opts.KVReadBatchSize = 100
if opts.BoltReadBatchSize == 0 {
opts.BoltReadBatchSize = 100
}
}
func createEventFollowers(jsonChan chan string, eventChan chan EventFollower) {
func createEventFollowers(wg *sync.WaitGroup, jsonChan chan string, eventChan chan EventFollower) {
defer wg.Done()
for json := range jsonChan {
eventChan <- EventFollower{JSON: json}
}
close(eventChan)
}
func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) {
func parseEventJSON(wg *sync.WaitGroup, inChan, parsedChan, invalidChan chan EventFollower) {
defer wg.Done()
for follower := range inChan {
var event roots.Event
jsonBytes := []byte(follower.JSON)
@@ -191,11 +169,13 @@ func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) {
}
func enforcePolicyRules(
graphdb GraphDB, boltdb BoltDB,
wg *sync.WaitGroup,
driver neo4j.Driver, boltdb *bolt.DB,
batchSize int,
inChan, queuedChan, skippedChan chan EventFollower,
) {
batch := []EventFollower{}
defer wg.Done()
var batch []EventFollower
for follower := range inChan {
batch = append(batch, follower)
@@ -215,17 +195,17 @@ func enforcePolicyRules(
}
func processPolicyRulesBatch(
boltdb BoltDB,
boltdb *bolt.DB,
batch []EventFollower,
queuedChan, skippedChan chan EventFollower,
) {
eventIDs := []string{}
eventIDs := make([]string, 0, len(batch))
for _, follower := range batch {
eventIDs = append(eventIDs, follower.ID)
}
existsMap := boltdb.BatchCheckEventsExist(eventIDs)
existsMap := BatchCheckEventsExist(boltdb, eventIDs)
for _, follower := range batch {
if existsMap[follower.ID] {
@@ -237,9 +217,10 @@ func processPolicyRulesBatch(
}
func convertEventsToSubgraphs(
expanders ExpanderPipeline,
wg *sync.WaitGroup, expanders ExpanderPipeline,
inChan, convertedChan chan EventFollower,
) {
defer wg.Done()
for follower := range inChan {
subgraph := EventToSubgraph(follower.Event, expanders)
follower.Subgraph = subgraph
@@ -249,93 +230,66 @@ func convertEventsToSubgraphs(
}
func writeEventsToDatabases(
graphdb GraphDB, boltdb BoltDB,
wg *sync.WaitGroup,
driver neo4j.Driver, boltdb *bolt.DB,
inChan chan EventFollower,
resultChan chan WriteResult,
) {
var wg sync.WaitGroup
defer wg.Done()
var localWg sync.WaitGroup
kvEventChan := make(chan EventFollower, 10)
graphEventChan := make(chan EventFollower, 10)
boltEventChan := make(chan EventFollower)
graphEventChan := make(chan EventFollower)
kvWriteDone := make(chan struct{})
kvErrorChan := make(chan error)
boltErrorChan := make(chan error)
graphResultChan := make(chan WriteResult)
wg.Add(2)
go func() {
defer wg.Done()
writeEventsToKVStore(
boltdb,
kvEventChan, kvWriteDone, kvErrorChan)
}()
go func() {
defer wg.Done()
writeEventsToGraphDriver(
graphdb,
graphEventChan, kvWriteDone, graphResultChan)
}()
localWg.Add(2)
go writeEventsToBoltDB(&localWg, boltdb, boltEventChan, boltErrorChan)
go writeEventsToGraphDB(&localWg, driver, graphEventChan, boltErrorChan, graphResultChan)
// Fan out events to both writers
for follower := range inChan {
kvEventChan <- follower
boltEventChan <- follower
graphEventChan <- follower
}
close(kvEventChan)
close(boltEventChan)
close(graphEventChan)
wg.Wait()
localWg.Wait()
kvError := <-kvErrorChan
graphResult := <-graphResultChan
var finalErr error
if kvError != nil && graphResult.Error != nil {
finalErr = fmt.Errorf("kvstore: %w; graphstore: %v", kvError, graphResult.Error)
} else if kvError != nil {
finalErr = fmt.Errorf("kvstore: %w", kvError)
} else if graphResult.Error != nil {
finalErr = fmt.Errorf("graphstore: %w", graphResult.Error)
}
resultChan <- WriteResult{
ResultSummaries: graphResult.ResultSummaries,
Error: finalErr,
}
resultChan <- graphResult
}
func writeEventsToKVStore(
boltdb BoltDB,
func writeEventsToBoltDB(
wg *sync.WaitGroup,
boltdb *bolt.DB,
inChan chan EventFollower,
done chan struct{},
resultChan chan error,
errorChan chan error,
) {
events := []EventBlob{}
defer wg.Done()
var events []EventBlob
for follower := range inChan {
events = append(events,
EventBlob{ID: follower.ID, JSON: follower.JSON})
}
err := boltdb.BatchWriteEvents(events)
if err != nil {
close(done)
} else {
done <- struct{}{}
close(done)
}
err := BatchWriteEvents(boltdb, events)
resultChan <- err
close(resultChan)
errorChan <- err
close(errorChan)
}
func writeEventsToGraphDriver(
graphdb GraphDB,
func writeEventsToGraphDB(
wg *sync.WaitGroup,
driver neo4j.Driver,
inChan chan EventFollower,
start chan struct{},
boltErrorChan chan error,
resultChan chan WriteResult,
) {
defer wg.Done()
matchKeys := NewSimpleMatchKeys()
batch := NewBatchSubgraph(matchKeys)
@@ -348,14 +302,17 @@ func writeEventsToGraphDriver(
}
}
_, ok := <-start
if !ok {
resultChan <- WriteResult{Error: fmt.Errorf("kv write failed, aborting graph write")}
boltErr := <-boltErrorChan
if boltErr != nil {
resultChan <- WriteResult{
Error: fmt.Errorf(
"boltdb write failed, aborting graph write: %w", boltErr,
)}
close(resultChan)
return
}
summaries, err := graphdb.MergeSubgraph(context.Background(), batch)
summaries, err := MergeSubgraph(context.Background(), driver, batch)
resultChan <- WriteResult{
ResultSummaries: summaries,
Error: err,
@@ -363,8 +320,9 @@ func writeEventsToGraphDriver(
close(resultChan)
}
func collectEvents(inChan chan EventFollower, resultChan chan []EventFollower) {
collected := []EventFollower{}
func collectEvents(wg *sync.WaitGroup, inChan chan EventFollower, resultChan chan []EventFollower) {
defer wg.Done()
var collected []EventFollower
for follower := range inChan {
collected = append(collected, follower)
}