Variety of refactors and optimizations.
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,15 +44,11 @@ func (s *BatchSubgraph) AddNode(node *Node) error {
|
|||||||
// Verify that the node has defined match property values.
|
// Verify that the node has defined match property values.
|
||||||
matchLabel, _, err := node.MatchProps(s.matchProvider)
|
matchLabel, _, err := node.MatchProps(s.matchProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid node: %s", err)
|
return fmt.Errorf("invalid node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the node's batch key.
|
// Determine the node's batch key.
|
||||||
batchKey := createNodeBatchKey(matchLabel, node.Labels.ToArray())
|
batchKey := createNodeBatchKey(matchLabel, node.Labels.AsSortedArray())
|
||||||
|
|
||||||
if _, exists := s.nodes[batchKey]; !exists {
|
|
||||||
s.nodes[batchKey] = []*Node{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the node to the sub
|
// Add the node to the sub
|
||||||
s.nodes[batchKey] = append(s.nodes[batchKey], node)
|
s.nodes[batchKey] = append(s.nodes[batchKey], node)
|
||||||
@@ -66,22 +61,18 @@ func (s *BatchSubgraph) AddRel(rel *Relationship) error {
|
|||||||
// Verify that the start node has defined match property values.
|
// Verify that the start node has defined match property values.
|
||||||
startLabel, _, err := rel.Start.MatchProps(s.matchProvider)
|
startLabel, _, err := rel.Start.MatchProps(s.matchProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid start node: %s", err)
|
return fmt.Errorf("invalid start node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the end node has defined match property values.
|
// Verify that the end node has defined match property values.
|
||||||
endLabel, _, err := rel.End.MatchProps(s.matchProvider)
|
endLabel, _, err := rel.End.MatchProps(s.matchProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid end node: %s", err)
|
return fmt.Errorf("invalid end node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the relationship's batch key.
|
// Determine the relationship's batch key.
|
||||||
batchKey := createRelBatchKey(rel.Type, startLabel, endLabel)
|
batchKey := createRelBatchKey(rel.Type, startLabel, endLabel)
|
||||||
|
|
||||||
if _, exists := s.rels[batchKey]; !exists {
|
|
||||||
s.rels[batchKey] = []*Relationship{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the relationship to the sub
|
// Add the relationship to the sub
|
||||||
s.rels[batchKey] = append(s.rels[batchKey], rel)
|
s.rels[batchKey] = append(s.rels[batchKey], rel)
|
||||||
|
|
||||||
@@ -105,7 +96,7 @@ func (s *BatchSubgraph) RelCount() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BatchSubgraph) nodeKeys() []string {
|
func (s *BatchSubgraph) nodeKeys() []string {
|
||||||
keys := []string{}
|
keys := make([]string, 0, len(s.nodes))
|
||||||
for l := range s.nodes {
|
for l := range s.nodes {
|
||||||
keys = append(keys, l)
|
keys = append(keys, l)
|
||||||
}
|
}
|
||||||
@@ -113,7 +104,7 @@ func (s *BatchSubgraph) nodeKeys() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BatchSubgraph) relKeys() []string {
|
func (s *BatchSubgraph) relKeys() []string {
|
||||||
keys := []string{}
|
keys := make([]string, 0, len(s.rels))
|
||||||
for t := range s.rels {
|
for t := range s.rels {
|
||||||
keys = append(keys, t)
|
keys = append(keys, t)
|
||||||
}
|
}
|
||||||
@@ -121,7 +112,7 @@ func (s *BatchSubgraph) relKeys() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
|
func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
|
||||||
batches := []NodeBatch{}
|
batches := make([]NodeBatch, 0, len(s.nodeKeys()))
|
||||||
|
|
||||||
for _, nodeKey := range s.nodeKeys() {
|
for _, nodeKey := range s.nodeKeys() {
|
||||||
matchLabel, labels, err := deserializeNodeBatchKey(nodeKey)
|
matchLabel, labels, err := deserializeNodeBatchKey(nodeKey)
|
||||||
@@ -146,7 +137,7 @@ func (s *BatchSubgraph) NodeBatches() ([]NodeBatch, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
|
func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
|
||||||
batches := []RelBatch{}
|
batches := make([]RelBatch, 0, len(s.relKeys()))
|
||||||
|
|
||||||
for _, relKey := range s.relKeys() {
|
for _, relKey := range s.relKeys() {
|
||||||
rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey)
|
rtype, startLabel, endLabel, err := deserializeRelBatchKey(relKey)
|
||||||
@@ -181,9 +172,8 @@ func (s *BatchSubgraph) RelBatches() ([]RelBatch, error) {
|
|||||||
|
|
||||||
// Helpers
|
// Helpers
|
||||||
|
|
||||||
func createNodeBatchKey(matchLabel string, labels []string) string {
|
func createNodeBatchKey(matchLabel string, sortedLabels []string) string {
|
||||||
sort.Strings(labels)
|
serializedLabels := strings.Join(sortedLabels, ",")
|
||||||
serializedLabels := strings.Join(labels, ",")
|
|
||||||
return fmt.Sprintf("%s:%s", matchLabel, serializedLabels)
|
return fmt.Sprintf("%s:%s", matchLabel, serializedLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,7 +235,7 @@ func MergeSubgraph(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
if nodeResultSummary != nil {
|
if nodeResultSummary != nil {
|
||||||
resultSummaries = append(resultSummaries, *nodeResultSummary)
|
resultSummaries = append(resultSummaries, nodeResultSummary)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,7 +251,7 @@ func MergeSubgraph(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
if relResultSummary != nil {
|
if relResultSummary != nil {
|
||||||
resultSummaries = append(resultSummaries, *relResultSummary)
|
resultSummaries = append(resultSummaries, relResultSummary)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,11 +274,11 @@ func MergeNodes(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
tx neo4j.ManagedTransaction,
|
tx neo4j.ManagedTransaction,
|
||||||
batch NodeBatch,
|
batch NodeBatch,
|
||||||
) (*neo4j.ResultSummary, error) {
|
) (neo4j.ResultSummary, error) {
|
||||||
cypherLabels := ToCypherLabels(batch.Labels)
|
cypherLabels := ToCypherLabels(batch.Labels)
|
||||||
cypherProps := ToCypherProps(batch.MatchKeys, "node.")
|
cypherProps := ToCypherProps(batch.MatchKeys, "node.")
|
||||||
|
|
||||||
serializedNodes := []*SerializedNode{}
|
serializedNodes := make([]*SerializedNode, 0, len(batch.Nodes))
|
||||||
for _, node := range batch.Nodes {
|
for _, node := range batch.Nodes {
|
||||||
serializedNodes = append(serializedNodes, node.Serialize())
|
serializedNodes = append(serializedNodes, node.Serialize())
|
||||||
}
|
}
|
||||||
@@ -316,21 +306,21 @@ func MergeNodes(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &summary, nil
|
return summary, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeRels(
|
func MergeRels(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
tx neo4j.ManagedTransaction,
|
tx neo4j.ManagedTransaction,
|
||||||
batch RelBatch,
|
batch RelBatch,
|
||||||
) (*neo4j.ResultSummary, error) {
|
) (neo4j.ResultSummary, error) {
|
||||||
cypherType := ToCypherLabel(batch.Type)
|
cypherType := ToCypherLabel(batch.Type)
|
||||||
startCypherLabel := ToCypherLabel(batch.StartLabel)
|
startCypherLabel := ToCypherLabel(batch.StartLabel)
|
||||||
endCypherLabel := ToCypherLabel(batch.EndLabel)
|
endCypherLabel := ToCypherLabel(batch.EndLabel)
|
||||||
startCypherProps := ToCypherProps(batch.StartMatchKeys, "rel.start.")
|
startCypherProps := ToCypherProps(batch.StartMatchKeys, "rel.start.")
|
||||||
endCypherProps := ToCypherProps(batch.EndMatchKeys, "rel.end.")
|
endCypherProps := ToCypherProps(batch.EndMatchKeys, "rel.end.")
|
||||||
|
|
||||||
serializedRels := []*SerializedRel{}
|
serializedRels := make([]*SerializedRel, 0, len(batch.Rels))
|
||||||
for _, rel := range batch.Rels {
|
for _, rel := range batch.Rels {
|
||||||
serializedRels = append(serializedRels, rel.Serialize())
|
serializedRels = append(serializedRels, rel.Serialize())
|
||||||
}
|
}
|
||||||
@@ -363,5 +353,5 @@ func MergeRels(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &summary, nil
|
return summary, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,20 +7,19 @@ import (
|
|||||||
|
|
||||||
func TestNodeBatchKey(t *testing.T) {
|
func TestNodeBatchKey(t *testing.T) {
|
||||||
matchLabel := "Event"
|
matchLabel := "Event"
|
||||||
labels := []string{"Event", "AddressableEvent"}
|
sortedLabels := []string{"AddressableEvent", "Event"}
|
||||||
|
|
||||||
// labels should be batched by key generator
|
|
||||||
expectedKey := "Event:AddressableEvent,Event"
|
expectedKey := "Event:AddressableEvent,Event"
|
||||||
|
|
||||||
// Test Serialization
|
// Test Serialization
|
||||||
batchKey := createNodeBatchKey(matchLabel, labels)
|
// labels are expected to be pre-sorted
|
||||||
|
batchKey := createNodeBatchKey(matchLabel, sortedLabels)
|
||||||
assert.Equal(t, expectedKey, batchKey)
|
assert.Equal(t, expectedKey, batchKey)
|
||||||
|
|
||||||
// Test Deserialization
|
// Test Deserialization
|
||||||
returnedMatchLabel, returnedLabels, err := deserializeNodeBatchKey(batchKey)
|
returnedMatchLabel, returnedLabels, err := deserializeNodeBatchKey(batchKey)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, matchLabel, returnedMatchLabel)
|
assert.Equal(t, matchLabel, returnedMatchLabel)
|
||||||
assert.ElementsMatch(t, labels, returnedLabels)
|
assert.ElementsMatch(t, sortedLabels, returnedLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelBatchKey(t *testing.T) {
|
func TestRelBatchKey(t *testing.T) {
|
||||||
|
|||||||
40
boltdb.go
40
boltdb.go
@@ -4,32 +4,11 @@ import (
|
|||||||
"github.com/boltdb/bolt"
|
"github.com/boltdb/bolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Interface
|
const BucketName string = "events"
|
||||||
|
|
||||||
type BoltDB interface {
|
type EventBlob struct {
|
||||||
Setup() error
|
ID string
|
||||||
BatchCheckEventsExist(eventIDs []string) map[string]bool
|
JSON string
|
||||||
BatchWriteEvents(events []EventBlob) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewKVDB(boltdb *bolt.DB) BoltDB {
|
|
||||||
return &boltDB{db: boltdb}
|
|
||||||
}
|
|
||||||
|
|
||||||
type boltDB struct {
|
|
||||||
db *bolt.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *boltDB) Setup() error {
|
|
||||||
return SetupBoltDB(b.db)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *boltDB) BatchCheckEventsExist(eventIDs []string) map[string]bool {
|
|
||||||
return BatchCheckEventsExist(b.db, eventIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *boltDB) BatchWriteEvents(events []EventBlob) error {
|
|
||||||
return BatchWriteEvents(b.db, events)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupBoltDB(boltdb *bolt.DB) error {
|
func SetupBoltDB(boltdb *bolt.DB) error {
|
||||||
@@ -39,19 +18,10 @@ func SetupBoltDB(boltdb *bolt.DB) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Functions
|
|
||||||
|
|
||||||
const BucketName string = "events"
|
|
||||||
|
|
||||||
type EventBlob struct {
|
|
||||||
ID string
|
|
||||||
JSON string
|
|
||||||
}
|
|
||||||
|
|
||||||
func BatchCheckEventsExist(boltdb *bolt.DB, eventIDs []string) map[string]bool {
|
func BatchCheckEventsExist(boltdb *bolt.DB, eventIDs []string) map[string]bool {
|
||||||
existsMap := make(map[string]bool)
|
existsMap := make(map[string]bool)
|
||||||
|
|
||||||
boltdb.View(func(tx *bolt.Tx) error {
|
_ = boltdb.View(func(tx *bolt.Tx) error {
|
||||||
bucket := tx.Bucket([]byte(BucketName))
|
bucket := tx.Bucket([]byte(BucketName))
|
||||||
if bucket == nil {
|
if bucket == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ func ToCypherProps(keys []string, prefix string) string {
|
|||||||
if prefix == "" {
|
if prefix == "" {
|
||||||
prefix = "$"
|
prefix = "$"
|
||||||
}
|
}
|
||||||
cypherPropsParts := []string{}
|
|
||||||
|
var cypherPropsParts []string
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
cypherPropsParts = append(
|
cypherPropsParts = append(
|
||||||
cypherPropsParts, fmt.Sprintf("%s: %s%s", key, prefix, key))
|
cypherPropsParts, fmt.Sprintf("%s: %s%s", key, prefix, key))
|
||||||
|
|||||||
82
expanders.go
82
expanders.go
@@ -1,82 +0,0 @@
|
|||||||
package heartwood
|
|
||||||
|
|
||||||
import (
|
|
||||||
roots "git.wisehodl.dev/jay/go-roots/events"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Expander func(e roots.Event, s *EventSubgraph)
|
|
||||||
type ExpanderPipeline []Expander
|
|
||||||
|
|
||||||
func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline {
|
|
||||||
return ExpanderPipeline(expanders)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DefaultExpanders() []Expander {
|
|
||||||
return []Expander{
|
|
||||||
ExpandTaggedEvents,
|
|
||||||
ExpandTaggedUsers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default Expander Functions
|
|
||||||
|
|
||||||
func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) {
|
|
||||||
tagNodes := s.NodesByLabel("Tag")
|
|
||||||
for _, tag := range e.Tags {
|
|
||||||
if !isValidTag(tag) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := tag[0]
|
|
||||||
value := tag[1]
|
|
||||||
|
|
||||||
if name != "e" || !roots.Hex64Pattern.MatchString(value) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
tagNode := findTagNode(tagNodes, name, value)
|
|
||||||
if tagNode == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
referencedEvent := NewEventNode(value)
|
|
||||||
|
|
||||||
s.AddNode(referencedEvent)
|
|
||||||
s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) {
|
|
||||||
tagNodes := s.NodesByLabel("Tag")
|
|
||||||
for _, tag := range e.Tags {
|
|
||||||
if !isValidTag(tag) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := tag[0]
|
|
||||||
value := tag[1]
|
|
||||||
|
|
||||||
if name != "p" || !roots.Hex64Pattern.MatchString(value) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
tagNode := findTagNode(tagNodes, name, value)
|
|
||||||
if tagNode == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
referencedEvent := NewUserNode(value)
|
|
||||||
|
|
||||||
s.AddNode(referencedEvent)
|
|
||||||
s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helpers
|
|
||||||
|
|
||||||
func findTagNode(nodes []*Node, name, value string) *Node {
|
|
||||||
for _, node := range nodes {
|
|
||||||
if node.Props["name"] == name && node.Props["value"] == value {
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -252,7 +252,7 @@ func UnmarshalGraphJSON(data []byte, f *GraphFilter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func marshalGraphArray(filters []GraphFilter) ([]json.RawMessage, error) {
|
func marshalGraphArray(filters []GraphFilter) ([]json.RawMessage, error) {
|
||||||
result := []json.RawMessage{}
|
result := make([]json.RawMessage, 0, len(filters))
|
||||||
for _, f := range filters {
|
for _, f := range filters {
|
||||||
b, err := MarshalGraphJSON(f)
|
b, err := MarshalGraphJSON(f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -268,7 +268,7 @@ func unmarshalGraphArray(raws json.RawMessage) ([]GraphFilter, error) {
|
|||||||
if err := json.Unmarshal(raws, &rawArray); err != nil {
|
if err := json.Unmarshal(raws, &rawArray); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var result []GraphFilter
|
result := make([]GraphFilter, 0, len(rawArray))
|
||||||
for _, raw := range rawArray {
|
for _, raw := range rawArray {
|
||||||
var f GraphFilter
|
var f GraphFilter
|
||||||
if err := UnmarshalGraphJSON(raw, &f); err != nil {
|
if err := UnmarshalGraphJSON(raw, &f); err != nil {
|
||||||
|
|||||||
13
graph.go
13
graph.go
@@ -2,7 +2,6 @@ package heartwood
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ========================================
|
// ========================================
|
||||||
@@ -33,7 +32,7 @@ type SimpleMatchKeys struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *SimpleMatchKeys) GetLabels() []string {
|
func (p *SimpleMatchKeys) GetLabels() []string {
|
||||||
labels := []string{}
|
labels := make([]string, 0, len(p.Keys))
|
||||||
for l := range p.Keys {
|
for l := range p.Keys {
|
||||||
labels = append(labels, l)
|
labels = append(labels, l)
|
||||||
}
|
}
|
||||||
@@ -43,9 +42,8 @@ func (p *SimpleMatchKeys) GetLabels() []string {
|
|||||||
func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) {
|
func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) {
|
||||||
if keys, exists := p.Keys[label]; exists {
|
if keys, exists := p.Keys[label]; exists {
|
||||||
return keys, exists
|
return keys, exists
|
||||||
} else {
|
|
||||||
return nil, exists
|
|
||||||
}
|
}
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========================================
|
// ========================================
|
||||||
@@ -56,7 +54,7 @@ func (p *SimpleMatchKeys) GetKeys(label string) ([]string, bool) {
|
|||||||
// properties.
|
// properties.
|
||||||
type Node struct {
|
type Node struct {
|
||||||
// Set of labels on the node.
|
// Set of labels on the node.
|
||||||
Labels Set[string]
|
Labels *StringSet
|
||||||
// Mapping of properties on the node.
|
// Mapping of properties on the node.
|
||||||
Props Properties
|
Props Properties
|
||||||
}
|
}
|
||||||
@@ -67,7 +65,7 @@ func NewNode(label string, props Properties) *Node {
|
|||||||
props = make(Properties)
|
props = make(Properties)
|
||||||
}
|
}
|
||||||
return &Node{
|
return &Node{
|
||||||
Labels: NewSet(label),
|
Labels: NewStringSet(label),
|
||||||
Props: props,
|
Props: props,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -79,8 +77,7 @@ func (n *Node) MatchProps(
|
|||||||
|
|
||||||
// Iterate over each label on the node, checking whether each has match
|
// Iterate over each label on the node, checking whether each has match
|
||||||
// keys associated with it.
|
// keys associated with it.
|
||||||
labels := n.Labels.ToArray()
|
labels := n.Labels.AsSortedArray()
|
||||||
sort.Strings(labels)
|
|
||||||
for _, label := range labels {
|
for _, label := range labels {
|
||||||
if keys, exists := matchProvider.GetKeys(label); exists {
|
if keys, exists := matchProvider.GetKeys(label); exists {
|
||||||
props := make(Properties)
|
props := make(Properties)
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func TestMatchProps(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "multiple labels, one matches",
|
name: "multiple labels, one matches",
|
||||||
node: &Node{
|
node: &Node{
|
||||||
Labels: NewSet("Event", "Unknown"),
|
Labels: NewStringSet("Event", "Unknown"),
|
||||||
Props: Properties{
|
Props: Properties{
|
||||||
"id": "abc123",
|
"id": "abc123",
|
||||||
},
|
},
|
||||||
|
|||||||
56
neo4j.go
56
neo4j.go
@@ -5,26 +5,6 @@ import (
|
|||||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Interface
|
|
||||||
|
|
||||||
type GraphDB interface {
|
|
||||||
MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGraphDriver(driver neo4j.Driver) GraphDB {
|
|
||||||
return &graphdb{driver: driver}
|
|
||||||
}
|
|
||||||
|
|
||||||
type graphdb struct {
|
|
||||||
driver neo4j.Driver
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *graphdb) MergeSubgraph(ctx context.Context, subgraph *BatchSubgraph) ([]neo4j.ResultSummary, error) {
|
|
||||||
return MergeSubgraph(ctx, n.driver, subgraph)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Functions
|
|
||||||
|
|
||||||
func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver, error) {
|
func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver, error) {
|
||||||
driver, err := neo4j.NewDriver(
|
driver, err := neo4j.NewDriver(
|
||||||
uri,
|
uri,
|
||||||
@@ -40,3 +20,39 @@ func ConnectNeo4j(ctx context.Context, uri, user, password string) (neo4j.Driver
|
|||||||
|
|
||||||
return driver, nil
|
return driver, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNeo4jSchema ensures that the necessary indexes and constraints exist in
|
||||||
|
// the database
|
||||||
|
func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error {
|
||||||
|
schemaQueries := []string{
|
||||||
|
`CREATE CONSTRAINT user_pubkey IF NOT EXISTS
|
||||||
|
FOR (n:User) REQUIRE n.pubkey IS UNIQUE`,
|
||||||
|
|
||||||
|
`CREATE INDEX user_pubkey IF NOT EXISTS
|
||||||
|
FOR (n:User) ON (n.pubkey)`,
|
||||||
|
|
||||||
|
`CREATE INDEX event_id IF NOT EXISTS
|
||||||
|
FOR (n:Event) ON (n.id)`,
|
||||||
|
|
||||||
|
`CREATE INDEX event_kind IF NOT EXISTS
|
||||||
|
FOR (n:Event) ON (n.kind)`,
|
||||||
|
|
||||||
|
`CREATE INDEX tag_name_value IF NOT EXISTS
|
||||||
|
FOR (n:Tag) ON (n.name, n.value)`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create indexes and constraints
|
||||||
|
for _, query := range schemaQueries {
|
||||||
|
_, err := neo4j.ExecuteQuery(ctx, driver,
|
||||||
|
query,
|
||||||
|
nil,
|
||||||
|
neo4j.EagerResultTransformer,
|
||||||
|
neo4j.ExecuteQueryWithDatabase("neo4j"))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
44
schema.go
44
schema.go
@@ -1,9 +1,7 @@
|
|||||||
package heartwood
|
package heartwood
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ========================================
|
// ========================================
|
||||||
@@ -80,7 +78,7 @@ func validateNodeLabel(node *Node, role string, expectedLabel string) {
|
|||||||
if !node.Labels.Contains(expectedLabel) {
|
if !node.Labels.Contains(expectedLabel) {
|
||||||
panic(fmt.Errorf(
|
panic(fmt.Errorf(
|
||||||
"expected %s node to have label %q. got %v",
|
"expected %s node to have label %q. got %v",
|
||||||
role, expectedLabel, node.Labels.ToArray(),
|
role, expectedLabel, node.Labels.AsSortedArray(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,43 +96,3 @@ func NewRelationshipWithValidation(
|
|||||||
|
|
||||||
return NewRelationship(rtype, start, end, props)
|
return NewRelationship(rtype, start, end, props)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========================================
|
|
||||||
// Schema Indexes and Constraints
|
|
||||||
// ========================================
|
|
||||||
|
|
||||||
// SetNeo4jSchema ensures that the necessary indexes and constraints exist in
|
|
||||||
// the database
|
|
||||||
func SetNeo4jSchema(ctx context.Context, driver neo4j.Driver) error {
|
|
||||||
schemaQueries := []string{
|
|
||||||
`CREATE CONSTRAINT user_pubkey IF NOT EXISTS
|
|
||||||
FOR (n:User) REQUIRE n.pubkey IS UNIQUE`,
|
|
||||||
|
|
||||||
`CREATE INDEX user_pubkey IF NOT EXISTS
|
|
||||||
FOR (n:User) ON (n.pubkey)`,
|
|
||||||
|
|
||||||
`CREATE INDEX event_id IF NOT EXISTS
|
|
||||||
FOR (n:Event) ON (n.id)`,
|
|
||||||
|
|
||||||
`CREATE INDEX event_kind IF NOT EXISTS
|
|
||||||
FOR (n:Event) ON (n.kind)`,
|
|
||||||
|
|
||||||
`CREATE INDEX tag_name_value IF NOT EXISTS
|
|
||||||
FOR (n:Tag) ON (n.name, n.value)`,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create indexes and constraints
|
|
||||||
for _, query := range schemaQueries {
|
|
||||||
_, err := neo4j.ExecuteQuery(ctx, driver,
|
|
||||||
query,
|
|
||||||
nil,
|
|
||||||
neo4j.EagerResultTransformer,
|
|
||||||
neo4j.ExecuteQueryWithDatabase("neo4j"))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ func TestNewRelationshipWithValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
rel := NewSignedRel(tc.start, tc.end, nil)
|
rel := NewSignedRel(tc.start, tc.end, nil)
|
||||||
assert.Equal(t, "SIGNED", rel.Type)
|
assert.Equal(t, "SIGNED", rel.Type)
|
||||||
assert.Contains(t, rel.Start.Labels.ToArray(), "User")
|
assert.Contains(t, rel.Start.Labels.AsSortedArray(), "User")
|
||||||
assert.Contains(t, rel.End.Labels.ToArray(), "Event")
|
assert.Contains(t, rel.End.Labels.AsSortedArray(), "Event")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
52
set.go
52
set.go
@@ -1,14 +1,20 @@
|
|||||||
package heartwood
|
package heartwood
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
// Sets
|
// Sets
|
||||||
|
|
||||||
type Set[T comparable] struct {
|
type StringSet struct {
|
||||||
inner map[T]struct{}
|
inner map[string]struct{}
|
||||||
|
sorted []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSet[T comparable](items ...T) Set[T] {
|
func NewStringSet(items ...string) *StringSet {
|
||||||
set := Set[T]{
|
set := &StringSet{
|
||||||
inner: make(map[T]struct{}),
|
inner: make(map[string]struct{}),
|
||||||
|
sorted: []string{},
|
||||||
}
|
}
|
||||||
for _, i := range items {
|
for _, i := range items {
|
||||||
set.Add(i)
|
set.Add(i)
|
||||||
@@ -16,20 +22,26 @@ func NewSet[T comparable](items ...T) Set[T] {
|
|||||||
return set
|
return set
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Set[T]) Add(item T) {
|
func (s *StringSet) Add(item string) {
|
||||||
s.inner[item] = struct{}{}
|
if _, exists := s.inner[item]; !exists {
|
||||||
|
s.inner[item] = struct{}{}
|
||||||
|
s.rebuildSorted()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Set[T]) Remove(item T) {
|
func (s *StringSet) Remove(item string) {
|
||||||
delete(s.inner, item)
|
if _, exists := s.inner[item]; exists {
|
||||||
|
delete(s.inner, item)
|
||||||
|
s.rebuildSorted()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Set[T]) Contains(item T) bool {
|
func (s *StringSet) Contains(item string) bool {
|
||||||
_, exists := s.inner[item]
|
_, exists := s.inner[item]
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Set[T]) Equal(other Set[T]) bool {
|
func (s *StringSet) Equal(other StringSet) bool {
|
||||||
if len(s.inner) != len(other.inner) {
|
if len(s.inner) != len(other.inner) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -41,14 +53,18 @@ func (s Set[T]) Equal(other Set[T]) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Set[T]) Length() int {
|
func (s *StringSet) Length() int {
|
||||||
return len(s.inner)
|
return len(s.inner)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Set[T]) ToArray() []T {
|
func (s *StringSet) AsSortedArray() []string {
|
||||||
array := []T{}
|
return s.sorted
|
||||||
for i := range s.inner {
|
}
|
||||||
array = append(array, i)
|
|
||||||
}
|
func (s *StringSet) rebuildSorted() {
|
||||||
return array
|
s.sorted = make([]string, 0, len(s.inner))
|
||||||
|
for item := range s.inner {
|
||||||
|
s.sorted = append(s.sorted, item)
|
||||||
|
}
|
||||||
|
sort.Strings(s.sorted)
|
||||||
}
|
}
|
||||||
|
|||||||
87
subgraph.go
87
subgraph.go
@@ -4,7 +4,7 @@ import (
|
|||||||
roots "git.wisehodl.dev/jay/go-roots/events"
|
roots "git.wisehodl.dev/jay/go-roots/events"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Event subgraph struct
|
// Types
|
||||||
|
|
||||||
type EventSubgraph struct {
|
type EventSubgraph struct {
|
||||||
nodes []*Node
|
nodes []*Node
|
||||||
@@ -35,7 +35,7 @@ func (s *EventSubgraph) Rels() []*Relationship {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *EventSubgraph) NodesByLabel(label string) []*Node {
|
func (s *EventSubgraph) NodesByLabel(label string) []*Node {
|
||||||
nodes := []*Node{}
|
var nodes []*Node
|
||||||
for _, node := range s.nodes {
|
for _, node := range s.nodes {
|
||||||
if node.Labels.Contains(label) {
|
if node.Labels.Contains(label) {
|
||||||
nodes = append(nodes, node)
|
nodes = append(nodes, node)
|
||||||
@@ -58,7 +58,7 @@ func isValidTag(t roots.Tag) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Event to subgraph conversion
|
// Event to subgraph pipeline
|
||||||
|
|
||||||
func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph {
|
func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph {
|
||||||
s := NewEventSubgraph()
|
s := NewEventSubgraph()
|
||||||
@@ -89,6 +89,8 @@ func EventToSubgraph(e roots.Event, p ExpanderPipeline) *EventSubgraph {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Core pipeline functions
|
||||||
|
|
||||||
func newEventNode(eventID string, createdAt int, kind int, content string) *Node {
|
func newEventNode(eventID string, createdAt int, kind int, content string) *Node {
|
||||||
eventNode := NewEventNode(eventID)
|
eventNode := NewEventNode(eventID)
|
||||||
eventNode.Props["created_at"] = createdAt
|
eventNode.Props["created_at"] = createdAt
|
||||||
@@ -106,7 +108,7 @@ func newSignedRel(user, event *Node) *Relationship {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTagNodes(tags []roots.Tag) []*Node {
|
func newTagNodes(tags []roots.Tag) []*Node {
|
||||||
nodes := []*Node{}
|
nodes := make([]*Node, 0, len(tags))
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if !isValidTag(tag) {
|
if !isValidTag(tag) {
|
||||||
continue
|
continue
|
||||||
@@ -117,9 +119,84 @@ func newTagNodes(tags []roots.Tag) []*Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTagRels(event *Node, tags []*Node) []*Relationship {
|
func newTagRels(event *Node, tags []*Node) []*Relationship {
|
||||||
rels := []*Relationship{}
|
rels := make([]*Relationship, 0, len(tags))
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
rels = append(rels, NewTaggedRel(event, tag, nil))
|
rels = append(rels, NewTaggedRel(event, tag, nil))
|
||||||
}
|
}
|
||||||
return rels
|
return rels
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expander Pipeline
|
||||||
|
|
||||||
|
type Expander func(e roots.Event, s *EventSubgraph)
|
||||||
|
type ExpanderPipeline []Expander
|
||||||
|
|
||||||
|
func NewExpanderPipeline(expanders ...Expander) ExpanderPipeline {
|
||||||
|
return ExpanderPipeline(expanders)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultExpanders() []Expander {
|
||||||
|
return []Expander{
|
||||||
|
ExpandTaggedEvents,
|
||||||
|
ExpandTaggedUsers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpandTaggedEvents(e roots.Event, s *EventSubgraph) {
|
||||||
|
tagNodes := s.NodesByLabel("Tag")
|
||||||
|
for _, tag := range e.Tags {
|
||||||
|
if !isValidTag(tag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := tag[0]
|
||||||
|
value := tag[1]
|
||||||
|
|
||||||
|
if name != "e" || !roots.Hex64Pattern.MatchString(value) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tagNode := findTagNode(tagNodes, name, value)
|
||||||
|
if tagNode == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
referencedEvent := NewEventNode(value)
|
||||||
|
|
||||||
|
s.AddNode(referencedEvent)
|
||||||
|
s.AddRel(NewReferencesEventRel(tagNode, referencedEvent, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpandTaggedUsers(e roots.Event, s *EventSubgraph) {
|
||||||
|
tagNodes := s.NodesByLabel("Tag")
|
||||||
|
for _, tag := range e.Tags {
|
||||||
|
if !isValidTag(tag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := tag[0]
|
||||||
|
value := tag[1]
|
||||||
|
|
||||||
|
if name != "p" || !roots.Hex64Pattern.MatchString(value) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tagNode := findTagNode(tagNodes, name, value)
|
||||||
|
if tagNode == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
referencedEvent := NewUserNode(value)
|
||||||
|
|
||||||
|
s.AddNode(referencedEvent)
|
||||||
|
s.AddRel(NewReferencesUserRel(tagNode, referencedEvent, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findTagNode(nodes []*Node, name, value string) *Node {
|
||||||
|
for _, node := range nodes {
|
||||||
|
if node.Props["name"] == name && node.Props["value"] == value {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ func nodesEqual(expected, got *Node) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compare label values
|
// Compare label values
|
||||||
for _, label := range expected.Labels.ToArray() {
|
for _, label := range expected.Labels.AsSortedArray() {
|
||||||
if !got.Labels.Contains(label) {
|
if !got.Labels.Contains(label) {
|
||||||
return fmt.Errorf("missing label %q", label)
|
return fmt.Errorf("missing label %q", label)
|
||||||
}
|
}
|
||||||
|
|||||||
192
write.go
192
write.go
@@ -5,14 +5,15 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
roots "git.wisehodl.dev/jay/go-roots/events"
|
roots "git.wisehodl.dev/jay/go-roots/events"
|
||||||
|
"github.com/boltdb/bolt"
|
||||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WriteOptions struct {
|
type WriteOptions struct {
|
||||||
Expanders ExpanderPipeline
|
Expanders ExpanderPipeline
|
||||||
KVReadBatchSize int
|
BoltReadBatchSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
type EventFollower struct {
|
type EventFollower struct {
|
||||||
@@ -39,7 +40,7 @@ type WriteReport struct {
|
|||||||
|
|
||||||
func WriteEvents(
|
func WriteEvents(
|
||||||
events []string,
|
events []string,
|
||||||
graphdb GraphDB, boltdb BoltDB,
|
driver neo4j.Driver, boltdb *bolt.DB,
|
||||||
opts *WriteOptions,
|
opts *WriteOptions,
|
||||||
) (WriteReport, error) {
|
) (WriteReport, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@@ -50,7 +51,7 @@ func WriteEvents(
|
|||||||
|
|
||||||
setDefaultWriteOptions(opts)
|
setDefaultWriteOptions(opts)
|
||||||
|
|
||||||
err := boltdb.Setup()
|
err := SetupBoltDB(boltdb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WriteReport{}, fmt.Errorf("error setting up bolt db: %w", err)
|
return WriteReport{}, fmt.Errorf("error setting up bolt db: %w", err)
|
||||||
}
|
}
|
||||||
@@ -58,80 +59,55 @@ func WriteEvents(
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
// Create Event Followers
|
// Create Event Followers
|
||||||
jsonChan := make(chan string, 10)
|
jsonChan := make(chan string)
|
||||||
eventChan := make(chan EventFollower, 10)
|
eventChan := make(chan EventFollower)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go createEventFollowers(&wg, jsonChan, eventChan)
|
||||||
defer wg.Done()
|
|
||||||
createEventFollowers(jsonChan, eventChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Parse Event JSON
|
// Parse Event JSON
|
||||||
parsedChan := make(chan EventFollower, 10)
|
parsedChan := make(chan EventFollower)
|
||||||
invalidChan := make(chan EventFollower, 10)
|
invalidChan := make(chan EventFollower)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go parseEventJSON(&wg, eventChan, parsedChan, invalidChan)
|
||||||
defer wg.Done()
|
|
||||||
parseEventJSON(eventChan, parsedChan, invalidChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Collect Invalid Events
|
// Collect Invalid Events
|
||||||
collectedInvalidChan := make(chan []EventFollower)
|
collectedInvalidChan := make(chan []EventFollower)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go collectEvents(&wg, invalidChan, collectedInvalidChan)
|
||||||
defer wg.Done()
|
|
||||||
collectEvents(invalidChan, collectedInvalidChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Enforce Policy Rules
|
// Enforce Policy Rules
|
||||||
queuedChan := make(chan EventFollower, 10)
|
queuedChan := make(chan EventFollower)
|
||||||
skippedChan := make(chan EventFollower, 10)
|
skippedChan := make(chan EventFollower)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go enforcePolicyRules(&wg, driver, boltdb, opts.BoltReadBatchSize,
|
||||||
defer wg.Done()
|
parsedChan, queuedChan, skippedChan)
|
||||||
enforcePolicyRules(
|
|
||||||
graphdb, boltdb,
|
|
||||||
opts.KVReadBatchSize,
|
|
||||||
parsedChan, queuedChan, skippedChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Collect Skipped Events
|
// Collect Skipped Events
|
||||||
collectedSkippedChan := make(chan []EventFollower)
|
collectedSkippedChan := make(chan []EventFollower)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go collectEvents(&wg, skippedChan, collectedSkippedChan)
|
||||||
defer wg.Done()
|
|
||||||
collectEvents(skippedChan, collectedSkippedChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Convert Events To Subgraphs
|
// Convert Events To Subgraphs
|
||||||
convertedChan := make(chan EventFollower, 10)
|
convertedChan := make(chan EventFollower)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go convertEventsToSubgraphs(&wg, opts.Expanders, queuedChan, convertedChan)
|
||||||
defer wg.Done()
|
|
||||||
convertEventsToSubgraphs(opts.Expanders, queuedChan, convertedChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Write Events To Databases
|
// Write Events To Databases
|
||||||
writeResultChan := make(chan WriteResult)
|
writeResultChan := make(chan WriteResult)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go writeEventsToDatabases(&wg, driver, boltdb, convertedChan, writeResultChan)
|
||||||
defer wg.Done()
|
|
||||||
writeEventsToDatabases(
|
|
||||||
graphdb, boltdb,
|
|
||||||
convertedChan, writeResultChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Send event jsons into pipeline
|
// Send event jsons into pipeline
|
||||||
go func() {
|
go func() {
|
||||||
for _, json := range events {
|
for _, raw := range events {
|
||||||
jsonChan <- json
|
jsonChan <- raw
|
||||||
}
|
}
|
||||||
close(jsonChan)
|
close(jsonChan)
|
||||||
}()
|
}()
|
||||||
@@ -158,19 +134,21 @@ func setDefaultWriteOptions(opts *WriteOptions) {
|
|||||||
if opts.Expanders == nil {
|
if opts.Expanders == nil {
|
||||||
opts.Expanders = NewExpanderPipeline(DefaultExpanders()...)
|
opts.Expanders = NewExpanderPipeline(DefaultExpanders()...)
|
||||||
}
|
}
|
||||||
if opts.KVReadBatchSize == 0 {
|
if opts.BoltReadBatchSize == 0 {
|
||||||
opts.KVReadBatchSize = 100
|
opts.BoltReadBatchSize = 100
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createEventFollowers(jsonChan chan string, eventChan chan EventFollower) {
|
func createEventFollowers(wg *sync.WaitGroup, jsonChan chan string, eventChan chan EventFollower) {
|
||||||
|
defer wg.Done()
|
||||||
for json := range jsonChan {
|
for json := range jsonChan {
|
||||||
eventChan <- EventFollower{JSON: json}
|
eventChan <- EventFollower{JSON: json}
|
||||||
}
|
}
|
||||||
close(eventChan)
|
close(eventChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) {
|
func parseEventJSON(wg *sync.WaitGroup, inChan, parsedChan, invalidChan chan EventFollower) {
|
||||||
|
defer wg.Done()
|
||||||
for follower := range inChan {
|
for follower := range inChan {
|
||||||
var event roots.Event
|
var event roots.Event
|
||||||
jsonBytes := []byte(follower.JSON)
|
jsonBytes := []byte(follower.JSON)
|
||||||
@@ -191,11 +169,13 @@ func parseEventJSON(inChan, parsedChan, invalidChan chan EventFollower) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func enforcePolicyRules(
|
func enforcePolicyRules(
|
||||||
graphdb GraphDB, boltdb BoltDB,
|
wg *sync.WaitGroup,
|
||||||
|
driver neo4j.Driver, boltdb *bolt.DB,
|
||||||
batchSize int,
|
batchSize int,
|
||||||
inChan, queuedChan, skippedChan chan EventFollower,
|
inChan, queuedChan, skippedChan chan EventFollower,
|
||||||
) {
|
) {
|
||||||
batch := []EventFollower{}
|
defer wg.Done()
|
||||||
|
var batch []EventFollower
|
||||||
|
|
||||||
for follower := range inChan {
|
for follower := range inChan {
|
||||||
batch = append(batch, follower)
|
batch = append(batch, follower)
|
||||||
@@ -215,17 +195,17 @@ func enforcePolicyRules(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func processPolicyRulesBatch(
|
func processPolicyRulesBatch(
|
||||||
boltdb BoltDB,
|
boltdb *bolt.DB,
|
||||||
batch []EventFollower,
|
batch []EventFollower,
|
||||||
queuedChan, skippedChan chan EventFollower,
|
queuedChan, skippedChan chan EventFollower,
|
||||||
) {
|
) {
|
||||||
eventIDs := []string{}
|
eventIDs := make([]string, 0, len(batch))
|
||||||
|
|
||||||
for _, follower := range batch {
|
for _, follower := range batch {
|
||||||
eventIDs = append(eventIDs, follower.ID)
|
eventIDs = append(eventIDs, follower.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
existsMap := boltdb.BatchCheckEventsExist(eventIDs)
|
existsMap := BatchCheckEventsExist(boltdb, eventIDs)
|
||||||
|
|
||||||
for _, follower := range batch {
|
for _, follower := range batch {
|
||||||
if existsMap[follower.ID] {
|
if existsMap[follower.ID] {
|
||||||
@@ -237,9 +217,10 @@ func processPolicyRulesBatch(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func convertEventsToSubgraphs(
|
func convertEventsToSubgraphs(
|
||||||
expanders ExpanderPipeline,
|
wg *sync.WaitGroup, expanders ExpanderPipeline,
|
||||||
inChan, convertedChan chan EventFollower,
|
inChan, convertedChan chan EventFollower,
|
||||||
) {
|
) {
|
||||||
|
defer wg.Done()
|
||||||
for follower := range inChan {
|
for follower := range inChan {
|
||||||
subgraph := EventToSubgraph(follower.Event, expanders)
|
subgraph := EventToSubgraph(follower.Event, expanders)
|
||||||
follower.Subgraph = subgraph
|
follower.Subgraph = subgraph
|
||||||
@@ -249,93 +230,66 @@ func convertEventsToSubgraphs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func writeEventsToDatabases(
|
func writeEventsToDatabases(
|
||||||
graphdb GraphDB, boltdb BoltDB,
|
wg *sync.WaitGroup,
|
||||||
|
driver neo4j.Driver, boltdb *bolt.DB,
|
||||||
inChan chan EventFollower,
|
inChan chan EventFollower,
|
||||||
resultChan chan WriteResult,
|
resultChan chan WriteResult,
|
||||||
) {
|
) {
|
||||||
var wg sync.WaitGroup
|
defer wg.Done()
|
||||||
|
var localWg sync.WaitGroup
|
||||||
|
|
||||||
kvEventChan := make(chan EventFollower, 10)
|
boltEventChan := make(chan EventFollower)
|
||||||
graphEventChan := make(chan EventFollower, 10)
|
graphEventChan := make(chan EventFollower)
|
||||||
|
|
||||||
kvWriteDone := make(chan struct{})
|
boltErrorChan := make(chan error)
|
||||||
|
|
||||||
kvErrorChan := make(chan error)
|
|
||||||
graphResultChan := make(chan WriteResult)
|
graphResultChan := make(chan WriteResult)
|
||||||
|
|
||||||
wg.Add(2)
|
localWg.Add(2)
|
||||||
go func() {
|
go writeEventsToBoltDB(&localWg, boltdb, boltEventChan, boltErrorChan)
|
||||||
defer wg.Done()
|
go writeEventsToGraphDB(&localWg, driver, graphEventChan, boltErrorChan, graphResultChan)
|
||||||
writeEventsToKVStore(
|
|
||||||
boltdb,
|
|
||||||
kvEventChan, kvWriteDone, kvErrorChan)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
writeEventsToGraphDriver(
|
|
||||||
graphdb,
|
|
||||||
graphEventChan, kvWriteDone, graphResultChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Fan out events to both writers
|
// Fan out events to both writers
|
||||||
for follower := range inChan {
|
for follower := range inChan {
|
||||||
kvEventChan <- follower
|
boltEventChan <- follower
|
||||||
graphEventChan <- follower
|
graphEventChan <- follower
|
||||||
}
|
}
|
||||||
close(kvEventChan)
|
close(boltEventChan)
|
||||||
close(graphEventChan)
|
close(graphEventChan)
|
||||||
|
|
||||||
wg.Wait()
|
localWg.Wait()
|
||||||
|
|
||||||
kvError := <-kvErrorChan
|
|
||||||
graphResult := <-graphResultChan
|
graphResult := <-graphResultChan
|
||||||
|
resultChan <- graphResult
|
||||||
var finalErr error
|
|
||||||
if kvError != nil && graphResult.Error != nil {
|
|
||||||
finalErr = fmt.Errorf("kvstore: %w; graphstore: %v", kvError, graphResult.Error)
|
|
||||||
} else if kvError != nil {
|
|
||||||
finalErr = fmt.Errorf("kvstore: %w", kvError)
|
|
||||||
} else if graphResult.Error != nil {
|
|
||||||
finalErr = fmt.Errorf("graphstore: %w", graphResult.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
resultChan <- WriteResult{
|
|
||||||
ResultSummaries: graphResult.ResultSummaries,
|
|
||||||
Error: finalErr,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeEventsToKVStore(
|
func writeEventsToBoltDB(
|
||||||
boltdb BoltDB,
|
wg *sync.WaitGroup,
|
||||||
|
boltdb *bolt.DB,
|
||||||
inChan chan EventFollower,
|
inChan chan EventFollower,
|
||||||
done chan struct{},
|
errorChan chan error,
|
||||||
resultChan chan error,
|
|
||||||
) {
|
) {
|
||||||
events := []EventBlob{}
|
defer wg.Done()
|
||||||
|
var events []EventBlob
|
||||||
|
|
||||||
for follower := range inChan {
|
for follower := range inChan {
|
||||||
events = append(events,
|
events = append(events,
|
||||||
EventBlob{ID: follower.ID, JSON: follower.JSON})
|
EventBlob{ID: follower.ID, JSON: follower.JSON})
|
||||||
}
|
}
|
||||||
|
|
||||||
err := boltdb.BatchWriteEvents(events)
|
err := BatchWriteEvents(boltdb, events)
|
||||||
if err != nil {
|
|
||||||
close(done)
|
|
||||||
} else {
|
|
||||||
done <- struct{}{}
|
|
||||||
close(done)
|
|
||||||
}
|
|
||||||
|
|
||||||
resultChan <- err
|
errorChan <- err
|
||||||
close(resultChan)
|
close(errorChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeEventsToGraphDriver(
|
func writeEventsToGraphDB(
|
||||||
graphdb GraphDB,
|
wg *sync.WaitGroup,
|
||||||
|
driver neo4j.Driver,
|
||||||
inChan chan EventFollower,
|
inChan chan EventFollower,
|
||||||
start chan struct{},
|
boltErrorChan chan error,
|
||||||
resultChan chan WriteResult,
|
resultChan chan WriteResult,
|
||||||
) {
|
) {
|
||||||
|
defer wg.Done()
|
||||||
matchKeys := NewSimpleMatchKeys()
|
matchKeys := NewSimpleMatchKeys()
|
||||||
batch := NewBatchSubgraph(matchKeys)
|
batch := NewBatchSubgraph(matchKeys)
|
||||||
|
|
||||||
@@ -348,14 +302,17 @@ func writeEventsToGraphDriver(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := <-start
|
boltErr := <-boltErrorChan
|
||||||
if !ok {
|
if boltErr != nil {
|
||||||
resultChan <- WriteResult{Error: fmt.Errorf("kv write failed, aborting graph write")}
|
resultChan <- WriteResult{
|
||||||
|
Error: fmt.Errorf(
|
||||||
|
"boltdb write failed, aborting graph write: %w", boltErr,
|
||||||
|
)}
|
||||||
close(resultChan)
|
close(resultChan)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
summaries, err := graphdb.MergeSubgraph(context.Background(), batch)
|
summaries, err := MergeSubgraph(context.Background(), driver, batch)
|
||||||
resultChan <- WriteResult{
|
resultChan <- WriteResult{
|
||||||
ResultSummaries: summaries,
|
ResultSummaries: summaries,
|
||||||
Error: err,
|
Error: err,
|
||||||
@@ -363,8 +320,9 @@ func writeEventsToGraphDriver(
|
|||||||
close(resultChan)
|
close(resultChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func collectEvents(inChan chan EventFollower, resultChan chan []EventFollower) {
|
func collectEvents(wg *sync.WaitGroup, inChan chan EventFollower, resultChan chan []EventFollower) {
|
||||||
collected := []EventFollower{}
|
defer wg.Done()
|
||||||
|
var collected []EventFollower
|
||||||
for follower := range inChan {
|
for follower := range inChan {
|
||||||
collected = append(collected, follower)
|
collected = append(collected, follower)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user