-
Notifications
You must be signed in to change notification settings - Fork 193
Add asumption for unique indices #2225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
|
|
||
| import pytensor | ||
| from pytensor import compile | ||
| from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption | ||
| from pytensor.compile import optdb | ||
| from pytensor.graph.basic import Constant, Variable | ||
| from pytensor.graph.rewriting.basic import ( | ||
|
|
@@ -231,6 +232,14 @@ def _constant_has_unique_indices(idx) -> bool: | |
| return result | ||
|
|
||
|
|
||
| def _has_unique_indices(fgraph, idx) -> bool: | ||
| """Whether ``idx``'s entries are provably duplicate-free: a constant with | ||
| unique entries, or a variable asserted ``unique_indices`` by the user.""" | ||
| return _constant_has_unique_indices(idx) or check_assumption( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a sense of which of these is cheaper? It's the assumption check if the Feature is already attached, but in live code i don't know.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it matters, people aren't going to put asumptions on constant indices (no need anyway)? Either way after the first check it's cached |
||
| fgraph, idx, UNIQUE_INDICES | ||
| ) | ||
|
|
||
|
|
||
| def _constant_is_arange(idx) -> tuple[int, int, int] | None: | ||
| """Match ``idx`` to ``np.arange(offset, offset + d * step, step)`` | ||
| and return ``(d, offset, step)``, else ``None``. | ||
|
|
@@ -1169,7 +1178,7 @@ def local_add_of_sparse_write(fgraph, node): | |
| # duplicate-free. Basic (slice/scalar) indexing is always unique; | ||
| # advanced integer-array indices must be checked. | ||
| if not inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor): | ||
| if not all(_constant_has_unique_indices(idx) for idx in idx_vars): | ||
| if not all(_has_unique_indices(fgraph, idx) for idx in idx_vars): | ||
| continue | ||
|
|
||
| others = [node.inputs[j] for j in range(len(node.inputs)) if j != i] | ||
|
|
@@ -2001,7 +2010,7 @@ def local_read_of_write_same_indices(fgraph, node): | |
| indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list) | ||
| for idx in indices: | ||
| if isinstance(idx, TensorVariable) and idx.type.ndim > 0: | ||
| if not _constant_has_unique_indices(idx): | ||
| if not _has_unique_indices(fgraph, idx): | ||
| return None | ||
|
|
||
| x_at_idx = x[tuple(indices)] | ||
|
|
@@ -2363,7 +2372,7 @@ def local_write_of_write_same_indices(fgraph, node): | |
| # sufficient: it guarantees no duplicates in the joint cross-product | ||
| # after broadcasting. | ||
| if not isinstance(node.op, IncSubtensor): | ||
| if not all(_constant_has_unique_indices(v) for v in outer_idx_vars): | ||
| if not all(_has_unique_indices(fgraph, v) for v in outer_idx_vars): | ||
| return | ||
| new_val = a + b | ||
| if ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we reject this flag on non-1d inputs for now? Otherwise we need to know which axis/axes have unique indices, and we don't have the machinery for parameterized assumptions yet. If you want to add it though I'd be happy, I want it for matrix rank among other things.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assumes unique over the whole data, I don't care about dims
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(not yet), the guard against repeated computations still needs some work, but I don't think people are gonna do (symbolic) advanced matrix indexing and want it to lift, they could always do flat and reshape