Skip to content

[Draft] Fix off-by-one in reduction FLOP counts#57

Draft
spMohanty wants to merge 10 commits into
mainfrom
fix-off-by-one-reduction-cost
Draft

[Draft] Fix off-by-one in reduction FLOP counts#57
spMohanty wants to merge 10 commits into
mainfrom
fix-off-by-one-reduction-cost

Conversation

@spMohanty
Copy link
Copy Markdown
Collaborator

Summary

A length-n reduction charges n − 1 FLOPs, not n — the first input is a
free copy, only the remaining n − 1 are accumulations.

The fix applies to every reduction path:

  • plain axis reductions
  • tuple-axis reductions
  • symmetry-aware reductions on SymmetricTensor

Formula

cost = max(unique_out × (u_R − 1), 1)

Where:

  • unique_out — unique outputs after propagate_symmetry_reduce + Burnside.
  • u_R — unique inputs per output slice, via Burnside over inner-clean
    groups only (g.axes ⊆ R).
    • Split groups (axes span reduced and kept): no input savings.
    • Output-only groups (g.axes ⊆ K): only shrink unique_out.
  • max(..., 1) — clamps scalar / size-1-axis / empty shape to 1 FLOP.

Tests

uv run pytest tests/3188 passed, 87 skipped, 0 failures.
(4 pre-existing scipy ImportErrors unrelated.)

Closes

Closes #56


Try it

Install from this branch:

uv venv .venv-whest
source .venv-whest/bin/activate
uv pip install "git+https://github.com/AIcrowd/whest.git@fix-off-by-one-reduction-cost"

Dense reductions — n − 1 per output slice:

import whest as we
import numpy as np

# 1-D sum: 50 − 1 = 49
with we.BudgetContext(flop_budget=10**6, quiet=True) as b:
    we.sum(we.ones(50))
    assert b.flops_used == 49

# Tuple axis: 5 outputs × (4*6 − 1) = 115
x = we.array(np.ones((4, 5, 6)))
with we.BudgetContext(flop_budget=10**6, quiet=True) as b:
    we.sum(x, axis=(0, 2))
    assert b.flops_used == 115

Symmetric reductions — full / preserving / split / inner-clean:

import whest as we
import numpy as np
from whest._symmetric import as_symmetric

S2      = as_symmetric(np.eye(10),          symmetric_axes=(0, 1))
S3      = as_symmetric(np.ones((5, 5, 10)), symmetric_axes=(0, 1))
S2_flat = as_symmetric(np.ones((5, 5)),     symmetric_axes=(0, 1))

cases = [
    # op,                              expected, note
    (lambda: we.sum(S2),                     54, "sym(10,10): 55 unique − 1"),
    (lambda: we.sum(S3, axis=2),            135, "preserving: 15 × (10 − 1)"),
    (lambda: we.sum(S2_flat, axis=0),        20, "split: 5 × (5 − 1), no inner savings"),
    (lambda: we.sum(S3, axis=(0, 1)),       140, "inner-clean: 10 × (15 − 1)"),
]

for op, expected, note in cases:
    with we.BudgetContext(flop_budget=10**6, quiet=True) as b:
        op()
        assert b.flops_used == expected, (b.flops_used, expected, note)
    print(f"OK  {expected:>4}  {note}")

Targeted test runs:

uv run pytest tests/test_flops.py -v -k reduction
uv run pytest tests/test_symmetric_pointwise.py::TestReductionSymmetry -v

Per code review: the parenthetical '(g.axes ⊂ K)' could be read as
the definition of inner-clean rather than the negation condition.
Rephrase to state the rule (g.axes ⊆ R) and why it fails here.
Reductions now charge unique_out × (u_R − 1) instead of
prod(input_shape). Symmetry-aware via propagate_symmetry_reduce
(output) and inner-clean Burnside (input per output slice).
Clamped to 1 for degenerate shapes.

Closes #56.
- Move propagate_symmetry_reduce to top-level import (no circular dep;
  _symmetric.py does not import _flops.py).
- Raise ValueError for scalar input with explicit axis instead of
  crashing with IndexError deeper in the call stack.
- Tighten type hints to list[PermutationGroup] | None via TYPE_CHECKING.
- Document the reduced_axes=None fallback in _compute_R_unique_count.
- Add regression test for the scalar+axis ValueError.
All downstream tests that embed the old reduction cost formula
(cost = numel) now expect the corrected cost (numel − 1 for
full reduction, unique_out × (u_R − 1) with clamping otherwise).
Minor polish from code review: explain where literals like 14, 12, 2
come from so future readers don't have to reverse-engineer them. Also
updates a stale "numel(input)" sub-bucket comment to match the new
formula.
The notes previously claimed cost_multiplier=2, but the code does
not pass cost_multiplier to the reduction cost path. The 2× factor
actually comes from weight=2.0 in weights.csv. Update notes to
reflect reality.
Three new tests in TestReductionSymmetry exercise the runtime path
(we.sum on SymmetricTensor with axis/tuple-axis) for the three sym
cases: split pair (no inner savings), preserving (sym survives on
output), and inner-clean (sym entirely within reduced axes).
The attached cost-description label on reduction wrappers (sum, mean,
std, etc.) now reflects the corrected off-by-one formula: "numel(input)
- 1" instead of "numel(input)". User-facing docstrings now match the
actual runtime behavior.
@spMohanty spMohanty changed the title Fix off-by-one in reduction FLOP counts [Draft] Fix off-by-one in reduction FLOP counts Apr 24, 2026
@spMohanty
Copy link
Copy Markdown
Collaborator Author

@wiwu2390 Reviewing more carefully — the "first element is a free copy" argument from #56 only structurally applies to single-fold ops. For multi-phase reductions (mean, std/var, ptp, …) the structural savings are either 0 or multiple, and uniform n−1 is the wrong shape of fix.

The argument's actual scope

"Free first element" is a property of a single fold: acc = a[0]; for i ∈ 1..n−1: acc = acc ⊕ a[i]. It saves 1 FLOP per fold phase, not per op call.

Per-op structural check

Op Phases Free-first savings Correct analytical
sum, prod, max, min 1 fold 1 n − 1
argmax, argmin, any, all 1 fold (compare) 1 n − 1
cumsum, cumprod (and cumulative_*) 1 fold (materialised) 1 n − 1
nan* variants of all the above 1 fold + NaN mask 1 n − 1
mean, average, nanmean 1 fold + 1 div 1 (but div survives) n
var, std, nanvar, nanstd 2 folds + 2 pointwise + div [+ sqrt] 2 ~4n (weight-calibrated)
ptp = max − min 2 folds + 1 sub 2 2n − 1
median, percentile, quantile (+ nan) selection/sort — (not a fold) placeholder

Uniform n − 1 × weight happens to land right for var/std (weight=2.0 × 2 folds = 2 FLOPs saved, matches theory), but that's a coincidence of the weight magnitude, not a structural result. For mean it undercounts by 1 FLOP per output slice; for ptp it under-saves by 1; for median/percentile/quantile the analytical was never fold-shaped to begin with.

Proposed narrowing

Keep n − 1 only for ops whose analytical cost is a single fold:

  • sum, prod, max, min
  • argmax, argmin, any, all
  • cumsum, cumprod, cumulative_sum, cumulative_prod
  • All nan* variants of the above
  • Their symmetric counterparts on SymmetricTensor

Revert to numel(input) for:

  • mean, average, nanmean
  • var, std, nanvar, nanstd
  • ptp
  • median, percentile, quantile, nanmedian, nanpercentile, nanquantile

@wilswu99
Copy link
Copy Markdown
Collaborator

I think the formula cost = max(unique_out × (u_R − 1), 1) is incorrect for two reasons:

  • In general, the number of unique input orbits that get mapped to a given output orbit depends on that output orbit. There may not exist a u_R that is the same across all output orbits. Instead, we should use the same partition-counting logic as in the sum step of einsum.
  • When the operation is not of ufunc.reduce form, it's not clear how to operate over only the orbits of the input, instead of all entries. In that case, I think we should fall back num_unique_out * num_reduced_elems where num_reduced_elems is the total number of input elements per output orbit, not the number of input orbits. This is equivalent to naive_flop_count * unique_out / total_out.

Example:

import whest as we
 
with we.BudgetContext(flop_budget=1e20) as bc:
    n = 10
    with we.namespace('init'):
        A = we.random.randn(n, n, n, n)
        g = we.Permutation(we.Cycle(0, 1)(2, 3), size=4)
        G = we.PermutationGroup(g, axes=(0, 1, 2, 3))
        A = we.symmetrize(A, group=G)
    with we.namespace('sum2'):
        S2 = we.sum(A, axis=(0, 1))

get_flops = lambda k: bc.summary_dict(by_namespace=True)['by_namespace'].get(k, {}).get('flops_used', 0)
assert get_flops('sum2') == 4995     # Currently 5445

Discussion here: https://alignmentresearch.slack.com/archives/C0AFESY004U/p1777064701467469?thread_ts=1776713328.350929&cid=C0AFESY004U

@wilswu99
Copy link
Copy Markdown
Collaborator

I haven't checked this PR's behavior on non-ufunc.reduce operations (e.g. median), since they run into #35, which isn't fixed on this branch.

@spMohanty
Copy link
Copy Markdown
Collaborator Author

Putting this on hold until the updated einsum-cost-calculation PR lands on main - same partition-counting logic, better to share than duplicate. Will need a careful re-evaluation of this PR on top of the shared helpers once the einsum-cost-calculation is in.

@spMohanty spMohanty marked this pull request as draft April 29, 2026 02:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Off-by-one in sum/mean reductions

2 participants