From e9028a1f7193a40bbe5193417bf0f2da977ba984 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:06:57 -0400 Subject: [PATCH 01/13] feat(typed): generic, type-safe client and query builder Add a generic typed layer over modusgraph.Client: typed.Client[T] with CRUD and iterators; a fluent Query[T] builder (filters, ordering, paging, edge traversal, IterNodes); MultiQuery for N homogeneous blocks in one round-trip; functional options; a filter DSL (typed/filter); and ordered result merging (typed/search). A small no-op-by-default Tracer seam (typed.SetTracer) lets a host plug in tracing without the typed package depending on any telemetry library. Self-contained: builds and tests against the current client with no other changes. --- typed/client.go | 87 +++ typed/client_test.go | 209 ++++++ typed/filter/filter.go | 118 +++ typed/filter/filter_test.go | 118 +++ typed/filter/fulltext.go | 21 + typed/filter/fulltext_test.go | 41 ++ typed/multi_query.go | 191 +++++ typed/multi_query_test.go | 127 ++++ typed/option.go | 17 + typed/option_test.go | 37 + typed/query.go | 565 ++++++++++++++ typed/query_test.go | 1294 +++++++++++++++++++++++++++++++++ typed/search/merge.go | 27 + typed/search/merge_test.go | 86 +++ typed/tracing.go | 58 ++ typed/tracing_test.go | 47 ++ 16 files changed, 3043 insertions(+) create mode 100644 typed/client.go create mode 100644 typed/client_test.go create mode 100644 typed/filter/filter.go create mode 100644 typed/filter/filter_test.go create mode 100644 typed/filter/fulltext.go create mode 100644 typed/filter/fulltext_test.go create mode 100644 typed/multi_query.go create mode 100644 typed/multi_query_test.go create mode 100644 typed/option.go create mode 100644 typed/option_test.go create mode 100644 typed/query.go create mode 100644 typed/query_test.go create mode 100644 typed/search/merge.go create mode 100644 typed/search/merge_test.go create mode 100644 typed/tracing.go create mode 100644 typed/tracing_test.go diff --git a/typed/client.go b/typed/client.go new file mode 100644 index 0000000..c540f89 --- /dev/null +++ b/typed/client.go @@ -0,0 +1,87 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package typed binds a Go type to the otherwise any-typed modusgraph.Client, +// providing generic, type-safe CRUD and query operations without per-entity +// code generation. It is the handwritten substrate that modusgraph-gen's +// generated clients compose over. +package typed + +import ( + "context" + "iter" + + "github.com/matthewmcneely/modusgraph" +) + +// Client provides type-safe CRUD and query operations over records of type T. +// T is the schema struct (for example schema.Actor); modusgraph reflects over +// the struct's dgraph/json tags, so T needs no constraint. +type Client[T any] struct { + conn modusgraph.Client +} + +// NewClient binds a Client[T] to conn. +func NewClient[T any](conn modusgraph.Client) *Client[T] { + return &Client[T]{conn: conn} +} + +// Get loads the T with the given UID. +func (c *Client[T]) Get(ctx context.Context, uid string) (rec *T, err error) { + ctx, span := tracer.StartSpan(ctx, "get", entityName[T]()) + defer func() { span.End(err) }() + var out T + if err = c.conn.Get(ctx, &out, uid); err != nil { + return nil, err + } + return &out, nil +} + +// Add inserts a new T. modusgraph writes the assigned UID back into rec. +func (c *Client[T]) Add(ctx context.Context, rec *T) (err error) { + ctx, span := tracer.StartSpan(ctx, "add", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Insert(ctx, rec) +} + +// Update modifies an existing T (must have its UID set). +func (c *Client[T]) Update(ctx context.Context, rec *T) (err error) { + ctx, span := tracer.StartSpan(ctx, "update", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Update(ctx, rec) +} + +// Upsert inserts or updates rec, matching against predicates. With no +// predicates, the first field tagged dgraph:"upsert" is used. +func (c *Client[T]) Upsert(ctx context.Context, rec *T, predicates ...string) (err error) { + ctx, span := tracer.StartSpan(ctx, "upsert", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Upsert(ctx, rec, predicates...) +} + +// Delete removes the T with the given UID. +func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { + ctx, span := tracer.StartSpan(ctx, "delete", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Delete(ctx, []string{uid}) +} + +// Query returns a typed query builder for T. conn and ctx are carried so the +// builder can run a WhereEdge pre-pass (see Query.WhereEdge) if one is needed. +func (c *Client[T]) Query(ctx context.Context) *Query[T] { + var z T + return &Query[T]{q: c.conn.Query(ctx, &z), conn: c.conn, ctx: ctx} +} + +// defaultPageSize is the page size IterNodes uses to page through results. +const defaultPageSize = 50 + +// Iter returns an iterator over every T, paging transparently so large result +// sets are not materialized at once. It yields each record in turn; on error +// it yields a final (nil, err) and stops. All pages execute against one +// read-only transaction, so the iteration reads a single consistent snapshot. +func (c *Client[T]) Iter(ctx context.Context) iter.Seq2[*T, error] { + return c.Query(ctx).IterNodes() +} diff --git a/typed/client_test.go b/typed/client_test.go new file mode 100644 index 0000000..6fa2b1d --- /dev/null +++ b/typed/client_test.go @@ -0,0 +1,209 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// widget is a minimal schema struct used to exercise the typed package. +type widget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +// owner and pet exercise Query.WhereEdge: owner has an outbound "pets" edge to +// pet, and pet's Name carries an index so eq(name, ...) resolves inside an edge +// filter. The pair is the typed-package analogue of the Person/Dog example in +// docs/specs/2026-05-21-query-edge-filter-design.md. +type owner struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Pets []*pet `json:"pets,omitempty"` +} + +type pet struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +// newConn builds a local file-backed modusgraph client for a test. +func newConn(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestClient_AddPopulatesUIDAndGetReadsBack(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if w.UID == "" { + t.Fatal("Add did not populate UID on the passed struct") + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Name != "sprocket" || got.Qty != 3 { + t.Fatalf("Get returned %+v, want Name=sprocket Qty=3", got) + } +} + +func TestClient_Update(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "gear", Qty: 1} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + w.Qty = 99 + if err := c.Update(ctx, w); err != nil { + t.Fatalf("Update: %v", err) + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Qty != 99 { + t.Fatalf("Update did not persist; Qty = %d, want 99", got.Qty) + } +} + +func TestClient_Delete(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "bolt"} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if err := c.Delete(ctx, w.UID); err != nil { + t.Fatalf("Delete: %v", err) + } + if _, err := c.Get(ctx, w.UID); err == nil { + t.Fatal("Get after Delete returned no error; expected not-found") + } +} + +func TestClient_IterPagesThroughAllRecords(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // 125 is deliberately larger than the package's 50-record page size, so + // a correct Iter must fetch more than one page. + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("Iter yielded %d records, want %d", seen, n) + } +} + +// gadget is a dedicated upsert struct. It must not be the shared widget, because +// widget is used in tests that insert many records with duplicate Name values; +// adding a "upsert" directive to widget.Name would cause those inserts to +// collide and break unrelated tests. +type gadget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Label string `json:"label,omitempty" dgraph:"index=exact upsert"` + Stock int `json:"stock,omitempty" dgraph:"index=int"` +} + +func TestClient_Upsert(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[gadget](newConn(t)) + + // First call — creates the record. + g := &gadget{Label: "sprocket", Stock: 10} + if err := c.Upsert(ctx, g, "label"); err != nil { + t.Fatalf("Upsert (create): %v", err) + } + if g.UID == "" { + t.Fatal("Upsert (create) did not populate UID") + } + + // Second call — same Label value, different Stock. Must UPDATE, not insert. + g2 := &gadget{Label: "sprocket", Stock: 99} + if err := c.Upsert(ctx, g2, "label"); err != nil { + t.Fatalf("Upsert (update): %v", err) + } + + // Exactly one record must exist and it must carry the updated Stock. + nodes, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Query after Upsert: %v", err) + } + if len(nodes) != 1 { + t.Fatalf("got %d gadgets after two upserts on the same label, want 1", len(nodes)) + } + if nodes[0].Stock != 99 { + t.Fatalf("upserted gadget Stock = %d, want 99", nodes[0].Stock) + } +} + +func TestClient_IterStopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("Iter yielded %d records after break at 10, want 10", seen) + } +} diff --git a/typed/filter/filter.go b/typed/filter/filter.go new file mode 100644 index 0000000..d67f118 --- /dev/null +++ b/typed/filter/filter.go @@ -0,0 +1,118 @@ +// Package filter provides typed values and a parameterised expression builder +// for composing dgraph @filter clauses on generated Query types. +// +// Generated By methods accept []UUID or []String and feed them into +// Builder.EqGroupUUID / Builder.EqGroupString. Consumers can also build +// custom expressions directly with Builder for cases the generator does not +// cover (multi-predicate joins, non-equality operators, domain defaults). +package filter + +import ( + "fmt" + "strings" +) + +// UUID is one UUID-valued filter term, optionally negated. A leading "!" in +// the parsed source negates the term ("!abc" becomes {Negated: true, Value: "abc"}). +type UUID struct { + Negated bool + Value string +} + +// String is one string-valued filter term, optionally negated. +type String struct { + Negated bool + Value string +} + +// ParseUUID parses "value" or "!value" into a UUID. +func ParseUUID(s string) UUID { + neg, v := parseNegation(s) + return UUID{Negated: neg, Value: v} +} + +// ParseString parses "value" or "!value" into a String. +func ParseString(s string) String { + neg, v := parseNegation(s) + return String{Negated: neg, Value: v} +} + +func parseNegation(s string) (bool, string) { + if strings.HasPrefix(s, "!") { + return true, s[1:] + } + return false, s +} + +// term is one predicate-agnostic value used by Builder. +type term struct { + value string + negated bool +} + +// Builder composes parameterised DQL @filter expressions. Terms within an +// EqGroup join with OR; groups join with AND. Required terms become their own +// single-term group. The output is the (expression, positional params) pair +// that typed.Query[T].Filter consumes. +type Builder struct { + groups []string + params []any +} + +func (b *Builder) param(v any) string { + b.params = append(b.params, v) + return fmt.Sprintf("$%d", len(b.params)) +} + +// EqGroupUUID adds an OR-group of eq(predicate, value) terms for one +// UUID-typed predicate. An empty terms slice is a no-op. +func (b *Builder) EqGroupUUID(predicate string, terms []UUID) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +// EqGroupString adds an OR-group of eq(predicate, value) terms for one +// string-typed predicate. +func (b *Builder) EqGroupString(predicate string, terms []String) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +func (b *Builder) addEqGroup(predicate string, terms []term) { + parts := make([]string, 0, len(terms)) + for _, t := range terms { + eq := fmt.Sprintf("eq(%s, %s)", predicate, b.param(t.value)) + if t.negated { + eq = "NOT " + eq + } + parts = append(parts, eq) + } + b.groups = append(b.groups, "("+strings.Join(parts, " OR ")+")") +} + +// RequiredEq adds a single mandatory eq(predicate, value) term (its own group). +func (b *Builder) RequiredEq(predicate, value string) { + b.groups = append(b.groups, fmt.Sprintf("eq(%s, %s)", predicate, b.param(value))) +} + +// Build returns the combined DQL filter expression and its parameters. When +// no groups were added it returns ("", nil) — callers should skip the +// .Filter() call entirely in that case. +func (b *Builder) Build() (string, []any) { + if len(b.groups) == 0 { + return "", nil + } + return strings.Join(b.groups, " AND "), b.params +} diff --git a/typed/filter/filter_test.go b/typed/filter/filter_test.go new file mode 100644 index 0000000..864a554 --- /dev/null +++ b/typed/filter/filter_test.go @@ -0,0 +1,118 @@ +package filter_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestParseUUID(t *testing.T) { + tests := []struct { + name string + in string + want filter.UUID + }{ + {"plain", "abc", filter.UUID{Value: "abc"}}, + {"negated", "!abc", filter.UUID{Negated: true, Value: "abc"}}, + {"empty", "", filter.UUID{}}, + {"just bang", "!", filter.UUID{Negated: true}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := filter.ParseUUID(tt.in) + if got != tt.want { + t.Errorf("ParseUUID(%q) = %+v, want %+v", tt.in, got, tt.want) + } + }) + } +} + +func TestParseString(t *testing.T) { + got := filter.ParseString("!hello") + want := filter.String{Negated: true, Value: "hello"} + if got != want { + t.Errorf("ParseString = %+v, want %+v", got, want) + } +} + +func TestBuilder_Empty(t *testing.T) { + var b filter.Builder + expr, params := b.Build() + if expr != "" || params != nil { + t.Errorf("empty Build = (%q, %v), want (\"\", nil)", expr, params) + } +} + +func TestBuilder_EqGroupUUID_SingleTerm(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "(eq(id, $1))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 1 || params[0] != "u1" { + t.Errorf("params = %v, want [u1]", params) + } +} + +func TestBuilder_EqGroupUUID_MultipleTermsJoinWithOR(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}, {Value: "u2"}, {Negated: true, Value: "u3"}}) + expr, params := b.Build() + want := "(eq(id, $1) OR eq(id, $2) OR NOT eq(id, $3))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 3 { + t.Errorf("len(params) = %d, want 3", len(params)) + } +} + +func TestBuilder_EqGroupString_NoTermsIsNoop(t *testing.T) { + var b filter.Builder + b.EqGroupString("name", nil) + expr, _ := b.Build() + if expr != "" { + t.Errorf("empty EqGroupString should be no-op, got expr=%q", expr) + } +} + +func TestBuilder_MultipleGroupsJoinWithAND(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + b.EqGroupString("name", []filter.String{{Value: "Alice"}}) + expr, params := b.Build() + want := "(eq(id, $1)) AND (eq(name, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "u1" || params[1] != "Alice" { + t.Errorf("params = %v, want [u1 Alice]", params) + } +} + +func TestBuilder_RequiredEqIsOwnGroup(t *testing.T) { + var b filter.Builder + b.RequiredEq("archiveStatus", "Active") + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "eq(archiveStatus, $1) AND (eq(id, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 { + t.Errorf("len(params) = %d, want 2", len(params)) + } +} + +func TestBuilder_PositionalParamsAreSequential(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "a"}, {Value: "b"}}) + b.EqGroupString("name", []filter.String{{Value: "c"}}) + expr, _ := b.Build() + if !strings.Contains(expr, "$1") || !strings.Contains(expr, "$2") || !strings.Contains(expr, "$3") { + t.Errorf("expected $1, $2, $3 in expr; got %q", expr) + } +} diff --git a/typed/filter/fulltext.go b/typed/filter/fulltext.go new file mode 100644 index 0000000..a025ef0 --- /dev/null +++ b/typed/filter/fulltext.go @@ -0,0 +1,21 @@ +package filter + +import "fmt" + +// AnyOfText adds a fulltext OR-match group: anyoftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AnyOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("anyoftext(%s, %s)", predicate, b.param(term))) +} + +// AllOfText adds a fulltext AND-match group: alloftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AllOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("alloftext(%s, %s)", predicate, b.param(term))) +} diff --git a/typed/filter/fulltext_test.go b/typed/filter/fulltext_test.go new file mode 100644 index 0000000..1d71e0b --- /dev/null +++ b/typed/filter/fulltext_test.go @@ -0,0 +1,41 @@ +package filter_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestAnyOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "honda civic") + expr, params := b.Build() + if !strings.Contains(expr, "anyoftext(resourceName, $1)") { + t.Fatalf("expected anyoftext(resourceName, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "honda civic" { + t.Fatalf("expected params [\"honda civic\"], got %v", params) + } +} + +func TestAllOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AllOfText("description", "engine block") + expr, params := b.Build() + if !strings.Contains(expr, "alloftext(description, $1)") { + t.Fatalf("expected alloftext(description, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "engine block" { + t.Fatalf("expected params [\"engine block\"], got %v", params) + } +} + +func TestAnyOfTextEmptyTermIsNoop(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "") + expr, params := b.Build() + if expr != "" || params != nil { + t.Fatalf("expected empty expr/params for empty term, got %q / %v", expr, params) + } +} diff --git a/typed/multi_query.go b/typed/multi_query.go new file mode 100644 index 0000000..98409c6 --- /dev/null +++ b/typed/multi_query.go @@ -0,0 +1,191 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// MultiQuery batches N homogeneous-type Query[T] blocks into a single +// Dgraph multi-block request. All blocks return rows of the same T; the +// per-block result is keyed by the block name supplied at Add. +// +// Dgraph executes the blocks concurrently on the server side; the entire +// batch costs one gRPC round-trip. +type MultiQuery[T any] struct { + conn modusgraph.Client + names []string + blocks map[string]*Query[T] +} + +// NewMultiQuery constructs a MultiQuery bound to conn. +func NewMultiQuery[T any](conn modusgraph.Client) *MultiQuery[T] { + return &MultiQuery[T]{ + conn: conn, + blocks: make(map[string]*Query[T]), + } +} + +// Add registers a named block. Names must be unique within one MultiQuery. +// Panics on duplicate name — the call site is a programming error, not a +// runtime condition. +func (mq *MultiQuery[T]) Add(name string, q *Query[T]) *MultiQuery[T] { + if _, exists := mq.blocks[name]; exists { + panic(fmt.Sprintf("multi_query: duplicate block name %q", name)) + } + mq.names = append(mq.names, name) + mq.blocks[name] = q + return mq +} + +// BlockNames returns the registered block names in insertion order. +func (mq *MultiQuery[T]) BlockNames() []string { + out := make([]string, len(mq.names)) + copy(out, mq.names) + return out +} + +// Execute runs every registered block in a single Dgraph round-trip and +// returns the per-block results, keyed by the block name supplied at Add. +// A block that matched no rows appears as an empty (non-nil) slice in the +// result map; the key is always present. +// +// Execute rejects blocks that carry WhereEdge constraints — those require a +// runtime pre-pass that cannot be folded into the multi-block batch. Run such +// queries individually with Query.Nodes. +// +// Dgraph keys response JSON by predicate name (e.g. resourceName), but Go +// structs typically use their json tag (e.g. name). Execute remaps the keys +// per T's tags before decoding so a schema that uses `dgraph:"predicate=..."` +// with a divergent `json:"..."` decodes correctly — matching the behavior of +// dgman's QueryBlock.Scan path. +func (mq *MultiQuery[T]) Execute(ctx context.Context) (map[string][]T, error) { + if len(mq.names) == 0 { + return map[string][]T{}, nil + } + + rawBlocks := make([]*dg.Query, 0, len(mq.names)) + for _, name := range mq.names { + block := mq.blocks[name] + if len(block.edges) != 0 { + return nil, fmt.Errorf("multi_query: block %q carries WhereEdge constraints; MultiQuery cannot batch edge-filtered blocks", name) + } + // Name the underlying dgman query so blocks do not collide on the + // default "data" name and so the response JSON keys are predictable. + block.q.Name(name) + rawBlocks = append(rawBlocks, block.q) + } + + dql := dg.NewQueryBlock(rawBlocks...).String() + raw, err := mq.conn.QueryRaw(ctx, dql, nil) + if err != nil { + return nil, fmt.Errorf("multi_query: dgraph: %w", err) + } + + var perBlockRaw map[string]json.RawMessage + if err := json.Unmarshal(raw, &perBlockRaw); err != nil { + return nil, fmt.Errorf("multi_query: decoding response: %w", err) + } + + var zero T + predMap := buildPredicateToJSONMap(reflect.TypeOf(zero)) + + out := make(map[string][]T, len(mq.names)) + for _, name := range mq.names { + body, ok := perBlockRaw[name] + if !ok { + out[name] = []T{} + continue + } + if len(predMap) > 0 { + remapped, err := remapArrayKeys(body, predMap) + if err == nil { + body = remapped + } + } + var rows []T + if err := json.Unmarshal(body, &rows); err != nil { + return nil, fmt.Errorf("multi_query: decoding block %q: %w", name, err) + } + if rows == nil { + rows = []T{} + } + out[name] = rows + } + return out, nil +} + +// buildPredicateToJSONMap returns a map from dgraph predicate name → JSON tag +// name for fields on T where the two differ. Mirrors dgman's unexported helper +// of the same name; we need our own because the multi-block response from +// QueryRaw bypasses dgman's scan path. +func buildPredicateToJSONMap(t reflect.Type) map[string]string { + for t != nil && t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t == nil || t.Kind() != reflect.Struct { + return nil + } + result := make(map[string]string) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + jsonName := strings.Split(jsonTag, ",")[0] + if jsonName == "" { + continue + } + dgraphTag := field.Tag.Get("dgraph") + if dgraphTag == "" { + continue + } + var predName string + for _, part := range strings.Fields(dgraphTag) { + if strings.HasPrefix(part, "predicate=") { + predName = strings.TrimPrefix(part, "predicate=") + break + } + } + if predName == "" || predName == jsonName { + continue + } + if predName == "uid" || predName == "dgraph.type" { + continue + } + result[predName] = jsonName + } + return result +} + +// remapArrayKeys rewrites top-level keys in each object of a JSON array using +// the predicate → JSON-tag map. Nested objects are left untouched (search +// callers iterate scalar predicates of the root type; edge fields are +// hydrated lazily, not in the multi-block response). +func remapArrayKeys(data json.RawMessage, predMap map[string]string) (json.RawMessage, error) { + var rows []map[string]json.RawMessage + if err := json.Unmarshal(data, &rows); err != nil { + return data, err + } + for i, row := range rows { + for k, v := range row { + if newK, ok := predMap[k]; ok && newK != k { + delete(row, k) + row[newK] = v + } + } + rows[i] = row + } + return json.Marshal(rows) +} diff --git a/typed/multi_query_test.go b/typed/multi_query_test.go new file mode 100644 index 0000000..98f1ae4 --- /dev/null +++ b/typed/multi_query_test.go @@ -0,0 +1,127 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestMultiQueryAddAccumulatesBlocks(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q1 := typed.NewClient[widget](conn).Query(context.Background()) + q2 := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q1) + mq.Add("byQty", q2) + got := mq.BlockNames() + if len(got) != 2 || got[0] != "byName" || got[1] != "byQty" { + t.Fatalf("BlockNames = %v, want [byName, byQty]", got) + } +} + +func TestMultiQueryAddRejectsDuplicateName(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on duplicate block name") + } + }() + mq.Add("byName", q) +} + +func TestMultiQueryExecuteReturnsPerBlockResults(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[widget](conn) + + for _, w := range []*widget{ + {Name: "sprocket", Qty: 1}, + {Name: "gear", Qty: 5}, + {Name: "bolt", Qty: 10}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[widget](conn) + mq.Add("all", c.Query(ctx)) + mq.Add("filtered", c.Query(ctx).Filter("eq(name, $1)", "gear")) + + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if got := len(results["all"]); got != 3 { + t.Fatalf("results[all] has %d rows, want 3", got) + } + if got := len(results["filtered"]); got != 1 { + t.Fatalf("results[filtered] has %d rows, want 1", got) + } + if results["filtered"][0].Name != "gear" { + t.Fatalf("results[filtered][0].Name = %q, want gear", results["filtered"][0].Name) + } +} + +func TestMultiQueryExecuteEmptyReturnsEmptyMap(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + results, err := mq.Execute(context.Background()) + if err != nil { + t.Fatalf("Execute on empty MultiQuery: %v", err) + } + if len(results) != 0 { + t.Fatalf("expected empty map, got %v", results) + } +} + +// renamed exercises the predicate-vs-json-tag remap. Dgraph returns the +// "thingName" key (the predicate name) but the struct's JSON tag is +// "name"; MultiQuery.Execute must remap before unmarshaling so Name +// populates. +type renamed struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"predicate=thingName index=hash,fulltext"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +func TestMultiQueryExecuteRemapsPredicateKeys(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[renamed](conn) + + for _, w := range []*renamed{ + {Name: "alpha", Qty: 1}, + {Name: "beta", Qty: 2}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[renamed](conn) + mq.Add("all", c.Query(ctx)) + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rows := results["all"] + if len(rows) != 2 { + t.Fatalf("rows = %d, want 2", len(rows)) + } + for _, r := range rows { + if r.Name == "" { + t.Fatalf("Name not populated; multi-block response was not remapped from predicate key: %+v", r) + } + } +} diff --git a/typed/option.go b/typed/option.go new file mode 100644 index 0000000..d944483 --- /dev/null +++ b/typed/option.go @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +// Option configures a *T. Generated With constructors return an Option; +// generated New/Wrap constructors apply them via Apply. +type Option[T any] func(*T) + +// Apply applies opts to target in declaration order. +func Apply[T any](target *T, opts ...Option[T]) { + for _, opt := range opts { + opt(target) + } +} diff --git a/typed/option_test.go b/typed/option_test.go new file mode 100644 index 0000000..7c1f378 --- /dev/null +++ b/typed/option_test.go @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestApply_RunsOptionsInOrder(t *testing.T) { + type rec struct{ trail []string } + r := &rec{} + + typed.Apply(r, + func(x *rec) { x.trail = append(x.trail, "a") }, + func(x *rec) { x.trail = append(x.trail, "b") }, + func(x *rec) { x.trail = append(x.trail, "c") }, + ) + + if got := strings.Join(r.trail, ""); got != "abc" { + t.Fatalf("Apply ran options as %q, want %q", got, "abc") + } +} + +func TestApply_NoOptionsIsNoop(t *testing.T) { + type rec struct{ n int } + r := &rec{n: 7} + typed.Apply(r) + if r.n != 7 { + t.Fatalf("Apply with no options mutated target: n = %d, want 7", r.n) + } +} diff --git a/typed/query.go b/typed/query.go new file mode 100644 index 0000000..e4b2199 --- /dev/null +++ b/typed/query.go @@ -0,0 +1,565 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "fmt" + "iter" + "strconv" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// Query is a fluent, type-safe query builder over records of type T. Builder +// methods return *Query[T] for chaining, except As, Var, and GroupBy, which +// change the result shape and transition to *RawQuery; terminal methods +// (Nodes, First, IterNodes) execute the query and decode typed results. +// +// A Query is single-use. Builder methods mutate the underlying query in place +// and return the same *Query, so a Query value should be built as one chain +// and handed to a single terminal. It is not safe to save a Query to a +// variable and branch it into independent queries: every branch shares — and +// keeps mutating — the same underlying query. +// +// Repeated builder calls do not all behave the same way. Limit, Offset, After, +// Cascade, Name, RootFunc, and Vars overwrite: the last call wins. Filter, +// OrderAsc, OrderDesc, and WhereEdge accumulate: each call adds to the query. +// Accumulated Filter fragments AND together (see CombinedFilter, OrGroup). +// +// Limit and Offset additionally record the bounds that IterNodes pages +// within — a Limit caps the rows it streams, an Offset is its start. +type Query[T any] struct { + q *dg.Query + conn modusgraph.Client // runs the WhereEdge pre-pass; set by Client.Query + ctx context.Context // carried for the WhereEdge pre-pass query + limit int // caller-set row cap; 0 = unbounded + offset int // caller-set starting offset; 0 = none + edges []edgeFilter // accumulated WhereEdge constraints; empty = none + filters []filterFrag // accumulated @filter fragments, ANDed; empty = none +} + +// edgeFilter is one accumulated WhereEdge constraint: a dgraph @filter +// expression scoped to an outbound edge predicate of T. +type edgeFilter struct { + predicate string + filter string + params []any +} + +// filterFrag is one accumulated @filter fragment. Fragments join with AND. +type filterFrag struct { + expr string + params []any +} + +// NewDetachedQuery returns a Query[T] with no connection, used only to +// accumulate a filter expression: its By/Filter calls record fragments +// that CombinedFilter reads back. It must not be executed (it has no terminal +// path) and exists as the capture target behind the generated Or and +// WhereBy combinators. +func NewDetachedQuery[T any]() *Query[T] { + return &Query[T]{} +} + +// Filter adds a dgraph @filter expression. params bind to placeholders. +// Repeated calls accumulate: every fragment ANDs together. +func (qb *Query[T]) Filter(filter string, params ...any) *Query[T] { + qb.addFilter(filter, params) + return qb +} + +// addFilter accumulates one @filter fragment. Fragments AND together: the +// effective filter is every fragment joined with AND, each fragment's $N +// placeholders shifted to stay bound to its own params. dgman's own Filter is +// last-write-wins, so the full combined expression is re-pushed on every call. +// A detached query (nil q — used to capture a sub-scope's filter for OrGroup or +// WhereBy) accumulates with no dgman query to push to; CombinedFilter +// reads the fragments back. +func (qb *Query[T]) addFilter(expr string, params []any) { + if expr == "" { + return + } + qb.filters = append(qb.filters, filterFrag{expr: expr, params: params}) + if qb.q != nil { + combined, cp := combineAnd(qb.filters) + qb.q.Filter(combined, cp...) + } +} + +// combineAnd joins fragments with AND, renumbering each fragment's ordinal +// placeholders against the concatenated params slice. +func combineAnd(frags []filterFrag) (string, []any) { + parts := make([]string, 0, len(frags)) + var params []any + for _, f := range frags { + if f.expr == "" { + continue + } + parts = append(parts, shiftPlaceholders(f.expr, len(params))) + params = append(params, f.params...) + } + if len(parts) == 0 { + return "", nil + } + return strings.Join(parts, " AND "), params +} + +// CombinedFilter returns the AND-combined accumulated @filter expression and +// its params, or ("", nil) when no filter was set. It is the substrate behind +// the generated Or and WhereBy combinators: they run a sub-scope's +// By/Filter calls against a detached query, then fold the captured +// expression into a parent OR group or edge constraint. +func (qb *Query[T]) CombinedFilter() (string, []any) { + return combineAnd(qb.filters) +} + +// OrGroup adds one @filter group that ORs the combined filter of each sub. +// Each sub is a detached Query[T] whose By/Filter calls have been +// accumulated; their combined (AND) expressions are parenthesized, joined with +// OR, and the whole OR group ANDs with the receiver's other filters. Subs with +// an empty filter are skipped; an all-empty OrGroup is a no-op. It is the +// substrate behind the generated Query.Or combinator. +func (qb *Query[T]) OrGroup(subs ...*Query[T]) *Query[T] { + parts := make([]string, 0, len(subs)) + var params []any + for _, s := range subs { + e, p := s.CombinedFilter() + if e == "" { + continue + } + parts = append(parts, "("+shiftPlaceholders(e, len(params))+")") + params = append(params, p...) + } + if len(parts) == 0 { + return qb + } + qb.addFilter("("+strings.Join(parts, " OR ")+")", params) + return qb +} + +// OrderAsc orders results ascending by clause. +func (qb *Query[T]) OrderAsc(clause string) *Query[T] { + qb.q.OrderAsc(clause) + return qb +} + +// OrderDesc orders results descending by clause. +func (qb *Query[T]) OrderDesc(clause string) *Query[T] { + qb.q.OrderDesc(clause) + return qb +} + +// Limit caps the number of results. dgman names this First; it is renamed +// here so it does not collide with the First terminal. +func (qb *Query[T]) Limit(n int) *Query[T] { + qb.limit = n + qb.q.First(n) + return qb +} + +// Offset skips the first n results. +func (qb *Query[T]) Offset(n int) *Query[T] { + qb.offset = n + qb.q.Offset(n) + return qb +} + +// After returns results with UID greater than uid (cursor pagination). +func (qb *Query[T]) After(uid string) *Query[T] { + qb.q.After(uid) + return qb +} + +// Cascade drops nodes missing any of the given predicates (all, if none given). +func (qb *Query[T]) Cascade(predicates ...string) *Query[T] { + qb.q.Cascade(predicates...) + return qb +} + +// RootFunc overrides the query root function. dgman's default root function +// is type(); RootFunc replaces it with an expression such as +// eq(name, "Alice") or has(email). Repeated calls overwrite. +func (qb *Query[T]) RootFunc(rootFunc string) *Query[T] { + qb.q.RootFunc(rootFunc) + return qb +} + +// Name sets the query block name. It defaults to "data"; dgman uses the name +// to both generate and decode the query, so a renamed block still decodes +// into []T. Repeated calls overwrite. +func (qb *Query[T]) Name(queryName string) *Query[T] { + qb.q.Name(queryName) + return qb +} + +// Vars supplies GraphQL variables for a parameterized query: funcDef is the +// query function definition (for example "getByName($n: string)") and vars +// binds each variable. The query then executes via dgraph's QueryWithVars +// path. Repeated calls overwrite. +func (qb *Query[T]) Vars(funcDef string, vars map[string]string) *Query[T] { + qb.q.Vars(funcDef, vars) + return qb +} + +// WhereEdge constrains results to records that have at least one `predicate` +// edge whose target node satisfies the dgraph @filter expression. params bind +// to $N placeholders within filter, exactly as Filter binds them. +// +// Where Filter constrains T's own scalar predicates, WhereEdge constrains a +// neighbouring node reached over an edge. dgraph's root @filter cannot express +// that, so a query carrying WhereEdge constraints executes in two steps: a +// pre-pass resolves the UIDs of roots that satisfy every constraint, then the +// main query runs against uid(...) — keeping ordering, pagination, and result +// projection on the normal path. See +// docs/specs/2026-05-21-query-edge-filter-design.md. +// +// WhereEdge accumulates: multiple calls AND together (a record must satisfy +// every edge constraint). It is the substrate behind the generated +// Query.Where methods. +func (qb *Query[T]) WhereEdge(predicate, filter string, params ...any) *Query[T] { + qb.edges = append(qb.edges, edgeFilter{predicate: predicate, filter: filter, params: params}) + return qb +} + +// WhereAnyOfText adds an @filter(anyoftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAnyOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("anyoftext(%s, $1)", predicate), []any{term}) + return qb +} + +// WhereAllOfText adds an @filter(alloftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAllOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("alloftext(%s, $1)", predicate), []any{term}) + return qb +} + +// As names the query block as a dgraph query variable. dgraph requires such a +// variable be consumed by another block, which a single-block typed query +// cannot do, so As transitions out of the typed query: it returns a *RawQuery, +// which exposes no node terminal. +func (qb *Query[T]) As(varName string) *RawQuery { + qb.q.As(varName) + return &RawQuery{q: qb.q} +} + +// Var marks the query block as a dgraph var block. A var block computes query +// variables and returns no data of its own, so Var transitions out of the +// typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) Var() *RawQuery { + qb.q.Var() + return &RawQuery{q: qb.q} +} + +// GroupBy adds an @groupby(predicate) aggregation. A grouped query returns +// aggregation groups rather than a slice of T, so GroupBy transitions out of +// the typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) GroupBy(predicate string) *RawQuery { + qb.q.GroupBy(predicate) + return &RawQuery{q: qb.q} +} + +// Nodes executes the query and returns all matching records. +func (qb *Query[T]) Nodes() (out []T, err error) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + matched, err := qb.resolveRoots() + if err != nil { + return nil, err + } + if !matched { + return nil, nil + } + if err = qb.q.Nodes(&out); err != nil { + return nil, err + } + return out, nil +} + +// First executes the query with an implicit Limit(1) and returns the first +// record, or (nil, nil) if the query matched no rows. +func (qb *Query[T]) First() (rec *T, err error) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + matched, err := qb.resolveRoots() + if err != nil { + return nil, err + } + if !matched { + return nil, nil + } + var out []T + if err = qb.q.First(1).Nodes(&out); err != nil { + return nil, err + } + if len(out) == 0 { + return nil, nil + } + return &out[0], nil +} + +// IterNodes executes the query and returns an iterator over matching records, +// paging transparently so a large result set is never materialized at once. +// +// IterNodes is a terminal operation: it drives Offset/Limit internally as it +// pages and leaves the builder spent — do not call another terminal on the +// same Query afterward. A Limit set on the query caps the total number of +// rows streamed; an Offset is the starting point. +// +// All pages execute against one read-only transaction, so the iteration reads +// a single consistent snapshot: a concurrent writer cannot make it skip or +// repeat rows. A WhereEdge pre-pass, when present, runs once before paging +// begins, in its own transaction. On error it yields a final (nil, err) and +// stops. +func (qb *Query[T]) IterNodes() iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + var ferr error + defer func() { span.End(ferr) }() + matched, err := qb.resolveRoots() + if err != nil { + ferr = err + yield(nil, err) + return + } + if !matched { + return // edge constraints present, but no root matched + } + remaining := qb.limit // 0 = unbounded + for off := qb.offset; ; off += defaultPageSize { + size := defaultPageSize + if remaining > 0 && remaining < size { + size = remaining // shrink the last page so it can't overshoot the cap + } + var page []T + if err := qb.q.Offset(off).First(size).Nodes(&page); err != nil { + ferr = err + yield(nil, err) + return + } + for i := range page { + if !yield(&page[i], nil) { + return // consumer broke out + } + } + if remaining > 0 { + if remaining -= len(page); remaining <= 0 { + return // hit the caller's Limit + } + } + if len(page) < size { + return // result set exhausted + } + } + } +} + +// Raw returns the underlying dgman query for operations Query does not wrap +// (for example the raw-selection Query method). Raw does not carry WhereEdge +// constraints — those are resolved only when a terminal runs. +func (qb *Query[T]) Raw() *dg.Query { + return qb.q +} + +// UID roots the query at a specific node UID. Results still decode into []T. +func (qb *Query[T]) UID(uid string) *Query[T] { + qb.q.UID(uid) + return qb +} + +// All sets the edge-traversal depth for this query, overriding the client's +// default maxEdgeTraversal. Use a small depth to stay under Dgraph's 4MB gRPC +// limit on highly-connected entities. +func (qb *Query[T]) All(depth int) *Query[T] { + qb.q.All(depth) + return qb +} + +// NodesAndCount executes the query and returns the matching records together +// with the total count (useful for pagination totals). Like Nodes, it runs the +// WhereEdge pre-pass first when edge constraints are present. +func (qb *Query[T]) NodesAndCount() ([]T, int, error) { + matched, err := qb.resolveRoots() + if err != nil { + return nil, 0, err + } + if !matched { + return nil, 0, nil + } + var out []T + count, err := qb.q.NodesAndCount(&out) + if err != nil { + return nil, 0, err + } + return out, count, nil +} + +// String renders the generated DQL without executing it. WhereEdge constraints +// are not reflected — they are resolved only when a terminal runs. +func (qb *Query[T]) String() string { + return qb.q.String() +} + +// FormatBlock renders the query as a single DQL block named name, without +// executing it. The returned text is suitable for inclusion inside a wrapping +// "{ ... }" multi-block request — it does not include outer braces. +// +// FormatBlock is the substrate behind MultiQuery; external callers can use it +// to compose typed queries into larger hand-written DQL requests. +// +// Filter parameters are inlined at Filter-call time (dgman renders $N +// placeholders into the filter string immediately), so the returned block +// carries no unresolved variables. WhereEdge constraints are not formatted — +// they require a runtime pre-pass and would produce no useful output here. +func (qb *Query[T]) FormatBlock(name string) (string, error) { + if len(qb.edges) != 0 { + return "", fmt.Errorf("typed: FormatBlock cannot render a Query carrying WhereEdge constraints") + } + qb.q.Name(name) + wrapped := dg.NewQueryBlock(qb.q).String() + // QueryBlock.String() wraps the block in "{\n ... }" — strip the wrapper so + // the caller can compose blocks inside their own braces. + inner := strings.TrimPrefix(wrapped, "{\n") + inner = strings.TrimSuffix(inner, "}") + return inner, nil +} + +// RawQuery is a query whose result is not a slice of T — produced by the +// shape-changing builders Query.As, Query.Var, and Query.GroupBy. A RawQuery +// deliberately exposes no typed node terminal: its result must be decoded by +// the caller through the underlying dgman query, obtained via Raw. +type RawQuery struct { + q *dg.Query +} + +// Raw returns the underlying dgman query, for the caller to execute and decode. +func (r *RawQuery) Raw() *dg.Query { + return r.q +} + +// String returns the generated DQL. +func (r *RawQuery) String() string { + return r.q.String() +} + +// As names the block as a dgraph query variable. See Query.As. +func (r *RawQuery) As(varName string) *RawQuery { + r.q.As(varName) + return r +} + +// Var marks the block as a dgraph var block. See Query.Var. +func (r *RawQuery) Var() *RawQuery { + r.q.Var() + return r +} + +// GroupBy adds an @groupby(predicate) aggregation. See Query.GroupBy. +func (r *RawQuery) GroupBy(predicate string) *RawQuery { + r.q.GroupBy(predicate) + return r +} + +// resolveRoots runs the WhereEdge pre-pass when the query carries edge +// constraints, rewriting the main query's root function to the matching UIDs. +// It returns matched=false when constraints are present but no root satisfied +// them — callers then return an empty result without running the main query. +// With no edge constraints it is a no-op returning matched=true. +func (qb *Query[T]) resolveRoots() (matched bool, err error) { + if len(qb.edges) == 0 { + return true, nil + } + uids, err := qb.matchedUIDs() + if err != nil { + return false, err + } + if len(uids) == 0 { + return false, nil + } + qb.q.RootFunc("uid(" + strings.Join(uids, ", ") + ")") + return true, nil +} + +// matchedUIDs runs the pre-pass: an @cascade query over T that keeps only +// nodes whose every WhereEdge predicate has a target matching its filter, and +// returns those nodes' UIDs. +func (qb *Query[T]) matchedUIDs() ([]string, error) { + var z T + pre := qb.conn.Query(qb.ctx, &z) + body, params := qb.edgeMatchBody() + pre.Cascade().Query(body, params...) + + var rows []struct { + UID string `json:"uid"` + } + if err := pre.Nodes(&rows); err != nil { + return nil, err + } + uids := make([]string, len(rows)) + for i := range rows { + uids[i] = rows[i].UID + } + return uids, nil +} + +// edgeMatchBody renders the selection set for the pre-pass: uid plus one +// aliased, filtered block per WhereEdge constraint. The caller adds a bare +// @cascade, which then drops any node with an empty block — so a survivor +// satisfies every constraint. Blocks are aliased mg_e0, mg_e1, ... so two +// constraints on the same predicate do not collide as duplicate fields. Each +// fragment's $N placeholders are shifted to stay bound to its own params once +// every fragment's params are concatenated into one slice. +func (qb *Query[T]) edgeMatchBody() (body string, params []any) { + var b strings.Builder + b.WriteString("{\n\tuid\n") + for i, e := range qb.edges { + b.WriteString("\tmg_e") + b.WriteString(strconv.Itoa(i)) + b.WriteString(" : ") + b.WriteString(e.predicate) + b.WriteString(" @filter(") + b.WriteString(shiftPlaceholders(e.filter, len(params))) + b.WriteString(") { uid }\n") + params = append(params, e.params...) + } + b.WriteString("}") + return b.String(), params +} + +// shiftPlaceholders rewrites dgman ordinal placeholders ($1, $2, ...) in expr, +// adding delta to each index. WhereEdge filters are written independently, each +// numbering its params from $1; concatenating them into one pre-pass body +// needs every fragment renumbered against the combined params slice. A '$' not +// followed by a digit is left as-is, matching dgman's parseQueryWithParams. +func shiftPlaceholders(expr string, delta int) string { + if delta == 0 || !strings.ContainsRune(expr, '$') { + return expr + } + var b strings.Builder + for i := 0; i < len(expr); i++ { + if expr[i] != '$' { + b.WriteByte(expr[i]) + continue + } + j := i + 1 + for j < len(expr) && expr[j] >= '0' && expr[j] <= '9' { + j++ + } + if j == i+1 { // '$' not followed by digits — leave verbatim + b.WriteByte('$') + continue + } + n, _ := strconv.Atoi(expr[i+1 : j]) + b.WriteByte('$') + b.WriteString(strconv.Itoa(n + delta)) + i = j - 1 + } + return b.String() +} diff --git a/typed/query_test.go b/typed/query_test.go new file mode 100644 index 0000000..588bf6b --- /dev/null +++ b/typed/query_test.go @@ -0,0 +1,1294 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "strings" + "testing" + + dg "github.com/dolan-in/dgman/v2" + "github.com/go-logr/logr/funcr" + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// newCountingConn builds a file-backed modusgraph client exactly like newConn, +// but wires in a logr.Logger that counts dgman query executions. dgman logs +// every executed query at verbosity 3 with the message "execute query"; the +// returned *int is incremented once per such log line. +// +// dgman's logger is process-global, and modusgraph allows only one live +// file-backed engine per process (see modusgraph.ErrSingletonOnly). Each call +// uses a fresh t.TempDir() URI for data isolation. Tests that use +// newCountingConn must NOT call t.Parallel(): a second live client would hit +// the engine singleton, and parallel tests would also corrupt the shared +// query count. +func newCountingConn(t *testing.T, count *int) modusgraph.Client { + t.Helper() + logger := funcr.New(func(_, args string) { + // funcr renders the message into args as `"msg"="execute query"`. + // Match that exact pair so unrelated dgman/pool log lines (which log + // other messages, e.g. "executeQuery" for query blocks) are ignored. + if strings.Contains(args, `"msg"="execute query"`) { + *count++ + } + }, funcr.Options{Verbosity: 3}) + conn, err := modusgraph.NewClient("file://"+t.TempDir(), + modusgraph.WithAutoSchema(true), modusgraph.WithLogger(logger)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestQuery_NodesReturnsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Nodes returned %d records, want 3", len(got)) + } +} + +func TestQuery_LimitCapsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + got, err := c.Query(ctx).Limit(2).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("Limit(2) returned %d records, want 2", len(got)) + } +} + +func TestQuery_FirstReturnsAMatch(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "only", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil || got.Name != "only" { + t.Fatalf("First returned %+v, want Name=only", got) + } +} + +func TestQuery_FirstNoMatchReturnsNilNil(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First on empty: unexpected error %v", err) + } + if got != nil { + t.Fatalf("First on empty returned %+v, want nil", got) + } +} + +func TestQuery_BuilderChainCompilesAndRuns(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "x", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Every builder method must return *Query[widget] so the chain stays typed. + _, err := c.Query(ctx). + OrderAsc("qty"). + Offset(0). + Limit(10). + Cascade(). + Nodes() + if err != nil { + t.Fatalf("builder chain Nodes: %v", err) + } +} + +func TestQuery_RawExposesUnderlyingBuilder(t *testing.T) { + c := typed.NewClient[widget](newConn(t)) + if c.Query(context.Background()).Raw() == nil { + t.Fatal("Raw() returned nil; expected the underlying *dg.Query") + } +} + +func TestQuery_Filter(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert three widgets with distinct names. + for _, name := range []string{"alpha", "beta", "gamma"} { + if err := c.Add(ctx, &widget{Name: name}); err != nil { + t.Fatalf("Add %s: %v", name, err) + } + } + + // Filter to exactly those whose name equals "beta" (index=exact allows eq()). + got, err := c.Query(ctx).Filter(`eq(name, "beta")`).Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("Filter returned %d records, want 1", len(got)) + } + if got[0].Name != "beta" { + t.Fatalf("Filter returned Name=%q, want beta", got[0].Name) + } +} + +func TestQuery_FilterAccumulatesWithAnd(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Three widgets; only "beta"/9 satisfies BOTH name=="beta" and qty>=5. + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "beta", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // Two Filter calls must AND together, not overwrite. With last-write-wins + // only ge(qty, 5) survives and this returns the two qty>=5 rows instead of + // the single AND match. + got, err := c.Query(ctx). + Filter(`eq(name, "beta")`). + Filter(`ge(qty, "5")`). + Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("two ANDed Filters returned %d records, want 1 (name==beta AND qty>=5)", len(got)) + } + if got[0].Name != "beta" || got[0].Qty != 9 { + t.Fatalf("got %+v, want Name=beta Qty=9", got[0]) + } +} + +func TestQuery_CombinedFilterShiftsPlaceholders(t *testing.T) { + q := typed.NewDetachedQuery[widget]() + if expr, params := q.CombinedFilter(); expr != "" || params != nil { + t.Fatalf("empty CombinedFilter = (%q, %v), want (\"\", nil)", expr, params) + } + q.Filter("eq(name, $1)", "a") + q.Filter("eq(qty, $1)", 7) + expr, params := q.CombinedFilter() + const want = "eq(name, $1) AND eq(qty, $2)" + if expr != want { + t.Fatalf("CombinedFilter expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "a" || params[1] != 7 { + t.Fatalf("CombinedFilter params = %v, want [a 7]", params) + } +} + +func TestQuery_OrGroup(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "gamma", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // name == "alpha" OR name == "gamma": two of three rows. + got, err := c.Query(ctx).OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("OrGroup Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("OrGroup(alpha, gamma) returned %d rows, want 2", len(got)) + } + + // AND-of-OR: qty>=5 AND (name==alpha OR name==gamma) → only alpha/9. + got, err = c.Query(ctx). + Filter(`ge(qty, "5")`). + OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("AND-of-OR Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "alpha" { + t.Fatalf("qty>=5 AND (alpha OR gamma) returned %+v, want [alpha/9]", got) + } +} + +func TestQuery_OrderAscDesc(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert widgets with distinct Qty values in non-sorted order so a + // stable natural ordering cannot hide a missing sort. + qtys := []int{30, 10, 50, 20, 40} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ascending. + asc, err := c.Query(ctx).OrderAsc("qty").Nodes() + if err != nil { + t.Fatalf("OrderAsc Nodes: %v", err) + } + if len(asc) != len(qtys) { + t.Fatalf("OrderAsc returned %d records, want %d", len(asc), len(qtys)) + } + for i := range len(asc) - 1 { + if asc[i].Qty > asc[i+1].Qty { + t.Fatalf("OrderAsc: asc[%d].Qty=%d > asc[%d].Qty=%d; not ascending", + i, asc[i].Qty, i+1, asc[i+1].Qty) + } + } + + // Descending. + desc, err := c.Query(ctx).OrderDesc("qty").Nodes() + if err != nil { + t.Fatalf("OrderDesc Nodes: %v", err) + } + if len(desc) != len(qtys) { + t.Fatalf("OrderDesc returned %d records, want %d", len(desc), len(qtys)) + } + for i := range len(desc) - 1 { + if desc[i].Qty < desc[i+1].Qty { + t.Fatalf("OrderDesc: desc[%d].Qty=%d < desc[%d].Qty=%d; not descending", + i, desc[i].Qty, i+1, desc[i+1].Qty) + } + } +} + +func TestQuery_OffsetSkipsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Five widgets with distinct, deliberately unsorted Qty values. + qtys := []int{40, 10, 50, 20, 30} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ordering ascending by qty gives 10,20,30,40,50; Offset(2) drops the + // first two, so 3 rows remain and the first is the 3rd-smallest (30). + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Nodes() + if err != nil { + t.Fatalf("Offset Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("OrderAsc.Offset(2) returned %d records, want 3", len(got)) + } + if got[0].Qty != 30 { + t.Fatalf("first row after Offset(2) has Qty=%d, want 30 (3rd-smallest)", got[0].Qty) + } +} + +func TestQuery_AfterCursor(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // First pass: grab all rows so we can pick a non-last cursor UID. + all, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + if len(all) < 3 { + t.Fatalf("expected at least 3 widgets, got %d", len(all)) + } + cursor := all[1].UID // a non-last row + + // After(cursor) uses default UID ordering to skip past the cursor node. + got, err := c.Query(ctx).After(cursor).Nodes() + if err != nil { + t.Fatalf("After Nodes: %v", err) + } + if len(got) == 0 { + t.Fatal("After(cursor) returned no rows; expected the rows past the cursor") + } + for _, w := range got { + if w.UID <= cursor { + t.Fatalf("After(%s) returned UID %s, which is not strictly greater than the cursor", + cursor, w.UID) + } + } +} + +func TestQuery_CascadeDropsIncompleteNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Widgets with Qty > 0 carry a qty predicate. Widgets with Qty left 0 + // have it omitted entirely (json tag is omitempty), so they have no qty + // predicate at all. + withQty := []int{5, 9, 13} + for _, q := range withQty { + if err := c.Add(ctx, &widget{Name: "has-qty", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + for i := range 4 { + if err := c.Add(ctx, &widget{Name: "no-qty"}); err != nil { + t.Fatalf("Add no-qty[%d]: %v", i, err) + } + } + + // @cascade(qty) drops any node that lacks the qty predicate. + got, err := c.Query(ctx).Cascade("qty").Nodes() + if err != nil { + t.Fatalf("Cascade Nodes: %v", err) + } + if len(got) != len(withQty) { + t.Fatalf("Cascade(qty) returned %d records, want %d (only the qty-bearing widgets)", + len(got), len(withQty)) + } + for _, w := range got { + if w.Qty == 0 { + t.Fatalf("Cascade(qty) returned a widget with Qty=0 (no qty predicate): %+v", w) + } + } +} + +func TestQuery_FilterOrderLimitOffsetCombined(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // A known set: five "keep" widgets plus a "drop" widget the filter excludes. + for _, q := range []int{50, 20, 40, 10, 30} { + if err := c.Add(ctx, &widget{Name: "keep", Qty: q}); err != nil { + t.Fatalf("Add keep qty=%d: %v", q, err) + } + } + if err := c.Add(ctx, &widget{Name: "drop", Qty: 99}); err != nil { + t.Fatalf("Add drop: %v", err) + } + + // Filter to name=keep -> qtys {10,20,30,40,50}; OrderAsc -> sorted; + // Offset(1) drops 10; Limit(2) keeps {20,30}. + got, err := c.Query(ctx). + Filter(`eq(name, "keep")`). + OrderAsc("qty"). + Offset(1). + Limit(2). + Nodes() + if err != nil { + t.Fatalf("combined chain Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("combined chain returned %d records, want 2", len(got)) + } + if got[0].Qty != 20 || got[1].Qty != 30 { + t.Fatalf("combined chain window = [%d, %d], want [20, 30]", got[0].Qty, got[1].Qty) + } +} + +func TestQuery_FirstOnMultipleRows(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, q := range []int{30, 10, 20} { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + + // First on an ascending-by-qty query yields exactly the smallest row. + got, err := c.Query(ctx).OrderAsc("qty").First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil { + t.Fatal("First returned nil on a non-empty result set") + } + if got.Qty != 10 { + t.Fatalf("First on OrderAsc(qty) returned Qty=%d, want 10 (smallest)", got.Qty) + } +} + +func TestQuery_NodesEmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) // fresh client, no inserts + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes on empty client: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("Nodes on empty client returned %d records, want 0", len(got)) + } +} + +func TestQuery_OrderAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // OrderAsc and OrderDesc accumulate: both clauses must survive on the + // same query. dgman renders them as "orderasc:"/"orderdesc:" in the + // generated query string. + q := c.Query(ctx).OrderAsc("name").OrderDesc("qty") + s := q.Raw().String() + if !strings.Contains(s, "orderasc: name") { + t.Fatalf("query string missing ascending name order; got:\n%s", s) + } + if !strings.Contains(s, "orderdesc: qty") { + t.Fatalf("query string missing descending qty order; got:\n%s", s) + } +} + +func TestQuery_CascadeOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Cascade overwrites: the second call wins, the first predicate is gone. + // dgman renders predicates as @cascade(pred1,pred2,...) with no spaces. + q := c.Query(ctx).Cascade("name").Cascade("qty") + s := q.Raw().String() + if !strings.Contains(s, "@cascade(qty)") { + t.Fatalf("second Cascade(qty) not rendered in query string; got:\n%s", s) + } + if strings.Contains(s, "@cascade(name)") { + t.Fatalf("first Cascade(name) still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_TerminalRunsTwice(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + // A terminal is re-runnable: calling Nodes twice on the same builder + // succeeds both times and yields equal-length results. + q := c.Query(ctx) + first, err := q.Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + second, err := q.Nodes() + if err != nil { + t.Fatalf("second Nodes: %v", err) + } + if len(first) != len(second) { + t.Fatalf("Nodes run twice returned %d then %d records; want equal lengths", + len(first), len(second)) + } +} + +func TestQuery_BuilderAliasesAndAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // (i) Filter accumulates: after two Filter calls both survive, ANDed. + q := c.Query(ctx) + q.Filter(`eq(name, "alpha")`) + q.Filter(`eq(name, "beta")`) + s := q.Raw().String() + if !strings.Contains(s, `eq(name, "alpha")`) { + t.Fatalf("Filter A dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, `eq(name, "beta")`) { + t.Fatalf("Filter B dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, " AND ") { + t.Fatalf("accumulated filters not ANDed; got:\n%s", s) + } + + // (ii) The builder aliases: a saved reference and further mutation observe + // the same underlying query. ref and q point at the same *Query, so a + // mutation through one is visible through the other. This documents the + // single-use footgun: you cannot branch a saved builder. + ref := q + if ref != q { + t.Fatal("builder reference is not identical to the original *Query") + } + q.OrderAsc("name") + if ref.Raw().String() != q.Raw().String() { + t.Fatal("mutating q did not affect ref; builder is expected to alias a shared query") + } + if !strings.Contains(ref.Raw().String(), "orderasc: name") { + t.Fatalf("order applied via q not visible through ref; got:\n%s", ref.Raw().String()) + } +} + +func TestQuery_RawRoundTrips(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "raw-target", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Take the raw *dg.Query, apply a dgman-only builder method directly, + // then execute via the raw query's own Nodes(&dst). + var raw *dg.Query = c.Query(ctx).Raw() + raw.OrderAsc("qty") + + var dst []widget + if err := raw.Nodes(&dst); err != nil { + t.Fatalf("raw query Nodes: %v", err) + } + if len(dst) != 1 { + t.Fatalf("raw query returned %d records, want 1", len(dst)) + } + if dst[0].Name != "raw-target" || dst[0].Qty != 7 { + t.Fatalf("raw query returned %+v, want Name=raw-target Qty=7", dst[0]) + } +} + +func TestQuery_SingleQueryPerTerminal(t *testing.T) { + // Uses the global dgman logger; must not run in parallel. + ctx := context.Background() + // queriesExecuted is incremented by newCountingConn's logger each time + // dgman runs a query, so it reflects real database round-trips. + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + + for i := range 2 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // Building the chain runs no queries: builder methods only mutate the AST. + before := queriesExecuted + q := c.Query(ctx).Filter(`eq(name, "w")`).OrderAsc("qty").Limit(10) + if queriesExecuted != before { + t.Fatalf("builder methods executed %d queries, want 0", queriesExecuted-before) + } + + // The Nodes terminal runs exactly one query. + if _, err := q.Nodes(); err != nil { + t.Fatalf("Nodes: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("Nodes executed %d queries, want exactly 1", got) + } + + // A fresh builder's First terminal also runs exactly one query. + before = queriesExecuted + if _, err := c.Query(ctx).First(); err != nil { + t.Fatalf("First: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("First executed %d queries, want exactly 1", got) + } +} + +func TestIterNodes_StreamsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 // > defaultPageSize (50): forces multiple pages + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for w, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + if w == nil { + t.Fatal("IterNodes yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } +} + +func TestIterNodes_StopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("IterNodes yielded %d records after break at 10, want 10", seen) + } +} + +func TestIterNodes_EmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes over empty set yielded error: %v", err) + } + seen++ + } + if seen != 0 { + t.Fatalf("IterNodes over empty set yielded %d records, want 0", seen) + } +} + +func TestIterNodes_RespectsLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 100 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(30).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != 30 { + t.Fatalf("Limit(30).IterNodes() streamed %d records, want 30", seen) + } +} + +func TestIterNodes_LimitExceedsResultSet(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 30 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(500).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("Limit(500).IterNodes() over %d records streamed %d, want %d", n, seen, n) + } +} + +func TestIterNodes_RespectsOffset(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 (not 0) so omitempty never suppresses the field, + // keeping OrderAsc("qty") a true total order over all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + var got []int + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(3).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 7 { + t.Fatalf("Offset(3).IterNodes() streamed %d records, want 7", len(got)) + } + for i, q := range got { + if q != i+4 { // Qty=1..10; offset 3 skips 1,2,3 → starts at 4 + t.Fatalf("Offset(3).IterNodes()[%d] Qty = %d, want %d", i, q, i+4) + } + } +} + +func TestIterNodes_RespectsOffsetAndLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all 200 records. + const n = 200 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + var got []int + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(60).Limit(120).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 120 { + t.Fatalf("Offset(60).Limit(120).IterNodes() streamed %d records, want 120", len(got)) + } + for i, q := range got { + if q != i+61 { // Qty=1..200; offset 60 skips 1..60 → starts at 61 + t.Fatalf("result[%d] Qty = %d, want %d", i, q, i+61) + } + } +} + +func TestIterNodes_OneQueryPerPage(t *testing.T) { + ctx := context.Background() + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + const n = 125 // ceil(125/50) = 3 page queries + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Obtaining the iterator runs no query — IterNodes is lazy. + seq := c.Query(ctx).IterNodes() + if queriesExecuted != 0 { + t.Fatalf("building the IterNodes iterator executed %d queries, want 0", queriesExecuted) + } + seen := 0 + for _, err := range seq { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } + if queriesExecuted != 3 { + t.Fatalf("IterNodes over %d records ran %d queries, want 3", n, queriesExecuted) + } +} + +func TestIterNodes_YieldsErrorAndStops(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "w", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + // A syntactically invalid @filter (unbalanced parenthesis) makes the page + // query fail at execution; IterNodes must yield one (nil, err) and stop. + gotErr := false + for w, err := range c.Query(ctx).Filter(`eq(name, "w"`).IterNodes() { + if err != nil { + gotErr = true + if w != nil { + t.Fatalf("error yield carried a non-nil widget: %+v", w) + } + break + } + t.Fatal("IterNodes over a malformed query yielded a record before erroring") + } + if !gotErr { + t.Fatal("IterNodes over a malformed query did not yield an error") + } +} + +func TestQuery_LimitOffsetStillDriveNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Regression: Limit/Offset now also set Query struct fields; confirm they + // still drive the Nodes terminal. + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Limit(3).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Offset(2).Limit(3).Nodes() returned %d records, want 3", len(got)) + } + for i, w := range got { + if w.Qty != i+3 { // Qty=1..10; offset 2 skips 1,2 → starts at 3 + t.Fatalf("Nodes()[%d] Qty = %d, want %d", i, w.Qty, i+3) + } + } +} + +func TestQuery_RootFuncOverridesRoot(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // RootFunc replaces the default type(widget) root with an eq() lookup; + // the query still decodes into []widget. + got, err := c.Query(ctx).RootFunc(`eq(name, "b")`).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf(`RootFunc(eq(name,"b")).Nodes() returned %d records, want 1`, len(got)) + } + if got[0].Name != "b" { + t.Fatalf("RootFunc lookup returned %q, want \"b\"", got[0].Name) + } +} + +func TestQuery_RootFuncRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RootFunc renders into the (func: ...) position and overwrites: the + // second call wins. + q := c.Query(ctx).RootFunc(`eq(name, "x")`).RootFunc(`eq(name, "y")`) + s := q.Raw().String() + if !strings.Contains(s, `func: eq(name, "y")`) { + t.Fatalf("second RootFunc not rendered; got:\n%s", s) + } + if strings.Contains(s, `eq(name, "x")`) { + t.Fatalf("first RootFunc still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_NameDecodesAfterRename(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Name renames the query block. dgman uses the name symmetrically to + // generate and decode, so a renamed block still decodes into []widget. + got, err := c.Query(ctx).Name("widgets").Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf(`Name("widgets").Nodes() returned %d records, want 3`, len(got)) + } +} + +func TestQuery_NameRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Name renders as the block name and overwrites: the second call wins. + q := c.Query(ctx).Name("first").Name("second") + s := q.Raw().String() + if !strings.Contains(s, "second(func:") { + t.Fatalf("second Name not rendered as block name; got:\n%s", s) + } + if strings.Contains(s, "first(func:") { + t.Fatalf("first Name still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_AsRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // As transitions to *RawQuery, prefixes the block with " as ", + // and overwrites: the second call wins. + q := c.Query(ctx).As("first").As("second") + if q == nil { + t.Fatal("As() returned nil *RawQuery") + } + s := q.String() + if !strings.Contains(s, "second as ") { + t.Fatalf("second As not rendered; got:\n%s", s) + } + if strings.Contains(s, "first as ") { + t.Fatalf("first As still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_VarsRendersQueryPrefix(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Vars renders a "query " prefix on the generated DQL. + q := c.Query(ctx).Vars("getByName($n: string)", map[string]string{"$n": "b"}) + s := q.Raw().String() + if !strings.Contains(s, "query getByName($n: string)") { + t.Fatalf("Vars did not render the query-definition prefix; got:\n%s", s) + } +} + +func TestQuery_VarsParameterizedQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Vars supplies a GraphQL variable bound into the root function; the + // query executes via dgraph's QueryWithVars path. + got, err := c.Query(ctx). + Vars("getByName($n: string)", map[string]string{"$n": "b"}). + RootFunc("eq(name, $n)"). + Nodes() + if err != nil { + t.Fatalf("Vars query Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "b" { + t.Fatalf(`Vars parameterized query returned %+v, want one widget named "b"`, got) + } +} + +func TestQuery_VarReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Var transitions to *RawQuery and emits a var block: dgman renders the + // block name as "var". + rq := c.Query(ctx).Var() + if rq == nil { + t.Fatal("Var() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "var(func:") { + t.Fatalf("Var() did not render a var block; got:\n%s", s) + } +} + +func TestQuery_GroupByReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // GroupBy transitions to *RawQuery and emits an @groupby clause. + rq := c.Query(ctx).GroupBy("name") + if rq == nil { + t.Fatal("GroupBy() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf(`GroupBy("name") did not render an @groupby clause; got:\n%s`, s) + } +} + +func TestRawQuery_RawExposesUnderlyingQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + rq := c.Query(ctx).Var() + // Raw returns the underlying *dg.Query; String mirrors Raw().String(). + var raw *dg.Query = rq.Raw() + if raw == nil { + t.Fatal("RawQuery.Raw() returned nil") + } + if rq.String() != raw.String() { + t.Fatalf("RawQuery.String() and Raw().String() differ:\n%s\n---\n%s", + rq.String(), raw.String()) + } +} + +func TestRawQuery_GroupByThenVarChains(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RawQuery re-exposes Var and GroupBy so the canonical .GroupBy(...).Var() + // composition still chains; both clauses survive. + s := c.Query(ctx).GroupBy("name").Var().String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing after GroupBy().Var(); got:\n%s", s) + } + if !strings.Contains(s, "var(func:") { + t.Fatalf("var block missing after GroupBy().Var(); got:\n%s", s) + } +} + +func TestRawQuery_CarriesEarlierBuilders(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Builders applied on *Query[T] before the GroupBy transition survive + // into the *RawQuery — the two share one underlying *dg.Query. + s := c.Query(ctx).Filter(`eq(name, "z")`).GroupBy("name").String() + if !strings.Contains(s, `eq(name, "z")`) { + t.Fatalf("Filter set before GroupBy did not survive the transition; got:\n%s", s) + } + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing; got:\n%s", s) + } +} + +// seedOwners inserts owner/pet pairs over conn for the WhereEdge tests. Each +// map entry is one owner owning one pet of the given name; the pet is inserted +// first so the owner's edge links an already-persisted node. It returns an +// owner client bound to conn. +func seedOwners(ctx context.Context, t *testing.T, conn modusgraph.Client, ownerToPet map[string]string) *typed.Client[owner] { + t.Helper() + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + for ownerName, petName := range ownerToPet { + p := &pet{Name: petName} + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", petName, err) + } + if err := owners.Add(ctx, &owner{Name: ownerName, Pets: []*pet{p}}); err != nil { + t.Fatalf("Add owner %q: %v", ownerName, err) + } + } + return owners +} + +func TestQuery_WhereEdgeFiltersByEdgeTarget(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + // WhereEdge constrains owners by a scalar of the pet reached over the + // "pets" edge — something a root Filter cannot express. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("WhereEdge(pets, name=Fido) returned %d owners, want 2 (Alice, Carol)", len(got)) + } + for _, o := range got { + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge returned %q, want only Fido owners (Alice, Carol)", o.Name) + } + } +} + +func TestQuery_WhereEdgeNoMatchReturnsEmpty(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // No pet is named Nemo: the pre-pass matches zero roots, so Nodes returns + // an empty result — not an error — and never runs the main query. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("WhereEdge for an unowned pet name returned %d owners, want 0", len(got)) + } +} + +func TestQuery_WhereEdgeBindsParams(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // The $1 placeholder in a WhereEdge filter binds exactly as it does for Filter. + got, err := owners.Query(ctx).WhereEdge("pets", "eq(name, $1)", "Rex").Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Bob" { + t.Fatalf("WhereEdge(pets, name=$1, Rex) returned %+v, want [Bob]", got) + } +} + +func TestQuery_WhereEdgeCombinesWithFilter(t *testing.T) { + ctx := context.Background() + // Alice and Carol both own a Fido; a root Filter on name narrows to Alice. + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + got, err := owners.Query(ctx). + Filter(`eq(name, "Alice")`). + WhereEdge("pets", `eq(name, "Fido")`). + Nodes() + if err != nil { + t.Fatalf("Filter+WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("Filter(name=Alice)+WhereEdge(pets,name=Fido) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeMultipleConstraintsAnd(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + + // Alice owns both Fido and Rex; Bob owns only Fido. + fido, rex := &pet{Name: "Fido"}, &pet{Name: "Rex"} + for _, p := range []*pet{fido, rex} { + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", p.Name, err) + } + } + if err := owners.Add(ctx, &owner{Name: "Alice", Pets: []*pet{fido, rex}}); err != nil { + t.Fatalf("Add Alice: %v", err) + } + if err := owners.Add(ctx, &owner{Name: "Bob", Pets: []*pet{fido}}); err != nil { + t.Fatalf("Add Bob: %v", err) + } + + // Two WhereEdge calls AND together: only an owner of BOTH pets survives. + got, err := owners.Query(ctx). + WhereEdge("pets", `eq(name, "Fido")`). + WhereEdge("pets", `eq(name, "Rex")`). + Nodes() + if err != nil { + t.Fatalf("two-WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("WhereEdge(Fido) AND WhereEdge(Rex) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeFirst(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // First runs the pre-pass too: it returns the Rex owner, never a Fido one. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Rex")`).First() + if err != nil { + t.Fatalf("WhereEdge First: %v", err) + } + if got == nil || got.Name != "Bob" { + t.Fatalf("WhereEdge(pets,name=Rex).First() = %+v, want Bob", got) + } + + // First with an edge constraint nothing satisfies is (nil, nil). + none, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).First() + if err != nil { + t.Fatalf("WhereEdge First no-match: unexpected error %v", err) + } + if none != nil { + t.Fatalf("WhereEdge First with no match = %+v, want nil", none) + } +} + +func TestQuery_WhereEdgeIterNodes(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + seen := 0 + for o, err := range owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).IterNodes() { + if err != nil { + t.Fatalf("WhereEdge IterNodes yielded error: %v", err) + } + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge IterNodes yielded %q, want a Fido owner", o.Name) + } + seen++ + } + if seen != 2 { + t.Fatalf("WhereEdge IterNodes streamed %d owners, want 2", seen) + } +} + +func TestQuery_UIDRootsAtNode(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).UID(w.UID).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "sprocket" { + t.Fatalf("UID query returned %+v, want one widget named sprocket", got) + } +} + +func TestQuery_NodesAndCountReturnsTotal(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := 0; i < 3; i++ { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add: %v", err) + } + } + + nodes, count, err := c.Query(ctx).NodesAndCount() + if err != nil { + t.Fatalf("NodesAndCount: %v", err) + } + if count != 3 || len(nodes) != 3 { + t.Fatalf("got count=%d len=%d, want 3 and 3", count, len(nodes)) + } +} + +func TestQuery_AllSetsTraversalDepth(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "deep", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // All(1) overrides the default traversal depth for this query; the call + // must chain and the query must still execute and decode. + got, err := c.Query(ctx).All(1).Nodes() + if err != nil { + t.Fatalf("Nodes with All(1): %v", err) + } + if len(got) != 1 { + t.Fatalf("got %d widgets, want 1", len(got)) + } +} + +func TestQuery_StringRendersDQL(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + dql := c.Query(ctx).Filter("eq(name, $1)", "sprocket").String() + if !strings.Contains(dql, "widget") { + t.Fatalf("String() = %q, want it to mention the widget type", dql) + } +} diff --git a/typed/search/merge.go b/typed/search/merge.go new file mode 100644 index 0000000..2546274 --- /dev/null +++ b/typed/search/merge.go @@ -0,0 +1,27 @@ +// Package search provides helpers for assembling fulltext / ranked search +// results across multiple typed query blocks. +package search + +// MergeByID concatenates inputs into a single slice while preserving +// first-seen order and dropping any subsequent occurrence of an ID already +// emitted. The id function extracts a comparable identifier from each row. +// +// MergeByID is intended for use after typed.MultiQuery.Execute, when +// consumers want a single ranked slice from N per-field result sets: +// inputs[0] takes priority, inputs[1] fills in next, etc. A nil result +// indicates no rows survived (the inputs were all empty). +func MergeByID[T any](id func(T) string, inputs ...[]T) []T { + seen := make(map[string]struct{}) + var out []T + for _, in := range inputs { + for _, row := range in { + k := id(row) + if _, dup := seen[k]; dup { + continue + } + seen[k] = struct{}{} + out = append(out, row) + } + } + return out +} diff --git a/typed/search/merge_test.go b/typed/search/merge_test.go new file mode 100644 index 0000000..e4e8583 --- /dev/null +++ b/typed/search/merge_test.go @@ -0,0 +1,86 @@ +package search_test + +import ( + "reflect" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/search" +) + +type rec struct { + ID string + Tag string +} + +func id(r rec) string { return r.ID } + +func TestMergeByID(t *testing.T) { + cases := []struct { + name string + inputs [][]rec + want []rec + }{ + { + name: "empty inputs returns nil", + inputs: nil, + want: nil, + }, + { + name: "single empty slice returns nil", + inputs: [][]rec{{}}, + want: nil, + }, + { + name: "single slice returns it as-is", + inputs: [][]rec{{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }}, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + { + name: "two slices merge in priority order", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "duplicate ID keeps first-seen entry", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "a", Tag: "desc"}, {ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "intra-slice duplicates dedup too", + inputs: [][]rec{ + {{ID: "a", Tag: "1"}, {ID: "a", Tag: "2"}, {ID: "b", Tag: "1"}}, + }, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := search.MergeByID(id, c.inputs...) + if !reflect.DeepEqual(got, c.want) { + t.Fatalf("got %v, want %v", got, c.want) + } + }) + } +} diff --git a/typed/tracing.go b/typed/tracing.go new file mode 100644 index 0000000..8a456ed --- /dev/null +++ b/typed/tracing.go @@ -0,0 +1,58 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "reflect" +) + +// Span is a tracing span for a single database operation. End is called once, +// with the operation's final error (nil on success). +type Span interface { + End(err error) +} + +// Tracer starts a Span around a typed-layer database operation. The typed +// client calls the installed Tracer for every DB call; the default is a no-op, +// so the typed package itself carries no tracing dependency. Install a real +// tracer — for example github.com/mlwelles/modusgraph-telemetry's OpenTelemetry +// tracer — with SetTracer. +type Tracer interface { + // StartSpan begins a span for operation op (for example "get") on the named + // collection, returning a context carrying the span and the Span itself. + StartSpan(ctx context.Context, op, collection string) (context.Context, Span) +} + +type noopSpan struct{} + +func (noopSpan) End(error) {} + +type noopTracer struct{} + +func (noopTracer) StartSpan(ctx context.Context, _, _ string) (context.Context, Span) { + return ctx, noopSpan{} +} + +// tracer is the process-wide tracer the typed package uses. It is a no-op until +// a host installs one via SetTracer. +var tracer Tracer = noopTracer{} + +// SetTracer installs the process-wide tracer for typed-layer DB spans. Passing +// nil restores the no-op tracer. Install once during startup; it is not safe to +// call concurrently with active queries. +func SetTracer(t Tracer) { + if t == nil { + t = noopTracer{} + } + tracer = t +} + +// entityName returns the unqualified Go type name of T (for example "Resource"), +// used as the db.collection.name span attribute. +func entityName[T any]() string { + return reflect.TypeFor[T]().Name() +} diff --git a/typed/tracing_test.go b/typed/tracing_test.go new file mode 100644 index 0000000..d9aab78 --- /dev/null +++ b/typed/tracing_test.go @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "testing" +) + +func TestSetTracer_InstallsAndResets(t *testing.T) { + t.Cleanup(func() { SetTracer(nil) }) + + rec := &recordingTracer{} + SetTracer(rec) + + _, span := tracer.StartSpan(context.Background(), "get", "Widget") + span.End(nil) + + if rec.op != "get" || rec.collection != "Widget" { + t.Fatalf("installed tracer not invoked: %+v", rec) + } + if !rec.ended { + t.Fatal("span.End was not called") + } + + // nil restores the no-op tracer, which must not panic. + SetTracer(nil) + _, span = tracer.StartSpan(context.Background(), "x", "Y") + span.End(nil) +} + +type recordingTracer struct { + op, collection string + ended bool +} + +func (r *recordingTracer) StartSpan(ctx context.Context, op, collection string) (context.Context, Span) { + r.op, r.collection = op, collection + return ctx, &recordingSpan{r} +} + +type recordingSpan struct{ r *recordingTracer } + +func (s *recordingSpan) End(error) { s.r.ended = true } From b477c28daec118d7eb2da55c3ed194891931381a Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:11:17 -0400 Subject: [PATCH 02/13] feat: aborted-transaction retry policy, runner, and client integration Add RetryPolicy / DefaultRetryPolicy and a runner that re-executes a function on aborted Dgraph transactions with exponential backoff (retry.go), exposed on the client via a WithRetry method. --- client.go | 3 + retry.go | 96 ++++++++++++++++++++ retry_internal_test.go | 68 ++++++++++++++ retry_test.go | 197 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 364 insertions(+) create mode 100644 retry.go create mode 100644 retry_internal_test.go create mode 100644 retry_test.go diff --git a/client.go b/client.go index be9813b..e4bb263 100644 --- a/client.go +++ b/client.go @@ -87,6 +87,9 @@ type Client interface { // DgraphClient returns a gRPC Dgraph client from the connection pool and a cleanup function. // The cleanup function must be called when finished with the client to return it to the pool. DgraphClient() (*dgo.Dgraph, func(), error) + + // WithRetry executes fn, retrying on aborted transactions per policy. + WithRetry(ctx context.Context, policy RetryPolicy, fn func() error) error } const ( diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..9b49fda --- /dev/null +++ b/retry.go @@ -0,0 +1,96 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "math/rand/v2" + "time" + + "github.com/dgraph-io/dgo/v250" +) + +// RetryPolicy controls how WithRetry handles aborted transactions. +// Modeled after dgraph4j's RetryPolicy: exponential backoff with jitter. +type RetryPolicy struct { + // MaxRetries is the maximum number of retry attempts after the initial try. + MaxRetries int + + // BaseDelay is the initial delay before the first retry. + // Subsequent delays grow exponentially: BaseDelay * 2^attempt. + BaseDelay time.Duration + + // MaxDelay caps the backoff duration. No single delay exceeds this. + MaxDelay time.Duration + + // Jitter adds randomness to each delay to prevent thundering herd. + // Expressed as a fraction of the computed delay (e.g. 0.1 = 10%). + Jitter float64 +} + +// DefaultRetryPolicy mirrors dgraph4j's defaults: +// 5 retries, 100ms base delay, 5s max delay, 10% jitter. +var DefaultRetryPolicy = RetryPolicy{ + MaxRetries: 10, + BaseDelay: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + Jitter: 0.1, +} + +// delay computes the backoff duration for a given attempt (0-indexed). +// Formula: min(BaseDelay * 2^attempt, MaxDelay) + random(0, delay * Jitter) +func (p RetryPolicy) delay(attempt int) time.Duration { + d := p.BaseDelay * time.Duration(1< p.MaxDelay { + d = p.MaxDelay + } + if p.Jitter > 0 { + d += time.Duration(float64(d) * p.Jitter * rand.Float64()) + } + return d +} + +// WithRetry executes fn, retrying on aborted transactions according to policy. +// +// This is an opt-in mechanism modeled after dgraph4j's client.withRetry(). +// The caller wraps their mutation logic in fn; WithRetry handles creating +// fresh attempts with exponential backoff when Dgraph returns a transaction +// abort due to concurrent conflicts. +// +// fn is called at least once. On each aborted-transaction error, WithRetry +// waits according to the policy's backoff schedule and calls fn again, up to +// policy.MaxRetries additional times. Non-abort errors are returned immediately. +// +// The context is checked between retries; if cancelled during a backoff sleep, +// the context error is returned. +// +// Usage: +// +// err := client.WithRetry(ctx, modusgraph.DefaultRetryPolicy, func() error { +// return client.Insert(ctx, &entity) +// }) +func (c client) WithRetry(ctx context.Context, policy RetryPolicy, fn func() error) error { + for attempt := range policy.MaxRetries + 1 { + err := fn() + if err == nil { + return nil + } + if !errors.Is(err, dgo.ErrAborted) || attempt >= policy.MaxRetries { + return err + } + d := policy.delay(attempt) + c.logger.V(1).Info("Transaction aborted, retrying", + "attempt", attempt+1, "maxRetries", policy.MaxRetries, "delay", d) + select { + case <-time.After(d): + case <-ctx.Done(): + return ctx.Err() + } + } + // Unreachable: the loop runs MaxRetries+1 times and returns on every path. + panic("unreachable") +} diff --git a/retry_internal_test.go b/retry_internal_test.go new file mode 100644 index 0000000..ce6bd2b --- /dev/null +++ b/retry_internal_test.go @@ -0,0 +1,68 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRetryPolicyDelayExponentialGrowth(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0, + } + + assert.Equal(t, 100*time.Millisecond, p.delay(0)) + assert.Equal(t, 200*time.Millisecond, p.delay(1)) + assert.Equal(t, 400*time.Millisecond, p.delay(2)) + assert.Equal(t, 800*time.Millisecond, p.delay(3)) + assert.Equal(t, 1600*time.Millisecond, p.delay(4)) +} + +func TestRetryPolicyDelayMaxCap(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 1 * time.Second, + MaxDelay: 3 * time.Second, + Jitter: 0, + } + + assert.Equal(t, 1*time.Second, p.delay(0)) + assert.Equal(t, 2*time.Second, p.delay(1)) + assert.Equal(t, 3*time.Second, p.delay(2)) + assert.Equal(t, 3*time.Second, p.delay(3)) + assert.Equal(t, 3*time.Second, p.delay(10)) +} + +func TestRetryPolicyDelayWithJitter(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0.5, + } + + for range 100 { + d := p.delay(0) + assert.GreaterOrEqual(t, d, 100*time.Millisecond, "delay should be at least base") + assert.LessOrEqual(t, d, 150*time.Millisecond, "delay should not exceed base + 50% jitter") + } +} + +func TestRetryPolicyDelayZeroJitter(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0, + } + + for range 10 { + assert.Equal(t, 100*time.Millisecond, p.delay(0)) + assert.Equal(t, 200*time.Millisecond, p.delay(1)) + } +} diff --git a/retry_test.go b/retry_test.go new file mode 100644 index 0000000..4cb0d86 --- /dev/null +++ b/retry_test.go @@ -0,0 +1,197 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "fmt" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/matthewmcneely/modusgraph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RetryEntity is a test struct with a unique index to provoke transaction conflicts. +type RetryEntity struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=term,exact upsert"` + Value int `json:"value,omitempty"` +} + +// TestConcurrentInsertsWithRetry verifies that WithRetry handles aborted +// transactions from concurrent inserts. Without WithRetry, concurrent inserts +// on the same predicate index would fail with dgo.ErrAborted. +func TestConcurrentInsertsWithRetry(t *testing.T) { + testCases := []struct { + name string + uri string + skip bool + }{ + { + name: "FileURI", + uri: "file://" + GetTempDir(t), + }, + { + name: "DgraphURI", + uri: "dgraph://" + os.Getenv("MODUSGRAPH_TEST_ADDR"), + skip: os.Getenv("MODUSGRAPH_TEST_ADDR") == "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.skip { + t.Skipf("Skipping %s: MODUSGRAPH_TEST_ADDR not set", tc.name) + return + } + + client, cleanup := CreateTestClient(t, tc.uri) + defer cleanup() + + ctx := context.Background() + const numWorkers = 8 + const entitiesPerWorker = 10 + + var succeeded atomic.Int64 + var wg sync.WaitGroup + + for w := range numWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for i := range entitiesPerWorker { + entity := &RetryEntity{ + Name: fmt.Sprintf("entity-%d-%d", w, i), + Value: w*entitiesPerWorker + i, + } + err := client.WithRetry(ctx, modusgraph.DefaultRetryPolicy, func() error { + return client.Insert(ctx, entity) + }) + if err != nil { + t.Errorf("worker %d entity %d: %v", w, i, err) + return + } + succeeded.Add(1) + } + }() + } + wg.Wait() + + total := int64(numWorkers * entitiesPerWorker) + require.Equal(t, total, succeeded.Load(), + "all concurrent inserts should succeed with retry") + }) + } +} + +// TestWithRetryContextCancellation verifies that WithRetry respects context +// cancellation during backoff sleeps. +func TestWithRetryContextCancellation(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + // Use a policy with a long delay so the context expires during backoff. + slowPolicy := modusgraph.RetryPolicy{ + MaxRetries: 10, + BaseDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + Jitter: 0, + } + + callCount := 0 + err := client.WithRetry(ctx, slowPolicy, func() error { + callCount++ + // Always return an error that looks like an abort to trigger retry. + // We simulate this by inserting a duplicate to get a UniqueError, + // but that won't be retried. Instead, use a real insert to a fresh + // entity so the first call succeeds. + // Actually, to test the cancellation path we need the fn to always + // fail with an aborted error. Since we can't easily manufacture + // dgo.ErrAborted, test that context cancellation returns ctx.Err() + // by having fn block until context is done. + <-ctx.Done() + return ctx.Err() + }) + + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, 1, callCount, "fn should be called once before context expires") +} + +// TestRetryPolicyDelay verifies the exponential backoff calculation. +func TestRetryPolicyDelay(t *testing.T) { + // Use the public struct fields to verify delay behavior indirectly + // by checking that DefaultRetryPolicy has the expected values. + p := modusgraph.DefaultRetryPolicy + assert.Equal(t, 10, p.MaxRetries) + assert.Equal(t, 100*time.Millisecond, p.BaseDelay) + assert.Equal(t, 5*time.Second, p.MaxDelay) + assert.InDelta(t, 0.1, p.Jitter, 0.001) +} + +// TestWithRetryNonAbortError verifies that non-abort errors are returned +// immediately without any retry. +func TestWithRetryNonAbortError(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + callCount := 0 + expectedErr := fmt.Errorf("not an abort error") + + err := client.WithRetry(context.Background(), modusgraph.DefaultRetryPolicy, func() error { + callCount++ + return expectedErr + }) + + assert.ErrorIs(t, err, expectedErr) + assert.Equal(t, 1, callCount, "non-abort errors should not trigger retry") +} + +// TestWithRetrySucceedsFirstTry verifies that WithRetry returns nil +// when fn succeeds on the first call. +func TestWithRetrySucceedsFirstTry(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + callCount := 0 + err := client.WithRetry(context.Background(), modusgraph.DefaultRetryPolicy, func() error { + callCount++ + return nil + }) + + assert.NoError(t, err) + assert.Equal(t, 1, callCount) +} + +// TestWithRetryMaxRetriesZero verifies that MaxRetries=0 calls fn exactly once +// and returns any error without retrying. +func TestWithRetryMaxRetriesZero(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + policy := modusgraph.RetryPolicy{MaxRetries: 0} + callCount := 0 + + err := client.WithRetry(context.Background(), policy, func() error { + callCount++ + return fmt.Errorf("always fails") + }) + + assert.Error(t, err) + assert.Equal(t, 1, callCount, "MaxRetries=0 should call fn exactly once") +} From b74cec750218c51c12e8282a86ebdbce1d8ebc19 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:48:08 -0400 Subject: [PATCH 03/13] feat: recognize generated schema types via SchemaTypeName + UnwrapSchema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the Schema interface (SchemaTypeName), the UnwrapSchema reflection helper, and the DgraphMapper interface (record.go). The client unwraps schema-defining values at the mutation and query boundary so generated wrapper types route to their backing schema struct. Plain structs do not implement Schema and are unaffected — UnwrapSchema is identity for them. --- client.go | 9 ++++ record.go | 58 ++++++++++++++++++++++++ record_test.go | 117 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 record.go create mode 100644 record_test.go diff --git a/client.go b/client.go index be9813b..14e38ee 100644 --- a/client.go +++ b/client.go @@ -486,6 +486,7 @@ func (c client) validateStruct(ctx context.Context, obj any) error { // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -503,6 +504,7 @@ func (c client) Insert(ctx context.Context, obj any) error { // // Deprecated: InsertRaw is now identical to Insert. Use Insert instead. func (c client) InsertRaw(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -518,6 +520,7 @@ func (c client) InsertRaw(ctx context.Context, obj any) error { // to be used for upserting. If none are specified, the first predicate with the `upsert` tag // will be used. func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error { + obj = UnwrapSchema(obj) // Validate struct before upsert if err := c.validateStruct(ctx, obj); err != nil { return err @@ -531,6 +534,7 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before update if err := c.validateStruct(ctx, obj); err != nil { return err @@ -557,6 +561,7 @@ func (c client) Delete(ctx context.Context, uids []string) error { // Get implements retrieving a single object by its UID. // Passed object must be a pointer to a struct. func (c client) Get(ctx context.Context, obj any, uid string) error { + obj = UnwrapSchema(obj) err := checkPointer(obj) if err != nil { return err @@ -575,6 +580,7 @@ func (c client) Get(ctx context.Context, obj any, uid string) error { // Returns a *dg.Query that can be further refined with filters, pagination, etc. // The returned query will be limited to the maximum number of edges specified in the options. func (c client) Query(ctx context.Context, model any) *dg.Query { + model = UnwrapSchema(model) client, err := c.pool.get() if err != nil { return nil @@ -590,6 +596,9 @@ func (c client) Query(ctx context.Context, model any) *dg.Query { // If any object contains SimString fields tagged `dgraph:"embedding"`, the // corresponding shadow float32vector predicates (__vec) are also registered. func (c client) UpdateSchema(ctx context.Context, obj ...any) error { + for i := range obj { + obj[i] = UnwrapSchema(obj[i]) + } dgClient, err := c.pool.get() if err != nil { c.logger.Error(err, "Failed to get client from pool") diff --git a/record.go b/record.go new file mode 100644 index 0000000..015c587 --- /dev/null +++ b/record.go @@ -0,0 +1,58 @@ +package modusgraph + +import "reflect" + +// Schema identifies a value as a record of a generated schema-defining type. +// modusgraph-gen-emitted schema structs implement this via a generated +// SchemaTypeName() method that returns the canonical entity name +// (e.g. "Studio"). The interface is intentionally minimal — a single method +// returning a useful piece of metadata. +// +// Plain user structs (not emitted by modusgraph-gen) do not implement Schema +// and are unaffected by the modusgraph.Client routing it enables; they pass +// through to the existing reflection-based dgman pipeline exactly as before. +type Schema interface { + SchemaTypeName() string +} + +// UnwrapSchema returns the schema-defining record contained in obj. If obj +// is nil, it is returned as-is. If obj is already a Schema, it is returned +// as-is. If obj exposes an Unwrap() method whose return value satisfies +// Schema, that return is substituted. Otherwise obj is returned unchanged. +// +// This is the bridge between modusgraph-gen-emitted wrapper types and the +// rest of modusgraph.Client. It is purely additive: types that don't +// implement Schema and don't have an Unwrap() method (i.e. existing +// modusgraph users' plain structs) pass through untouched. +// +// Note on errors.Unwrap overlap: Go's errors package uses Unwrap() error +// as the standard "give me the wrapped thing" method. UnwrapSchema's +// secondary check (the returned value must itself implement Schema) means +// an error wrapper is not mistaken for a modusgraph wrapper — the +// reflection probe finds Unwrap(), calls it, gets an error, fails the +// Schema check, and returns the original obj. +func UnwrapSchema(obj any) any { + if obj == nil { + return obj + } + if _, ok := obj.(Schema); ok { + return obj + } + v := reflect.ValueOf(obj) + if !v.IsValid() { + return obj + } + m := v.MethodByName("Unwrap") + if !m.IsValid() { + return obj + } + mt := m.Type() + if mt.NumIn() != 0 || mt.NumOut() != 1 { + return obj + } + inner := m.Call(nil)[0].Interface() + if _, ok := inner.(Schema); ok { + return inner + } + return obj +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 0000000..1f6ef72 --- /dev/null +++ b/record_test.go @@ -0,0 +1,117 @@ +package modusgraph + +import ( + "errors" + "testing" +) + +type fakeRecord struct{ name string } + +func (f *fakeRecord) SchemaTypeName() string { return f.name } + +type fakeWrapper struct{ inner *fakeRecord } + +func (w *fakeWrapper) Unwrap() *fakeRecord { return w.inner } + +type fakeNonSchema struct{ X string } + +func TestUnwrapSchema_PassthroughForPlainStruct(t *testing.T) { + in := &fakeNonSchema{X: "hi"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough, got %T", out) + } +} + +func TestUnwrapSchema_PassthroughForSchemaStruct(t *testing.T) { + in := &fakeRecord{name: "Studio"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough for direct Schema, got %T", out) + } +} + +func TestUnwrapSchema_UnwrapsWrapper(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + out := UnwrapSchema(w) + if out != any(inner) { + t.Fatalf("expected unwrapped inner, got %T (%v)", out, out) + } +} + +func TestUnwrapSchema_IgnoresErrorsUnwrap(t *testing.T) { + // errors.New("x") has no Unwrap; wrap one to get something with Unwrap() error. + inner := errors.New("inner") + outer := &wrappedErr{err: inner} + out := UnwrapSchema(outer) + if out != any(outer) { + t.Fatalf("expected passthrough for error wrapper, got %T", out) + } +} + +type wrappedErr struct{ err error } + +func (w *wrappedErr) Error() string { return w.err.Error() } +func (w *wrappedErr) Unwrap() error { return w.err } + +func TestUnwrapSchema_NilInput(t *testing.T) { + if out := UnwrapSchema(nil); out != nil { + t.Fatalf("expected nil for nil input, got %v", out) + } +} + +// recordingClient is the minimal surface needed to verify that wrappers +// passed to the Client interface get unwrapped before reaching internal +// reflection. It records whatever it received and returns nil. Each method +// applies obj = UnwrapSchema(obj) at the top, mirroring the patch landing +// in this task. +type recordingClient struct { + seen []any +} + +func (c *recordingClient) capture(obj any) any { + obj = UnwrapSchema(obj) + c.seen = append(c.seen, obj) + return obj +} + +func TestUnwrapSchema_CaptureForwardsInner(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + c := &recordingClient{} + got := c.capture(w) + if got != any(inner) { + t.Fatalf("expected inner record, got %T (%v)", got, got) + } + if len(c.seen) != 1 || c.seen[0] != any(inner) { + t.Fatalf("expected recording to hold inner record, got %v", c.seen) + } +} + +func TestUnwrapSchema_CapturePassthroughForPlain(t *testing.T) { + plain := &fakeNonSchema{X: "y"} + c := &recordingClient{} + got := c.capture(plain) + if got != any(plain) { + t.Fatalf("expected plain struct passthrough, got %T", got) + } +} + +func TestUnwrapSchema_VariadicUnwrapsEachElement(t *testing.T) { + innerA := &fakeRecord{name: "Studio"} + innerB := &fakeRecord{name: "Film"} + templates := []any{ + &fakeWrapper{inner: innerA}, + innerB, // already a Schema; passthrough + } + for i, obj := range templates { + templates[i] = UnwrapSchema(obj) + } + if templates[0] != any(innerA) { + t.Fatalf("template[0]: expected innerA, got %T", templates[0]) + } + if templates[1] != any(innerB) { + t.Fatalf("template[1]: expected innerB (passthrough), got %T", templates[1]) + } +} From a4997cf7a0ae6b8b2c6d75d09f068c13bd5879c0 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:00:04 -0400 Subject: [PATCH 04/13] ci: drop redundant Dgraph standalone from -short unit job The unit-test job runs `go test -short`, which skips every test that needs a live Dgraph. Standing up a dgraph/standalone container (and setting MODUSGRAPH_TEST_ADDR) therefore adds setup the job never uses. Remove both; the integration and load suites keep their own dedicated jobs. --- .github/workflows/ci-go-unit-tests.yaml | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/.github/workflows/ci-go-unit-tests.yaml b/.github/workflows/ci-go-unit-tests.yaml index 3623538..31878ab 100644 --- a/.github/workflows/ci-go-unit-tests.yaml +++ b/.github/workflows/ci-go-unit-tests.yaml @@ -39,22 +39,5 @@ jobs: go-version: 1.25.0 cache-dependency-path: go.sum - - name: Set up Dgraph - if: matrix.os == 'linux' - run: | - docker run -d --name dgraph-standalone -p 9080:9080 -p 8080:8080 dgraph/standalone:latest - echo "Waiting for Dgraph to be ready..." - for i in {1..30}; do - if curl -s http://localhost:8080/health > /dev/null; then - echo "Dgraph is ready!" - break - fi - echo "Attempt $i: Dgraph not ready, waiting..." - sleep 2 - done - sleep 5 - - name: Run Unit Tests - env: - MODUSGRAPH_TEST_ADDR: ${{ matrix.os == 'linux' && 'localhost:9080' || '' }} run: go test -short -race -v . From 616065df67aa02200717d6ad71e3f4f75b641720 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:00:09 -0400 Subject: [PATCH 05/13] chore: ignore IDE dirs, query binary, benchmark output, worktrees Add common local artifacts to .gitignore: editor config (.idea/, .vscode/), the built ./query binary, load_test benchmark JSON, and git worktrees. --- .gitignore | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.gitignore b/.gitignore index a63304e..b75ce0d 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,16 @@ go.work.sum .env cpu_profile.prof + +# IDE config +.idea/ +.vscode/ + +# Built query binary +/query + +# Benchmark result files +load_test/*.json + +# git worktrees +.worktrees/ From eee3d018f8f11445e93abbe2e2804529a541993d Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:06:19 -0400 Subject: [PATCH 06/13] feat: WithGRPCDialOption for custom gRPC dial settings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds WithGRPCDialOption(opt grpc.DialOption), a general escape hatch for gRPC dial settings the dedicated options do not cover — TLS transport credentials, interceptors, keepalive, and so on — on remote (dgraph://) connections. The existing WithMaxRecvMsgSize is folded into the same dial-option assembly, so the two compose cleanly, and the client dedup key now counts the custom dial options so differently-configured clients are not merged. No change for embedded (file://) URIs. --- client.go | 33 ++++++++++++++++++++++++++++----- dial_options_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 dial_options_test.go diff --git a/client.go b/client.go index be9813b..7834db2 100644 --- a/client.go +++ b/client.go @@ -124,6 +124,7 @@ type clientOptions struct { maxEdgeTraversal int cacheSizeMB int maxRecvMsgSize int + grpcDialOptions []grpc.DialOption namespace string logger logr.Logger validator StructValidator @@ -189,6 +190,18 @@ func WithMaxRecvMsgSize(size int) ClientOpt { } } +// WithGRPCDialOption appends a custom grpc.DialOption applied when opening a +// remote (dgraph://) connection. It is the general escape hatch for gRPC dial +// settings the dedicated options do not cover — TLS transport credentials, +// interceptors, keepalive parameters, and so on. May be supplied multiple +// times; the options are applied in the order given, after any option implied +// by WithMaxRecvMsgSize. Ignored for embedded (file://) URIs. +func WithGRPCDialOption(opt grpc.DialOption) ClientOpt { + return func(o *clientOptions) { + o.grpcDialOptions = append(o.grpcDialOptions, opt) + } +} + // WithValidator sets a validator instance for struct validation. // The validator will be used to validate structs before insert, upsert, and update operations. // If no validator is provided, validation will be skipped. @@ -279,16 +292,26 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) { client.logger.V(2).Info("Opening new Dgraph connection", "uri", uri) return dgo.Open(uri) } + // Assemble any custom gRPC dial options. maxRecvMsgSize is folded + // into the same mechanism as WithGRPCDialOption so the two compose. + var dialOpts []grpc.DialOption if options.maxRecvMsgSize > 0 { + dialOpts = append(dialOpts, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize))) + } + dialOpts = append(dialOpts, options.grpcDialOptions...) + if len(dialOpts) > 0 { endpoint, dgoOpts, err := parseDgraphURI(uri) if err != nil { return nil, err } - dgoOpts = append(dgoOpts, dgo.WithGrpcOption( - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize)))) + for _, opt := range dialOpts { + dgoOpts = append(dgoOpts, dgo.WithGrpcOption(opt)) + } factory = func() (*dgo.Dgraph, error) { client.logger.V(2).Info("Opening new Dgraph connection", - "uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize) + "uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize, + "grpcDialOptions", len(options.grpcDialOptions)) return dgo.NewClient(endpoint, dgoOpts...) } } @@ -430,9 +453,9 @@ func (c client) key() string { if c.options.embeddingProvider != nil { embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider) } - return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, + return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s:%d", c.uri, c.options.autoSchema, c.options.poolSize, c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.maxRecvMsgSize, - c.options.namespace, validatorKey, embeddingKey) + c.options.namespace, validatorKey, embeddingKey, len(c.options.grpcDialOptions)) } // embeddingProvider implements the embeddingClient interface, exposing the diff --git a/dial_options_test.go b/dial_options_test.go new file mode 100644 index 0000000..c64e257 --- /dev/null +++ b/dial_options_test.go @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "testing" + + "google.golang.org/grpc" +) + +func TestWithGRPCDialOptionAppends(t *testing.T) { + var o clientOptions + WithGRPCDialOption(grpc.WithUserAgent("a"))(&o) + WithGRPCDialOption(grpc.WithUserAgent("b"))(&o) + if got := len(o.grpcDialOptions); got != 2 { + t.Fatalf("expected 2 dial options, got %d", got) + } +} + +func TestKeyDistinguishesGRPCDialOptions(t *testing.T) { + base := client{uri: "dgraph://localhost:9080"} + withOpt := client{uri: "dgraph://localhost:9080"} + WithGRPCDialOption(grpc.WithUserAgent("x"))(&withOpt.options) + if base.key() == withOpt.key() { + t.Fatal("client.key() must differ when grpcDialOptions differ, else clients dedup incorrectly") + } +} From bd16559b97c3206c9f4897318ed994b6b9c2cb7c Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:11:23 -0400 Subject: [PATCH 07/13] feat: AlterSchema, dropPredicate, and embedded DropAttr MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds raw schema-DDL primitives that complement UpdateSchema's object-template inference: - Client.AlterSchema(ctx, schema) applies a raw DQL schema string directly, giving full control over predicate types, indexes, and directives — useful for migrations that declare predicates no Go type models yet. - Engine.dropPredicate deletes a single predicate (and its data) from the embedded engine via posting.DeletePredicate. - embedded_client.go routes an Alter carrying DropAttr to dropPredicate, so the embedded path matches a remote Dgraph cluster's DropAttr behavior. TestDropPredicateEmbedded exercises the full declare/insert/drop cycle against the embedded engine. --- client.go | 19 +++++++++++ embedded_client.go | 6 ++++ engine.go | 19 +++++++++++ schema_ddl_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+) create mode 100644 schema_ddl_test.go diff --git a/client.go b/client.go index be9813b..2189a3a 100644 --- a/client.go +++ b/client.go @@ -69,6 +69,12 @@ type Client interface { // Pass one or more objects that will be used as templates for the schema. UpdateSchema(context.Context, ...any) error + // AlterSchema applies a raw Dgraph Schema Definition Language string directly, + // bypassing the object-template inference of UpdateSchema. Use it when you need + // full control over predicate types, indexes, and directives — for example, + // schema migrations that declare predicates no Go type models yet. + AlterSchema(ctx context.Context, schema string) error + // GetSchema retrieves the current schema definition from the database. // Returns a string containing the full schema in Dgraph Schema Definition Language. GetSchema(context.Context) (string, error) @@ -585,6 +591,19 @@ func (c client) Query(ctx context.Context, model any) *dg.Query { return txn.Get(model).All(c.options.maxEdgeTraversal) } +// AlterSchema applies a raw DQL schema string directly via Dgraph Alter, +// without the object-template inference performed by UpdateSchema. +func (c client) AlterSchema(ctx context.Context, schema string) error { + dgClient, err := c.pool.get() + if err != nil { + c.logger.Error(err, "Failed to get client from pool") + return err + } + defer c.pool.put(dgClient) + + return dgClient.Alter(ctx, &api.Operation{Schema: schema}) +} + // UpdateSchema implements updating the Dgraph schema. Pass one or more // objects that will be used to generate the schema. // If any object contains SimString fields tagged `dgraph:"embedding"`, the diff --git a/embedded_client.go b/embedded_client.go index 329f4bf..fbe993d 100644 --- a/embedded_client.go +++ b/embedded_client.go @@ -146,6 +146,12 @@ func (c *embeddedDgraphClient) Alter( } return &api.Payload{}, nil } + if in.DropAttr != "" { + if err := c.engine.dropPredicate(ctx, c.ns, in.DropAttr); err != nil { + return nil, err + } + return &api.Payload{}, nil + } if in.Schema != "" { if err := c.engine.alterSchema(ctx, c.ns, in.Schema); err != nil { return nil, err diff --git a/engine.go b/engine.go index d9d236d..34e7f41 100644 --- a/engine.go +++ b/engine.go @@ -271,6 +271,25 @@ func (engine *Engine) dropData(ctx context.Context, ns *Namespace) error { return nil } +// dropPredicate deletes a single predicate (and its data) from the embedded +// engine — the in-process equivalent of a gRPC Alter with DropAttr set. +func (engine *Engine) dropPredicate(ctx context.Context, ns *Namespace, pred string) error { + engine.mutex.Lock() + defer engine.mutex.Unlock() + + if !engine.isOpen.Load() { + return ErrClosedEngine + } + + startTs, err := engine.z.nextTs() + if err != nil { + return err + } + + nsAttr := x.NamespaceAttr(ns.ID(), pred) + return posting.DeletePredicate(ctx, nsAttr, startTs) +} + func (engine *Engine) alterSchema(ctx context.Context, ns *Namespace, sch string) error { engine.mutex.Lock() defer engine.mutex.Unlock() diff --git a/schema_ddl_test.go b/schema_ddl_test.go new file mode 100644 index 0000000..8315dd2 --- /dev/null +++ b/schema_ddl_test.go @@ -0,0 +1,84 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "os" + "testing" + + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/stretchr/testify/require" +) + +// TestDropPredicateEmbedded exercises the schema-DDL surface end-to-end: +// Client.AlterSchema declares a raw predicate, and a gRPC Alter with DropAttr +// routes through embedded_client.go's DropAttr arm into engine.dropPredicate. +func TestDropPredicateEmbedded(t *testing.T) { + testCases := []struct { + name string + uri string + skip bool + }{ + { + name: "DropPredicateWithFileURI", + uri: "file://" + GetTempDir(t), + }, + { + name: "DropPredicateWithDgraphURI", + uri: "dgraph://" + os.Getenv("MODUSGRAPH_TEST_ADDR"), + skip: os.Getenv("MODUSGRAPH_TEST_ADDR") == "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.skip { + t.Skipf("Skipping %s: MODUSGRAPH_TEST_ADDR not set", tc.name) + return + } + + client, cleanup := CreateTestClient(t, tc.uri) + defer cleanup() + + ctx := context.Background() + + // Declare an indexed string predicate and insert a node carrying it. + err := client.AlterSchema(ctx, "dropme: string @index(exact) .") + require.NoError(t, err, "AlterSchema should succeed") + + dg, dgCleanup, err := client.DgraphClient() + require.NoError(t, err, "DgraphClient should succeed") + defer dgCleanup() + + _, err = dg.NewTxn().Mutate(ctx, &api.Mutation{ + SetJson: []byte(`[{"dropme":"hello"}]`), + CommitNow: true, + }) + require.NoError(t, err, "mutate should succeed") + + // Confirm the predicate is present before the drop. + raw, err := client.QueryRaw(ctx, `{ q(func: has(dropme)) { c: count(uid) } }`, nil) + require.NoError(t, err, "count query should succeed") + require.Contains(t, string(raw), `"c":1`, "predicate present before drop") + + // Drop the predicate via the public path; this exercises + // embedded_client.go's DropAttr arm + engine.dropPredicate. + err = dg.Alter(ctx, &api.Operation{DropAttr: "dropme"}) + require.NoError(t, err, "DropAttr should succeed") + + // Confirm the data is gone (no nodes have the predicate). + raw, err = client.QueryRaw(ctx, `{ q(func: has(dropme)) { c: count(uid) } }`, nil) + require.NoError(t, err, "count query should succeed after drop") + require.Contains(t, string(raw), `"c":0`, "predicate values gone after drop") + + // Confirm the schema entry is gone. + raw, err = client.QueryRaw(ctx, `schema(pred: [dropme]) { type }`, nil) + require.NoError(t, err, "schema query should succeed after drop") + require.NotContains(t, string(raw), "dropme", "predicate schema entry gone after drop") + }) + } +} From 501b3ef55db7cb7c7b98b628293b66ece37cb492 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:14:08 -0400 Subject: [PATCH 08/13] feat: SelfValidator for private-field validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds SelfValidator, an opt-in seam that lets a type drive its own validation. When a value passed to Insert, Upsert, or Update implements SelfValidator, the client calls ValidateWith instead of handing the value straight to the configured StructValidator. validateStruct now routes each element through a new validateOne helper that detects SelfValidator (on the value or its address) and otherwise falls back to StructCtx exactly as before — behavior is unchanged for ordinary structs. This is the runtime seam generated entities use to validate unexported fields: the generated ValidateWith builds a mirror struct with exported fields the go-playground validator can read by reflection. --- client.go | 30 +++++++++++++++++-- self_validator_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 self_validator_test.go diff --git a/client.go b/client.go index be9813b..5edeaf1 100644 --- a/client.go +++ b/client.go @@ -109,6 +109,16 @@ type StructValidator interface { StructCtx(ctx context.Context, s interface{}) error } +// SelfValidator lets a type drive its own validation. When a value passed to +// Insert, Upsert, or Update implements SelfValidator, the client calls +// ValidateWith instead of handing the value straight to the configured +// StructValidator. This is the seam generated entities use to validate private +// fields: the generated ValidateWith builds a mirror struct with exported +// fields the underlying go-playground validator can read by reflection. +type SelfValidator interface { + ValidateWith(ctx context.Context, v StructValidator) error +} + // clientOptions holds configuration options for the client. // // autoSchema: whether to automatically manage the schema. @@ -472,17 +482,33 @@ func (c client) validateStruct(ctx context.Context, obj any) error { } elem = elem.Elem() } - if err := c.options.validator.StructCtx(ctx, elem.Interface()); err != nil { + if err := c.validateOne(ctx, elem); err != nil { return err } } } else { - return c.options.validator.StructCtx(ctx, obj) + return c.validateOne(ctx, val) } return nil } +// validateOne validates a single struct value. If the value (or its address) +// implements SelfValidator, validation is delegated to ValidateWith so the type +// can validate fields the configured StructValidator cannot reach directly — +// for example unexported fields exposed through a generated mirror struct. +// Otherwise the value is validated by the configured StructValidator as usual. +func (c client) validateOne(ctx context.Context, val reflect.Value) error { + iface := val.Interface() + if val.CanAddr() { + iface = val.Addr().Interface() + } + if sv, ok := iface.(SelfValidator); ok { + return sv.ValidateWith(ctx, c.options.validator) + } + return c.options.validator.StructCtx(ctx, iface) +} + // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { diff --git a/self_validator_test.go b/self_validator_test.go new file mode 100644 index 0000000..b7620ad --- /dev/null +++ b/self_validator_test.go @@ -0,0 +1,65 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "testing" +) + +// recordingValidator counts StructCtx calls so tests can assert which path ran. +type recordingValidator struct{ calls int } + +func (r *recordingValidator) StructCtx(_ context.Context, _ interface{}) error { + r.calls++ + return nil +} + +var errSelfValidated = errors.New("self-validated") + +type selfValidatingEntity struct{ Name string } + +func (s *selfValidatingEntity) ValidateWith(_ context.Context, _ StructValidator) error { + return errSelfValidated +} + +type plainEntity struct{ Name string } + +func TestValidateRoutesToSelfValidator(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), &selfValidatingEntity{Name: "x"}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path, got %v", err) + } + if rv.calls != 0 { + t.Fatalf("StructCtx must not run for a SelfValidator, got %d calls", rv.calls) + } +} + +func TestValidateFallsBackToStructCtx(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + if err := c.validateStruct(context.Background(), &plainEntity{Name: "x"}); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rv.calls != 1 { + t.Fatalf("expected StructCtx to run once, got %d", rv.calls) + } +} + +func TestValidateSelfValidatorInSlice(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), []*selfValidatingEntity{{Name: "a"}}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path for slice elements, got %v", err) + } +} From 0ae8003114a6eacfaa4e5aa73407767c29629898 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:01:09 -0400 Subject: [PATCH 09/13] feat: add Client.LoadOrStore (insert-if-absent) --- client.go | 35 +++++++++++++++++++++++++++++++++ consume_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 consume_test.go diff --git a/client.go b/client.go index c78aff3..267a95c 100644 --- a/client.go +++ b/client.go @@ -46,6 +46,11 @@ type Client interface { // will be used. Upsert(context.Context, any, ...string) error + // LoadOrStore stores the object only if no node matches the upsert + // predicates, returning loaded=true when an existing node already matched + // (the object is then populated from it). Insert-if-absent. + LoadOrStore(ctx context.Context, obj any, predicates ...string) (loaded bool, err error) + // Update modifies an existing object in the database. // The object must be a pointer to a struct and must have a UID field set. Update(context.Context, any) error @@ -589,6 +594,36 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error }) } +// LoadOrStore stores obj only if no node already matches the upsert predicates, +// reporting whether one already existed (loaded == true). Built on dgman +// MutateOrGet, which returns the UIDs of newly created nodes only: an empty +// result means an existing node matched, and obj is populated with its fields. +// With no predicates, the first field tagged dgraph:"upsert" is used. +func (c client) LoadOrStore(ctx context.Context, obj any, predicates ...string) (loaded bool, err error) { + obj = UnwrapSchema(obj) + if err := c.validateStruct(ctx, obj); err != nil { + return false, err + } + + dgClient, err := c.pool.get() + if err != nil { + c.logger.Error(err, "Failed to get client from pool") + return false, err + } + defer c.pool.put(dgClient) + + tx := dg.NewTxnContext(ctx, dgClient).SetCommitNow() + uids, err := tx.MutateOrGet(obj, predicates...) + if err != nil { + if uniqueErr := parseUniqueError(err); uniqueErr != nil { + return false, uniqueErr + } + return false, err + } + // MutateOrGet returns created UIDs only; empty => an existing node matched. + return len(uids) == 0, nil +} + // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { diff --git a/consume_test.go b/consume_test.go new file mode 100644 index 0000000..97862e2 --- /dev/null +++ b/consume_test.go @@ -0,0 +1,52 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" +) + +type consumeJTI struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + JTI string `json:"jti,omitempty" dgraph:"index=hash upsert unique"` +} + +func newConsumeClient(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestLoadOrStore(t *testing.T) { + conn := newConsumeClient(t) + ctx := context.Background() + + first := &consumeJTI{JTI: "abc"} + loaded, err := conn.LoadOrStore(ctx, first, "jti") + if err != nil { + t.Fatalf("first LoadOrStore: %v", err) + } + if loaded { + t.Fatal("first store: want loaded=false (newly created)") + } + + second := &consumeJTI{JTI: "abc"} + loaded, err = conn.LoadOrStore(ctx, second, "jti") + if err != nil { + t.Fatalf("second LoadOrStore: %v", err) + } + if !loaded { + t.Fatal("second store: want loaded=true (already existed)") + } +} From d7811e1cc0686463ce78f55258460f90575cb703 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:04:29 -0400 Subject: [PATCH 10/13] feat: add Client.LoadAndDelete (atomic read-and-consume) --- client.go | 136 ++++++++++++++++++++++++++++++++++++++++++++++++ consume_test.go | 37 +++++++++++++ 2 files changed, 173 insertions(+) diff --git a/client.go b/client.go index 267a95c..51b52e2 100644 --- a/client.go +++ b/client.go @@ -51,6 +51,11 @@ type Client interface { // (the object is then populated from it). Insert-if-absent. LoadOrStore(ctx context.Context, obj any, predicates ...string) (loaded bool, err error) + // LoadAndDelete atomically reads the node whose key predicate equals key + // into obj and deletes it, returning loaded=false when none matched. + // Read-and-consume; concurrent callers elect one winner. + LoadAndDelete(ctx context.Context, obj any, key any, predicates ...string) (loaded bool, err error) + // Update modifies an existing object in the database. // The object must be a pointer to a struct and must have a UID field set. Update(context.Context, any) error @@ -624,6 +629,137 @@ func (c client) LoadOrStore(ctx context.Context, obj any, predicates ...string) return len(uids) == 0, nil } +// firstUpsertPredicate returns the Dgraph predicate name of the first field +// tagged dgraph:"...upsert...". The predicate defaults to the json tag name +// unless an explicit predicate= token is present. It returns "" if no upsert +// field exists. +func firstUpsertPredicate(obj any) string { + v := reflect.ValueOf(obj) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + t := v.Type() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + dgTag := f.Tag.Get("dgraph") + if !strings.Contains(dgTag, "upsert") { + continue + } + // Explicit predicate= wins. + for _, directive := range strings.Fields(dgTag) { + if strings.HasPrefix(directive, "predicate=") { + return strings.TrimPrefix(directive, "predicate=") + } + } + // Otherwise fall back to the json tag name. + if jsonTag := f.Tag.Get("json"); jsonTag != "" { + return strings.Split(jsonTag, ",")[0] + } + return f.Name + } + return "" +} + +// uidOf reflects out the UID field of a dgraph struct pointer. +func uidOf(obj any) string { + v := reflect.ValueOf(obj) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + f := v.FieldByName("UID") + if f.IsValid() && f.Kind() == reflect.String { + return f.String() + } + return "" +} + +// LoadAndDelete atomically reads the node whose key predicate equals key into +// obj and deletes it, returning loaded=false (and leaving obj zero) when no +// node matched. The read and delete share one transaction with no CommitNow, +// so two concurrent callers conflict on commit: exactly one wins (loaded=true), +// the loser aborts and retries into not-found (loaded=false). This reproduces +// PostgreSQL's DELETE … RETURNING. With no predicates, the first dgraph:"upsert" +// field is used. +func (c client) LoadAndDelete(ctx context.Context, obj any, key any, predicates ...string) (loaded bool, err error) { + obj = UnwrapSchema(obj) + if err := checkPointer(obj); err != nil { + return false, err + } + + pred := "" + if len(predicates) > 0 { + pred = predicates[0] + } else { + pred = firstUpsertPredicate(obj) + } + if pred == "" { + return false, fmt.Errorf("LoadAndDelete: no key predicate (pass one or tag a field dgraph:\"upsert\")") + } + + dgClient, err := c.pool.get() + if err != nil { + c.logger.Error(err, "Failed to get client from pool") + return false, err + } + defer c.pool.put(dgClient) + + // Bounded retry: Dgraph aborts the loser of a commit conflict; the retry + // reads the node already gone and reports not-found. + const maxAttempts = 10 + for attempt := 0; ; attempt++ { + tx := dg.NewTxnContext(ctx, dgClient) + getErr := tx.Get(obj). + Filter("eq("+pred+", $1)", key). + All(c.options.maxEdgeTraversal). + Node() + if getErr != nil { + _ = tx.Discard() + // dgman returns ErrNodeNotFound when nothing matches. + if errors.Is(getErr, dg.ErrNodeNotFound) { + return false, nil + } + return false, getErr + } + + uid := uidOf(obj) + if uid == "" { + _ = tx.Discard() + return false, nil + } + + if delErr := tx.DeleteNode(uid); delErr != nil { + _ = tx.Discard() + return false, delErr + } + + if cErr := tx.Commit(); cErr != nil { + _ = tx.Discard() + if isAbortedErr(cErr) { + // Lost the race or a concurrent change; retry — the winner has + // already deleted the node, so the retry reads not-found. + if attempt < maxAttempts { + continue + } + } + return false, cErr + } + return true, nil + } +} + +// isAbortedErr reports whether err is a Dgraph transaction-conflict abort, +// matching both dgo's ErrAborted sentinel and the underlying message in case a +// wrapped or stringified form reaches us. +func isAbortedErr(err error) bool { + if err == nil { + return false + } + if errors.Is(err, dgo.ErrAborted) { + return true + } + return strings.Contains(strings.ToLower(err.Error()), "aborted") +} + // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { diff --git a/consume_test.go b/consume_test.go index 97862e2..479d0de 100644 --- a/consume_test.go +++ b/consume_test.go @@ -50,3 +50,40 @@ func TestLoadOrStore(t *testing.T) { t.Fatal("second store: want loaded=true (already existed)") } } + +type consumeState struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + State string `json:"state,omitempty" dgraph:"index=hash upsert"` + Secret string `json:"secret,omitempty"` +} + +func TestLoadAndDelete(t *testing.T) { + conn := newConsumeClient(t) + ctx := context.Background() + + if err := conn.Insert(ctx, &consumeState{State: "s1", Secret: "shh"}); err != nil { + t.Fatalf("Insert: %v", err) + } + + var got consumeState + loaded, err := conn.LoadAndDelete(ctx, &got, "s1", "state") + if err != nil { + t.Fatalf("LoadAndDelete: %v", err) + } + if !loaded { + t.Fatal("first consume: want loaded=true") + } + if got.Secret != "shh" { + t.Fatalf("want prior secret %q, got %q", "shh", got.Secret) + } + + var again consumeState + loaded, err = conn.LoadAndDelete(ctx, &again, "s1", "state") + if err != nil { + t.Fatalf("second LoadAndDelete: %v", err) + } + if loaded { + t.Fatal("second consume: want loaded=false (already consumed)") + } +} From cb5053803446244c82a2c52953e4e1ce2048e63b Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:05:31 -0400 Subject: [PATCH 11/13] feat: add typed Client[T].LoadOrStore --- typed/client.go | 14 +++++++++++ typed/consume_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 typed/consume_test.go diff --git a/typed/client.go b/typed/client.go index c540f89..f6425b1 100644 --- a/typed/client.go +++ b/typed/client.go @@ -61,6 +61,20 @@ func (c *Client[T]) Upsert(ctx context.Context, rec *T, predicates ...string) (e return c.conn.Upsert(ctx, rec, predicates...) } +// LoadOrStore stores rec only if no node matches the upsert predicates, +// returning the resulting record and loaded=true when one already existed. +// Insert-if-absent (compare sync.Map.LoadOrStore). With no predicates, the +// first field tagged dgraph:"upsert" is used. +func (c *Client[T]) LoadOrStore(ctx context.Context, rec *T, predicates ...string) (out *T, loaded bool, err error) { + ctx, span := tracer.StartSpan(ctx, "loadOrStore", entityName[T]()) + defer func() { span.End(err) }() + loaded, err = c.conn.LoadOrStore(ctx, rec, predicates...) + if err != nil { + return nil, false, err + } + return rec, loaded, nil +} + // Delete removes the T with the given UID. func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { ctx, span := tracer.StartSpan(ctx, "delete", entityName[T]()) diff --git a/typed/consume_test.go b/typed/consume_test.go new file mode 100644 index 0000000..6504228 --- /dev/null +++ b/typed/consume_test.go @@ -0,0 +1,54 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +type jti struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + JTI string `json:"jti,omitempty" dgraph:"index=hash upsert unique"` +} + +func newTypedConn(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestTypedLoadOrStore(t *testing.T) { + c := typed.NewClient[jti](newTypedConn(t)) + ctx := context.Background() + + rec, loaded, err := c.LoadOrStore(ctx, &jti{JTI: "abc"}, "jti") + if err != nil { + t.Fatalf("first: %v", err) + } + if loaded { + t.Fatal("first: want loaded=false") + } + if rec.UID == "" { + t.Fatal("first: want a UID assigned") + } + + _, loaded, err = c.LoadOrStore(ctx, &jti{JTI: "abc"}, "jti") + if err != nil { + t.Fatalf("second: %v", err) + } + if !loaded { + t.Fatal("second: want loaded=true") + } +} From 56f09fd6196a77f48d4e8887b0e69741c5a4e95f Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:07:09 -0400 Subject: [PATCH 12/13] feat: add typed Client[T].LoadAndDelete --- typed/client.go | 15 +++++++++++++++ typed/consume_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/typed/client.go b/typed/client.go index f6425b1..e278712 100644 --- a/typed/client.go +++ b/typed/client.go @@ -75,6 +75,21 @@ func (c *Client[T]) LoadOrStore(ctx context.Context, rec *T, predicates ...strin return rec, loaded, nil } +// LoadAndDelete atomically reads the T whose key predicate equals key and +// deletes it, returning (nil, false, nil) when none matched. Read-and-consume +// (compare sync.Map.LoadAndDelete). With no predicates, the first field tagged +// dgraph:"upsert" is used. +func (c *Client[T]) LoadAndDelete(ctx context.Context, key any, predicates ...string) (rec *T, loaded bool, err error) { + ctx, span := tracer.StartSpan(ctx, "loadAndDelete", entityName[T]()) + defer func() { span.End(err) }() + var out T + loaded, err = c.conn.LoadAndDelete(ctx, &out, key, predicates...) + if err != nil || !loaded { + return nil, loaded, err + } + return &out, true, nil +} + // Delete removes the T with the given UID. func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { ctx, span := tracer.StartSpan(ctx, "delete", entityName[T]()) diff --git a/typed/consume_test.go b/typed/consume_test.go index 6504228..ae28116 100644 --- a/typed/consume_test.go +++ b/typed/consume_test.go @@ -52,3 +52,38 @@ func TestTypedLoadOrStore(t *testing.T) { t.Fatal("second: want loaded=true") } } + +type state struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + State string `json:"state,omitempty" dgraph:"index=hash upsert"` + Secret string `json:"secret,omitempty"` +} + +func TestTypedLoadAndDelete(t *testing.T) { + c := typed.NewClient[state](newTypedConn(t)) + ctx := context.Background() + + if err := c.Add(ctx, &state{State: "s1", Secret: "shh"}); err != nil { + t.Fatalf("Add: %v", err) + } + + rec, loaded, err := c.LoadAndDelete(ctx, "s1", "state") + if err != nil { + t.Fatalf("LoadAndDelete: %v", err) + } + if !loaded { + t.Fatal("first: want loaded=true") + } + if rec.Secret != "shh" { + t.Fatalf("want secret %q, got %q", "shh", rec.Secret) + } + + rec, loaded, err = c.LoadAndDelete(ctx, "s1", "state") + if err != nil { + t.Fatalf("second: %v", err) + } + if loaded || rec != nil { + t.Fatalf("second: want (nil, false), got (%v, %v)", rec, loaded) + } +} From 4ad3512d01f7202b860d7ffeb0c156feade8b791 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:14:34 -0400 Subject: [PATCH 13/13] test: assert LoadAndDelete elects a single winner under contention Add TestLoadAndDeleteSingleWinner and serialize LoadAndDelete's read-then-delete critical section with a per-client mutex so exactly one in-process caller consumes a node. The embedded engine's commit path does no optimistic-concurrency conflict check, so the shared read-write transaction alone cannot abort losers; the lock guarantees single-winner semantics regardless of backend. --- client.go | 27 ++++++++++++++++++++++++--- typed/consume_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 51b52e2..40b18e9 100644 --- a/client.go +++ b/client.go @@ -303,9 +303,10 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) { } client := client{ - uri: uri, - options: options, - logger: options.logger, + uri: uri, + options: options, + logger: options.logger, + consumeMu: &sync.Mutex{}, } clientMapLock.Lock() @@ -471,6 +472,15 @@ type client struct { options clientOptions pool *clientPool logger logr.Logger + // consumeMu serializes LoadAndDelete's read-then-delete critical section so + // exactly one in-process caller consumes a given node. The client value is + // copied (value receivers, cached by value in clientMap), so the mutex is a + // pointer shared across every copy that shares this client's connection. + // Against a real Dgraph cluster the shared read-write transaction would also + // abort losers on commit conflict; this lock additionally guarantees + // single-winner semantics against the embedded engine, whose commit path + // performs no optimistic-concurrency conflict check. + consumeMu *sync.Mutex } func (c client) key() string { @@ -703,6 +713,17 @@ func (c client) LoadAndDelete(ctx context.Context, obj any, key any, predicates } defer c.pool.put(dgClient) + // Serialize the read-then-delete critical section across in-process callers. + // The shared read-write transaction already elects one winner against a real + // Dgraph cluster (the loser aborts on commit), but the embedded engine does + // no commit-time conflict check, so without this lock concurrent callers + // would each read the node and each report loaded=true. The lock makes + // read-and-consume atomic regardless of backend. + if c.consumeMu != nil { + c.consumeMu.Lock() + defer c.consumeMu.Unlock() + } + // Bounded retry: Dgraph aborts the loser of a commit conflict; the retry // reads the node already gone and reports not-found. const maxAttempts = 10 diff --git a/typed/consume_test.go b/typed/consume_test.go index ae28116..16eca08 100644 --- a/typed/consume_test.go +++ b/typed/consume_test.go @@ -7,6 +7,7 @@ package typed_test import ( "context" + "sync" "testing" "github.com/matthewmcneely/modusgraph" @@ -87,3 +88,38 @@ func TestTypedLoadAndDelete(t *testing.T) { t.Fatalf("second: want (nil, false), got (%v, %v)", rec, loaded) } } + +func TestLoadAndDeleteSingleWinner(t *testing.T) { + c := typed.NewClient[state](newTypedConn(t)) + ctx := context.Background() + if err := c.Add(ctx, &state{State: "race", Secret: "one"}); err != nil { + t.Fatalf("Add: %v", err) + } + + const racers = 8 + var wg sync.WaitGroup + wins := make([]bool, racers) + wg.Add(racers) + for i := 0; i < racers; i++ { + go func(i int) { + defer wg.Done() + _, loaded, err := c.LoadAndDelete(ctx, "race", "state") + if err != nil { + t.Errorf("racer %d: %v", i, err) + return + } + wins[i] = loaded + }(i) + } + wg.Wait() + + won := 0 + for _, w := range wins { + if w { + won++ + } + } + if won != 1 { + t.Fatalf("want exactly one winner, got %d", won) + } +}