From 806cef661bbc58fe76d55ccc34889848148f6676 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 5 May 2026 13:25:04 +0100 Subject: [PATCH] Arm backend: Add TOSA dialect binary elementwise ops Signed-off-by: Saoirse Stewart --- .../test_tosa_dialect_binary_ops.py | 439 ++++++++++++++++++ backends/arm/tosa/dialect/__init__.py | 1 + backends/arm/tosa/dialect/ops/_common.py | 20 + .../tosa/dialect/ops/binary_elementwise.py | 336 ++++++++++++++ 4 files changed, 796 insertions(+) create mode 100644 backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py create mode 100644 backends/arm/tosa/dialect/ops/binary_elementwise.py diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py new file mode 100644 index 00000000000..f886e29c834 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py @@ -0,0 +1,439 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import ( + get_registered_tosa_ops, +) +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def _to_fake(mode: FakeTensorMode, *values): + return [ + mode.from_tensor(value) if isinstance(value, torch.Tensor) else value + for value in values + ] + + +@pytest.mark.parametrize( + ( + "op_name", + "spec", + "input1", + "input2", + "kwargs", + "expected_shape", + "expected_dtype", + ), + [ + pytest.param( + "ADD", + "TOSA-1.1+FP", + torch.randn((2, 1, 3), dtype=torch.float32), + torch.randn((1, 4, 3), dtype=torch.float32), + {}, + (2, 4, 3), + torch.float32, + ), + pytest.param( + "ARITHMETIC_RIGHT_SHIFT", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + torch.ones((2, 3), dtype=torch.int8), + {"round": True}, + (2, 3), + torch.int8, + ), + pytest.param( + "BITWISE_AND", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + {}, + (2, 3), + torch.int8, + ), + pytest.param( + "BITWISE_OR", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + {}, + (2, 3), + torch.int8, + ), + pytest.param( + "BITWISE_XOR", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + {}, + (2, 3), + torch.int8, + ), + pytest.param( + "EQUAL", + "TOSA-1.1+INT", + torch.randint(1, 16, (2, 1, 3), dtype=torch.int32), + torch.randint(1, 8, (1, 4, 3), dtype=torch.int32), + {}, + (2, 4, 3), + torch.bool, + ), + pytest.param( + "GREATER", + "TOSA-1.1+FP", + torch.randn((2, 1, 3), dtype=torch.float32), + torch.randn((1, 4, 3), dtype=torch.float32), + {}, + (2, 4, 3), + torch.bool, + ), + pytest.param( + "GREATER_EQUAL", + "TOSA-1.1+INT", + torch.randint(1, 16, (2, 1, 3), dtype=torch.int32), + torch.randint(1, 8, (1, 4, 3), dtype=torch.int32), + {}, + (2, 4, 3), + torch.bool, + ), + pytest.param( + "INTDIV", + "TOSA-1.1+INT", + torch.randint(1, 16, (2, 3), dtype=torch.int32), + torch.randint(1, 8, (2, 3), dtype=torch.int32), + {}, + (2, 3), + torch.int32, + ), + pytest.param( + "LOGICAL_AND", + "TOSA-1.1+FP", + torch.tensor([[True, False], [True, True]], dtype=torch.bool), + torch.tensor([[True, True], [False, True]], dtype=torch.bool), + {}, + (2, 2), + torch.bool, + ), + pytest.param( + "LOGICAL_LEFT_SHIFT", + "TOSA-1.1+INT", + torch.randint(0, 8, (2, 3), dtype=torch.int8), + torch.ones((2, 3), dtype=torch.int8), + {}, + (2, 3), + torch.int8, + ), + pytest.param( + "LOGICAL_RIGHT_SHIFT", + "TOSA-1.1+INT", + torch.randint(0, 8, (2, 3), dtype=torch.int8), + torch.ones((2, 3), dtype=torch.int8), + {}, + (2, 3), + torch.int8, + ), + pytest.param( + "LOGICAL_OR", + "TOSA-1.1+FP", + torch.tensor([[True, False], [True, True]], dtype=torch.bool), + torch.tensor([[True, True], [False, True]], dtype=torch.bool), + {}, + (2, 2), + torch.bool, + ), + pytest.param( + "LOGICAL_XOR", + "TOSA-1.1+FP", + torch.tensor([[True, False], [True, True]], dtype=torch.bool), + torch.tensor([[True, True], [False, True]], dtype=torch.bool), + {}, + (2, 2), + torch.bool, + ), + pytest.param( + "MAXIMUM", + "TOSA-1.1+FP", + torch.randn((2, 1, 3), dtype=torch.float32), + torch.randn((1, 4, 3), dtype=torch.float32), + {}, + (2, 4, 3), + torch.float32, + ), + pytest.param( + "MINIMUM", + "TOSA-1.1+INT", + torch.randint(1, 16, (2, 1, 3), dtype=torch.int32), + torch.randint(1, 8, (1, 4, 3), dtype=torch.int32), + {}, + (2, 4, 3), + torch.int32, + ), + pytest.param( + "MUL", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + {}, + (2, 3), + torch.int32, + ), + pytest.param( + "POW", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32), + torch.randn((2, 3), dtype=torch.float32), + {}, + (2, 3), + torch.float32, + ), + pytest.param( + "SUB", + "TOSA-1.1+INT", + torch.randint(1, 16, (2, 1, 3), dtype=torch.int32), + torch.randint(1, 8, (1, 4, 3), dtype=torch.int32), + {}, + (2, 4, 3), + torch.int32, + ), + ], +) +def test_tosa_binary_ops( + op_name: str, + spec: str, + input1: torch.Tensor, + input2: torch.Tensor, + kwargs: dict[str, object], + expected_shape: tuple[int, ...], + expected_dtype: torch.dtype, +) -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string(spec) + ), FakeTensorMode() as mode: + output = getattr(exir_ops.backend.tosa, op_name).default( + *_to_fake(mode, input1, input2), + **kwargs, + ) + + assert output.dtype == expected_dtype + assert tuple(output.shape) == expected_shape + + +@pytest.mark.parametrize("op_name", ["LOGICAL_LEFT_SHIFT", "LOGICAL_RIGHT_SHIFT"]) +@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32]) +def test_logical_shift_supports_int_dtype_on_fp_profile( + op_name: str, + dtype: torch.dtype, +) -> None: + input1 = torch.randint(0, 8, (2, 3), dtype=dtype) + input2 = torch.ones((2, 3), dtype=dtype) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + output = getattr(exir_ops.backend.tosa, op_name).default( + *_to_fake(mode, input1, input2) + ) + + assert output.dtype == dtype + assert tuple(output.shape) == tuple(input1.shape) + + +@pytest.mark.parametrize( + "spec", + [ + pytest.param("TOSA-1.1+INT+int64", id="int_profile"), + pytest.param("TOSA-1.1+FP+int64", id="fp_profile"), + ], +) +def test_bitwise_and_supports_int64_extension(spec: str) -> None: + input1 = torch.randint(0, 8, (2, 3), dtype=torch.int64) + input2 = torch.ones((2, 3), dtype=torch.int64) + + with TosaLoweringContext( + TosaSpecification.create_from_string(spec) + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.BITWISE_AND.default( + *_to_fake(mode, input1, input2) + ) + + assert output.dtype == torch.int64 + assert tuple(output.shape) == tuple(input1.shape) + + +def test_bitwise_and_rejects_int64_without_extension() -> None: + input1 = torch.randint(0, 8, (2, 3), dtype=torch.int64) + input2 = torch.ones((2, 3), dtype=torch.int64) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+INT") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="doesn't support int64"): + exir_ops.backend.tosa.BITWISE_AND.default(*_to_fake(mode, input1, input2)) + + +@pytest.mark.parametrize( + ("op", "spec", "expected"), + [ + pytest.param( + exir_ops.backend.tosa.ARITHMETIC_RIGHT_SHIFT.default, + "TOSA-1.1+INT", + True, + id="arithmetic_right_shift_int", + ), + pytest.param( + exir_ops.backend.tosa.ARITHMETIC_RIGHT_SHIFT.default, + "TOSA-1.1+FP", + False, + id="arithmetic_right_shift_fp", + ), + pytest.param( + exir_ops.backend.tosa.BITWISE_OR.default, + "TOSA-1.1+INT", + True, + id="bitwise_or_int", + ), + pytest.param( + exir_ops.backend.tosa.BITWISE_OR.default, + "TOSA-1.1+FP", + False, + id="bitwise_or_fp", + ), + pytest.param( + exir_ops.backend.tosa.BITWISE_XOR.default, + "TOSA-1.1+INT", + True, + id="bitwise_xor_int", + ), + pytest.param( + exir_ops.backend.tosa.BITWISE_XOR.default, + "TOSA-1.1+FP", + False, + id="bitwise_xor_fp", + ), + ], +) +def test_tosa_integer_shift_and_bitwise_ops_registered_for_int_profile_only( + op, + spec: str, + expected: bool, +) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string(spec)): + registered_ops = get_registered_tosa_ops() + + assert (op in registered_ops) is expected + + +@pytest.mark.parametrize("spec", ["TOSA-1.1+INT", "TOSA-1.1+FP"]) +def test_tosa_bitwise_and_registered_for_all_profiles(spec: str) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string(spec)): + registered_ops = get_registered_tosa_ops() + + assert exir_ops.backend.tosa.BITWISE_AND.default in registered_ops + + +@pytest.mark.parametrize( + ("spec", "expected"), + [ + pytest.param("TOSA-1.1+INT", False, id="pow_int"), + pytest.param("TOSA-1.1+FP", True, id="pow_fp"), + ], +) +def test_tosa_pow_registered_for_fp_profile_only(spec: str, expected: bool) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string(spec)): + registered_ops = get_registered_tosa_ops() + + assert (exir_ops.backend.tosa.POW.default in registered_ops) is expected + + +def test_pow_accepts_bfloat16_with_bf16_extension() -> None: + input1 = torch.randn((2, 3), dtype=torch.bfloat16) + input2 = torch.randn((2, 3), dtype=torch.bfloat16) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+bf16") + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.POW.default(*_to_fake(mode, input1, input2)) + + assert output.dtype == torch.bfloat16 + assert tuple(output.shape) == tuple(input1.shape) + + +def test_mul_rejects_non_zero_shift_for_non_int32() -> None: + input1 = torch.randint(-8, 8, (2, 3), dtype=torch.int8) + input2 = torch.randint(-8, 8, (2, 3), dtype=torch.int8) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+INT") + ), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="Only int32 MUL supports a non-zero shift value", + ): + exir_ops.backend.tosa.MUL.default( + *_to_fake(mode, input1, input2), + shift=3, + ) + + +def test_intdiv_supports_int32_on_fp_profile() -> None: + input1 = torch.randint(1, 16, (2, 3), dtype=torch.int32) + input2 = torch.randint(1, 8, (2, 3), dtype=torch.int32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.INTDIV.default(*_to_fake(mode, input1, input2)) + + assert output.dtype == torch.int32 + assert tuple(output.shape) == tuple(input1.shape) + + +def test_equal_rejects_int8() -> None: + input1 = torch.randint(-8, 8, (2, 3), dtype=torch.int8) + input2 = torch.randint(-8, 8, (2, 3), dtype=torch.int8) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+INT") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="Unsupported dtype"): + exir_ops.backend.tosa.EQUAL.default(*_to_fake(mode, input1, input2)) + + +@pytest.mark.parametrize("op_name", ["EQUAL", "GREATER", "GREATER_EQUAL"]) +def test_compare_ops_reject_int32_on_fp_profile(op_name: str) -> None: + input1 = torch.randint(1, 16, (2, 3), dtype=torch.int32) + input2 = torch.randint(1, 8, (2, 3), dtype=torch.int32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="doesn't support int32"): + getattr(exir_ops.backend.tosa, op_name).default( + *_to_fake(mode, input1, input2) + ) + + +@pytest.mark.parametrize("op_name", ["MAXIMUM", "MINIMUM"]) +def test_extrema_ops_reject_int32_on_fp_profile(op_name: str) -> None: + input1 = torch.randint(1, 16, (2, 3), dtype=torch.int32) + input2 = torch.randint(1, 8, (2, 3), dtype=torch.int32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="doesn't support int32"): + getattr(exir_ops.backend.tosa, op_name).default( + *_to_fake(mode, input1, input2) + ) diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 601a8ab41d1..53c88167f69 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -8,6 +8,7 @@ argmax, avg_pool2d, avg_pool2d_adaptive, + binary_elementwise, conv2d, conv3d, custom, diff --git a/backends/arm/tosa/dialect/ops/_common.py b/backends/arm/tosa/dialect/ops/_common.py index c05e1a9d173..daeef30b097 100644 --- a/backends/arm/tosa/dialect/ops/_common.py +++ b/backends/arm/tosa/dialect/ops/_common.py @@ -9,6 +9,26 @@ _VALID_NAN_MODES = {"PROPAGATE", "IGNORE"} +def broadcast_shape( + input1: torch.Tensor, input2: torch.Tensor, op: str +) -> tuple[int | torch.SymInt, ...]: + try: + return tuple(torch.broadcast_shapes(input1.shape, input2.shape)) + except (RuntimeError, ValueError) as err: + raise TosaValueError( + f"Failed to broadcast shapes {tuple(input1.shape)} and {tuple(input2.shape)}", + op=op, + ) from err + + +def require_same_dtype(input1: torch.Tensor, input2: torch.Tensor, op: str) -> None: + if input1.dtype != input2.dtype: + raise TosaValueError( + f"Expected matching dtypes but got {input1.dtype} and {input2.dtype}", + op=op, + ) + + def validate_nan_mode(nan_mode: str, op: str) -> None: if nan_mode not in _VALID_NAN_MODES: raise TosaValueError( diff --git a/backends/arm/tosa/dialect/ops/binary_elementwise.py b/backends/arm/tosa/dialect/ops/binary_elementwise.py new file mode 100644 index 00000000000..0b62cc49867 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/binary_elementwise.py @@ -0,0 +1,336 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops._common import ( + broadcast_shape, + require_same_dtype, + validate_nan_mode, +) +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + +FP_SPECS = TosaSpecification.all_versions_for_profile("FP") +INT_SPECS = TosaSpecification.all_versions_for_profile("INT") +INT_DTYPES = (torch.int8, torch.int16, torch.int32) +FP_DTYPES = (torch.float16, torch.float32) + + +def _dtype_name(dtype: torch.dtype) -> str: + return str(dtype).removeprefix("torch.") + + +def _raise_unsupported_dtype(dtype: torch.dtype, op: str) -> None: + raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op) + + +def _raise_unsupported_profile(dtype: torch.dtype, op: str) -> None: + raise TosaValueError( + f"TOSA spec {get_context_spec()} doesn't support {_dtype_name(dtype)} for {op}", + op=op, + ) + + +def _binary_meta( + input1: torch.Tensor, + input2: torch.Tensor, + op: str, + *, + output_dtype: torch.dtype | None = None, +) -> torch.Tensor: + require_same_dtype(input1, input2, op) + output_shape = broadcast_shape(input1, input2, op) + return torch.empty(output_shape, dtype=output_dtype or input1.dtype) + + +def _require_int_profile_support(dtype: torch.dtype, op: str) -> None: + if not get_context_spec().support_integer(): + _raise_unsupported_profile(dtype, op) + + +def _validate_fp_dtype(dtype: torch.dtype, op: str) -> None: + tosa_spec = get_context_spec() + + if dtype in FP_DTYPES: + if not tosa_spec.support_float(): + _raise_unsupported_profile(dtype, op) + return + + if dtype == torch.bfloat16: + if not (tosa_spec.support_float() and tosa_spec.support_extension("bf16")): + _raise_unsupported_profile(dtype, op) + return + + _raise_unsupported_dtype(dtype, op) + + +def _validate_int_dtype(dtype: torch.dtype, op: str) -> None: + if dtype in INT_DTYPES: + _require_int_profile_support(dtype, op) + return + + _raise_unsupported_dtype(dtype, op) + + +def _validate_any_profile_int_dtype(dtype: torch.dtype, op: str) -> None: + if dtype not in INT_DTYPES: + _raise_unsupported_dtype(dtype, op) + + +def _validate_bitwise_and_dtype(dtype: torch.dtype) -> None: + if dtype in INT_DTYPES: + _require_int_profile_support(dtype, "BITWISE_AND") + return + + if dtype == torch.int64: + if not get_context_spec().support_extension("int64"): + _raise_unsupported_profile(dtype, "BITWISE_AND") + return + + _raise_unsupported_dtype(dtype, "BITWISE_AND") + + +def _validate_add_sub_dtype(dtype: torch.dtype, op: str) -> None: + if dtype == torch.int32: + return + + _validate_fp_dtype(dtype, op) + + +def _validate_profile_int32_or_fp_dtype(dtype: torch.dtype, op: str) -> None: + if dtype == torch.int32: + _require_int_profile_support(dtype, op) + return + + _validate_fp_dtype(dtype, op) + + +def _validate_bool_dtype(dtype: torch.dtype, op: str) -> None: + if dtype != torch.bool: + _raise_unsupported_dtype(dtype, op) + + +def _validate_int32_dtype(dtype: torch.dtype, op: str) -> None: + if dtype != torch.int32: + _raise_unsupported_dtype(dtype, op) + + +def _validate_and_infer_mul_output_dtype(dtype: torch.dtype) -> torch.dtype: # type: ignore[return] + if dtype in FP_DTYPES or dtype == torch.bfloat16: + _validate_fp_dtype(dtype, "MUL") + return dtype + + if dtype in INT_DTYPES: + if dtype != torch.int32: + _require_int_profile_support(dtype, "MUL") + return torch.int32 + + _raise_unsupported_dtype(dtype, "MUL") + + +@register_fake_tosa_op( + "ADD(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def ADD(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_add_sub_dtype(input1.dtype, "ADD") + return _binary_meta(input1, input2, "ADD") + + +@register_fake_tosa_op( + "ARITHMETIC_RIGHT_SHIFT(Tensor input1, Tensor input2, *, bool round=False) -> Tensor", + INT_SPECS, +) +def ARITHMETIC_RIGHT_SHIFT( + input1: torch.Tensor, + input2: torch.Tensor, + *, + round: bool = False, +) -> torch.Tensor: + _validate_int_dtype(input1.dtype, "ARITHMETIC_RIGHT_SHIFT") + return _binary_meta(input1, input2, "ARITHMETIC_RIGHT_SHIFT") + + +@register_fake_tosa_op( + "BITWISE_AND(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def BITWISE_AND(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_bitwise_and_dtype(input1.dtype) + return _binary_meta(input1, input2, "BITWISE_AND") + + +@register_fake_tosa_op( + "BITWISE_OR(Tensor input1, Tensor input2) -> Tensor", + INT_SPECS, +) +def BITWISE_OR(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_int_dtype(input1.dtype, "BITWISE_OR") + return _binary_meta(input1, input2, "BITWISE_OR") + + +@register_fake_tosa_op( + "BITWISE_XOR(Tensor input1, Tensor input2) -> Tensor", + INT_SPECS, +) +def BITWISE_XOR(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_int_dtype(input1.dtype, "BITWISE_XOR") + return _binary_meta(input1, input2, "BITWISE_XOR") + + +@register_fake_tosa_op( + "EQUAL(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def EQUAL(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_profile_int32_or_fp_dtype(input1.dtype, "EQUAL") + return _binary_meta(input1, input2, "EQUAL", output_dtype=torch.bool) + + +@register_fake_tosa_op( + "GREATER(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def GREATER(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_profile_int32_or_fp_dtype(input1.dtype, "GREATER") + return _binary_meta(input1, input2, "GREATER", output_dtype=torch.bool) + + +@register_fake_tosa_op( + "GREATER_EQUAL(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def GREATER_EQUAL(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_profile_int32_or_fp_dtype(input1.dtype, "GREATER_EQUAL") + return _binary_meta(input1, input2, "GREATER_EQUAL", output_dtype=torch.bool) + + +@register_fake_tosa_op( + "INTDIV(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def INTDIV(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_int32_dtype(input1.dtype, "INTDIV") + return _binary_meta(input1, input2, "INTDIV") + + +@register_fake_tosa_op( + "LOGICAL_AND(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def LOGICAL_AND(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_bool_dtype(input1.dtype, "LOGICAL_AND") + return _binary_meta(input1, input2, "LOGICAL_AND") + + +@register_fake_tosa_op( + "LOGICAL_LEFT_SHIFT(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def LOGICAL_LEFT_SHIFT(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_any_profile_int_dtype(input1.dtype, "LOGICAL_LEFT_SHIFT") + return _binary_meta(input1, input2, "LOGICAL_LEFT_SHIFT") + + +@register_fake_tosa_op( + "LOGICAL_RIGHT_SHIFT(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def LOGICAL_RIGHT_SHIFT(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_any_profile_int_dtype(input1.dtype, "LOGICAL_RIGHT_SHIFT") + return _binary_meta(input1, input2, "LOGICAL_RIGHT_SHIFT") + + +@register_fake_tosa_op( + "LOGICAL_OR(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def LOGICAL_OR(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_bool_dtype(input1.dtype, "LOGICAL_OR") + return _binary_meta(input1, input2, "LOGICAL_OR") + + +@register_fake_tosa_op( + "LOGICAL_XOR(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def LOGICAL_XOR(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_bool_dtype(input1.dtype, "LOGICAL_XOR") + return _binary_meta(input1, input2, "LOGICAL_XOR") + + +@register_fake_tosa_op( + "MAXIMUM(Tensor input1, Tensor input2, *, str nan_mode='PROPAGATE') -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def MAXIMUM( + input1: torch.Tensor, + input2: torch.Tensor, + *, + nan_mode: str = "PROPAGATE", +) -> torch.Tensor: + validate_nan_mode(nan_mode, "MAXIMUM") + _validate_profile_int32_or_fp_dtype(input1.dtype, "MAXIMUM") + return _binary_meta(input1, input2, "MAXIMUM") + + +@register_fake_tosa_op( + "MINIMUM(Tensor input1, Tensor input2, *, str nan_mode='PROPAGATE') -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def MINIMUM( + input1: torch.Tensor, + input2: torch.Tensor, + *, + nan_mode: str = "PROPAGATE", +) -> torch.Tensor: + validate_nan_mode(nan_mode, "MINIMUM") + _validate_profile_int32_or_fp_dtype(input1.dtype, "MINIMUM") + return _binary_meta(input1, input2, "MINIMUM") + + +@register_fake_tosa_op( + "MUL(Tensor input1, Tensor input2, *, int shift=0) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def MUL( + input1: torch.Tensor, + input2: torch.Tensor, + *, + shift: int = 0, +) -> torch.Tensor: + output_dtype = _validate_and_infer_mul_output_dtype(input1.dtype) + + if shift < 0 or shift > 63: + raise TosaValueError("shift must be in the range [0, 63]", op="MUL") + if input1.dtype != torch.int32 and shift != 0: + raise TosaValueError( + "Only int32 MUL supports a non-zero shift value", + op="MUL", + ) + + return _binary_meta(input1, input2, "MUL", output_dtype=output_dtype) + + +@register_fake_tosa_op( + "POW(Tensor input1, Tensor input2) -> Tensor", + FP_SPECS, +) +def POW(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_fp_dtype(input1.dtype, "POW") + return _binary_meta(input1, input2, "POW") + + +@register_fake_tosa_op( + "SUB(Tensor input1, Tensor input2) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def SUB(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + _validate_add_sub_dtype(input1.dtype, "SUB") + return _binary_meta(input1, input2, "SUB")