From 4c00826e594e3caf4b5d06bf51bfb37f0b56230e Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 15:59:10 +0200 Subject: [PATCH 01/10] test(flops): add failing tests for off-by-one reduction cost fix --- tests/test_flops.py | 116 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 4 deletions(-) diff --git a/tests/test_flops.py b/tests/test_flops.py index 811f24bbdb..2bf8aef360 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -90,7 +90,8 @@ def test_analytical_pointwise_cost_scalar(): def test_analytical_reduction_cost(): - assert analytical_reduction_cost(input_shape=(256, 256), axis=None) == 256 * 256 + # Dense full reduction: numel − 1 accumulations (first value is a free copy). + assert analytical_reduction_cost(input_shape=(256, 256), axis=None) == 256 * 256 - 1 def test_analytical_svd_cost(): @@ -116,15 +117,120 @@ def test_analytical_pointwise_cost_no_symmetry_unchanged(): def test_analytical_reduction_cost_symmetric(): + # sym(5,5) has 15 unique elements; full reduction costs 15 − 1 = 14. info = SymmetryInfo(symmetric_axes=[(0, 1)], shape=(5, 5)) assert ( analytical_reduction_cost(input_shape=(5, 5), axis=None, symmetry_info=info) - == 15 + == 14 ) def test_analytical_reduction_cost_no_symmetry_unchanged(): - assert analytical_reduction_cost(input_shape=(5, 5), axis=None) == 25 + # (5,5) dense full reduction: 25 − 1 = 24. + assert analytical_reduction_cost(input_shape=(5, 5), axis=None) == 24 + + +def test_analytical_reduction_cost_axis_nonsymmetric(): + # (10, 20) axis=0: 20 outputs, each reducing 10 values → 20 * (10−1) = 180. + assert analytical_reduction_cost(input_shape=(10, 20), axis=0) == 180 + + +def test_analytical_reduction_cost_axis_nonsymmetric_last_axis(): + # (10, 20) axis=1: 10 outputs, each reducing 20 values → 10 * (20−1) = 190. + assert analytical_reduction_cost(input_shape=(10, 20), axis=1) == 190 + + +def test_analytical_reduction_cost_axis_negative(): + # axis=-1 is equivalent to axis=ndim-1. + assert analytical_reduction_cost(input_shape=(10, 20), axis=-1) == 10 * (20 - 1) + + +def test_analytical_reduction_cost_tuple_axis(): + # (4, 5, 6) reducing axes (0, 2) → kept axis 1 has 5 outputs, + # each reduces 4*6 = 24 values → 5 * (24 − 1) = 115. + assert analytical_reduction_cost(input_shape=(4, 5, 6), axis=(0, 2)) == 115 + + +def test_analytical_reduction_cost_tuple_axis_full(): + # Tuple covering all axes is equivalent to axis=None. + assert analytical_reduction_cost(input_shape=(4, 5), axis=(0, 1)) == 4 * 5 - 1 + + +def test_analytical_reduction_cost_axis_sym_preserving(): + # sym(5,5,10) with symmetric axes (0,1) reducing axis=2. + # K = {0,1} preserves the S_2 symmetry → 15 unique outputs. + # R = {2} is disjoint from sym group's axes (0,1), so that group is NOT + # inner-clean (g.axes ⊂ K). u_R = 10 (no inner savings). + # Cost = 15 * (10 − 1) = 135. + info = SymmetryInfo(symmetric_axes=[(0, 1)], shape=(5, 5, 10)) + assert ( + analytical_reduction_cost( + input_shape=(5, 5, 10), axis=2, symmetry_info=info + ) + == 135 + ) + + +def test_analytical_reduction_cost_axis_sym_split_pair(): + # sym(5,5) with symmetric axes (0,1) reducing axis=0. + # The sym group spans both R={0} and K={1} → split. No inner savings. + # After propagate_symmetry_reduce, S_2 does not survive onto a single axis + # → 5 unique outputs, each reducing 5 values → 5 * (5 − 1) = 20. + info = SymmetryInfo(symmetric_axes=[(0, 1)], shape=(5, 5)) + assert ( + analytical_reduction_cost( + input_shape=(5, 5), axis=0, symmetry_info=info + ) + == 20 + ) + + +def test_analytical_reduction_cost_axis_sym_split_s3(): + # sym(5,5,5) with symmetric axes (0,1,2) reducing axis=0. + # Setwise stabilizer of {0} restricted to {1,2} is S_2 → 15 unique outputs. + # Split group → no inner savings. u_R = 5. + # Cost = 15 * (5 − 1) = 60. + info = SymmetryInfo(symmetric_axes=[(0, 1, 2)], shape=(5, 5, 5)) + assert ( + analytical_reduction_cost( + input_shape=(5, 5, 5), axis=0, symmetry_info=info + ) + == 60 + ) + + +def test_analytical_reduction_cost_axis_sym_inner_clean(): + # sym(5,5,10) with symmetric axes (0,1) reducing axes (0,1). + # Group acts entirely within R = {0,1} → inner-clean. + # u_R = 15 (Burnside on the S_2 group over 5×5). 10 outputs (axis 2 kept). + # Cost = 10 * (15 − 1) = 140. + info = SymmetryInfo(symmetric_axes=[(0, 1)], shape=(5, 5, 10)) + assert ( + analytical_reduction_cost( + input_shape=(5, 5, 10), axis=(0, 1), symmetry_info=info + ) + == 140 + ) + + +def test_analytical_reduction_cost_scalar(): + # Scalar input: max(1 − 1, 1) = 1 (clamped floor). + assert analytical_reduction_cost(input_shape=(), axis=None) == 1 + + +def test_analytical_reduction_cost_size_one_axis(): + # axis of size 1 has 0 accumulations → clamped to 1. + assert analytical_reduction_cost(input_shape=(1, 10), axis=0) == 1 + + +def test_analytical_reduction_cost_empty_shape(): + # Shape containing 0: degenerate but should not crash; clamped to 1. + assert analytical_reduction_cost(input_shape=(0,), axis=None) == 1 + + +def test_analytical_reduction_cost_single_element(): + # (1,) full reduction: 1 − 1 = 0 → clamped to 1. + assert analytical_reduction_cost(input_shape=(1,), axis=None) == 1 def test_analytical_einsum_cost_symmetric_input(): @@ -156,8 +262,10 @@ def test_public_pointwise_cost_is_weighted(tmp_path): def test_public_reduction_cost_is_weighted(tmp_path): + # Analytical cost for (4, 5) full reduction: 4*5 − 1 = 19. + # Weighted: int(19 * 3.25) = 61. load_weights(_write_weights(tmp_path, {"sum": 3.25}), use_packaged_default=False) - assert public_flops.reduction_cost("sum", input_shape=(4, 5), axis=None) == 65 + assert public_flops.reduction_cost("sum", input_shape=(4, 5), axis=None) == 61 def test_public_einsum_cost_is_weighted(tmp_path): From c6adde26b108ad4260342545920cc8f63d173193 Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:05:34 +0200 Subject: [PATCH 02/10] test: clarify inner-clean condition in sym_preserving comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per code review: the parenthetical '(g.axes ⊂ K)' could be read as the definition of inner-clean rather than the negation condition. Rephrase to state the rule (g.axes ⊆ R) and why it fails here. --- tests/test_flops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_flops.py b/tests/test_flops.py index 2bf8aef360..01a405ae36 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -159,8 +159,8 @@ def test_analytical_reduction_cost_tuple_axis_full(): def test_analytical_reduction_cost_axis_sym_preserving(): # sym(5,5,10) with symmetric axes (0,1) reducing axis=2. # K = {0,1} preserves the S_2 symmetry → 15 unique outputs. - # R = {2} is disjoint from sym group's axes (0,1), so that group is NOT - # inner-clean (g.axes ⊂ K). u_R = 10 (no inner savings). + # Inner-clean requires g.axes ⊆ R. Here g.axes = {0,1} ⊄ R = {2}, + # so the group is output-only, not inner-clean. u_R = 10 (no inner savings). # Cost = 15 * (10 − 1) = 135. info = SymmetryInfo(symmetric_axes=[(0, 1)], shape=(5, 5, 10)) assert ( From 4d6462fa4039d84aab667cb8b8d4cd88a9ef1b4d Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:07:57 +0200 Subject: [PATCH 03/10] fix(flops): correct off-by-one in analytical_reduction_cost MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reductions now charge unique_out × (u_R − 1) instead of prod(input_shape). Symmetry-aware via propagate_symmetry_reduce (output) and inner-clean Burnside (input per output slice). Clamped to 1 for degenerate shapes. Closes #56. --- src/whest/_flops.py | 147 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 130 insertions(+), 17 deletions(-) diff --git a/src/whest/_flops.py b/src/whest/_flops.py index 1ad7f1f252..b21179e51d 100644 --- a/src/whest/_flops.py +++ b/src/whest/_flops.py @@ -121,39 +121,152 @@ def analytical_pointwise_cost( return max(result, 1) +def _normalize_axis( + axis: int | tuple[int, ...] | None, ndim: int +) -> tuple[int, ...] | None: + """Return normalized reduction axes as a sorted tuple, or None for full reduction.""" + if axis is None: + return None + if isinstance(axis, int): + axis = (axis,) + normalized = tuple(sorted((a % ndim) if ndim > 0 else a for a in axis)) + return normalized + + +def _compute_output_unique_count( + groups: list | None, + input_shape: tuple[int, ...], + reduced_axes: tuple[int, ...] | None, +) -> int: + """Number of unique outputs after reducing ``reduced_axes`` of ``input_shape``. + + Uses ``propagate_symmetry_reduce`` to obtain the surviving output symmetry + groups, then multiplies Burnside counts (over symmetric axes) by free-axis + sizes (kept non-symmetric axes). + """ + # Runtime import to avoid circular dependency at module load. + from whest._symmetric import propagate_symmetry_reduce + + ndim = len(input_shape) + if reduced_axes is None: + # Full reduction → scalar output. + return 1 + kept_axes = [d for d in range(ndim) if d not in set(reduced_axes)] + if not kept_axes: + return 1 + + # Propagate to get the surviving output symmetry groups. + # (Operates on input-axis numbering internally; returns output-axis numbering.) + out_groups = None + if groups: + out_groups = propagate_symmetry_reduce( + groups, ndim, reduced_axes, keepdims=False + ) + + # Output shape in output-axis numbering. + output_shape = tuple(input_shape[d] for d in kept_axes) + + # Sum Burnside unique counts over the output groups; multiply free-dim sizes. + accounted: set[int] = set() + total = 1 + if out_groups: + for group in out_groups: + axes = group.axes + if axes is None: + continue + size_dict = {i: output_shape[axes[i]] for i in range(group.degree)} + total *= group.burnside_unique_count(size_dict) + accounted.update(axes) + for i, size in enumerate(output_shape): + if i not in accounted: + total *= size + return total + + +def _compute_R_unique_count( + groups: list | None, + input_shape: tuple[int, ...], + reduced_axes: tuple[int, ...] | None, +) -> int: + """Number of unique inputs feeding one output slice. + + Only **inner-clean** sym groups (``g.axes ⊆ R``) contribute input-side + savings — they act entirely within the reduced axes and combine equivalent + values. Split groups (spanning both R and K) and output-only groups + (``g.axes ⊆ K``) contribute no savings here. + """ + ndim = len(input_shape) + if reduced_axes is None: + reduced_axes = tuple(range(ndim)) + reduced_set = set(reduced_axes) + + accounted: set[int] = set() + total = 1 + if groups: + for group in groups: + axes = group.axes + if axes is None: + continue + axes_set = set(axes) + if not axes_set.issubset(reduced_set): + # Not inner-clean: either output-only (g.axes ⊆ K) — doesn't + # touch reduced axes — or split. Neither contributes to u_R. + continue + size_dict = {i: input_shape[axes[i]] for i in range(group.degree)} + total *= group.burnside_unique_count(size_dict) + accounted.update(axes) + for d in reduced_axes: + if d not in accounted: + total *= input_shape[d] + return total + + def analytical_reduction_cost( input_shape: tuple[int, ...], - axis: int | None = None, + axis: int | tuple[int, ...] | None = None, symmetry_info: SymmetryInfo | None = None, ) -> int: """FLOP cost of a reduction operation. + The cost of a reduction is the number of accumulations performed: + for each output, the first input value is a free copy, and the remaining + ``u_R − 1`` values are accumulated in. Total cost: + + .. math:: + + \\text{cost} = \\text{unique\\_out} \\times (u_R - 1) + + where ``unique_out`` is the number of unique outputs (accounting for + output-side symmetry) and ``u_R`` is the number of unique inputs feeding + one output slice (accounting for inner-clean input symmetry). + Parameters ---------- input_shape : tuple of int Shape of the input array. - axis : int or None, optional - Axis along which to reduce. If None, reduce over all elements. + axis : int, tuple of int, or None, optional + Axis or axes along which to reduce. If None, reduce over all elements. symmetry_info : SymmetryInfo or None, optional - If provided, only unique elements are counted. + If provided, symmetry is used to count unique outputs and unique + per-output inputs. Only inner-clean groups (g.axes ⊆ reduced axes) + contribute per-output input savings; split groups do not. Returns ------- int - Estimated FLOP count (one per element). - - Notes - ----- - The ``axis`` parameter is accepted for API consistency but does not - affect the result: a reduction always touches every element regardless - of which axis is reduced, so the cost is always ``prod(input_shape)``. + Estimated FLOP count. Clamped to a minimum of 1 for degenerate shapes + (scalar, size-1 axis, empty shape) so every reduction registers at + least 1 flop for budget tracking purposes. """ - if symmetry_info is not None: - return max(symmetry_info.unique_elements, 1) - result = 1 - for dim in input_shape: - result *= dim - return max(result, 1) + ndim = len(input_shape) + reduced_axes = _normalize_axis(axis, ndim) + groups = symmetry_info.groups if symmetry_info is not None else None + + unique_out = _compute_output_unique_count(groups, input_shape, reduced_axes) + u_R = _compute_R_unique_count(groups, input_shape, reduced_axes) + + cost = unique_out * (u_R - 1) + return max(cost, 1) # Backward-compatible internal aliases. The public weighted API lives in From bf799d3a7e021cb83777554ea99e2893dabd3deb Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:18:45 +0200 Subject: [PATCH 04/10] refactor(flops): address code review feedback on reduction cost - Move propagate_symmetry_reduce to top-level import (no circular dep; _symmetric.py does not import _flops.py). - Raise ValueError for scalar input with explicit axis instead of crashing with IndexError deeper in the call stack. - Tighten type hints to list[PermutationGroup] | None via TYPE_CHECKING. - Document the reduced_axes=None fallback in _compute_R_unique_count. - Add regression test for the scalar+axis ValueError. --- src/whest/_flops.py | 28 +++++++++++++++++++++------- tests/test_flops.py | 6 ++++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/whest/_flops.py b/src/whest/_flops.py index b21179e51d..3a5b41b3da 100644 --- a/src/whest/_flops.py +++ b/src/whest/_flops.py @@ -6,7 +6,10 @@ from collections import Counter from typing import TYPE_CHECKING +from whest._symmetric import propagate_symmetry_reduce + if TYPE_CHECKING: + from whest._perm_group import PermutationGroup from whest._symmetric import SymmetryInfo @@ -124,17 +127,28 @@ def analytical_pointwise_cost( def _normalize_axis( axis: int | tuple[int, ...] | None, ndim: int ) -> tuple[int, ...] | None: - """Return normalized reduction axes as a sorted tuple, or None for full reduction.""" + """Return normalized reduction axes as a sorted tuple, or None for full reduction. + + Raises + ------ + ValueError + If ``axis`` is not ``None`` but ``ndim == 0`` (scalar input has no axes). + """ if axis is None: return None + if ndim == 0: + raise ValueError( + f"axis={axis!r} is out of bounds for scalar input (ndim=0); " + "use axis=None for full reduction of a scalar" + ) if isinstance(axis, int): axis = (axis,) - normalized = tuple(sorted((a % ndim) if ndim > 0 else a for a in axis)) + normalized = tuple(sorted(a % ndim for a in axis)) return normalized def _compute_output_unique_count( - groups: list | None, + groups: list[PermutationGroup] | None, input_shape: tuple[int, ...], reduced_axes: tuple[int, ...] | None, ) -> int: @@ -144,9 +158,6 @@ def _compute_output_unique_count( groups, then multiplies Burnside counts (over symmetric axes) by free-axis sizes (kept non-symmetric axes). """ - # Runtime import to avoid circular dependency at module load. - from whest._symmetric import propagate_symmetry_reduce - ndim = len(input_shape) if reduced_axes is None: # Full reduction → scalar output. @@ -184,7 +195,7 @@ def _compute_output_unique_count( def _compute_R_unique_count( - groups: list | None, + groups: list[PermutationGroup] | None, input_shape: tuple[int, ...], reduced_axes: tuple[int, ...] | None, ) -> int: @@ -194,6 +205,9 @@ def _compute_R_unique_count( savings — they act entirely within the reduced axes and combine equivalent values. Split groups (spanning both R and K) and output-only groups (``g.axes ⊆ K``) contribute no savings here. + + When ``reduced_axes is None`` (full reduction), every axis is treated as + reduced, so ``u_R`` equals the total number of unique input elements. """ ndim = len(input_shape) if reduced_axes is None: diff --git a/tests/test_flops.py b/tests/test_flops.py index 01a405ae36..f355af384e 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -218,6 +218,12 @@ def test_analytical_reduction_cost_scalar(): assert analytical_reduction_cost(input_shape=(), axis=None) == 1 +def test_analytical_reduction_cost_scalar_with_axis_raises(): + # Scalar input with an explicit axis should raise ValueError (no axes exist). + with pytest.raises(ValueError, match="scalar"): + analytical_reduction_cost(input_shape=(), axis=0) + + def test_analytical_reduction_cost_size_one_axis(): # axis of size 1 has 0 accumulations → clamped to 1. assert analytical_reduction_cost(input_shape=(1, 10), axis=0) == 1 From d7f21f95fc7fddbd9dd6409275b8eaa0b7cc9d08 Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:20:07 +0200 Subject: [PATCH 05/10] feat(flops): widen reduction_cost axis hint to accept tuple[int, ...] --- src/whest/flops.py | 11 ++++++----- tests/test_flops.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/whest/flops.py b/src/whest/flops.py index 64046a38d9..0896935963 100644 --- a/src/whest/flops.py +++ b/src/whest/flops.py @@ -217,7 +217,7 @@ def reduction_cost( op_name: str, *, input_shape: tuple[int, ...], - axis: int | None = None, + axis: int | tuple[int, ...] | None = None, symmetry_info: SymmetryInfo | None = None, ) -> int: """Weighted FLOP cost of a reduction operation. @@ -228,11 +228,12 @@ def reduction_cost( Operation name used for weight lookup, e.g. ``"sum"`` or ``"max"``. input_shape : tuple of int Shape of the reduction input. - axis : int or None, optional - Reduction axis. Accepted for API consistency with the analytical helper. + axis : int, tuple of int, or None, optional + Reduction axis or axes. When a tuple is given, the cost is computed + with those axes treated as the reduced set. symmetry_info : SymmetryInfo or None, optional - If provided, only unique elements are counted analytically before the - operation weight is applied. + If provided, symmetry is used to count unique outputs and inputs + (see ``_flops.analytical_reduction_cost``). Returns ------- diff --git a/tests/test_flops.py b/tests/test_flops.py index f355af384e..b4dad1ba00 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -274,6 +274,16 @@ def test_public_reduction_cost_is_weighted(tmp_path): assert public_flops.reduction_cost("sum", input_shape=(4, 5), axis=None) == 61 +def test_public_reduction_cost_tuple_axis(tmp_path): + # (4, 5, 6) with axis=(0, 2): kept axis 1 has 5 outputs, each reduces 24. + # Analytical: 5 * (24 − 1) = 115. Weighted with weight=1.0: 115. + load_weights(_write_weights(tmp_path, {"sum": 1.0}), use_packaged_default=False) + assert ( + public_flops.reduction_cost("sum", input_shape=(4, 5, 6), axis=(0, 2)) + == 115 + ) + + def test_public_einsum_cost_is_weighted(tmp_path): load_weights(_write_weights(tmp_path, {"einsum": 2.0}), use_packaged_default=False) assert public_flops.einsum_cost("ij,jk->ik", shapes=[(3, 4), (4, 5)]) == 120 From 6df071ccc8b429881ffaba4512beb1f95c5f9766 Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:27:00 +0200 Subject: [PATCH 06/10] test: update reduction expectations for off-by-one fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All downstream tests that embed the old reduction cost formula (cost = numel) now expect the corrected cost (numel − 1 for full reduction, unique_out × (u_R − 1) with clamping otherwise). --- tests/test_cost_formula_vs_code.py | 10 +++++----- tests/test_integration.py | 6 +++--- tests/test_methodology_consistency.py | 9 +++++---- tests/test_pointwise.py | 8 ++++---- tests/test_pointwise_coverage.py | 2 +- tests/test_symmetric_pointwise.py | 3 ++- 6 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/test_cost_formula_vs_code.py b/tests/test_cost_formula_vs_code.py index ab65cadf37..98d5ad58c5 100644 --- a/tests/test_cost_formula_vs_code.py +++ b/tests/test_cost_formula_vs_code.py @@ -239,7 +239,7 @@ def test_vecdot_batch_times_k(we): # --------------------------------------------------------------------------- -# Counted Reduction — numel(input) +# Counted Reduction — numel(input) − 1 (first value is a free copy) # --------------------------------------------------------------------------- _REDUCTION_NUMEL = [ @@ -281,28 +281,28 @@ def test_reduction_numel(name, we): a = numpy.random.rand(10, 10) fn = getattr(we, name) cost = _cost_of(fn, a) - assert cost == 100, f"{name}: expected numel(input)=100, got {cost}" + assert cost == 99, f"{name}: expected numel(input)-1=99, got {cost}" @pytest.mark.parametrize("name", ["percentile", "nanpercentile"]) def test_percentile_numel(name, we): a = numpy.random.rand(10, 10) cost = _cost_of(getattr(we, name), a, q=50) - assert cost == 100, f"{name}: expected numel(input)=100, got {cost}" + assert cost == 99, f"{name}: expected numel(input)-1=99, got {cost}" @pytest.mark.parametrize("name", ["quantile", "nanquantile"]) def test_quantile_numel(name, we): a = numpy.random.rand(10, 10) cost = _cost_of(getattr(we, name), a, q=0.5) - assert cost == 100, f"{name}: expected numel(input)=100, got {cost}" + assert cost == 99, f"{name}: expected numel(input)-1=99, got {cost}" @pytest.mark.parametrize("name", ["cumulative_sum", "cumulative_prod"]) def test_cumulative_numel(name, we): a = numpy.random.rand(10, 10) cost = _cost_of(getattr(we, name), a, axis=0) - assert cost == 100, f"{name}: expected numel(input)=100, got {cost}" + assert cost == 90, f"{name}: expected 10*(10-1)=90, got {cost}" # --------------------------------------------------------------------------- diff --git a/tests/test_integration.py b/tests/test_integration.py index b475fb1bfd..8668271b66 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -41,9 +41,9 @@ def test_budget_tracking_accuracy(): with we.BudgetContext(flop_budget=10**8) as budget: we.einsum("ij,jk->ik", A, B) # 10 * 20 * 30 = 6000 (FMA=1) we.exp(we.ones((100,))) # 100 - we.sum(we.ones((50,))) # 50 - assert budget.flops_used == 6000 + 100 + 50 - assert budget.flops_remaining == 10**8 - 6150 + we.sum(we.ones((50,))) # 50 − 1 = 49 + assert budget.flops_used == 6000 + 100 + 49 + assert budget.flops_remaining == 10**8 - 6149 def test_flop_query_matches_execution(): diff --git a/tests/test_methodology_consistency.py b/tests/test_methodology_consistency.py index e717fe59b3..6c9ee124c7 100644 --- a/tests/test_methodology_consistency.py +++ b/tests/test_methodology_consistency.py @@ -124,20 +124,21 @@ def test_exp(self): class TestReductionConsistency: - """Reductions: cost = numel(input).""" + """Reductions: cost = numel(input) − 1 (first value is a free copy).""" def test_sum(self): n = 1000 a = np.random.rand(n) runtime_cost = _run_and_get_cost(we.sum, a) - assert runtime_cost == n + assert runtime_cost == n - 1 def test_mean(self): n = 1000 a = np.random.rand(n) runtime_cost = _run_and_get_cost(we.mean, a) - # mean charges n+1 (sum + divide) or just n depending on impl - assert runtime_cost >= n + # mean charges (n−1) + possibly a divide; impl-dependent but must be + # at least n−1. + assert runtime_cost >= n - 1 # --------------------------------------------------------------------------- diff --git a/tests/test_pointwise.py b/tests/test_pointwise.py index 9bb45e8bc9..484e1edac5 100644 --- a/tests/test_pointwise.py +++ b/tests/test_pointwise.py @@ -76,7 +76,7 @@ def test_sum_full(): x = numpy.ones((5, 3)) with BudgetContext(flop_budget=10**6) as budget: result = sum(x) - assert budget.flops_used == 15 + assert budget.flops_used == 14 assert float(result) == 15.0 @@ -85,21 +85,21 @@ def test_sum_axis(): with BudgetContext(flop_budget=10**6) as budget: result = sum(x, axis=0) assert result.shape == (3,) - assert budget.flops_used == 15 + assert budget.flops_used == 12 def test_mean_cost(): x = numpy.ones((10, 20)) with BudgetContext(flop_budget=10**6) as budget: mean(x, axis=0) - assert budget.flops_used == 200 # numel(input) = 200 + assert budget.flops_used == 180 # 20 outputs × (10−1) def test_std_cost(): x = numpy.ones((10, 20)) with BudgetContext(flop_budget=10**6) as budget: std(x, axis=0) - assert budget.flops_used == 200 # numel(input) = 200 + assert budget.flops_used == 180 # 20 outputs × (10−1) def test_argmax_result(): diff --git a/tests/test_pointwise_coverage.py b/tests/test_pointwise_coverage.py index ebfa91f365..45814cc2f0 100644 --- a/tests/test_pointwise_coverage.py +++ b/tests/test_pointwise_coverage.py @@ -784,7 +784,7 @@ def test_sum_on_list(self): with BudgetContext(flop_budget=10**6) as budget: result = sum([1.0, 2.0, 3.0]) assert numpy.isclose(result, 6.0) - assert budget.flops_used == 3 + assert budget.flops_used == 2 def test_argmax_on_list(self): with BudgetContext(flop_budget=10**6): diff --git a/tests/test_symmetric_pointwise.py b/tests/test_symmetric_pointwise.py index 53fecc5006..f1d7ac86c7 100644 --- a/tests/test_symmetric_pointwise.py +++ b/tests/test_symmetric_pointwise.py @@ -83,7 +83,8 @@ def test_sum_symmetric_cost(self): S = as_symmetric(data, symmetric_axes=(0, 1)) with BudgetContext(flop_budget=10**6, quiet=True) as budget: we.sum(S) - assert budget.flops_used == 55 + # sym(10,10) has 55 unique elements; full reduction: 55 − 1 = 54. + assert budget.flops_used == 54 def test_sum_returns_plain(self): import whest as we From 19915c1bd5e26cb90b1d8235388c762b9cb990f4 Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:31:59 +0200 Subject: [PATCH 07/10] test: add inline arithmetic comments for updated reduction expectations Minor polish from code review: explain where literals like 14, 12, 2 come from so future readers don't have to reverse-engineer them. Also updates a stale "numel(input)" sub-bucket comment to match the new formula. --- tests/test_cost_formula_vs_code.py | 2 +- tests/test_pointwise.py | 4 ++-- tests/test_pointwise_coverage.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_cost_formula_vs_code.py b/tests/test_cost_formula_vs_code.py index 98d5ad58c5..18b80d5658 100644 --- a/tests/test_cost_formula_vs_code.py +++ b/tests/test_cost_formula_vs_code.py @@ -265,7 +265,7 @@ def test_vecdot_batch_times_k(we): "nansum", "median", "nanmedian", - # mean/std/var also cost numel(input) per sheet + # mean/std/var also cost numel(input) − 1 per sheet "mean", "average", "nanmean", diff --git a/tests/test_pointwise.py b/tests/test_pointwise.py index 484e1edac5..b416469592 100644 --- a/tests/test_pointwise.py +++ b/tests/test_pointwise.py @@ -76,7 +76,7 @@ def test_sum_full(): x = numpy.ones((5, 3)) with BudgetContext(flop_budget=10**6) as budget: result = sum(x) - assert budget.flops_used == 14 + assert budget.flops_used == 14 # 5*3 − 1 assert float(result) == 15.0 @@ -85,7 +85,7 @@ def test_sum_axis(): with BudgetContext(flop_budget=10**6) as budget: result = sum(x, axis=0) assert result.shape == (3,) - assert budget.flops_used == 12 + assert budget.flops_used == 12 # 3 outputs × (5−1) def test_mean_cost(): diff --git a/tests/test_pointwise_coverage.py b/tests/test_pointwise_coverage.py index 45814cc2f0..c8ad691188 100644 --- a/tests/test_pointwise_coverage.py +++ b/tests/test_pointwise_coverage.py @@ -784,7 +784,7 @@ def test_sum_on_list(self): with BudgetContext(flop_budget=10**6) as budget: result = sum([1.0, 2.0, 3.0]) assert numpy.isclose(result, 6.0) - assert budget.flops_used == 2 + assert budget.flops_used == 2 # 3 inputs − 1 def test_argmax_on_list(self): with BudgetContext(flop_budget=10**6): From 7778f104974b03f7ceaccb97792e1bd6ce1db421 Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:34:13 +0200 Subject: [PATCH 08/10] docs: correct stale std/var notes (weight, not cost_multiplier) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The notes previously claimed cost_multiplier=2, but the code does not pass cost_multiplier to the reduction cost path. The 2× factor actually comes from weight=2.0 in weights.csv. Update notes to reflect reality. --- src/whest/_registry.py | 4 ++-- src/whest/data/weights.csv | 4 ++-- tests/test_pointwise_coverage.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/whest/_registry.py b/src/whest/_registry.py index 7543de2fcc..f8b4c1e398 100644 --- a/src/whest/_registry.py +++ b/src/whest/_registry.py @@ -674,12 +674,12 @@ "std": { "category": "counted_reduction", "module": "numpy", - "notes": "Standard deviation; cost_multiplier=2 (two passes).", + "notes": "Standard deviation; weight=2.0 accounts for two-pass algorithm.", }, "var": { "category": "counted_reduction", "module": "numpy", - "notes": "Variance; cost_multiplier=2 (two passes).", + "notes": "Variance; weight=2.0 accounts for two-pass algorithm.", }, "argmax": { "category": "counted_reduction", diff --git a/src/whest/data/weights.csv b/src/whest/data/weights.csv index 253aef8fa8..8818c2cbe9 100644 --- a/src/whest/data/weights.csv +++ b/src/whest/data/weights.csv @@ -80,8 +80,8 @@ copysign,benchmarked,Pointwise Binary,numel(output),1.0000,1.1021,,"1,000 FLOPs nextafter,benchmarked,Pointwise Binary,numel(output),1.0000,5.7999,,"1,000 FLOPs (1000 elements)",low,Timing-based weight (fp_arith_inst_retired blind to this op). Ratio vs add=5.7999,,0.60,5.7999,0.1724,0.3715,"np.nextafter(a, b, out=_out)","a: (10000000,), b: (10000000,)","[28214641, 60009654, 60009598]",1135600279,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L460,baseline,10 ldexp,benchmarked,Pointwise Binary,numel(output),1.0000,3.3667,,"1,000 FLOPs (1000 elements)",high,Timing-based (fp_arith_inst_retired blind to comparisons). Ratio vs add=3.3667,,0.60,0.0000,N/A,0.0014,"np.ldexp(a, b, out=_out)","a: (10000000,), b: (10000000,)","[28215481, 60010494, 60010438]",,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L451,baseline,10 heaviside,benchmarked,Pointwise Binary,numel(output),1.0000,1.3916,,"1,000 FLOPs (1000 elements)",high,Timing-based (fp_arith_inst_retired blind to comparisons). Ratio vs add=1.3916,,0.30,1.3789,0.7252,0.0015,"np.heaviside(x, 0.5)","x: (10000000,), h=0.5","[14116734, 30009661, 30009620]",269979729,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L448,baseline,10 -std,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Standard deviation; cost_multiplier=2 (two passes).,,4.30,1.6020,1.2484,0.0216,np.std(x),"x: (10000000,)","[414116751, 430009678, 430009637]",313671080,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L533,moderate,10 -var,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Variance; cost_multiplier=2 (two passes).,,4.30,1.6012,1.2491,0.0216,np.var(x),"x: (10000000,)","[414116731, 430009658, 430009617]",313505568,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L534,moderate,10 +std,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Standard deviation; weight=2.0 accounts for two-pass algorithm.,,4.30,1.6020,1.2484,0.0216,np.std(x),"x: (10000000,)","[414116751, 430009678, 430009637]",313671080,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L533,moderate,10 +var,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Variance; weight=2.0 accounts for two-pass algorithm.,,4.30,1.6012,1.2491,0.0216,np.var(x),"x: (10000000,)","[414116731, 430009658, 430009617]",313505568,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L534,moderate,10 nanstd,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Standard deviation ignoring NaNs.,,4.30,3.1232,0.6404,0.0216,np.nanstd(x),"x: (10000000,)","[414116751, 430009678, 430009637]",611516405,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L564,moderate,10 nanvar,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Variance ignoring NaNs.,,4.30,3.0977,0.6456,0.0216,np.nanvar(x),"x: (10000000,)","[414116731, 430009658, 430009617]",606531077,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L566,moderate,10 sum,benchmarked,Reductions,numel(input),1.0000,1.0000,,"1,000 FLOPs (1000 elements)",medium,Sum of array elements.,,1.30,0.2440,4.0984,0.0736,np.sum(x),"x: (10000000,)","[114116711, 130009638, 130009597]",47773697,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L528,baseline,10 diff --git a/tests/test_pointwise_coverage.py b/tests/test_pointwise_coverage.py index c8ad691188..fa39233042 100644 --- a/tests/test_pointwise_coverage.py +++ b/tests/test_pointwise_coverage.py @@ -282,7 +282,7 @@ def test_reduction_keepdims_reduces_group(self): # axis 0 is in group (0,1); remaining single axis => group lost def test_std_on_symmetric_tensor(self): - """std() (cost_multiplier=2, extra_output=True) on a symmetric tensor.""" + """std() (weight=2.0, extra_output=True) on a symmetric tensor.""" st = _make_symmetric_3d(4) with BudgetContext(flop_budget=10**6): result = std(st, axis=2) From e83887d4517c4db92ceef0b23f445e9ac988a315 Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:38:57 +0200 Subject: [PATCH 09/10] test(symmetric): add runtime tests for axis-reduction cost paths Three new tests in TestReductionSymmetry exercise the runtime path (we.sum on SymmetricTensor with axis/tuple-axis) for the three sym cases: split pair (no inner savings), preserving (sym survives on output), and inner-clean (sym entirely within reduced axes). --- tests/test_symmetric_pointwise.py | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_symmetric_pointwise.py b/tests/test_symmetric_pointwise.py index f1d7ac86c7..489f4a6ccf 100644 --- a/tests/test_symmetric_pointwise.py +++ b/tests/test_symmetric_pointwise.py @@ -94,3 +94,36 @@ def test_sum_returns_plain(self): with BudgetContext(flop_budget=10**6, quiet=True): result = we.sum(S) assert not isinstance(result, SymmetricTensor) + + def test_sum_axis_sym_split_pair_runtime(self): + import whest as we + + # sym(5,5) reducing axis=0 → split group → 5 outputs × (5−1) = 20. + data = numpy.ones((5, 5)) + S = as_symmetric(data, symmetric_axes=(0, 1)) + with BudgetContext(flop_budget=10**6, quiet=True) as budget: + we.sum(S, axis=0) + assert budget.flops_used == 20 + + def test_sum_axis_sym_preserving_runtime(self): + import whest as we + + # sym(5,5,10) with sym axes (0,1) reducing axis=2: + # S_2 survives on output (kept axes {0,1}) → 15 unique outputs × (10−1) = 135. + data = numpy.ones((5, 5, 10)) + S = as_symmetric(data, symmetric_axes=(0, 1)) + with BudgetContext(flop_budget=10**6, quiet=True) as budget: + we.sum(S, axis=2) + assert budget.flops_used == 135 + + def test_sum_tuple_axis_inner_clean_runtime(self): + import whest as we + + # sym(5,5,10) reducing both (0,1) → inner-clean group (g.axes ⊆ R). + # u_R = 15 (Burnside on S_2 over 5×5); 10 outputs (axis 2 kept). + # Cost = 10 × (15−1) = 140. + data = numpy.ones((5, 5, 10)) + S = as_symmetric(data, symmetric_axes=(0, 1)) + with BudgetContext(flop_budget=10**6, quiet=True) as budget: + we.sum(S, axis=(0, 1)) + assert budget.flops_used == 140 From f21f607f8d45b2db4a9f2221f399a0e9d3454fc8 Mon Sep 17 00:00:00 2001 From: Sharada Mohanty Date: Fri, 24 Apr 2026 16:43:42 +0200 Subject: [PATCH 10/10] docs(_pointwise): update counted-reduction cost docstring label The attached cost-description label on reduction wrappers (sum, mean, std, etc.) now reflects the corrected off-by-one formula: "numel(input) - 1" instead of "numel(input)". User-facing docstrings now match the actual runtime behavior. --- src/whest/_pointwise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/whest/_pointwise.py b/src/whest/_pointwise.py index 95214cf8ef..80ebf88040 100644 --- a/src/whest/_pointwise.py +++ b/src/whest/_pointwise.py @@ -273,9 +273,9 @@ def wrapper(a, axis=None, **kwargs): wrapper.__name__ = op_name wrapper.__qualname__ = op_name cost_desc = ( - f"numel(input) * {cost_multiplier} FLOPs" + f"(numel(input) - 1) * {cost_multiplier} FLOPs" if cost_multiplier > 1 - else "numel(input) FLOPs" + else "numel(input) - 1 FLOPs" ) if extra_output: cost_desc += " + numel(output)"