From 991368801756f86ead175a218c15554007376aed Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Tue, 2 Jun 2026 10:25:24 +0200 Subject: [PATCH] Arm backend: Add CanonicalizeViewCopyPermutePass This pass replaces PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView and FuseCascadedTransposeOrPermuteOps pass reordering and fusing in the same step, and improves performance by using a more general condition for which combinations may be fused, see docstring. + remove stale comments in mha test Signed-off-by: Adrian Lundell Change-Id: I5e9d28f2fd9a71aa277039d66fe644d988227729 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 8 +- .../canonicalize_view_copy_permute_pass.py | 398 ++++++++++++ .../test_high_rank_permute_view_invariants.py | 2 +- .../arm/test/misc/test_transpose_counts.py | 10 +- .../arm/test/ops/test_multihead_attention.py | 9 +- ...est_canonicalize_view_copy_permute_pass.py | 573 ++++++++++++++++++ .../arm/test/passes/test_permute_view_swap.py | 230 +++++++ .../arm/test/passes/test_view_permute_swap.py | 227 +++++++ 9 files changed, 1438 insertions(+), 20 deletions(-) create mode 100644 backends/arm/_passes/canonicalize_view_copy_permute_pass.py create mode 100644 backends/arm/test/passes/test_canonicalize_view_copy_permute_pass.py create mode 100644 backends/arm/test/passes/test_permute_view_swap.py create mode 100644 backends/arm/test/passes/test_view_permute_swap.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index ea4d49a79bb..e0d35563527 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -9,6 +9,7 @@ from .accumulate_index_put_pass import AccumulateIndexPutPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa +from .canonicalize_view_copy_permute_pass import CanonicalizeViewCopyPermutePass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .constant_folding_pass import ConstantFoldingPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 700b58f6c85..97857cbf5c1 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -15,6 +15,7 @@ AccumulateIndexPutPass, BroadcastArgsPass, CanonicalizeGatherPass, + CanonicalizeViewCopyPermutePass, CastInt64BuffersToInt32Pass, CastToInt32Pass, ComputeConstantOpsAOTPass, @@ -29,7 +30,6 @@ ConvertInt64OutputOpsToInt32Pass, ConvertMinMaxPass, ConvertMmToBmmPass, - ConvertPermuteSingletonToViewPass, ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, @@ -164,9 +164,6 @@ from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( FuseCascadedTransposeOrPermuteOps, ) -from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( - PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, -) from executorch.exir import ExportedProgram from executorch.exir._program_utils import _get_updated_graph_signature @@ -615,9 +612,8 @@ def _tosa_pipeline( RewritePadPass(), FuseViewCopyTransformPass(), RemovePermutesAroundElementwiseTosaOps(), - PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), + CanonicalizeViewCopyPermutePass(), FuseCascadedTransposeOrPermuteOps(), - ConvertPermuteSingletonToViewPass(), RewriteHighRankSingletonPermutePass(), DecomposePermuteForU55Pass(), RewriteSlicePass(), diff --git a/backends/arm/_passes/canonicalize_view_copy_permute_pass.py b/backends/arm/_passes/canonicalize_view_copy_permute_pass.py new file mode 100644 index 00000000000..ce2cf71e54e --- /dev/null +++ b/backends/arm/_passes/canonicalize_view_copy_permute_pass.py @@ -0,0 +1,398 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import cast, Sequence, Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.dim_maps import ( + _dim_equals, + _is_permutation, + _normalize_dim, + _normalize_dims, + ViewMap, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +from torch.fx import GraphModule, Node +from torch.fx.node import Target +from torch.fx.passes.infra.pass_base import PassResult + +_Dim = int | torch.SymInt + + +class CanonicalizeViewCopyPermutePass(ArmPass): + """Canonicalize view/permute chains. + + The pass repeatedly fuses adjacent compatible ops and swaps adjacent + view/permute pairs when the swap exposes more fusing. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + _VIEW_TARGET = exir_ops.edge.aten.view_copy.default + _PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default + _TARGETS = {_VIEW_TARGET, _PERMUTE_TARGET} + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + + for chain in self._collect_chains(graph_module): + updated_chain = chain + + while True: + updated_chain, fused = self._fuse_sequential_ops( + graph_module, updated_chain + ) + modified = modified or fused + + if len(updated_chain) > 2: + op1: Node = updated_chain[0] + op2: Node | None = updated_chain[1] + swapped_args = None + i = 2 + while swapped_args is None and op2 is not None: + swapped_args = self._maybe_swap_args(op1, op2) + if swapped_args is None: + op1 = op2 + op2 = updated_chain[i] if i < len(updated_chain) else None + i += 1 + + if swapped_args is not None: + input_node = op1.args[0] + assert isinstance(input_node, Node) + assert op2 is not None + op1_target = op1.target + op2_target = op2.target + self._set_node_op(op1, op2_target, input_node, swapped_args[0]) + self._set_node_op(op2, op1_target, op1, swapped_args[1]) + modified = True + else: + break + + else: + break + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + def _collect_chains(self, graph_module: GraphModule) -> list[list[Node]]: + """Returns a list of linear chains of view/permutes in the graph.""" + chains: list[list[Node]] = [] + + view_permute_nodes = [ + node for node in graph_module.graph.nodes if node.target in self._TARGETS + ] + + while view_permute_nodes: + node = view_permute_nodes.pop(0) + chain = [node] + current = node + + while len(current.users) == 1: + user = next(iter(current.users)) + if user.target not in self._TARGETS: + break + view_permute_nodes.remove(user) + chain.append(user) + current = user + + chains.append(chain) + + return chains + + def _fuse_sequential_ops( + self, graph_module: GraphModule, chain: list[Node] + ) -> tuple[list[Node], bool]: + """Loop over chain and fuse adjacent ops and remove no-op views/permutes + until no more fusions are possible. + + Returns the updated chain and whether any changes were made. + + """ + + updated_chain = list(chain) + any_changed = False + + while True: + changed = False + index = 0 + + while index < len(updated_chain): + node = updated_chain[index] + input_node = cast(Node, node.args[0]) + + if node.target == self._VIEW_TARGET and self._shapes_equal( + self._shape(input_node), self._shape(node) + ): + # Identity view + self._remove_node( + graph_module, updated_chain, index, replacement=input_node + ) + changed = True + any_changed = True + continue + + if node.target == self._PERMUTE_TARGET: + + dims = self._permute_dims(node) + + # Normalize dims + if any( + dim < 0 or dim >= len(self._shape(input_node)) for dim in dims + ): + dims = [ + _normalize_dim(dim, len(self._shape(input_node))) + for dim in dims + ] + self._set_node_op(node, self._PERMUTE_TARGET, input_node, dims) + changed = True + any_changed = True + + # Identity permute + if dims == list(range(len(dims))): + self._remove_node( + graph_module, updated_chain, index, replacement=input_node + ) + changed = True + any_changed = True + continue + + input_shape = self._shape(input_node) + + # Permute w/o data movement e.g. [1, 2] -> [2, 1] decomposes to view + # Dynamic views are not supported by TOSA, so only check for static shapes + if not any( + isinstance(dim, torch.SymInt) for dim in input_shape + ) and self._is_singleton_permutation(input_shape, dims): + self._set_node_op( + node, self._VIEW_TARGET, input_node, self._shape(node) + ) + changed = True + any_changed = True + continue + + if index + 1 < len(updated_chain): + next_node = updated_chain[index + 1] + if ( + node.target == self._VIEW_TARGET + and next_node.target == self._VIEW_TARGET + ): + # Fuse conscutive views + self._set_node_op( + node, self._VIEW_TARGET, input_node, self._shape(next_node) + ) + self._remove_node( + graph_module, updated_chain, index + 1, replacement=node + ) + changed = True + any_changed = True + continue + + if ( + node.target == self._PERMUTE_TARGET + and next_node.target == self._PERMUTE_TARGET + ): + # Fuse consecutive permutes + dims = self._permute_dims(node) + next_dims = self._permute_dims(next_node) + self._set_node_op( + node, + self._PERMUTE_TARGET, + input_node, + [dims[dim] for dim in next_dims], + ) + self._remove_node( + graph_module, updated_chain, index + 1, replacement=node + ) + changed = True + any_changed = True + continue + + index += 1 + + if not changed: + return updated_chain, any_changed + + def _maybe_swap_args( + self, op1: Node, op2: Node + ) -> tuple[Sequence[_Dim], Sequence[_Dim]] | None: + """Returns updated arguments for a valid op swap, or None if the ops + cannot be swapped. + """ + input_node = op1.args[0] + assert isinstance(input_node, Node) + input_val = input_node.meta["val"] + + if op1.target == self._PERMUTE_TARGET and op2.target == self._VIEW_TARGET: + return self._permute_view_swap(input_val, op1, op2) + + if op1.target == self._VIEW_TARGET and op2.target == self._PERMUTE_TARGET: + return self._view_permute_swap(op1, op2) + + return None + + def _permute_view_swap( + self, input_val: torch.Tensor, permute_node: Node, view_node: Node + ) -> tuple[list[_Dim], list[int]] | None: + """Return updated args for swapping permute(P).view(S) -> + view(S').permute(P') + + P describes each permuted output axis in terms of an input axis. Use + inverse(P) to recover where each original input axis landed in the + permuted tensor, then map that axis order through the view to construct + S' and the final restoring permutation P'. + + """ + x_shape = cast(list[_Dim], list(input_val.shape)) + permute_dims = _normalize_dims(self._permute_dims(permute_node), len(x_shape)) + + view_map = self._view_map(view_node) + if view_map is None: + return None + if not view_map.is_valid_map: + return None + + permuted_axis = self._inverse_permutation(permute_dims) + target_axis_order = view_map.map_permutation(permuted_axis) + if target_axis_order is None: + return None + + view_shape_before_permute = [ + view_map.target_shape[target_axis] for target_axis in target_axis_order + ] + + return ( + view_shape_before_permute, + self._inverse_permutation(target_axis_order), + ) + + def _view_permute_swap( + self, view_node: Node, permute_node: Node + ) -> tuple[list[int], list[_Dim]] | None: + """Return updated args for swapping view(S).permute(P) -> + permute(P').view(S') + + where S' is the shape of the original permute output, and P' is computed + using the view_map helper. + + """ + view_map = self._view_map(view_node) + if view_map is None: + return None + if not view_map.is_valid_map: + return None + + permute_dims = self._permute_dims(permute_node) + mapped_dims = view_map.map_permutation_inverse(permute_dims) + if mapped_dims is None: + return None + + return mapped_dims, self._shape(permute_node) + + @staticmethod + def _inverse_permutation(permutation: Sequence[int]) -> list[int]: + inverse = [0] * len(permutation) + for index, dim in enumerate(permutation): + inverse[dim] = index + return inverse + + @classmethod + def _is_singleton_permutation( + cls, shape: Sequence[_Dim], permutation: Sequence[int] + ) -> bool: + rank = len(shape) + normalized_perm = [_normalize_dim(dim, rank) for dim in permutation] + if not _is_permutation(normalized_perm, rank): + return False + + non_singleton_axes = [ + axis for axis, dim in enumerate(shape) if not _dim_equals(dim, 1) + ] + permuted_non_singleton_axes = [ + axis for axis in normalized_perm if not _dim_equals(shape[axis], 1) + ] + return permuted_non_singleton_axes == non_singleton_axes + + @staticmethod + def _view_map(view_node: Node) -> ViewMap | None: + try: + return ViewMap(view_node) + except AssertionError: + return None + + @staticmethod + def _shapes_equal(lhs: Sequence[_Dim], rhs: Sequence[_Dim]) -> bool: + return len(lhs) == len(rhs) and all( + _dim_equals(lhs_dim, rhs_dim) for lhs_dim, rhs_dim in zip(lhs, rhs) + ) + + def _remove_node( + self, + graph_module: GraphModule, + chain: list[Node], + index: int, + replacement: Node, + ) -> None: + node = chain[index] + assert node is not replacement + + node.replace_all_uses_with(replacement) + graph_module.graph.erase_node(node) + del chain[index] + + def _set_node_op( + self, + node: Node, + target: Target, + input_node: Node, + arg: Sequence[_Dim], + ) -> None: + node.target = target + node.args = (input_node, list(arg)) + self._refresh_meta(node) + + def _permute_dims(self, node: Node) -> list[int]: + assert node.target == self._PERMUTE_TARGET, "Expected permute node" + return list(cast(Sequence[int], node.args[1])) + + @classmethod + def _refresh_meta(cls, node: Node) -> None: + input_node = node.args[0] + assert isinstance(input_node, Node) + input_val = input_node.meta.get("val") + if input_val is None or node.target not in cls._TARGETS: + return + + # Compute new meta shapes to preserve SymInts. + if isinstance(input_val, torch.Tensor): + if node.target == cls._VIEW_TARGET: + node.meta["val"] = input_val.new_empty( + tuple(cast(Sequence[_Dim], node.args[1])) + ) + return + + if node.target == cls._PERMUTE_TARGET: + dims = _normalize_dims( + cast(Sequence[int], node.args[1]), len(input_val.shape) + ) + node.meta["val"] = input_val.new_empty( + tuple(input_val.shape[dim] for dim in dims) + ) + return + + node.meta["val"] = node.target(input_val, *node.args[1:]) # type: ignore[operator] + + @staticmethod + def _shape(node: Node) -> list[_Dim]: + return cast(list[_Dim], list(node.meta["val"].shape)) diff --git a/backends/arm/test/misc/test_high_rank_permute_view_invariants.py b/backends/arm/test/misc/test_high_rank_permute_view_invariants.py index 6004553141a..d3500f09068 100644 --- a/backends/arm/test/misc/test_high_rank_permute_view_invariants.py +++ b/backends/arm/test/misc/test_high_rank_permute_view_invariants.py @@ -159,7 +159,7 @@ def _build_high_rank_permute_cases() -> dict[str, TransposeInvariantCase]: 20260225 ) # nosec B311: deterministic RNG for test case generation start_shape = [1, 16, 16, 64] - expected_transpose_counts = [4, 3, 3, 3, 2, 3, 3, 3, 3, 2] + expected_transpose_counts = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] cases: dict[str, TransposeInvariantCase] = {} for idx in range(10): ops = _generate_chain(rng, start_shape, steps=8) diff --git a/backends/arm/test/misc/test_transpose_counts.py b/backends/arm/test/misc/test_transpose_counts.py index 8ce032058bf..fc05f5b2717 100644 --- a/backends/arm/test/misc/test_transpose_counts.py +++ b/backends/arm/test/misc/test_transpose_counts.py @@ -408,7 +408,7 @@ def forward(self, x: torch.Tensor): "maxpool2d_dilation": TransposeCountCase( MaxPool2dDilatedModule(), (torch.randn(1, 2, 8, 8),), - 4, + 2, ), "lstm": TransposeCountCase( LstmModule(), @@ -428,7 +428,7 @@ def forward(self, x: torch.Tensor): "multihead_attention_rank3": TransposeCountCase( MultiheadAttentionModule(), (torch.randn(2, 4, 8),), - 8, + 7, ), "cumsum_rank3_dim0": TransposeCountCase( CumsumModule(), @@ -444,7 +444,7 @@ def forward(self, x: torch.Tensor): Model1ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5 ), "model_2_conv_mha_linear_layernorm": TransposeCountCase( - Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 9 + Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 8 ), "model_3_lstm_linear": TransposeCountCase( Model3LstmLinear(), (torch.randn(2, 16, 8),), 2 @@ -513,7 +513,7 @@ def forward(self, x: torch.Tensor): "pixel_shuffle_channels_last": TransposeCountCase( PixelShuffleModule(), (torch.randn(1, 8, 2, 2).to(memory_format=torch.channels_last),), - 3, + 1, ), "grouped_conv_channels_last": TransposeCountCase( GroupedConvModule(), @@ -538,7 +538,7 @@ def forward(self, x: torch.Tensor): "maxpool2d_dilation_channels_last": TransposeCountCase( MaxPool2dDilatedModule(), (torch.randn(1, 2, 8, 8).to(memory_format=torch.channels_last),), - 6, + 4, ), "groupnorm_channels_last": TransposeCountCase( GroupNormModule(), diff --git a/backends/arm/test/ops/test_multihead_attention.py b/backends/arm/test/ops/test_multihead_attention.py index 50dcaae4635..206013a0498 100644 --- a/backends/arm/test/ops/test_multihead_attention.py +++ b/backends/arm/test/ops/test_multihead_attention.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -58,8 +58,6 @@ def test_multihead_attention_tosa_INT(test_data): (*test_data, *test_data, *test_data), [], [], - # TODO: Per-channel quantization is broken (MLETORCH-1144) - per_channel_quantization=False, ) pipeline.run() @@ -76,9 +74,6 @@ def test_multihead_attention_u55_INT(test_data: input_t1): (*test_data, *test_data, *test_data), [], [], - use_to_edge_transform_and_lower=True, - # TODO: Per-channel quantization is broken (MLETORCH-1144) - per_channel_quantization=False, ) pipeline.pop_stage("check_count.exir") pipeline.run() @@ -97,8 +92,6 @@ def test_multihead_attention_u85_INT(test_data: input_t1): [], [], use_to_edge_transform_and_lower=True, - # TODO: Per-channel quantization is broken (MLETORCH-1144) - per_channel_quantization=False, ) pipeline.run() diff --git a/backends/arm/test/passes/test_canonicalize_view_copy_permute_pass.py b/backends/arm/test/passes/test_canonicalize_view_copy_permute_pass.py new file mode 100644 index 00000000000..38a55c8ba10 --- /dev/null +++ b/backends/arm/test/passes/test_canonicalize_view_copy_permute_pass.py @@ -0,0 +1,573 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import cast + +import sympy # type: ignore[import-untyped] +import torch +from executorch.backends.arm._passes import CanonicalizeViewCopyPermutePass +from executorch.backends.test.graph_builder import GraphBuilder +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils import _pytree as pytree + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + assert isinstance(symint, torch.SymInt) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def _count_node( + graph_module: torch.fx.GraphModule, target: torch.fx.node.Target +) -> int: + return sum( + node.op == "call_function" and node.target == target + for node in graph_module.graph.nodes + ) + + +def _compute_nodes(graph_module: torch.fx.GraphModule) -> list[torch.fx.node.Target]: + return [ + node.target for node in graph_module.graph.nodes if node.op == "call_function" + ] + + +def _validate_numerics( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...], +) -> None: + original.eval() + modified.eval() + with torch.no_grad(): + original_output = original(*inputs) + modified_output = modified(*inputs) + + flat_original, _ = pytree.tree_flatten(original_output) + flat_modified, _ = pytree.tree_flatten(modified_output) + for original_tensor, modified_tensor in zip(flat_original, flat_modified): + torch.testing.assert_close(original_tensor, modified_tensor) + + +def test_canonicalize_direct_permute_chain() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 3, 4, 5) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 3, 1]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(p1, [0, 3, 1, 2]), + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 0 + ) + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_pixel_shuffle_view_permute_chain() -> None: + builder = GraphBuilder() + x_data = torch.randn(1, 2, 2, 8) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 3, 1, 2]), + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [1, 2, 2, 2, 2, 2]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [0, 1, 4, 2, 5, 3]), + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p2, [1, 2, 4, 4]), + ) + p3 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v2, [0, 2, 3, 1]), + ) + builder.output([p3]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 1 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 2 + assert _compute_nodes(result.graph_module) == [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.view_copy.default, + ] + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_direct_output_axis_permute() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 12, 5) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [2, 3, 4, 5]), + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [2, 0, 1, 3]), + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [4, 6, 5]), + ) + builder.output([v2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 1 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 1 + + compute_nodes = [ + node for node in result.graph_module.graph.nodes if node.op == "call_function" + ] + assert [node.target for node in compute_nodes] == [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + ] + assert compute_nodes[0].args[1] == [6, 4, 5] + assert compute_nodes[1].args[1] == [1, 0, 2] + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_reordered_view_with_grouped_permute() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 3, 4) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [3, 2, 4]), + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [2, 0, 1]), + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [4, 6]), + ) + builder.output([v2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 1 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 1 + + compute_nodes = [ + node for node in result.graph_module.graph.nodes if node.op == "call_function" + ] + assert [node.target for node in compute_nodes] == [ + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.view_copy.default, + ] + assert compute_nodes[0].args[1] == [2, 0, 1] + assert compute_nodes[1].args[1] == [4, 6] + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_moves_view_before_permute() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 3, 4) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [2, 0, 1]), + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [4, 6]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [1, 0]), + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 0 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 1 + + compute_nodes = [ + node for node in result.graph_module.graph.nodes if node.op == "call_function" + ] + assert [node.target for node in compute_nodes] == [ + exir_ops.edge.aten.view_copy.default, + ] + assert compute_nodes[0].args[1] == [6, 4] + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_moves_permute_before_view() -> None: + builder = GraphBuilder() + x_data = torch.randn(1, 2, 10, 10) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 2, 5, 2, 5, 2]), + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [3, 5, 0, 1, 2, 4]), + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [4, 2, 5, 5]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v2, [0, 2, 3, 1]), + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 1 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 2 + + compute_nodes = [ + node for node in result.graph_module.graph.nodes if node.op == "call_function" + ] + assert [node.target for node in compute_nodes] == [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.view_copy.default, + ] + assert compute_nodes[0].args[1] == [1, 2, 5, 2, 5, 2] + assert compute_nodes[1].args[1] == [3, 5, 0, 2, 4, 1] + assert compute_nodes[2].args[1] == [4, 5, 5, 2] + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_follows_interleaved_chain_users() -> None: + builder = GraphBuilder() + x_data = torch.randn(4, 2, 4) + y_data = torch.randn(2, 3) + x = builder.placeholder("x", x_data) + y = builder.placeholder("y", y_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [1, 0, 2]), + ) + unrelated = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(y, [3, 2]), + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [1, 2, 4, 4]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [0, 2, 1, 3]), + ) + builder.output([p2, unrelated]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 0 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 2 + + compute_nodes = [ + node for node in result.graph_module.graph.nodes if node.op == "call_function" + ] + assert compute_nodes[0].target == exir_ops.edge.aten.view_copy.default + assert compute_nodes[0].args[1] == [1, 4, 2, 4] + _validate_numerics(gm_before, result.graph_module, (x_data, y_data)) + + +def test_canonicalize_reordered_view_rejects_separating_permute() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 3, 4) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [3, 2, 4]), + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [0, 2, 1]), + ) + builder.output([p1]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert not result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 1 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 1 + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_symbolic_pixel_shuffle_view_permute_chain() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=2) + + with FakeTensorMode(shape_env=shape_env, allow_non_fake_inputs=True) as mode: + builder = GraphBuilder(fake_tensor_mode=mode) + x = builder.placeholder("x", torch.empty(size=(batch, 2, 2, 8))) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 3, 1, 2]), + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [batch, 2, 2, 2, 2, 2]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [0, 1, 4, 2, 5, 3]), + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p2, [batch, 2, 4, 4]), + ) + p3 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v2, [0, 2, 3, 1]), + ) + builder.output([p3]) + original = builder.get_graph_module() + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 1 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 2 + + compute_nodes = [ + node for node in result.graph_module.graph.nodes if node.op == "call_function" + ] + assert compute_nodes[0].args[1] == [batch, 2, 2, 2, 2, 2] + assert compute_nodes[1].args[1] == [0, 1, 4, 2, 5, 3] + assert compute_nodes[2].args[1] == [batch, 4, 4, 2] + + +def test_canonicalize_symbolic_singleton_permute_stays_permute() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=2) + + with FakeTensorMode(shape_env=shape_env, allow_non_fake_inputs=True) as mode: + builder = GraphBuilder(fake_tensor_mode=mode) + x = builder.placeholder("x", torch.empty(size=(batch, 1, 4))) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [1, 0, 2]), + ) + builder.output([p1]) + original = builder.get_graph_module() + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert not result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 1 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 0 + + +def test_canonicalize_view_permute_swap_uses_factor_order() -> None: + x = torch.empty((2, 3, 4, 5, 6, 7, 8, 9), device="meta") + graph = torch.fx.Graph() + input_node = graph.placeholder("x") + input_node.meta["val"] = x + view = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(input_node, [1, 2, 3, 4, 5, 6, 7, 8, 9]), + ) + view.meta["val"] = torch.empty((1, 2, 3, 4, 5, 6, 7, 8, 9), device="meta") + permute = graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(view, [0, 8, 7, 6, 5, 4, 3, 2, 1]), + ) + permute.meta["val"] = torch.empty((1, 9, 8, 7, 6, 5, 4, 3, 2), device="meta") + + assert CanonicalizeViewCopyPermutePass()._view_permute_swap(view, permute) == ( + [7, 6, 5, 4, 3, 2, 1, 0], + [1, 9, 8, 7, 6, 5, 4, 3, 2], + ) + + +def test_canonicalize_view_permute_swap_rejects_reordered_split_axis() -> None: + x = torch.empty((4,), device="meta") + graph = torch.fx.Graph() + input_node = graph.placeholder("x") + input_node.meta["val"] = x + view = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(input_node, [2, 2]), + ) + view.meta["val"] = torch.empty((2, 2), device="meta") + permute = graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(view, [1, 0]), + ) + permute.meta["val"] = torch.empty((2, 2), device="meta") + + assert CanonicalizeViewCopyPermutePass()._view_permute_swap(view, permute) is None + + +def test_canonicalize_does_not_cross_multi_user_chain_node() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 3, 4) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 1]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(p1, [0, 2, 1]), + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p1, [2, 12]), + ) + builder.output([p2, v1]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert not result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 2 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 1 + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_unsupported_start_view_does_not_block_suffix() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 3) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [3, 2]), + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(v1, [1, 0]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(p1, [1, 0]), + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 0 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 1 + _validate_numerics(gm_before, result.graph_module, (x_data,)) + + +def test_canonicalize_unsupported_end_view_does_not_block_prefix() -> None: + builder = GraphBuilder() + x_data = torch.randn(2, 3) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [1, 0]), + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(p1, [1, 0]), + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(p2, [3, 2]), + ) + builder.output([v1]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = CanonicalizeViewCopyPermutePass() + result = cast(PassResult, pass_instance.call(original)) + + assert result.modified + assert ( + _count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default) == 0 + ) + assert _count_node(result.graph_module, exir_ops.edge.aten.view_copy.default) == 1 + _validate_numerics(gm_before, result.graph_module, (x_data,)) diff --git a/backends/arm/test/passes/test_permute_view_swap.py b/backends/arm/test/passes/test_permute_view_swap.py new file mode 100644 index 00000000000..a140865d3d9 --- /dev/null +++ b/backends/arm/test/passes/test_permute_view_swap.py @@ -0,0 +1,230 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import permutations +from typing import cast, Sequence + +import pytest +import sympy # type: ignore[import-untyped] +import torch +from executorch.backends.arm._passes import CanonicalizeViewCopyPermutePass +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +_Dim = int | torch.SymInt + + +def _numel(shape: list[int]) -> int: + numel = 1 + for dim in shape: + numel *= dim + return numel + + +def _factorizations(numel: int, max_rank: int) -> list[list[int]]: + shapes: list[list[int]] = [] + + def recurse(remaining: int, rank: int, shape: list[int]) -> None: + if rank == 0: + if remaining == 1: + shapes.append(list(shape)) + return + + for dim in range(1, remaining + 1): + if remaining % dim == 0: + shape.append(dim) + recurse(remaining // dim, rank - 1, shape) + shape.pop() + + for rank in range(1, max_rank + 1): + recurse(numel, rank, []) + + return shapes + + +def _meta_tensor(shape: list[int]) -> torch.Tensor: + return torch.empty(tuple(shape), device="meta") + + +def _permute_view_swap( + x: torch.Tensor, permute_dims: list[int], output_shape: Sequence[_Dim] +) -> tuple[list[_Dim], list[int]] | None: + graph = torch.fx.Graph() + input_node = graph.placeholder("x") + input_node.meta["val"] = x + normalized_dims = [dim if dim >= 0 else dim + len(x.shape) for dim in permute_dims] + permute = graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(input_node, permute_dims), + ) + permute.meta["val"] = x.new_empty(tuple(x.shape[dim] for dim in normalized_dims)) + view = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(permute, output_shape), + ) + view.meta["val"] = x.new_empty(tuple(output_shape)) + return CanonicalizeViewCopyPermutePass()._permute_view_swap(x, permute, view) + + +def _assert_swap_matches_tensor_behavior( + input_shape: list[int], + permute_dims: list[int], + output_shape: list[int], + swapped_args: tuple[Sequence[_Dim], Sequence[int]], +) -> None: + view_shape, swapped_permute_dims = swapped_args + normalized_dims = [ + dim if dim >= 0 else dim + len(input_shape) for dim in permute_dims + ] + + data = torch.arange(_numel(input_shape)).reshape(input_shape) + original = data.permute(normalized_dims).contiguous().reshape(output_shape) + swapped = ( + data.reshape(cast(list[int], list(view_shape))) + .permute(list(swapped_permute_dims)) + .contiguous() + ) + + assert list(swapped.shape) == output_shape + torch.testing.assert_close(swapped, original) + + +@pytest.mark.parametrize( + "input_shape, permute_dims, output_shape, expected", + [ + ([2, 3, 4], [2, 0, 1], [4, 6], ([6, 4], [1, 0])), + ([2, 3, 4], [1, 2, 0], [12, 2], ([2, 12], [1, 0])), + ([2, 3, 4], [0, 2, 1], [2, 2, 2, 3], ([2, 3, 2, 2], [0, 2, 3, 1])), + ([2, 3, 4], [-1, 0, 1], [4, 6], ([6, 4], [1, 0])), + ([2, 3, 4], [2, 0, 1], [1, 4, 6], ([1, 6, 4], [0, 2, 1])), + ([2, 3, 4], [2, 0, 1], [4, 1, 6], ([1, 6, 4], [2, 0, 1])), + ([2, 3, 4], [2, 0, 1], [4, 6, 1], ([6, 4, 1], [1, 0, 2])), + ( + [2, 3, 4], + [2, 0, 1], + [2, 2, 2, 3], + ([2, 3, 2, 2], [2, 3, 0, 1]), + ), + ([2, 3, 4], [1, 2, 0], [3, 2, 2, 2], ([2, 3, 2, 2], [1, 2, 3, 0])), + ([2, 3, 4, 5], [0, 2, 3, 1], [2, 20, 3], ([2, 3, 20], [0, 2, 1])), + ([2, 3, 4, 5], [2, 3, 0, 1], [20, 2, 3], ([2, 3, 20], [2, 0, 1])), + ([2, 3, 4, 5], [2, 3, 0, 1], [4, 5, 6], ([6, 4, 5], [1, 2, 0])), + ([2, 3, 4, 5], [3, 0, 1, 2], [5, 24], ([24, 5], [1, 0])), + ([1, 2, 3, 4], [2, 3, 0, 1], [12, 2], ([2, 12], [1, 0])), + ([2, 1, 3, 4], [1, 3, 0, 2], [4, 6], ([6, 4], [1, 0])), + ([2, 3, 1, 4], [2, 0, 3, 1], [2, 2, 2, 3], ([2, 3, 2, 2], [0, 2, 3, 1])), + ([2, 2, 3], [2, 0, 1], [3, 2, 2], ([2, 2, 3], [2, 0, 1])), + ([2, 2, 3], [1, 2, 0], [2, 3, 2], ([2, 2, 3], [1, 2, 0])), + ([2, 2, 2, 3], [3, 0, 1, 2], [3, 8], ([8, 3], [1, 0])), + ([2, 2, 2, 3], [3, 0, 1, 2], [3, 2, 2, 2], ([2, 2, 2, 3], [3, 0, 1, 2])), + ], +) +def test_permute_view_swap_expected_rewrites( + input_shape: list[int], + permute_dims: list[int], + output_shape: list[int], + expected: tuple[list[int], list[int]], +) -> None: + swapped_args = _permute_view_swap( + _meta_tensor(input_shape), permute_dims, output_shape + ) + + assert swapped_args == expected + _assert_swap_matches_tensor_behavior( + input_shape, permute_dims, output_shape, swapped_args + ) + + +@pytest.mark.parametrize( + "input_shape, permute_dims, output_shape", + [ + ([2, 3, 4], [1, 0, 2], [3, 8]), + ([2, 3, 4], [1, 0, 2], [6, 4]), + ([2, 3, 4], [2, 1, 0], [4, 6]), + ([2, 3, 4], [0, 2, 1], [8, 3]), + ([2, 3, 4, 5], [3, 0, 1, 2], [10, 12]), + ([2, 3, 4, 5], [1, 3, 0, 2], [15, 8]), + ([2, 3, 4], [2, 0, 1], [5, 5]), + ], +) +def test_permute_view_swap_rejects_unsupported_rewrites( + input_shape: list[int], permute_dims: list[int], output_shape: list[int] +) -> None: + assert ( + _permute_view_swap(_meta_tensor(input_shape), permute_dims, output_shape) + is None + ) + + +def test_permute_view_swap_generated_cases_are_semantically_valid() -> None: + input_shapes = [ + [2, 3], + [2, 4], + [2, 3, 4], + [2, 2, 3], + [1, 2, 3], + [2, 1, 3], + [2, 3, 1], + [2, 2, 2, 3], + ] + + total_cases = 0 + accepted_cases = 0 + rejected_cases = 0 + + for input_shape in input_shapes: + output_shapes = _factorizations(_numel(input_shape), max_rank=4) + for permute_dims in permutations(range(len(input_shape))): + for output_shape in output_shapes: + total_cases += 1 + swapped_args = _permute_view_swap( + _meta_tensor(input_shape), list(permute_dims), output_shape + ) + if swapped_args is None: + rejected_cases += 1 + continue + + accepted_cases += 1 + _assert_swap_matches_tensor_behavior( + input_shape, list(permute_dims), output_shape, swapped_args + ) + + assert total_cases > 1000 + assert accepted_cases > 200 + assert rejected_cases > 200 + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + assert isinstance(symint, torch.SymInt) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def test_permute_view_swap_preserves_symbolic_dimensions() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=2) + + with FakeTensorMode(shape_env=shape_env, allow_non_fake_inputs=True): + x = torch.empty((batch, 2, 3, 4), device="cpu") + assert _permute_view_swap(x, [2, 3, 0, 1], [12, batch, 2]) == ( + [batch, 2, 12], + [2, 0, 1], + ) + assert _permute_view_swap(x, [0, 2, 3, 1], [batch, 12, 2]) == ( + [batch, 2, 12], + [0, 2, 1], + ) + + x = torch.empty((batch, 2, 10, 10), device="cpu") + assert _permute_view_swap(x, [0, 1, 2, 3], [batch, 2, 5, 2, 5, 2]) == ( + [batch, 2, 5, 2, 5, 2], + [0, 1, 2, 3, 4, 5], + ) diff --git a/backends/arm/test/passes/test_view_permute_swap.py b/backends/arm/test/passes/test_view_permute_swap.py new file mode 100644 index 00000000000..3bfb1459ecf --- /dev/null +++ b/backends/arm/test/passes/test_view_permute_swap.py @@ -0,0 +1,227 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import permutations +from typing import cast, Sequence + +import pytest +import sympy # type: ignore[import-untyped] +import torch +from executorch.backends.arm._passes import CanonicalizeViewCopyPermutePass +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +_Dim = int | torch.SymInt + + +def _numel(shape: list[int]) -> int: + numel = 1 + for dim in shape: + numel *= dim + return numel + + +def _factorizations(numel: int, max_rank: int) -> list[list[int]]: + shapes: list[list[int]] = [] + + def recurse(remaining: int, rank: int, shape: list[int]) -> None: + if rank == 0: + if remaining == 1: + shapes.append(list(shape)) + return + + for dim in range(1, remaining + 1): + if remaining % dim == 0: + shape.append(dim) + recurse(remaining // dim, rank - 1, shape) + shape.pop() + + for rank in range(1, max_rank + 1): + recurse(numel, rank, []) + + return shapes + + +def _meta_tensor(shape: list[int]) -> torch.Tensor: + return torch.empty(tuple(shape), device="meta") + + +def _view_permute_swap( + x: torch.Tensor, view_shape: Sequence[_Dim], permute_dims: list[int] +) -> tuple[list[int], list[_Dim]] | None: + graph = torch.fx.Graph() + input_node = graph.placeholder("x") + input_node.meta["val"] = x + view = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(input_node, view_shape), + ) + view.meta["val"] = x.new_empty(tuple(view_shape)) + normalized_dims = [ + dim if dim >= 0 else dim + len(view_shape) for dim in permute_dims + ] + permute = graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(view, permute_dims), + ) + permute.meta["val"] = x.new_empty(tuple(view_shape[dim] for dim in normalized_dims)) + return CanonicalizeViewCopyPermutePass()._view_permute_swap(view, permute) + + +def _assert_swap_matches_tensor_behavior( + input_shape: list[int], + view_shape: list[int], + permute_dims: list[int], + swapped_args: tuple[Sequence[int], Sequence[_Dim]], +) -> None: + swapped_permute_dims, output_shape = swapped_args + normalized_dims = [ + dim if dim >= 0 else dim + len(view_shape) for dim in permute_dims + ] + + data = torch.arange(_numel(input_shape)).reshape(input_shape) + original = data.reshape(view_shape).permute(normalized_dims).contiguous() + swapped = ( + data.permute(list(swapped_permute_dims)) + .contiguous() + .reshape(cast(list[int], list(output_shape))) + ) + + assert list(swapped.shape) == list(original.shape) + torch.testing.assert_close(swapped, original) + + +@pytest.mark.parametrize( + "input_shape, view_shape, permute_dims, expected", + [ + ([2, 3, 4], [2, 3, 2, 2], [0, 2, 3, 1], ([0, 2, 1], [2, 2, 2, 3])), + ([2, 3, 4], [2, 3, 2, 2], [2, 3, 0, 1], ([2, 0, 1], [2, 2, 2, 3])), + ([2, 3, 4], [2, 12], [1, 0], ([1, 2, 0], [12, 2])), + ([2, 3, 4], [6, 4], [1, 0], ([2, 0, 1], [4, 6])), + ([2, 3, 4], [6, 4], [-1, 0], ([2, 0, 1], [4, 6])), + ([2, 3, 4], [1, 6, 4], [0, 2, 1], ([2, 0, 1], [1, 4, 6])), + ([2, 3, 4], [1, 6, 4], [2, 0, 1], ([2, 0, 1], [4, 1, 6])), + ([2, 3, 4], [1, 6, 4], [2, 1, 0], ([2, 0, 1], [4, 6, 1])), + ([2, 3, 4], [2, 3, 2, 2], [1, 2, 3, 0], ([1, 2, 0], [3, 2, 2, 2])), + ([2, 3, 4, 5], [2, 3, 20], [0, 2, 1], ([0, 2, 3, 1], [2, 20, 3])), + ([2, 3, 4, 5], [2, 3, 20], [2, 0, 1], ([2, 3, 0, 1], [20, 2, 3])), + ([2, 3, 4, 5], [6, 4, 5], [1, 2, 0], ([2, 3, 0, 1], [4, 5, 6])), + ([2, 3, 4, 5], [24, 5], [1, 0], ([3, 0, 1, 2], [5, 24])), + ([1, 2, 3, 4], [2, 12], [1, 0], ([0, 2, 3, 1], [12, 2])), + ([2, 1, 3, 4], [6, 4], [1, 0], ([1, 3, 0, 2], [4, 6])), + ([2, 3, 1, 4], [2, 3, 2, 2], [0, 2, 3, 1], ([0, 2, 3, 1], [2, 2, 2, 3])), + ([2, 2, 3], [2, 2, 3], [2, 0, 1], ([2, 0, 1], [3, 2, 2])), + ([2, 2, 3], [2, 2, 3], [1, 2, 0], ([1, 2, 0], [2, 3, 2])), + ([2, 2, 2, 3], [8, 3], [1, 0], ([3, 0, 1, 2], [3, 8])), + ([2, 2, 2, 3], [2, 2, 2, 3], [3, 0, 1, 2], ([3, 0, 1, 2], [3, 2, 2, 2])), + ( + [2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, 8, 7, 6, 5, 4, 3, 2, 1], + ([7, 6, 5, 4, 3, 2, 1, 0], [1, 9, 8, 7, 6, 5, 4, 3, 2]), + ), + ], +) +def test_view_permute_swap_expected_rewrites( + input_shape: list[int], + view_shape: list[int], + permute_dims: list[int], + expected: tuple[list[int], list[int]], +) -> None: + swapped_args = _view_permute_swap( + _meta_tensor(input_shape), view_shape, permute_dims + ) + + assert swapped_args == expected + _assert_swap_matches_tensor_behavior( + input_shape, view_shape, permute_dims, swapped_args + ) + + +@pytest.mark.parametrize( + "input_shape, view_shape, permute_dims", + [ + ([2, 3, 4], [2, 4, 3], [0, 2, 1]), + ([4], [2, 2], [1, 0]), + ], +) +def test_view_permute_swap_rejects_unsupported_rewrites( + input_shape: list[int], view_shape: list[int], permute_dims: list[int] +) -> None: + assert ( + _view_permute_swap(_meta_tensor(input_shape), view_shape, permute_dims) is None + ) + + +def test_view_permute_swap_generated_cases_are_semantically_valid() -> None: + input_shapes = [ + [2, 3], + [2, 4], + [2, 3, 4], + [2, 2, 3], + [1, 2, 3], + [2, 1, 3], + [2, 3, 1], + [2, 2, 2, 3], + ] + + total_cases = 0 + accepted_cases = 0 + rejected_cases = 0 + + for input_shape in input_shapes: + view_shapes = _factorizations(_numel(input_shape), max_rank=4) + for view_shape in view_shapes: + for permute_dims in permutations(range(len(view_shape))): + total_cases += 1 + swapped_args = _view_permute_swap( + _meta_tensor(input_shape), view_shape, list(permute_dims) + ) + if swapped_args is None: + rejected_cases += 1 + continue + + accepted_cases += 1 + _assert_swap_matches_tensor_behavior( + input_shape, view_shape, list(permute_dims), swapped_args + ) + + assert total_cases > 1000 + assert accepted_cases > 200 + assert rejected_cases > 200 + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + assert isinstance(symint, torch.SymInt) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def test_view_permute_swap_preserves_symbolic_dimensions() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=2) + + with FakeTensorMode(shape_env=shape_env, allow_non_fake_inputs=True): + x = torch.empty((batch, 2, 3, 4), device="cpu") + assert _view_permute_swap(x, [batch, 2, 12], [0, 2, 1]) == ( + [0, 2, 3, 1], + [batch, 12, 2], + ) + assert _view_permute_swap(x, [batch, 2, 3, 4], [2, 3, 0, 1]) == ( + [2, 3, 0, 1], + [3, 4, batch, 2], + ) + + x = torch.empty((batch, 2, 10, 10), device="cpu") + assert _view_permute_swap(x, [batch, 2, 5, 2, 5, 2], [0, 1, 2, 3, 4, 5]) == ( + [0, 1, 2, 3], + [batch, 2, 5, 2, 5, 2], + )