Skip to content

Generalize local_sum_prod_alloc#2198

Draft
ricardoV94 wants to merge 1 commit into
pymc-devs:mainfrom
ricardoV94:alloc_sum_to_mul
Draft

Generalize local_sum_prod_alloc#2198
ricardoV94 wants to merge 1 commit into
pymc-devs:mainfrom
ricardoV94:alloc_sum_to_mul

Conversation

@ricardoV94

@ricardoV94 ricardoV94 commented Jun 8, 2026

Copy link
Copy Markdown
Member

Sum(Alloc(x, n)) (e.g. broadcast_to(...).sum(), or sum(ones((5, 3)) * x)) wasn't simplified to x * n / sum(x) * 5 unless the alloc'd value was a scalar constant. The old local_opt_alloc extracted a constant via get_underlying_scalar_constant_value, which was an inherited Theano restriction rather than a real requirement.

Rewrite changes (pytensor/tensor/rewriting/math.py)

Renamed local_opt_alloc -> local_careduce_of_alloc and generalized it along three axes:

  1. Any value, not just scalar constants. The old code only fired when the alloc'd value was a scalar constant (it called get_underlying_scalar_constant_value). The new code has no value precondition at all: the only gate is structural (the reduction's input must be an Alloc). It fires for any value (scalar, symbolic, or holding real data on the kept axes, e.g. a batch dim from vectorization), and uses each reduced axis's broadcastability to decide whether to factor it out as a size multiplier or reduce it on the value.

  2. Push the reduction through the alloc. Reduced axes are split into the ones the alloc broadcasts (a brand-new dim, or a size-1 dim the alloc expands) vs. the ones where the value has real data. Broadcast axes are factored out; real axes are reduced on the smaller value via op.clone(axis=...). It never bails except on non-Alloc input, and kept broadcast dims are re-materialized by an output alloc.

  3. All seven CAReduce ops, not just Sum/Prod: Sum -> multiply by prod(broadcast sizes); Prod/ProdWithoutZeros -> raise to that power; Max/Min/All/Any (idempotent) -> broadcast axes simply dropped, no factor.

Closes #378

@ricardoV94 ricardoV94 force-pushed the alloc_sum_to_mul branch 2 times, most recently from 79b99ce to c5d1cdb Compare June 9, 2026 12:57
Applies to non-constant values and other CAReduce types
@ricardoV94 ricardoV94 marked this pull request as ready for review June 9, 2026 14:06
@ricardoV94 ricardoV94 requested a review from jessegrabowski June 9, 2026 14:07
@ricardoV94 ricardoV94 changed the title Generalize local_sum__prod_alloc Generalize local_sum_prod_alloc Jun 9, 2026
@ricardoV94

Copy link
Copy Markdown
Member Author

I'll tweak the new tests to be more in line with #2103

@ricardoV94 ricardoV94 marked this pull request as draft June 9, 2026 21:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Optimize reductions of broadcasting (alloc)

1 participant