From a5d0f985161a0b41663049daebb00539d266befb Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Thu, 11 Jun 2026 11:41:26 +0100 Subject: [PATCH 1/2] Arm backend: Integrate exir to Tosa pass Adds pass to arm pass manager to integrate exir_to_dialect_pass. Also adds TOSA dialect activation node visitors. Signed-off-by: Saoirse Stewart --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + .../aten_to_tosa_activation_functions.py | 130 ++++++++++++++++++ backends/arm/_passes/exir_to_tosa_pass.py | 41 ++++++ backends/arm/operators/__init__.py | 8 +- .../{op_clamp.py => op_tosa_clamp.py} | 56 ++------ .../operators/{op_erf.py => op_tosa_erf.py} | 2 +- .../{op_sigmoid.py => op_tosa_sigmoid.py} | 2 +- .../operators/{op_tanh.py => op_tosa_tanh.py} | 2 +- backends/arm/tosa/dialect/ops/activation.py | 6 +- 10 files changed, 193 insertions(+), 57 deletions(-) create mode 100644 backends/arm/_passes/aten_to_tosa_activation_functions.py create mode 100644 backends/arm/_passes/exir_to_tosa_pass.py rename backends/arm/operators/{op_clamp.py => op_tosa_clamp.py} (54%) rename backends/arm/operators/{op_erf.py => op_tosa_erf.py} (96%) rename backends/arm/operators/{op_sigmoid.py => op_tosa_sigmoid.py} (96%) rename backends/arm/operators/{op_tanh.py => op_tosa_tanh.py} (96%) diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index ea4d49a79bb..e265147e125 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -104,6 +104,7 @@ from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa from .deduplicate_get_attr_pass import DeduplicateGetAttrPass # noqa from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa +from .exir_to_tosa_pass import ExirToTosaPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, QuantizeClampArgumentsPass, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 700b58f6c85..6d5cf0d3847 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -104,6 +104,7 @@ DecorateFp32toInt32CastingPass, DeduplicateGetAttrPass, EnsureUniqueOutputNodesPass, + ExirToTosaPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, FuseConsecutiveConcatShapesPass, @@ -622,6 +623,7 @@ def _tosa_pipeline( DecomposePermuteForU55Pass(), RewriteSlicePass(), InsertConstShapesPass(), + ExirToTosaPass(exported_program), ] ) diff --git a/backends/arm/_passes/aten_to_tosa_activation_functions.py b/backends/arm/_passes/aten_to_tosa_activation_functions.py new file mode 100644 index 00000000000..9b92b31e630 --- /dev/null +++ b/backends/arm/_passes/aten_to_tosa_activation_functions.py @@ -0,0 +1,130 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.specification import get_context_spec +from executorch.backends.transforms.aten_to_dialect_pass import ( + AtenToDialectPass, + DialectNodeSpec, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import Node + + +# Each rewrite returns the TOSA dialect node spec for one supported ATen +# activation op, preserving args unless TOSA requires normalized attributes. +def rewrite_erf(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec: + return DialectNodeSpec( + exir_ops.backend.tosa.ERF.default, + node.args, + dict(node.kwargs), + ) + + +def rewrite_sigmoid(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec: + return DialectNodeSpec( + exir_ops.backend.tosa.SIGMOID.default, + node.args, + dict(node.kwargs), + ) + + +def rewrite_tanh(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec: + return DialectNodeSpec( + exir_ops.backend.tosa.TANH.default, + node.args, + dict(node.kwargs), + ) + + +def _extract_dtype(node: Node) -> torch.dtype | None: + value = node.meta.get("val") + if isinstance(value, tuple): + value = value[0] + if isinstance(value, list): + if not value: + return None + value = value[0] + return getattr(value, "dtype", None) + + +def _dtype_bounds(dtype: torch.dtype) -> tuple[int | float, int | float]: + if dtype.is_floating_point: + fp_info = torch.finfo(dtype) + return fp_info.min, fp_info.max + + int_info = torch.iinfo(dtype) + return int_info.min, int_info.max + + +def _is_tosa_clamp_dtype_supported(dtype: torch.dtype) -> bool: + tosa_spec = get_context_spec() + + if dtype == torch.int8: + return tosa_spec.support_integer() + + if dtype == torch.int16: + return tosa_spec.support_integer() and tosa_spec.support_extension("int16") + + if dtype in (torch.float16, torch.float32): + return tosa_spec.support_float() + + if dtype == torch.bfloat16: + return tosa_spec.support_float() and tosa_spec.support_extension("bf16") + + return False + + +def _normalize_clamp_bound( + bound, + *, + dtype: torch.dtype, + default: int | float, +) -> int | float | None: + if bound is None: + return default + if isinstance(bound, bool): + return None + if dtype.is_floating_point: + if isinstance(bound, (int, float)): + return float(bound) + return None + if isinstance(bound, int): + return bound + return None + + +def _get_min_max_arguments( + node: Node, dtype: torch.dtype +) -> tuple[int | float, int | float] | None: + dtype_min, dtype_max = _dtype_bounds(dtype) + min_val = _normalize_clamp_bound( + node.args[1] if len(node.args) > 1 else node.kwargs.get("min"), + dtype=dtype, + default=dtype_min, + ) + max_val = _normalize_clamp_bound( + node.args[2] if len(node.args) > 2 else node.kwargs.get("max"), + dtype=dtype, + default=dtype_max, + ) + if min_val is None or max_val is None: + return None + return min_val, max_val + + +def rewrite_clamp(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | None: + dtype = _extract_dtype(node) + if dtype is None or not _is_tosa_clamp_dtype_supported(dtype): + return None + + min_max_args = _get_min_max_arguments(node, dtype) + if min_max_args is None: + return None + + return DialectNodeSpec( + exir_ops.backend.tosa.CLAMP.default, + (node.args[0], *min_max_args), + ) diff --git a/backends/arm/_passes/exir_to_tosa_pass.py b/backends/arm/_passes/exir_to_tosa_pass.py new file mode 100644 index 00000000000..b77171b9eaf --- /dev/null +++ b/backends/arm/_passes/exir_to_tosa_pass.py @@ -0,0 +1,41 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +from executorch.backends.arm._passes.aten_to_tosa_activation_functions import ( + rewrite_clamp, + rewrite_erf, + rewrite_sigmoid, + rewrite_tanh, +) +from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass +from executorch.exir.dialects._ops import ops as exir_ops + + +class ExirToTosaPass(AtenToDialectPass): + """Rewrite simple EXIR ops to equivalent backend TOSA dialect ops. + + Rewrite functions are grouped by op category and registered with the shared + ATen-to-dialect pass infrastructure. + + """ + + +_ACTIVATION_FUNCTION_REWRITES = { + exir_ops.edge.aten.clamp.default: rewrite_clamp, + exir_ops.edge.aten.erf.default: rewrite_erf, + exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid, + exir_ops.edge.aten.tanh.default: rewrite_tanh, +} + +_DIRECT_REWRITE_CATEGORIES = { + "activation_functions": _ACTIVATION_FUNCTION_REWRITES, +} + +# Register each category's ATen targets with the function that builds the +# corresponding TOSA dialect node spec. +for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values(): + for _edge_target, _rewrite_fn in _rewrite_category.items(): + ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 9436bfe2ab3..d2c2846b68c 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -19,11 +19,9 @@ op_bitwise_not, op_cat, op_ceil, - op_clamp, op_cond_if, op_cos, op_eq, - op_erf, op_exp, op_floor, op_ge, @@ -40,18 +38,18 @@ op_repeat, op_rshift_tensor, op_rsqrt, - op_sigmoid, op_sin, op_sub, op_sum, - op_tanh, op_to_dim_order_copy, op_tosa_avg_pool2d, op_tosa_avg_pool2d_adaptive, + op_tosa_clamp, op_tosa_conv2d, op_tosa_conv3d, op_tosa_custom, op_tosa_depthwise_conv2d, + op_tosa_erf, op_tosa_gather, op_tosa_identity, op_tosa_matmul, @@ -62,8 +60,10 @@ op_tosa_resize, op_tosa_scatter, op_tosa_shapes, + op_tosa_sigmoid, op_tosa_slice, op_tosa_table, + op_tosa_tanh, op_tosa_transpose_conv2d, op_view, op_where, diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_tosa_clamp.py similarity index 54% rename from backends/arm/operators/op_clamp.py rename to backends/arm/operators/op_tosa_clamp.py index 5792e6647cf..9ca53c5b15a 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_tosa_clamp.py @@ -3,8 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree - -from typing import Any, List, Tuple +from typing import Any, cast, List import torch import tosa_serializer as ts @@ -18,49 +17,19 @@ validate_same_dtype, validate_valid_dtype, ) - from executorch.backends.arm.tosa.mapping import TosaArg from torch.fx import Node @register_node_visitor class ClampVisitor(NodeVisitor): - target = "aten.clamp.default" - - def __init__(self, *args): - super().__init__(*args) - - def _get_min_max_arguments( - self, node: Node, dtype: torch.dtype - ) -> Tuple[int | float, int | float]: - def cast_type(value: Any) -> int | float: - if isinstance(value, int): - return value - else: - # Attempt to cast to float - return float(value) - - if dtype.is_floating_point: - dtype_min = torch.finfo(dtype).min - dtype_max = torch.finfo(dtype).max - else: - dtype_min = torch.iinfo(dtype).min - dtype_max = torch.iinfo(dtype).max + target = "tosa.CLAMP.default" - min_arg = dtype_min - max_arg = dtype_max - - if node.args[1] is not None: - min_arg = cast_type(node.args[1]) - - if len(node.args) > 2: - if node.args[2] is not None: - max_arg = cast_type(node.args[2]) - - return min_arg, max_arg - - def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes: - return torch.full((1,), value, dtype=dtype).view(torch.uint8).numpy().tolist() + def _to_bytes(self, value: int | float, dtype: torch.dtype) -> List[int]: + return cast( + List[int], + torch.full((1,), value, dtype=dtype).view(torch.uint8).numpy().tolist(), + ) def define_node( self, @@ -71,12 +40,7 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [2, 3]) validate_same_dtype(self.target, [inputs[0], output], ts) - supported_dtypes = [ - ts.DType.INT8, - ts.DType.FP16, - ts.DType.BF16, - ts.DType.FP32, - ] + supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.BF16, ts.DType.FP32] if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( @@ -87,8 +51,8 @@ def define_node( ) node_input_dtype = node.meta["val"].dtype - # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments - min_val, max_val = self._get_min_max_arguments(node, node_input_dtype) + min_val = cast(int | float, node.args[1]) + max_val = cast(int | float, node.args[2]) attr = ts.TosaSerializerAttribute() attr.ClampAttribute( diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_tosa_erf.py similarity index 96% rename from backends/arm/operators/op_erf.py rename to backends/arm/operators/op_tosa_erf.py index 8ccb54d1178..030d81d6c92 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_tosa_erf.py @@ -17,7 +17,7 @@ @register_node_visitor class ErfVisitor(SimpleNodeVisitor): - target = "aten.erf.default" + target = "tosa.ERF.default" tosa_specs = FP_SPECS @classmethod diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_tosa_sigmoid.py similarity index 96% rename from backends/arm/operators/op_sigmoid.py rename to backends/arm/operators/op_tosa_sigmoid.py index a59210276c5..211bbad49ad 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_tosa_sigmoid.py @@ -17,7 +17,7 @@ @register_node_visitor class SigmoidVisitor(SimpleNodeVisitor): - target = "aten.sigmoid.default" + target = "tosa.SIGMOID.default" tosa_specs = FP_SPECS @classmethod diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tosa_tanh.py similarity index 96% rename from backends/arm/operators/op_tanh.py rename to backends/arm/operators/op_tosa_tanh.py index 35a79bf75a4..80711b7ab62 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tosa_tanh.py @@ -17,7 +17,7 @@ @register_node_visitor class TanhVisitor(SimpleNodeVisitor): - target = "aten.tanh.default" + target = "tosa.TANH.default" tosa_specs = FP_SPECS @classmethod diff --git a/backends/arm/tosa/dialect/ops/activation.py b/backends/arm/tosa/dialect/ops/activation.py index 333ab0e52d4..3c3fbffe176 100644 --- a/backends/arm/tosa/dialect/ops/activation.py +++ b/backends/arm/tosa/dialect/ops/activation.py @@ -36,10 +36,8 @@ def _validate_clamp_dtype(dtype: torch.dtype, op: str) -> None: ) return - _validate_float_dtype(dtype, op) - return - - raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op) + _validate_float_dtype(dtype, op) + return def _validate_float_dtype(dtype: torch.dtype, op: str) -> None: From c225a252ecd585bee66541022c85bd243465fb13 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Mon, 15 Jun 2026 14:11:37 +0100 Subject: [PATCH 2/2] Arm backends: Add aten_to_dialect_pass to TARGETS --- backends/arm/_passes/TARGETS | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index f029c6b79ab..8b7340fc56b 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -51,6 +51,7 @@ runtime.python_library( deps = [ ":core", ":arm_pass_manager_base" if runtime.is_oss else ":arm_pass_manager_fb", + "//executorch/backends/transforms:aten_to_dialect_pass", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/tosa/dialect:lib", "//executorch/backends/transforms:fuse_view_copy",