Identified in #67. Currently a subset of shape ops correctly remap a SymmetricTensor's symmetry to its new axis layout, but the rest silently downgrade to a plain WhestArray and lose the metadata.
What works (post-#67)
me.swapaxes, me.transpose, me.moveaxis — they call wrap_with_symmetry(remap_group_axes(group, axis_map)) and preserve symmetry on the new axes.
import whest as we
A = we.symmetrize(we.random.randn(3, 3),
symmetry=we.SymmetryGroup.symmetric(axes=(0, 1)))
print(we.swapaxes(A, 0, 1).symmetry) # SymmetryGroup([1, 0], axes=(1, 0)) ✓
What doesn't
Anything that touches axis count, splits an axis, or merges axes:
import whest as we
A = we.symmetrize(we.random.randn(3, 3),
symmetry=we.SymmetryGroup.symmetric(axes=(0, 1)))
print(type(we.reshape(A, (9,))).__name__) # WhestArray (lost symmetry)
print(getattr(we.reshape(A, (9,)), 'symmetry')) # None
Other affected ops (subset): reshape, concatenate, stack / vstack / hstack / column_stack, split / hsplit / vsplit / dsplit, atleast_1d / atleast_2d / atleast_3d, broadcast_to, expand_dims, squeeze, flip, roll, tile, repeat.
Why it matters
Users who pass a SymmetricTensor through any reasonable pipeline (e.g. we.stack([A, B]), we.reshape(A, (-1, A.shape[-1]))) silently lose their symmetry tag, so downstream FLOP accounting reverts to the dense formula even when the symmetry would have survived.
Suggested approach
Per-op decision logic in the __array_function__ allowlist binding, calling helpers in whest._symmetry_utils:
| Op |
Symmetry survival rule |
reshape(a, new_shape) |
Survives only if new_shape preserves the contiguous block of symmetric axes intact (incl. -1 cases). Otherwise downgrade. |
concatenate / stack along axis k |
Restrict every input's group to axes other than k (use restrict_group_to_axes); intersect across inputs (intersect_groups); on stack, lift axes ≥ k by 1. |
transpose / swapaxes / moveaxis |
Already handled — keep as reference. |
squeeze, expand_dims |
Lift / drop axes via remap_group_axes. |
broadcast_to |
If new shape only adds prepended length-1 axes, shift symmetry axes by the lift count via remap_group_axes. |
flip(a, axis=k) |
Symmetry on the flipped axis is preserved iff the group's action on that axis is invariant under index reversal (always true for S_n symmetric, conditionally for cyclic / dihedral). |
roll, tile, repeat |
Generally destroy permutation invariance — drop. |
whest._symmetry_utils already exposes remap_group_axes, restrict_group_to_axes, intersect_groups, direct_product_groups, embed_group — most rules can be expressed with these.
Acceptance criteria
- For each shape op listed above, add an explicit code path that either remaps and preserves the symmetry, or — when survival is genuinely impossible — downgrades with a justification comment.
- Tests in
tests/test_symmetry_transport.py (or a new tests/test_shape_op_symmetry.py) covering:
reshape survival on identity reshapes and downgrade on axis-merging reshapes.
stack / concatenate preserving the surviving common subgroup.
expand_dims / squeeze axis-shift correctness.
broadcast_to with prepended length-1 axes.
Related
Identified in #67. Currently a subset of shape ops correctly remap a
SymmetricTensor's symmetry to its new axis layout, but the rest silently downgrade to a plainWhestArrayand lose the metadata.What works (post-#67)
me.swapaxes,me.transpose,me.moveaxis— they callwrap_with_symmetry(remap_group_axes(group, axis_map))and preserve symmetry on the new axes.What doesn't
Anything that touches axis count, splits an axis, or merges axes:
Other affected ops (subset):
reshape,concatenate,stack/vstack/hstack/column_stack,split/hsplit/vsplit/dsplit,atleast_1d/atleast_2d/atleast_3d,broadcast_to,expand_dims,squeeze,flip,roll,tile,repeat.Why it matters
Users who pass a
SymmetricTensorthrough any reasonable pipeline (e.g.we.stack([A, B]),we.reshape(A, (-1, A.shape[-1]))) silently lose their symmetry tag, so downstream FLOP accounting reverts to the dense formula even when the symmetry would have survived.Suggested approach
Per-op decision logic in the
__array_function__allowlist binding, calling helpers inwhest._symmetry_utils:reshape(a, new_shape)new_shapepreserves the contiguous block of symmetric axes intact (incl.-1cases). Otherwise downgrade.concatenate / stackalong axiskk(userestrict_group_to_axes); intersect across inputs (intersect_groups); onstack, lift axes ≥kby 1.transpose / swapaxes / moveaxissqueeze,expand_dimsremap_group_axes.broadcast_toremap_group_axes.flip(a, axis=k)S_nsymmetric, conditionally forcyclic/dihedral).roll,tile,repeatwhest._symmetry_utilsalready exposesremap_group_axes,restrict_group_to_axes,intersect_groups,direct_product_groups,embed_group— most rules can be expressed with these.Acceptance criteria
tests/test_symmetry_transport.py(or a newtests/test_shape_op_symmetry.py) covering:reshapesurvival on identity reshapes and downgrade on axis-merging reshapes.stack/concatenatepreserving the surviving common subgroup.expand_dims/squeezeaxis-shift correctness.broadcast_towith prepended length-1 axes.Related
__array_function__where this work would slot in.wrap_with_symmetry/remap_group_axes/restrict_group_to_axes.