Skip to content

Shape ops (reshape, concatenate, stack, split, ...) downgrade SymmetricTensor instead of remapping symmetry #68

@spMohanty

Description

@spMohanty

Identified in #67. Currently a subset of shape ops correctly remap a SymmetricTensor's symmetry to its new axis layout, but the rest silently downgrade to a plain WhestArray and lose the metadata.

What works (post-#67)

me.swapaxes, me.transpose, me.moveaxis — they call wrap_with_symmetry(remap_group_axes(group, axis_map)) and preserve symmetry on the new axes.

import whest as we
A = we.symmetrize(we.random.randn(3, 3),
                  symmetry=we.SymmetryGroup.symmetric(axes=(0, 1)))
print(we.swapaxes(A, 0, 1).symmetry)   # SymmetryGroup([1, 0], axes=(1, 0))  ✓

What doesn't

Anything that touches axis count, splits an axis, or merges axes:

import whest as we
A = we.symmetrize(we.random.randn(3, 3),
                  symmetry=we.SymmetryGroup.symmetric(axes=(0, 1)))
print(type(we.reshape(A, (9,))).__name__)         # WhestArray (lost symmetry)
print(getattr(we.reshape(A, (9,)), 'symmetry'))   # None

Other affected ops (subset): reshape, concatenate, stack / vstack / hstack / column_stack, split / hsplit / vsplit / dsplit, atleast_1d / atleast_2d / atleast_3d, broadcast_to, expand_dims, squeeze, flip, roll, tile, repeat.

Why it matters

Users who pass a SymmetricTensor through any reasonable pipeline (e.g. we.stack([A, B]), we.reshape(A, (-1, A.shape[-1]))) silently lose their symmetry tag, so downstream FLOP accounting reverts to the dense formula even when the symmetry would have survived.

Suggested approach

Per-op decision logic in the __array_function__ allowlist binding, calling helpers in whest._symmetry_utils:

Op Symmetry survival rule
reshape(a, new_shape) Survives only if new_shape preserves the contiguous block of symmetric axes intact (incl. -1 cases). Otherwise downgrade.
concatenate / stack along axis k Restrict every input's group to axes other than k (use restrict_group_to_axes); intersect across inputs (intersect_groups); on stack, lift axes ≥ k by 1.
transpose / swapaxes / moveaxis Already handled — keep as reference.
squeeze, expand_dims Lift / drop axes via remap_group_axes.
broadcast_to If new shape only adds prepended length-1 axes, shift symmetry axes by the lift count via remap_group_axes.
flip(a, axis=k) Symmetry on the flipped axis is preserved iff the group's action on that axis is invariant under index reversal (always true for S_n symmetric, conditionally for cyclic / dihedral).
roll, tile, repeat Generally destroy permutation invariance — drop.

whest._symmetry_utils already exposes remap_group_axes, restrict_group_to_axes, intersect_groups, direct_product_groups, embed_group — most rules can be expressed with these.

Acceptance criteria

  • For each shape op listed above, add an explicit code path that either remaps and preserves the symmetry, or — when survival is genuinely impossible — downgrades with a justification comment.
  • Tests in tests/test_symmetry_transport.py (or a new tests/test_shape_op_symmetry.py) covering:
    • reshape survival on identity reshapes and downgrade on axis-merging reshapes.
    • stack / concatenate preserving the surviving common subgroup.
    • expand_dims / squeeze axis-shift correctness.
    • broadcast_to with prepended length-1 axes.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:coreCore whest API, counting, ndarray, and dispatch/wrapping pathsenhancementNew feature or requestpriority:p2Nice-to-have, scheduledtopic:symmetrySymmetry metadata, inference, propagation, groups, and symmetric tensors

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions