Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 144 additions & 17 deletions src/whest/_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from collections import Counter
from typing import TYPE_CHECKING

from whest._symmetric import propagate_symmetry_reduce

if TYPE_CHECKING:
from whest._perm_group import PermutationGroup
from whest._symmetric import SymmetryInfo


Expand Down Expand Up @@ -121,39 +124,163 @@ def analytical_pointwise_cost(
return max(result, 1)


def _normalize_axis(
axis: int | tuple[int, ...] | None, ndim: int
) -> tuple[int, ...] | None:
"""Return normalized reduction axes as a sorted tuple, or None for full reduction.

Raises
------
ValueError
If ``axis`` is not ``None`` but ``ndim == 0`` (scalar input has no axes).
"""
if axis is None:
return None
if ndim == 0:
raise ValueError(
f"axis={axis!r} is out of bounds for scalar input (ndim=0); "
"use axis=None for full reduction of a scalar"
)
if isinstance(axis, int):
axis = (axis,)
normalized = tuple(sorted(a % ndim for a in axis))
return normalized


def _compute_output_unique_count(
groups: list[PermutationGroup] | None,
input_shape: tuple[int, ...],
reduced_axes: tuple[int, ...] | None,
) -> int:
"""Number of unique outputs after reducing ``reduced_axes`` of ``input_shape``.

Uses ``propagate_symmetry_reduce`` to obtain the surviving output symmetry
groups, then multiplies Burnside counts (over symmetric axes) by free-axis
sizes (kept non-symmetric axes).
"""
ndim = len(input_shape)
if reduced_axes is None:
# Full reduction → scalar output.
return 1
kept_axes = [d for d in range(ndim) if d not in set(reduced_axes)]
if not kept_axes:
return 1

# Propagate to get the surviving output symmetry groups.
# (Operates on input-axis numbering internally; returns output-axis numbering.)
out_groups = None
if groups:
out_groups = propagate_symmetry_reduce(
groups, ndim, reduced_axes, keepdims=False
)

# Output shape in output-axis numbering.
output_shape = tuple(input_shape[d] for d in kept_axes)

# Sum Burnside unique counts over the output groups; multiply free-dim sizes.
accounted: set[int] = set()
total = 1
if out_groups:
for group in out_groups:
axes = group.axes
if axes is None:
continue
size_dict = {i: output_shape[axes[i]] for i in range(group.degree)}
total *= group.burnside_unique_count(size_dict)
accounted.update(axes)
for i, size in enumerate(output_shape):
if i not in accounted:
total *= size
return total


def _compute_R_unique_count(
groups: list[PermutationGroup] | None,
input_shape: tuple[int, ...],
reduced_axes: tuple[int, ...] | None,
) -> int:
"""Number of unique inputs feeding one output slice.

Only **inner-clean** sym groups (``g.axes ⊆ R``) contribute input-side
savings — they act entirely within the reduced axes and combine equivalent
values. Split groups (spanning both R and K) and output-only groups
(``g.axes ⊆ K``) contribute no savings here.

When ``reduced_axes is None`` (full reduction), every axis is treated as
reduced, so ``u_R`` equals the total number of unique input elements.
"""
ndim = len(input_shape)
if reduced_axes is None:
reduced_axes = tuple(range(ndim))
reduced_set = set(reduced_axes)

accounted: set[int] = set()
total = 1
if groups:
for group in groups:
axes = group.axes
if axes is None:
continue
axes_set = set(axes)
if not axes_set.issubset(reduced_set):
# Not inner-clean: either output-only (g.axes ⊆ K) — doesn't
# touch reduced axes — or split. Neither contributes to u_R.
continue
size_dict = {i: input_shape[axes[i]] for i in range(group.degree)}
total *= group.burnside_unique_count(size_dict)
accounted.update(axes)
for d in reduced_axes:
if d not in accounted:
total *= input_shape[d]
return total


def analytical_reduction_cost(
input_shape: tuple[int, ...],
axis: int | None = None,
axis: int | tuple[int, ...] | None = None,
symmetry_info: SymmetryInfo | None = None,
) -> int:
"""FLOP cost of a reduction operation.

The cost of a reduction is the number of accumulations performed:
for each output, the first input value is a free copy, and the remaining
``u_R − 1`` values are accumulated in. Total cost:

.. math::

\\text{cost} = \\text{unique\\_out} \\times (u_R - 1)

where ``unique_out`` is the number of unique outputs (accounting for
output-side symmetry) and ``u_R`` is the number of unique inputs feeding
one output slice (accounting for inner-clean input symmetry).

Parameters
----------
input_shape : tuple of int
Shape of the input array.
axis : int or None, optional
Axis along which to reduce. If None, reduce over all elements.
axis : int, tuple of int, or None, optional
Axis or axes along which to reduce. If None, reduce over all elements.
symmetry_info : SymmetryInfo or None, optional
If provided, only unique elements are counted.
If provided, symmetry is used to count unique outputs and unique
per-output inputs. Only inner-clean groups (g.axes ⊆ reduced axes)
contribute per-output input savings; split groups do not.

Returns
-------
int
Estimated FLOP count (one per element).

Notes
-----
The ``axis`` parameter is accepted for API consistency but does not
affect the result: a reduction always touches every element regardless
of which axis is reduced, so the cost is always ``prod(input_shape)``.
Estimated FLOP count. Clamped to a minimum of 1 for degenerate shapes
(scalar, size-1 axis, empty shape) so every reduction registers at
least 1 flop for budget tracking purposes.
"""
if symmetry_info is not None:
return max(symmetry_info.unique_elements, 1)
result = 1
for dim in input_shape:
result *= dim
return max(result, 1)
ndim = len(input_shape)
reduced_axes = _normalize_axis(axis, ndim)
groups = symmetry_info.groups if symmetry_info is not None else None

unique_out = _compute_output_unique_count(groups, input_shape, reduced_axes)
u_R = _compute_R_unique_count(groups, input_shape, reduced_axes)

cost = unique_out * (u_R - 1)
return max(cost, 1)


# Backward-compatible internal aliases. The public weighted API lives in
Expand Down
4 changes: 2 additions & 2 deletions src/whest/_pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def wrapper(a, axis=None, **kwargs):
wrapper.__name__ = op_name
wrapper.__qualname__ = op_name
cost_desc = (
f"numel(input) * {cost_multiplier} FLOPs"
f"(numel(input) - 1) * {cost_multiplier} FLOPs"
if cost_multiplier > 1
else "numel(input) FLOPs"
else "numel(input) - 1 FLOPs"
)
if extra_output:
cost_desc += " + numel(output)"
Expand Down
4 changes: 2 additions & 2 deletions src/whest/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,12 +674,12 @@
"std": {
"category": "counted_reduction",
"module": "numpy",
"notes": "Standard deviation; cost_multiplier=2 (two passes).",
"notes": "Standard deviation; weight=2.0 accounts for two-pass algorithm.",
},
"var": {
"category": "counted_reduction",
"module": "numpy",
"notes": "Variance; cost_multiplier=2 (two passes).",
"notes": "Variance; weight=2.0 accounts for two-pass algorithm.",
},
"argmax": {
"category": "counted_reduction",
Expand Down
4 changes: 2 additions & 2 deletions src/whest/data/weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ copysign,benchmarked,Pointwise Binary,numel(output),1.0000,1.1021,,"1,000 FLOPs
nextafter,benchmarked,Pointwise Binary,numel(output),1.0000,5.7999,,"1,000 FLOPs (1000 elements)",low,Timing-based weight (fp_arith_inst_retired blind to this op). Ratio vs add=5.7999,,0.60,5.7999,0.1724,0.3715,"np.nextafter(a, b, out=_out)","a: (10000000,), b: (10000000,)","[28214641, 60009654, 60009598]",1135600279,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L460,baseline,10
ldexp,benchmarked,Pointwise Binary,numel(output),1.0000,3.3667,,"1,000 FLOPs (1000 elements)",high,Timing-based (fp_arith_inst_retired blind to comparisons). Ratio vs add=3.3667,,0.60,0.0000,N/A,0.0014,"np.ldexp(a, b, out=_out)","a: (10000000,), b: (10000000,)","[28215481, 60010494, 60010438]",,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L451,baseline,10
heaviside,benchmarked,Pointwise Binary,numel(output),1.0000,1.3916,,"1,000 FLOPs (1000 elements)",high,Timing-based (fp_arith_inst_retired blind to comparisons). Ratio vs add=1.3916,,0.30,1.3789,0.7252,0.0015,"np.heaviside(x, 0.5)","x: (10000000,), h=0.5","[14116734, 30009661, 30009620]",269979729,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L448,baseline,10
std,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Standard deviation; cost_multiplier=2 (two passes).,,4.30,1.6020,1.2484,0.0216,np.std(x),"x: (10000000,)","[414116751, 430009678, 430009637]",313671080,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L533,moderate,10
var,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Variance; cost_multiplier=2 (two passes).,,4.30,1.6012,1.2491,0.0216,np.var(x),"x: (10000000,)","[414116731, 430009658, 430009617]",313505568,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L534,moderate,10
std,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Standard deviation; weight=2.0 accounts for two-pass algorithm.,,4.30,1.6020,1.2484,0.0216,np.std(x),"x: (10000000,)","[414116751, 430009678, 430009637]",313671080,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L533,moderate,10
var,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Variance; weight=2.0 accounts for two-pass algorithm.,,4.30,1.6012,1.2491,0.0216,np.var(x),"x: (10000000,)","[414116731, 430009658, 430009617]",313505568,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L534,moderate,10
nanstd,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Standard deviation ignoring NaNs.,,4.30,3.1232,0.6404,0.0216,np.nanstd(x),"x: (10000000,)","[414116751, 430009678, 430009637]",611516405,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L564,moderate,10
nanvar,benchmarked,Reductions,numel(input),2.0000,4.0000,,"2,000 FLOPs (1000 elements)",high,Variance ignoring NaNs.,,4.30,3.0977,0.6456,0.0216,np.nanvar(x),"x: (10000000,)","[414116731, 430009658, 430009617]",606531077,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L566,moderate,10
sum,benchmarked,Reductions,numel(input),1.0000,1.0000,,"1,000 FLOPs (1000 elements)",medium,Sum of array elements.,,1.30,0.2440,4.0984,0.0736,np.sum(x),"x: (10000000,)","[114116711, 130009638, 130009597]",47773697,https://github.com/AIcrowd/whest/blob/main/src/whest/_pointwise.py#L528,baseline,10
Expand Down
11 changes: 6 additions & 5 deletions src/whest/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def reduction_cost(
op_name: str,
*,
input_shape: tuple[int, ...],
axis: int | None = None,
axis: int | tuple[int, ...] | None = None,
symmetry_info: SymmetryInfo | None = None,
) -> int:
"""Weighted FLOP cost of a reduction operation.
Expand All @@ -228,11 +228,12 @@ def reduction_cost(
Operation name used for weight lookup, e.g. ``"sum"`` or ``"max"``.
input_shape : tuple of int
Shape of the reduction input.
axis : int or None, optional
Reduction axis. Accepted for API consistency with the analytical helper.
axis : int, tuple of int, or None, optional
Reduction axis or axes. When a tuple is given, the cost is computed
with those axes treated as the reduced set.
symmetry_info : SymmetryInfo or None, optional
If provided, only unique elements are counted analytically before the
operation weight is applied.
If provided, symmetry is used to count unique outputs and inputs
(see ``_flops.analytical_reduction_cost``).

Returns
-------
Expand Down
12 changes: 6 additions & 6 deletions tests/test_cost_formula_vs_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_vecdot_batch_times_k(we):


# ---------------------------------------------------------------------------
# Counted Reduction — numel(input)
# Counted Reduction — numel(input) − 1 (first value is a free copy)
# ---------------------------------------------------------------------------

_REDUCTION_NUMEL = [
Expand All @@ -265,7 +265,7 @@ def test_vecdot_batch_times_k(we):
"nansum",
"median",
"nanmedian",
# mean/std/var also cost numel(input) per sheet
# mean/std/var also cost numel(input) − 1 per sheet
"mean",
"average",
"nanmean",
Expand All @@ -281,28 +281,28 @@ def test_reduction_numel(name, we):
a = numpy.random.rand(10, 10)
fn = getattr(we, name)
cost = _cost_of(fn, a)
assert cost == 100, f"{name}: expected numel(input)=100, got {cost}"
assert cost == 99, f"{name}: expected numel(input)-1=99, got {cost}"


@pytest.mark.parametrize("name", ["percentile", "nanpercentile"])
def test_percentile_numel(name, we):
a = numpy.random.rand(10, 10)
cost = _cost_of(getattr(we, name), a, q=50)
assert cost == 100, f"{name}: expected numel(input)=100, got {cost}"
assert cost == 99, f"{name}: expected numel(input)-1=99, got {cost}"


@pytest.mark.parametrize("name", ["quantile", "nanquantile"])
def test_quantile_numel(name, we):
a = numpy.random.rand(10, 10)
cost = _cost_of(getattr(we, name), a, q=0.5)
assert cost == 100, f"{name}: expected numel(input)=100, got {cost}"
assert cost == 99, f"{name}: expected numel(input)-1=99, got {cost}"


@pytest.mark.parametrize("name", ["cumulative_sum", "cumulative_prod"])
def test_cumulative_numel(name, we):
a = numpy.random.rand(10, 10)
cost = _cost_of(getattr(we, name), a, axis=0)
assert cost == 100, f"{name}: expected numel(input)=100, got {cost}"
assert cost == 90, f"{name}: expected 10*(10-1)=90, got {cost}"


# ---------------------------------------------------------------------------
Expand Down
Loading
Loading