Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytensor/assumptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
DIAGONAL,
IMPLIES,
LOWER_TRIANGULAR,
MATRIX_KEYS,
ORTHOGONAL,
PERMUTATION,
POSITIVE_DEFINITE,
SELECTION,
SYMMETRIC,
UNIQUE_INDICES,
UPPER_TRIANGULAR,
AssumptionFeature,
AssumptionKey,
Expand Down
5 changes: 4 additions & 1 deletion pytensor/assumptions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def register_constant_inference(key: AssumptionKey, fn: ConstantInferFn) -> None
ORTHOGONAL = AssumptionKey("orthogonal", short_name="orth")
SELECTION = AssumptionKey("selection", short_name="sel")
PERMUTATION = AssumptionKey("permutation", short_name="perm")
UNIQUE_INDICES = AssumptionKey("unique_indices", short_name="uniq")

ALL_KEYS = (
MATRIX_KEYS = (
DIAGONAL,
LOWER_TRIANGULAR,
UPPER_TRIANGULAR,
Expand All @@ -115,6 +116,8 @@ def register_constant_inference(key: AssumptionKey, fn: ConstantInferFn) -> None
PERMUTATION,
)

ALL_KEYS = (*MATRIX_KEYS, UNIQUE_INDICES)

# Implications about structural properties derivably from other structural properties
register_implies(DIAGONAL, LOWER_TRIANGULAR, UPPER_TRIANGULAR, SYMMETRIC)
register_implies(POSITIVE_DEFINITE, SYMMETRIC)
Expand Down
8 changes: 8 additions & 0 deletions pytensor/assumptions/specify.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def assume(
orthogonal: bool | None = None,
selection: bool | None = None,
permutation: bool | None = None,
unique_indices: bool | None = None,
):
"""Attach structural assumptions to a symbolic tensor.

Expand Down Expand Up @@ -96,6 +97,12 @@ def assume(
Assert that *x* is (or is not) a selection matrix
permutation : bool, optional
Assert that *x* is (or is not) a permutation matrix
unique_indices : bool, optional

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we reject this flag on non-1d inputs for now? Otherwise we need to know which axis/axes have unique indices, and we don't have the machinery for parameterized assumptions yet. If you want to add it though I'd be happy, I want it for matrix rank among other things.

@ricardoV94 ricardoV94 Jun 15, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assumes unique over the whole data, I don't care about dims

@ricardoV94 ricardoV94 Jun 15, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not yet), the guard against repeated computations still needs some work, but I don't think people are gonna do (symbolic) advanced matrix indexing and want it to lift, they could always do flat and reshape

Assert that *x*'s entries address distinct positions when used as an
index (or that they do not): no value repeats, and no negative entry
aliases a non-negative one (e.g. ``-1`` and ``n-1``). Such an index can
never enlarge the axis it indexes, so it can be lifted earlier through
operations without risk of duplicating computation.

Returns
-------
Expand All @@ -121,6 +128,7 @@ def assume(
"orthogonal": orthogonal,
"selection": selection,
"permutation": permutation,
"unique_indices": unique_indices,
}
assumptions = {
name: FactState.TRUE if value else FactState.FALSE
Expand Down
15 changes: 12 additions & 3 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytensor
from pytensor import compile
from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption
from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
Expand Down Expand Up @@ -231,6 +232,14 @@ def _constant_has_unique_indices(idx) -> bool:
return result


def _has_unique_indices(fgraph, idx) -> bool:
"""Whether ``idx``'s entries are provably duplicate-free: a constant with
unique entries, or a variable asserted ``unique_indices`` by the user."""
return _constant_has_unique_indices(idx) or check_assumption(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a sense of which of these is cheaper? It's the assumption check if the Feature is already attached, but in live code i don't know.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters, people aren't going to put asumptions on constant indices (no need anyway)? Either way after the first check it's cached

fgraph, idx, UNIQUE_INDICES
)


def _constant_is_arange(idx) -> tuple[int, int, int] | None:
"""Match ``idx`` to ``np.arange(offset, offset + d * step, step)``
and return ``(d, offset, step)``, else ``None``.
Expand Down Expand Up @@ -1169,7 +1178,7 @@ def local_add_of_sparse_write(fgraph, node):
# duplicate-free. Basic (slice/scalar) indexing is always unique;
# advanced integer-array indices must be checked.
if not inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor):
if not all(_constant_has_unique_indices(idx) for idx in idx_vars):
if not all(_has_unique_indices(fgraph, idx) for idx in idx_vars):
continue

others = [node.inputs[j] for j in range(len(node.inputs)) if j != i]
Expand Down Expand Up @@ -2001,7 +2010,7 @@ def local_read_of_write_same_indices(fgraph, node):
indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list)
for idx in indices:
if isinstance(idx, TensorVariable) and idx.type.ndim > 0:
if not _constant_has_unique_indices(idx):
if not _has_unique_indices(fgraph, idx):
return None

x_at_idx = x[tuple(indices)]
Expand Down Expand Up @@ -2363,7 +2372,7 @@ def local_write_of_write_same_indices(fgraph, node):
# sufficient: it guarantees no duplicates in the joint cross-product
# after broadcasting.
if not isinstance(node.op, IncSubtensor):
if not all(_constant_has_unique_indices(v) for v in outer_idx_vars):
if not all(_has_unique_indices(fgraph, v) for v in outer_idx_vars):
return
new_val = a + b
if (
Expand Down
13 changes: 8 additions & 5 deletions pytensor/tensor/rewriting/subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple

from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption
from pytensor.compile import optdb
from pytensor.graph import (
Constant,
Expand Down Expand Up @@ -214,8 +215,8 @@ def _lift_subtensor_non_axis(
return None


def _index_provably_smaller(idx, val_static_dim) -> bool:
# Per-axis check: non-repeating indices can't expand a single axis.
def _index_provably_not_larger(idx, val_static_dim, fgraph=None) -> bool:
# Per-axis check: an index that can't repeat a position can't enlarge that axis.
# Does not account for cross-axis broadcast expansion from outer indexing.
if isinstance(idx, slice) or idx.ndim == 0:
return True
Expand All @@ -225,11 +226,13 @@ def _index_provably_smaller(idx, val_static_dim) -> bool:
return True
if _constant_has_unique_indices(idx):
return True
if check_assumption(fgraph, idx, UNIQUE_INDICES):
return True
if isinstance(idx.owner_op, ARange):
return True
if isinstance(idx.owner_op, Reshape | DimShuffle):
# Views that don't add dimensions
if _index_provably_smaller(idx.owner.inputs[0], val_static_dim):
if _index_provably_not_larger(idx.owner.inputs[0], val_static_dim, fgraph):
return True

# Fallback to static shape analysis
Expand Down Expand Up @@ -351,7 +354,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
continue
if inp.type.broadcastable[axis]:
continue
if not _index_provably_smaller(idx, inp.type.shape[axis]):
if not _index_provably_not_larger(idx, inp.type.shape[axis], fgraph):
return None

batch_ndim = (
Expand Down Expand Up @@ -742,7 +745,7 @@ def lift_subtensor_through_alloc(fgraph, node):
dangerous_index_reaches_val = any(
not val.type.broadcastable[axis]
# Per-axis check; doesn't account for net effect across all axes.
and not _index_provably_smaller(idx, val.type.shape[axis])
and not _index_provably_not_larger(idx, val.type.shape[axis], fgraph)
for axis, idx in enumerate(val_indexer)
)

Expand Down
4 changes: 2 additions & 2 deletions tests/assumptions/test_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import pytensor.tensor as pt
from pytensor.assumptions import (
ALL_KEYS,
DIAGONAL,
LOWER_TRIANGULAR,
MATRIX_KEYS,
ORTHOGONAL,
PERMUTATION,
POSITIVE_DEFINITE,
Expand All @@ -20,7 +20,7 @@
def test_eye_identity_has_all_properties():
e = pt.eye(5)
_, af = make_fgraph(e)
for key in ALL_KEYS:
for key in MATRIX_KEYS:
assert af.check(e, key), f"Eye should be {key}"


Expand Down
8 changes: 8 additions & 0 deletions tests/assumptions/test_specify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ORTHOGONAL,
POSITIVE_DEFINITE,
SYMMETRIC,
UNIQUE_INDICES,
UPPER_TRIANGULAR,
ConflictingAssumptionsError,
FactState,
Expand Down Expand Up @@ -34,6 +35,13 @@ def test_assume_records_false_assertions(key):
assert af.get(x_not, key) is FactState.FALSE


def test_assume_unique_indices_on_vector():
idx = pt.vector("idx", dtype="int64")
idx_uniq = assume(idx, unique_indices=True)
_, af = make_fgraph(idx_uniq)
assert af.check(idx_uniq, UNIQUE_INDICES)


def test_assume_multiple_true_kwargs():
x = pt.matrix("x", shape=(3, 3))
x_both = assume(x, diagonal=True, positive_definite=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/assumptions/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import pytensor.tensor as pt
from pytensor.assumptions import (
ALL_KEYS,
DIAGONAL,
LOWER_TRIANGULAR,
MATRIX_KEYS,
ORTHOGONAL,
POSITIVE_DEFINITE,
SYMMETRIC,
Expand All @@ -17,23 +17,23 @@
class TestSubtensorMatrixPropertyPropagation:
"""A ``Subtensor`` that leaves the trailing two axes alone preserves matrix properties."""

@pytest.mark.parametrize("key", ALL_KEYS)
@pytest.mark.parametrize("key", MATRIX_KEYS)
def test_scalar_index_strips_leading_axis(self, key):
x = pt.tensor3("x", shape=(5, 4, 4))
x_tagged = assume(x, **{key.name: True})
y = x_tagged[2]
_, af = make_fgraph(y)
assert af.check(y, key)

@pytest.mark.parametrize("key", ALL_KEYS)
@pytest.mark.parametrize("key", MATRIX_KEYS)
def test_explicit_full_slices(self, key):
x = pt.tensor3("x", shape=(5, 4, 4))
x_tagged = assume(x, **{key.name: True})
y = x_tagged[2, :, :]
_, af = make_fgraph(y)
assert af.check(y, key)

@pytest.mark.parametrize("key", ALL_KEYS)
@pytest.mark.parametrize("key", MATRIX_KEYS)
def test_partial_slice_on_batch_axis(self, key):
x = pt.tensor3("x", shape=(5, 4, 4))
x_tagged = assume(x, **{key.name: True})
Expand Down
23 changes: 23 additions & 0 deletions tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor import shared
from pytensor.assumptions import assume
from pytensor.compile.maker import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp
Expand Down Expand Up @@ -1395,6 +1396,28 @@ def test_inc_symbolic_idx_not_rewritten(self):
isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo
)

def test_inc_asserted_unique_idx_rewritten(self):
"""A symbolic index asserted unique_indices is duplicate-free, so inc is rewritten."""
x = matrix("x")
y = matrix("y")
idx = ivector("idx")
idx_unique = assume(idx, unique_indices=True)

o = x[idx_unique].inc(y)[idx_unique]

result = utt.RewriteTester(
[x, y, idx],
[o],
include="fast_run",
exclude=(
"fusion",
"fuse_indexed_into_elemwise",
"inplace",
"local_replace_AdvancedSubtensor",
),
)
result.assert_graph(x[idx] + y)

def test_shape_unsafe_excluded(self):
"""When shape_unsafe is excluded the rewrite must not fire, so
out-of-bounds and shape errors are still caught at runtime."""
Expand Down
20 changes: 19 additions & 1 deletion tests/tensor/rewriting/test_subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from pytensor import scalar as ps
from pytensor import tensor as pt
from pytensor.assumptions import assume
from pytensor.compile import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph import (
Expand Down Expand Up @@ -44,6 +45,7 @@
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.assumptions import DrainSpecifyAssumptions
from pytensor.tensor.rewriting.subtensor import (
local_adv_idx_to_diagonal,
)
Expand All @@ -62,7 +64,7 @@
AdvancedSubtensor,
Subtensor,
)
from tests.unittest_tools import assert_equal_computations
from tests.unittest_tools import RewriteTester, assert_equal_computations


mode_opt = config.mode
Expand Down Expand Up @@ -216,6 +218,22 @@ def test_elemwise_adv_index_not_provably_smaller_bails(self):
rewritten = rewrite_graph(out)
assert equal_computations([rewritten], [out])

def test_elemwise_adv_index_assumed_unique_lifts(self):
"""An unbounded adv index asserted unique_indices can never enlarge, so it lifts."""
x = pt.matrix("x")
y = pt.matrix("y")
idx = pt.lvector("idx")
idx_unique = assume(idx, unique_indices=True)
out = (x + y)[idx_unique]
# Drain resolves the asserted fact onto idx, then canonicalize lifts the index.
result = RewriteTester(
[x, y, idx],
[out],
include="canonicalize",
custom_rewrite=DrainSpecifyAssumptions(),
)
result.assert_graph(x[idx] + y[idx])

def test_blockwise(self):
class CoreTestOp(Op):
itypes = [dvector, dvector]
Expand Down
Loading