diff --git a/pytensor/assumptions/__init__.py b/pytensor/assumptions/__init__.py index 38ca208867..8b580139a8 100644 --- a/pytensor/assumptions/__init__.py +++ b/pytensor/assumptions/__init__.py @@ -19,11 +19,13 @@ DIAGONAL, IMPLIES, LOWER_TRIANGULAR, + MATRIX_KEYS, ORTHOGONAL, PERMUTATION, POSITIVE_DEFINITE, SELECTION, SYMMETRIC, + UNIQUE_INDICES, UPPER_TRIANGULAR, AssumptionFeature, AssumptionKey, diff --git a/pytensor/assumptions/core.py b/pytensor/assumptions/core.py index a4799106f9..9b1195fc5d 100644 --- a/pytensor/assumptions/core.py +++ b/pytensor/assumptions/core.py @@ -103,8 +103,9 @@ def register_constant_inference(key: AssumptionKey, fn: ConstantInferFn) -> None ORTHOGONAL = AssumptionKey("orthogonal", short_name="orth") SELECTION = AssumptionKey("selection", short_name="sel") PERMUTATION = AssumptionKey("permutation", short_name="perm") +UNIQUE_INDICES = AssumptionKey("unique_indices", short_name="uniq") -ALL_KEYS = ( +MATRIX_KEYS = ( DIAGONAL, LOWER_TRIANGULAR, UPPER_TRIANGULAR, @@ -115,6 +116,8 @@ def register_constant_inference(key: AssumptionKey, fn: ConstantInferFn) -> None PERMUTATION, ) +ALL_KEYS = (*MATRIX_KEYS, UNIQUE_INDICES) + # Implications about structural properties derivably from other structural properties register_implies(DIAGONAL, LOWER_TRIANGULAR, UPPER_TRIANGULAR, SYMMETRIC) register_implies(POSITIVE_DEFINITE, SYMMETRIC) diff --git a/pytensor/assumptions/specify.py b/pytensor/assumptions/specify.py index 201e7d6282..bb9f77d6ae 100644 --- a/pytensor/assumptions/specify.py +++ b/pytensor/assumptions/specify.py @@ -69,6 +69,7 @@ def assume( orthogonal: bool | None = None, selection: bool | None = None, permutation: bool | None = None, + unique_indices: bool | None = None, ): """Attach structural assumptions to a symbolic tensor. @@ -96,6 +97,12 @@ def assume( Assert that *x* is (or is not) a selection matrix permutation : bool, optional Assert that *x* is (or is not) a permutation matrix + unique_indices : bool, optional + Assert that *x*'s entries address distinct positions when used as an + index (or that they do not): no value repeats, and no negative entry + aliases a non-negative one (e.g. ``-1`` and ``n-1``). Such an index can + never enlarge the axis it indexes, so it can be lifted earlier through + operations without risk of duplicating computation. Returns ------- @@ -121,6 +128,7 @@ def assume( "orthogonal": orthogonal, "selection": selection, "permutation": permutation, + "unique_indices": unique_indices, } assumptions = { name: FactState.TRUE if value else FactState.FALSE diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index e045648816..04856ce74f 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -6,6 +6,7 @@ import pytensor from pytensor import compile +from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption from pytensor.compile import optdb from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( @@ -231,6 +232,14 @@ def _constant_has_unique_indices(idx) -> bool: return result +def _has_unique_indices(fgraph, idx) -> bool: + """Whether ``idx``'s entries are provably duplicate-free: a constant with + unique entries, or a variable asserted ``unique_indices`` by the user.""" + return _constant_has_unique_indices(idx) or check_assumption( + fgraph, idx, UNIQUE_INDICES + ) + + def _constant_is_arange(idx) -> tuple[int, int, int] | None: """Match ``idx`` to ``np.arange(offset, offset + d * step, step)`` and return ``(d, offset, step)``, else ``None``. @@ -1169,7 +1178,7 @@ def local_add_of_sparse_write(fgraph, node): # duplicate-free. Basic (slice/scalar) indexing is always unique; # advanced integer-array indices must be checked. if not inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor): - if not all(_constant_has_unique_indices(idx) for idx in idx_vars): + if not all(_has_unique_indices(fgraph, idx) for idx in idx_vars): continue others = [node.inputs[j] for j in range(len(node.inputs)) if j != i] @@ -2001,7 +2010,7 @@ def local_read_of_write_same_indices(fgraph, node): indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list) for idx in indices: if isinstance(idx, TensorVariable) and idx.type.ndim > 0: - if not _constant_has_unique_indices(idx): + if not _has_unique_indices(fgraph, idx): return None x_at_idx = x[tuple(indices)] @@ -2363,7 +2372,7 @@ def local_write_of_write_same_indices(fgraph, node): # sufficient: it guarantees no duplicates in the joint cross-product # after broadcasting. if not isinstance(node.op, IncSubtensor): - if not all(_constant_has_unique_indices(v) for v in outer_idx_vars): + if not all(_has_unique_indices(fgraph, v) for v in outer_idx_vars): return new_val = a + b if ( diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 59dc8eb7dc..c0e92fe89d 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -4,6 +4,7 @@ import numpy as np from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption from pytensor.compile import optdb from pytensor.graph import ( Constant, @@ -214,8 +215,8 @@ def _lift_subtensor_non_axis( return None -def _index_provably_smaller(idx, val_static_dim) -> bool: - # Per-axis check: non-repeating indices can't expand a single axis. +def _index_provably_not_larger(idx, val_static_dim, fgraph=None) -> bool: + # Per-axis check: an index that can't repeat a position can't enlarge that axis. # Does not account for cross-axis broadcast expansion from outer indexing. if isinstance(idx, slice) or idx.ndim == 0: return True @@ -225,11 +226,13 @@ def _index_provably_smaller(idx, val_static_dim) -> bool: return True if _constant_has_unique_indices(idx): return True + if check_assumption(fgraph, idx, UNIQUE_INDICES): + return True if isinstance(idx.owner_op, ARange): return True if isinstance(idx.owner_op, Reshape | DimShuffle): # Views that don't add dimensions - if _index_provably_smaller(idx.owner.inputs[0], val_static_dim): + if _index_provably_not_larger(idx.owner.inputs[0], val_static_dim, fgraph): return True # Fallback to static shape analysis @@ -351,7 +354,7 @@ def local_subtensor_of_batch_dims(fgraph, node): continue if inp.type.broadcastable[axis]: continue - if not _index_provably_smaller(idx, inp.type.shape[axis]): + if not _index_provably_not_larger(idx, inp.type.shape[axis], fgraph): return None batch_ndim = ( @@ -742,7 +745,7 @@ def lift_subtensor_through_alloc(fgraph, node): dangerous_index_reaches_val = any( not val.type.broadcastable[axis] # Per-axis check; doesn't account for net effect across all axes. - and not _index_provably_smaller(idx, val.type.shape[axis]) + and not _index_provably_not_larger(idx, val.type.shape[axis], fgraph) for axis, idx in enumerate(val_indexer) ) diff --git a/tests/assumptions/test_alloc.py b/tests/assumptions/test_alloc.py index 57f37eefb6..c248770800 100644 --- a/tests/assumptions/test_alloc.py +++ b/tests/assumptions/test_alloc.py @@ -3,9 +3,9 @@ import pytensor.tensor as pt from pytensor.assumptions import ( - ALL_KEYS, DIAGONAL, LOWER_TRIANGULAR, + MATRIX_KEYS, ORTHOGONAL, PERMUTATION, POSITIVE_DEFINITE, @@ -20,7 +20,7 @@ def test_eye_identity_has_all_properties(): e = pt.eye(5) _, af = make_fgraph(e) - for key in ALL_KEYS: + for key in MATRIX_KEYS: assert af.check(e, key), f"Eye should be {key}" diff --git a/tests/assumptions/test_specify.py b/tests/assumptions/test_specify.py index 6d3ee3992f..0847f48bd5 100644 --- a/tests/assumptions/test_specify.py +++ b/tests/assumptions/test_specify.py @@ -7,6 +7,7 @@ ORTHOGONAL, POSITIVE_DEFINITE, SYMMETRIC, + UNIQUE_INDICES, UPPER_TRIANGULAR, ConflictingAssumptionsError, FactState, @@ -34,6 +35,13 @@ def test_assume_records_false_assertions(key): assert af.get(x_not, key) is FactState.FALSE +def test_assume_unique_indices_on_vector(): + idx = pt.vector("idx", dtype="int64") + idx_uniq = assume(idx, unique_indices=True) + _, af = make_fgraph(idx_uniq) + assert af.check(idx_uniq, UNIQUE_INDICES) + + def test_assume_multiple_true_kwargs(): x = pt.matrix("x", shape=(3, 3)) x_both = assume(x, diagonal=True, positive_definite=True) diff --git a/tests/assumptions/test_subtensor.py b/tests/assumptions/test_subtensor.py index 367bf4ce3e..e27436049c 100644 --- a/tests/assumptions/test_subtensor.py +++ b/tests/assumptions/test_subtensor.py @@ -2,9 +2,9 @@ import pytensor.tensor as pt from pytensor.assumptions import ( - ALL_KEYS, DIAGONAL, LOWER_TRIANGULAR, + MATRIX_KEYS, ORTHOGONAL, POSITIVE_DEFINITE, SYMMETRIC, @@ -17,7 +17,7 @@ class TestSubtensorMatrixPropertyPropagation: """A ``Subtensor`` that leaves the trailing two axes alone preserves matrix properties.""" - @pytest.mark.parametrize("key", ALL_KEYS) + @pytest.mark.parametrize("key", MATRIX_KEYS) def test_scalar_index_strips_leading_axis(self, key): x = pt.tensor3("x", shape=(5, 4, 4)) x_tagged = assume(x, **{key.name: True}) @@ -25,7 +25,7 @@ def test_scalar_index_strips_leading_axis(self, key): _, af = make_fgraph(y) assert af.check(y, key) - @pytest.mark.parametrize("key", ALL_KEYS) + @pytest.mark.parametrize("key", MATRIX_KEYS) def test_explicit_full_slices(self, key): x = pt.tensor3("x", shape=(5, 4, 4)) x_tagged = assume(x, **{key.name: True}) @@ -33,7 +33,7 @@ def test_explicit_full_slices(self, key): _, af = make_fgraph(y) assert af.check(y, key) - @pytest.mark.parametrize("key", ALL_KEYS) + @pytest.mark.parametrize("key", MATRIX_KEYS) def test_partial_slice_on_batch_axis(self, key): x = pt.tensor3("x", shape=(5, 4, 4)) x_tagged = assume(x, **{key.name: True}) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 4f7503d1d7..1c1cbb7bb0 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -5,6 +5,7 @@ import pytensor.scalar as ps import pytensor.tensor as pt from pytensor import shared +from pytensor.assumptions import assume from pytensor.compile.maker import function from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp @@ -1395,6 +1396,28 @@ def test_inc_symbolic_idx_not_rewritten(self): isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo ) + def test_inc_asserted_unique_idx_rewritten(self): + """A symbolic index asserted unique_indices is duplicate-free, so inc is rewritten.""" + x = matrix("x") + y = matrix("y") + idx = ivector("idx") + idx_unique = assume(idx, unique_indices=True) + + o = x[idx_unique].inc(y)[idx_unique] + + result = utt.RewriteTester( + [x, y, idx], + [o], + include="fast_run", + exclude=( + "fusion", + "fuse_indexed_into_elemwise", + "inplace", + "local_replace_AdvancedSubtensor", + ), + ) + result.assert_graph(x[idx] + y) + def test_shape_unsafe_excluded(self): """When shape_unsafe is excluded the rewrite must not fire, so out-of-bounds and shape errors are still caught at runtime.""" diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index b245f53237..5db04799b6 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -11,6 +11,7 @@ ) from pytensor import scalar as ps from pytensor import tensor as pt +from pytensor.assumptions import assume from pytensor.compile import get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp from pytensor.graph import ( @@ -44,6 +45,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot from pytensor.tensor.math import sum as pt_sum +from pytensor.tensor.rewriting.assumptions import DrainSpecifyAssumptions from pytensor.tensor.rewriting.subtensor import ( local_adv_idx_to_diagonal, ) @@ -62,7 +64,7 @@ AdvancedSubtensor, Subtensor, ) -from tests.unittest_tools import assert_equal_computations +from tests.unittest_tools import RewriteTester, assert_equal_computations mode_opt = config.mode @@ -216,6 +218,22 @@ def test_elemwise_adv_index_not_provably_smaller_bails(self): rewritten = rewrite_graph(out) assert equal_computations([rewritten], [out]) + def test_elemwise_adv_index_assumed_unique_lifts(self): + """An unbounded adv index asserted unique_indices can never enlarge, so it lifts.""" + x = pt.matrix("x") + y = pt.matrix("y") + idx = pt.lvector("idx") + idx_unique = assume(idx, unique_indices=True) + out = (x + y)[idx_unique] + # Drain resolves the asserted fact onto idx, then canonicalize lifts the index. + result = RewriteTester( + [x, y, idx], + [out], + include="canonicalize", + custom_rewrite=DrainSpecifyAssumptions(), + ) + result.assert_graph(x[idx] + y[idx]) + def test_blockwise(self): class CoreTestOp(Op): itypes = [dvector, dvector]