Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ runtime.python_library(
deps = [
":core",
":arm_pass_manager_base" if runtime.is_oss else ":arm_pass_manager_fb",
"//executorch/backends/transforms:aten_to_dialect_pass",
"//executorch/backends/arm/tosa:utils",
"//executorch/backends/arm/tosa/dialect:lib",
"//executorch/backends/transforms:fuse_view_copy",
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
from .deduplicate_get_attr_pass import DeduplicateGetAttrPass # noqa
from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa
from .exir_to_tosa_pass import ExirToTosaPass # noqa
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
QuantizeClampArgumentsPass,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
DecorateFp32toInt32CastingPass,
DeduplicateGetAttrPass,
EnsureUniqueOutputNodesPass,
ExirToTosaPass,
FoldAndAnnotateQParamsPass,
FuseBatchNorm2dPass,
FuseConsecutiveConcatShapesPass,
Expand Down Expand Up @@ -622,6 +623,7 @@ def _tosa_pipeline(
DecomposePermuteForU55Pass(),
RewriteSlicePass(),
InsertConstShapesPass(),
ExirToTosaPass(exported_program),
]
)

Expand Down
130 changes: 130 additions & 0 deletions backends/arm/_passes/aten_to_tosa_activation_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.backends.transforms.aten_to_dialect_pass import (
AtenToDialectPass,
DialectNodeSpec,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node


# Each rewrite returns the TOSA dialect node spec for one supported ATen
# activation op, preserving args unless TOSA requires normalized attributes.
def rewrite_erf(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
return DialectNodeSpec(
exir_ops.backend.tosa.ERF.default,
node.args,
dict(node.kwargs),
)


def rewrite_sigmoid(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
return DialectNodeSpec(
exir_ops.backend.tosa.SIGMOID.default,
node.args,
dict(node.kwargs),
)


def rewrite_tanh(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
return DialectNodeSpec(
exir_ops.backend.tosa.TANH.default,
node.args,
dict(node.kwargs),
)


def _extract_dtype(node: Node) -> torch.dtype | None:
value = node.meta.get("val")
if isinstance(value, tuple):
value = value[0]
if isinstance(value, list):
if not value:
return None
value = value[0]
return getattr(value, "dtype", None)


def _dtype_bounds(dtype: torch.dtype) -> tuple[int | float, int | float]:
if dtype.is_floating_point:
fp_info = torch.finfo(dtype)
return fp_info.min, fp_info.max

int_info = torch.iinfo(dtype)
return int_info.min, int_info.max


def _is_tosa_clamp_dtype_supported(dtype: torch.dtype) -> bool:
tosa_spec = get_context_spec()

if dtype == torch.int8:
return tosa_spec.support_integer()

if dtype == torch.int16:
return tosa_spec.support_integer() and tosa_spec.support_extension("int16")

if dtype in (torch.float16, torch.float32):
return tosa_spec.support_float()

if dtype == torch.bfloat16:
return tosa_spec.support_float() and tosa_spec.support_extension("bf16")

return False


def _normalize_clamp_bound(
bound,
*,
dtype: torch.dtype,
default: int | float,
) -> int | float | None:
if bound is None:
return default
if isinstance(bound, bool):
return None
if dtype.is_floating_point:
if isinstance(bound, (int, float)):
return float(bound)
return None
if isinstance(bound, int):
return bound
return None


def _get_min_max_arguments(
node: Node, dtype: torch.dtype
) -> tuple[int | float, int | float] | None:
dtype_min, dtype_max = _dtype_bounds(dtype)
min_val = _normalize_clamp_bound(
node.args[1] if len(node.args) > 1 else node.kwargs.get("min"),
dtype=dtype,
default=dtype_min,
)
max_val = _normalize_clamp_bound(
node.args[2] if len(node.args) > 2 else node.kwargs.get("max"),
dtype=dtype,
default=dtype_max,
)
if min_val is None or max_val is None:
return None
return min_val, max_val


def rewrite_clamp(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | None:
dtype = _extract_dtype(node)
if dtype is None or not _is_tosa_clamp_dtype_supported(dtype):
return None

min_max_args = _get_min_max_arguments(node, dtype)
if min_max_args is None:
return None

return DialectNodeSpec(
exir_ops.backend.tosa.CLAMP.default,
(node.args[0], *min_max_args),
)
41 changes: 41 additions & 0 deletions backends/arm/_passes/exir_to_tosa_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
rewrite_clamp,
rewrite_erf,
rewrite_sigmoid,
rewrite_tanh,
)
from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass
from executorch.exir.dialects._ops import ops as exir_ops


class ExirToTosaPass(AtenToDialectPass):
"""Rewrite simple EXIR ops to equivalent backend TOSA dialect ops.

Rewrite functions are grouped by op category and registered with the shared
ATen-to-dialect pass infrastructure.

"""


_ACTIVATION_FUNCTION_REWRITES = {
exir_ops.edge.aten.clamp.default: rewrite_clamp,
exir_ops.edge.aten.erf.default: rewrite_erf,
exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid,
exir_ops.edge.aten.tanh.default: rewrite_tanh,
}

_DIRECT_REWRITE_CATEGORIES = {
"activation_functions": _ACTIVATION_FUNCTION_REWRITES,
}

# Register each category's ATen targets with the function that builds the
# corresponding TOSA dialect node spec.
for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values():
for _edge_target, _rewrite_fn in _rewrite_category.items():
ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn)
8 changes: 4 additions & 4 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
op_bitwise_not,
op_cat,
op_ceil,
op_clamp,
op_cond_if,
op_cos,
op_eq,
op_erf,
op_exp,
op_floor,
op_ge,
Expand All @@ -40,18 +38,18 @@
op_repeat,
op_rshift_tensor,
op_rsqrt,
op_sigmoid,
op_sin,
op_sub,
op_sum,
op_tanh,
op_to_dim_order_copy,
op_tosa_avg_pool2d,
op_tosa_avg_pool2d_adaptive,
op_tosa_clamp,
op_tosa_conv2d,
op_tosa_conv3d,
op_tosa_custom,
op_tosa_depthwise_conv2d,
op_tosa_erf,
op_tosa_gather,
op_tosa_identity,
op_tosa_matmul,
Expand All @@ -62,8 +60,10 @@
op_tosa_resize,
op_tosa_scatter,
op_tosa_shapes,
op_tosa_sigmoid,
op_tosa_slice,
op_tosa_table,
op_tosa_tanh,
op_tosa_transpose_conv2d,
op_view,
op_where,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree


from typing import Any, List, Tuple
from typing import Any, cast, List

import torch
import tosa_serializer as ts
Expand All @@ -18,49 +17,19 @@
validate_same_dtype,
validate_valid_dtype,
)

from executorch.backends.arm.tosa.mapping import TosaArg
from torch.fx import Node


@register_node_visitor
class ClampVisitor(NodeVisitor):
target = "aten.clamp.default"

def __init__(self, *args):
super().__init__(*args)

def _get_min_max_arguments(
self, node: Node, dtype: torch.dtype
) -> Tuple[int | float, int | float]:
def cast_type(value: Any) -> int | float:
if isinstance(value, int):
return value
else:
# Attempt to cast to float
return float(value)

if dtype.is_floating_point:
dtype_min = torch.finfo(dtype).min
dtype_max = torch.finfo(dtype).max
else:
dtype_min = torch.iinfo(dtype).min
dtype_max = torch.iinfo(dtype).max
target = "tosa.CLAMP.default"

min_arg = dtype_min
max_arg = dtype_max

if node.args[1] is not None:
min_arg = cast_type(node.args[1])

if len(node.args) > 2:
if node.args[2] is not None:
max_arg = cast_type(node.args[2])

return min_arg, max_arg

def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes:
return torch.full((1,), value, dtype=dtype).view(torch.uint8).numpy().tolist()
def _to_bytes(self, value: int | float, dtype: torch.dtype) -> List[int]:
return cast(
List[int],
torch.full((1,), value, dtype=dtype).view(torch.uint8).numpy().tolist(),
)

def define_node(
self,
Expand All @@ -71,12 +40,7 @@ def define_node(
) -> None:
validate_num_inputs(self.target, inputs, [2, 3])
validate_same_dtype(self.target, [inputs[0], output], ts)
supported_dtypes = [
ts.DType.INT8,
ts.DType.FP16,
ts.DType.BF16,
ts.DType.FP32,
]
supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.BF16, ts.DType.FP32]
if self.tosa_spec.support_extension("int16"):
supported_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
Expand All @@ -87,8 +51,8 @@ def define_node(
)

node_input_dtype = node.meta["val"].dtype
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
min_val, max_val = self._get_min_max_arguments(node, node_input_dtype)
min_val = cast(int | float, node.args[1])
max_val = cast(int | float, node.args[2])

attr = ts.TosaSerializerAttribute()
attr.ClampAttribute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class ErfVisitor(SimpleNodeVisitor):
target = "aten.erf.default"
target = "tosa.ERF.default"
tosa_specs = FP_SPECS

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class SigmoidVisitor(SimpleNodeVisitor):
target = "aten.sigmoid.default"
target = "tosa.SIGMOID.default"
tosa_specs = FP_SPECS

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class TanhVisitor(SimpleNodeVisitor):
target = "aten.tanh.default"
target = "tosa.TANH.default"
tosa_specs = FP_SPECS

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions backends/arm/tosa/dialect/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ def _validate_clamp_dtype(dtype: torch.dtype, op: str) -> None:
)
return

_validate_float_dtype(dtype, op)
return

raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op)
_validate_float_dtype(dtype, op)
return


def _validate_float_dtype(dtype: torch.dtype, op: str) -> None:
Expand Down
Loading