Generalize local_sum_prod_alloc#2198
Draft
ricardoV94 wants to merge 1 commit into
Draft
Conversation
79b99ce to
c5d1cdb
Compare
Applies to non-constant values and other CAReduce types
c5d1cdb to
6ccda78
Compare
Member
Author
|
I'll tweak the new tests to be more in line with #2103 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Sum(Alloc(x, n))(e.g.broadcast_to(...).sum(), orsum(ones((5, 3)) * x)) wasn't simplified tox * n/sum(x) * 5unless the alloc'd value was a scalar constant. The oldlocal_opt_allocextracted a constant viaget_underlying_scalar_constant_value, which was an inherited Theano restriction rather than a real requirement.Rewrite changes (
pytensor/tensor/rewriting/math.py)Renamed
local_opt_alloc->local_careduce_of_allocand generalized it along three axes:Any value, not just scalar constants. The old code only fired when the alloc'd value was a scalar constant (it called
get_underlying_scalar_constant_value). The new code has no value precondition at all: the only gate is structural (the reduction's input must be anAlloc). It fires for any value (scalar, symbolic, or holding real data on the kept axes, e.g. a batch dim from vectorization), and uses each reduced axis's broadcastability to decide whether to factor it out as a size multiplier or reduce it on the value.Push the reduction through the alloc. Reduced axes are split into the ones the alloc broadcasts (a brand-new dim, or a size-1 dim the alloc expands) vs. the ones where the value has real data. Broadcast axes are factored out; real axes are reduced on the smaller value via
op.clone(axis=...). It never bails except on non-Allocinput, and kept broadcast dims are re-materialized by an output alloc.All seven CAReduce ops, not just
Sum/Prod:Sum-> multiply byprod(broadcast sizes);Prod/ProdWithoutZeros-> raise to that power;Max/Min/All/Any(idempotent) -> broadcast axes simply dropped, no factor.Closes #378