diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index c4c4e984f2d..f9755f40e21 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -109,6 +109,43 @@ def node_is_effectively_static_tensor( ) +def weights_are_effectively_static( + node: Node, parameters_mapping: dict[str, Parameter], weight_index: int = 1 +) -> bool: + """Neutron IR sometimes requires some weights to be static. This method checks if this is the case for the + provided `node`. + + Sometimes a `permute_copy` is inserted to transpose the weights during edge lowering. The `permute_copy` is + then removed during conversion to Neutron IR if it transposes static data. In those cases, the weights will be + static. Therefore, it is ok if the weights are produced by a `permute_copy` with a static input. + + :param node: Tensor node to check for data. + :param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the + `state_dict` attribute of an edge program. + :param weight_index: Index to the `node.args` where the weight is located. Defaults to 1. + :return: True if the weight at the given index is effectively static. + """ + + def _is_permute_copy(node_: Node) -> bool: + return hasattr(node_, "target") and node_.target == PermuteCopy + + if ( + _is_dequantize(dq_node := node.args[weight_index]) + and _is_quantize(q_node := dq_node.args[0]) + and _is_permute_copy(permute_copy_node := q_node.args[0]) + ): + # The weights are produced by a `permute_copy`. Its input (the weights) must be static. + return node_is_effectively_static_tensor( + permute_copy_node.args[0], parameters_mapping + ) + + else: + # There is no `permute_copy`. The weights must be static directly. + return node_is_effectively_static_tensor( + node.args[weight_index], parameters_mapping + ) + + def try_get_tensor_constant_from_node( graph_module: GraphModule, node: Node ) -> Parameter | None: diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/addmm_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/addmm_converter.py index 0df41526da2..0b80be695f4 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/addmm_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/addmm_converter.py @@ -1,9 +1,15 @@ -# Copyright 2024-2025 NXP +# Copyright 2024-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.backends.nxp.backend.edge_helper import input_rank +import torch + +from executorch.backends.nxp.backend.edge_helper import ( + input_rank, + node_is_effectively_static_tensor, + weights_are_effectively_static, +) from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, @@ -12,10 +18,18 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( fully_connected_options, ) + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter +# The edge operator signature is: aten.addmm(bias, input, weight, *, beta=1, alpha=1) +MAIN_INPUT_IDX = 1 +WEIGHT_IDX = 2 +BIAS_IDX = 0 + + class AddMMConverter(NodeConverter): """Convert the `aten.addmm` operator to TFLite `FullyConnected` with a bias input.""" @@ -29,12 +43,67 @@ def _is_supported_in_IR( return False # The weights must be 2D. - if input_rank(node, 2) != 2: + if input_rank(node, WEIGHT_IDX) != 2: + return False + + alpha, beta = node.kwargs.get("alpha", 1), node.kwargs.get("beta", 1) + if alpha != 1 or beta != 1: + # As these cases seem rare, conversion is not implemented for the time being. + return False + + return True + + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + # Main input and output must be `int8` or `uint8`. + if not NodeConverter.uses_quantization_type_for_io( + node, [torch.int8, torch.uint8], [MAIN_INPUT_IDX], [0] + ): + return False + + # Weights must be `int8`. + if not NodeConverter.uses_quantization_type_for_io( + node, [torch.int8], [WEIGHT_IDX], [] + ): + return False + + # Bias must be `int32`. + if not NodeConverter.uses_quantization_type_for_io( + node, [torch.int32], [BIAS_IDX], [] + ): + return False + + # Weights must be constant. + if not weights_are_effectively_static( + node, parameters_mapping, weight_index=WEIGHT_IDX + ): + return False + + # The bias must be constant. + if not node_is_effectively_static_tensor( + node.args[BIAS_IDX], parameters_mapping + ): return False return True def convert(self, node: Node): + """Convert the `aten.addmm` operator to NeutronIR `FullyConnected`. + The schema is: + addmm( + Tensor self, + Tensor mat1, + Tensor mat2, + *, + Scalar beta=1, + Scalar alpha=1 + ) -> Tensor + """ self.assert_convertible(node) t_op = self._create_tflite_op_with_io_tensors(node) @@ -47,14 +116,14 @@ def convert(self, node: Node): w = t_op.tmp_inputs[2] y = t_op.tmp_outputs[0] - # Assign the operator its TFLite inputs and outputs + # Assign the operator its Neutron IR inputs and outputs t_op.tmp_inputs = [x, w, bias] t_op.tmp_outputs = [y] ops = OpsList(middle_op=t_op) # The `aten.addmm` uses main input with shape [M, N] and the weights have the shape [N, O]. - # TFLite `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]). + # Neutron IR `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]). # Insert a `Transpose` operator to permute the weights to achieve correct conversion. (The `Transpose` will not # be present in the output model if the weights are static.) ops.add_pre(self.builder.create_transpose_operator_before(t_op, 1, [1, 0])) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mm_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mm_converter.py index dd9e3e2da54..d35f5437be7 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mm_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mm_converter.py @@ -1,9 +1,14 @@ -# Copyright 2024-2025 NXP +# Copyright 2024-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.backends.nxp.backend.edge_helper import input_rank +import torch + +from executorch.backends.nxp.backend.edge_helper import ( + input_rank, + weights_are_effectively_static, +) from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, @@ -12,6 +17,7 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( fully_connected_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -33,8 +39,37 @@ def _is_supported_in_IR( return True + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + # Main input and output must be `int8` or `uint8`. + if not NodeConverter.uses_quantization_type_for_io( + node, [torch.int8, torch.uint8], [0], [0] + ): + return False + + # Weights must be `int8`. + if not NodeConverter.uses_quantization_type_for_io(node, [torch.int8], [1], []): + return False + + # Weights must be static. + if not weights_are_effectively_static(node, parameters_mapping): + return False + + return True + def convert(self, node: Node): - """Convert the `aten.mm` operator to TFLite `FullyConnected` without a bias input.""" + """Convert the `aten.mm` operator to Neutron IR `FullyConnected` without a bias input. + The schema is: + mm( + Tensor self, + Tensor mat2 + ) -> Tensor + """ self.assert_convertible(node) t_op = self._create_tflite_op_with_io_tensors(node) @@ -44,14 +79,14 @@ def convert(self, node: Node): w = t_op.tmp_inputs[1] y = t_op.tmp_outputs[0] - # Assign the operator its TFLite inputs and outputs + # Assign the operator its Neutron IR inputs and outputs t_op.tmp_inputs = [x, w] t_op.tmp_outputs = [y] ops = OpsList(middle_op=t_op) # The `aten.mm` uses main input with shape [M, N] and the weights have the shape [N, O]. - # TFLite `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]). + # Neutron IR `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]). # Insert a `Transpose` operator to permute the weights to achieve correct conversion. (The `Transpose` will not # be present in the output model if the weights are static.) ops.add_pre(self.builder.create_transpose_operator_before(t_op, 1, [1, 0])) diff --git a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py index 0a0f6641f4b..dafea5d259b 100644 --- a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py +++ b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py @@ -9,6 +9,7 @@ from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer +from executorch.backends.nxp.tests.ops_aliases import PermuteCopy # noinspection PyProtectedMember from executorch.exir.dialects._ops import ops as exir_ops @@ -109,9 +110,11 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): main_cluster_node_to_auxiliary_nodes = { AddMM: [ ViewCopy, + PermuteCopy, ], MM: [ ViewCopy, + PermuteCopy, ], ViewCopy: [Clone, CloneDimOrder], Conv: [ diff --git a/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py index a8cdee41830..d095b5a5237 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py @@ -1,96 +1,158 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest - -import kgb import numpy as np + +# noinspection PyUnusedImports +import pytest import torch -from executorch.backends.nxp.backend.edge_program_converter import ( - EdgeProgramToIRConverter, -) +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program -from executorch.backends.nxp.tests.executors import ( - convert_run_compare, - graph_contains_any_of_ops, -) +from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier, Operator from executorch.backends.nxp.tests.models import AddmmModule, LinearModule -from executorch.exir.dialects._ops import ops as exir_ops -from parameterized import parameterized -from torch.export import ExportedProgram - - -class TestAddmmConversion(unittest.TestCase): - @classmethod - def setUpClass(cls): - torch.manual_seed(23) - np.random.seed(42) - - @parameterized.expand([("QAT", True), ("PTQ", False)]) - def test_addmm_conversion(self, _, use_qat: bool): - with kgb.spy_on( - EdgeProgramToIRConverter.convert_program, - call_original=True, - owner=EdgeProgramToIRConverter, - ) as converter_spy: - input_shape = (1, 32) - model = AddmmModule(input_shape[1]) - - edge_program = to_quantized_edge_program( - model, input_shape, use_qat=use_qat - ).exported_program() - - # Make sure that all nodes were delegated. - assert not graph_contains_any_of_ops( - graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default] - ) - assert any( - "lowered_module" in node.name for node in edge_program.graph.nodes - ) - - tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value - exported_program: ExportedProgram = converter_spy.calls[-1].args[0] - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( - np.int8 - ) - convert_run_compare( - exported_program, - input_data, - tfl_model=tflite_flatbuffers_model, - ) - - @parameterized.expand([("QAT", True), ("PTQ", False)]) - def test_linear_conversion__with_bias(self, _, use_qat: bool): - with kgb.spy_on( - EdgeProgramToIRConverter.convert_program, - call_original=True, - owner=EdgeProgramToIRConverter, - ) as converter_spy: - input_shape = (10, 32) - model = LinearModule(bias=True) - - edge_program = to_quantized_edge_program( - model, input_shape, use_qat=use_qat - ).exported_program() - - # Make sure that all nodes were delegated. - assert not graph_contains_any_of_ops( - graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default] - ) - assert any( - "lowered_module" in node.name for node in edge_program.graph.nodes - ) - - tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value - exported_program: ExportedProgram = converter_spy.calls[-1].args[0] - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( - np.int8 - ) - convert_run_compare( - exported_program, - input_data, - tfl_model=tflite_flatbuffers_model, - ) +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddMM, + ExecutorchDelegateCall, + MM, + PermuteCopy, + ViewCopy, +) +from executorch.backends.nxp.tests.use_qat import * # noqa F403 + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(42) + np.random.seed(23) + + +class TestMM: + + # noinspection PyMethodMayBeStatic + def assert_delegated( + self, + model, + input_shape, + mocker, + use_qat=False, + expected_delegated_ops: dict[Operator, int] | None = None, + ): + if expected_delegated_ops is None: + expected_delegated_ops = {AddMM: 1} + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops=expected_delegated_ops, + expected_non_delegated_ops={}, + ) + + # Create a RandomDatasetCreator that covers also negative numbers to properly test the operator. + dataset_creator = RandomDatasetCreator(low=-2, high=2) + + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset_creator, + use_qat=use_qat, + ) + + # noinspection PyMethodMayBeStatic + def assert_not_delegated(self, model, input_shape): + delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() + + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [AddMM, MM]) + + @pytest.mark.parametrize( + "input_shape", + [ + # PyTorch allows only 2D inputs. + (1, 32), + (3, 11), + ], + ids=lambda s: f"input_shape = {s}", + ) + def test__from_addmm(self, mocker, use_qat, input_shape: tuple[int, ...]): + model = AddmmModule(input_shape[-1]) + self.assert_delegated(model, input_shape, mocker, use_qat=use_qat) + + def test__from_addmm__unsupported_alpha(self): + input_shape = (1, 8) + model = AddmmModule(input_shape[-1], alpha=0.42) + self.assert_not_delegated(model, input_shape) + + def test__from_addmm__unsupported_beta(self): + input_shape = (1, 8) + model = AddmmModule(input_shape[-1], beta=0.42) + self.assert_not_delegated(model, input_shape) + + @pytest.mark.parametrize( + "alpha", + [1, 1.0], + ids=lambda a: f"alpha = {a}", + ) + def test__from_addmm__supported_alpha(self, mocker, use_qat, alpha): + input_shape = (1, 8) + model = AddmmModule(input_shape[-1], alpha=alpha) + self.assert_delegated(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "beta", + [1, 1.0], + ids=lambda b: f"beta = {b}", + ) + def test__from_addmm__supported_beta(self, mocker, use_qat, beta): + input_shape = (1, 8) + model = AddmmModule(input_shape[-1], beta=beta) + self.assert_delegated(model, input_shape, mocker, use_qat) + + @pytest.mark.parametrize( + "input_shape", + [ + (1, 32), + (3, 11), + ], + ids=lambda s: f"input_shape = {s}", + ) + def test__from_linear_with_bias__2d( + self, mocker, use_qat, input_shape: tuple[int, ...] + ): + model = LinearModule(bias=True, in_features=input_shape[-1], out_features=7) + self.assert_delegated( + model, + input_shape, + mocker, + use_qat=use_qat, + expected_delegated_ops={AddMM: 1, PermuteCopy: 1}, + ) + + @pytest.mark.parametrize( + "input_shape", + [ + (1, 3, 8), + (2, 3, 5), + (2, 3, 3, 3), + ], + ids=lambda s: f"input_shape = {s}", + ) + def test__from_linear_with_bias__higher_ranks( + self, mocker, use_qat, input_shape: tuple[int, ...] + ): + # More than 2D cases get reshaped to 2D, so two extra view_copy nodes are delegated. + + model = LinearModule(bias=True, in_features=input_shape[-1], out_features=7) + self.assert_delegated( + model, + input_shape, + mocker, + use_qat=use_qat, + expected_delegated_ops={AddMM: 1, PermuteCopy: 1, ViewCopy: 2}, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py index 60dbfd1b215..ca59d8c579a 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py @@ -1,97 +1,111 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest - -import kgb import numpy as np + +# noinspection PyUnusedImports +import pytest import torch -from executorch.backends.nxp.backend.edge_program_converter import ( - EdgeProgramToIRConverter, -) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program -from executorch.backends.nxp.tests.executors import ( - convert_run_compare, - graph_contains_any_of_ops, -) +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier, Operator from executorch.backends.nxp.tests.models import LinearModule, MmModule -from executorch.exir.dialects._ops import ops as exir_ops -from parameterized import parameterized -from torch.export import ExportedProgram +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import MM, PermuteCopy, ViewCopy +from executorch.backends.nxp.tests.use_qat import * # noqa F403 + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(42) + np.random.seed(23) -class TestMmConversion(unittest.TestCase): - @classmethod - def setUpClass(cls): - torch.manual_seed(23) - np.random.seed(42) +class TestMM: - @parameterized.expand([("QAT", True), ("PTQ", False)]) - def test_mm_conversion(self, _, use_qat: bool): - with kgb.spy_on( - EdgeProgramToIRConverter.convert_program, - call_original=True, - owner=EdgeProgramToIRConverter, - ) as converter_spy: - input_shape = (1, 32) - model = MmModule(input_shape[1]) + # noinspection PyMethodMayBeStatic + def assert_delegated( + self, + model, + input_shape, + mocker, + use_qat=False, + expected_delegated_ops: dict[Operator, int] | None = None, + ): + if expected_delegated_ops is None: + expected_delegated_ops = {MM: 1} - edge_program = to_quantized_edge_program( - model, input_shape, use_qat=use_qat - ).exported_program() + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops=expected_delegated_ops, + expected_non_delegated_ops={}, + ) - # Make sure that all nodes were delegated. - assert not graph_contains_any_of_ops( - graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default] - ) - assert any( - "lowered_module" in node.name for node in edge_program.graph.nodes - ) + # Create a RandomDatasetCreator that covers also negative numbers to properly test the operator. + dataset_creator = RandomDatasetCreator(low=-2, high=2) - tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value - exported_program: ExportedProgram = converter_spy.calls[-1].args[0] - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( - np.int8 - ) - convert_run_compare( - exported_program, - input_data, - tfl_model=tflite_flatbuffers_model, - atol=1.0, - ) + lower_run_compare( + model, + input_shape, + graph_verifier, + dataset_creator, + use_qat=use_qat, + ) - @parameterized.expand([("QAT", True), ("PTQ", False)]) - def test_linear_conversion__without_bias(self, _, use_qat: bool): - with kgb.spy_on( - EdgeProgramToIRConverter.convert_program, - call_original=True, - owner=EdgeProgramToIRConverter, - ) as converter_spy: - input_shape = (10, 32) - model = LinearModule(bias=False) + @pytest.mark.parametrize( + "input_shape", + [ + # PyTorch allows only 2D inputs. + (1, 32), + (3, 11), + ], + ids=lambda s: f"input_shape = {s}", + ) + def test__from_mm(self, mocker, use_qat, input_shape: tuple[int, ...]): + model = MmModule(input_shape[-1]) + self.assert_delegated(model, input_shape, mocker, use_qat=use_qat) - edge_program = to_quantized_edge_program( - model, input_shape, use_qat=use_qat - ).exported_program() + @pytest.mark.parametrize( + "input_shape", + [ + (1, 32), + (3, 11), + ], + ids=lambda s: f"input_shape = {s}", + ) + def test__from_linear_without_bias( + self, mocker, use_qat, input_shape: tuple[int, ...] + ): + model = LinearModule(bias=False, in_features=input_shape[-1], out_features=7) + self.assert_delegated( + model, + input_shape, + mocker, + use_qat=use_qat, + expected_delegated_ops={MM: 1, PermuteCopy: 1}, + ) - # Make sure that all nodes were delegated. - assert not graph_contains_any_of_ops( - graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default] - ) - assert any( - "lowered_module" in node.name for node in edge_program.graph.nodes - ) + @pytest.mark.parametrize( + "input_shape", + [ + (1, 3, 8), + (2, 3, 5), + (2, 3, 3, 3), + ], + ids=lambda s: f"input_shape = {s}", + ) + def test__from_linear_without_bias__higher_ranks( + self, mocker, use_qat, input_shape: tuple[int, ...] + ): + # More than 2D cases get reshaped to 2D, so two extra view_copy nodes are delegated. - tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value - exported_program: ExportedProgram = converter_spy.calls[-1].args[0] - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( - np.int8 - ) - convert_run_compare( - exported_program, - input_data, - tfl_model=tflite_flatbuffers_model, - ) + model = LinearModule(bias=False, in_features=input_shape[-1], out_features=7) + self.assert_delegated( + model, + input_shape, + mocker, + use_qat=use_qat, + expected_delegated_ops={MM: 1, PermuteCopy: 1, ViewCopy: 2}, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index ab42560f075..b93a33d04bd 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -14,7 +14,7 @@ from executorch.backends.nxp.tests.models import Conv2dModule, LinearModule, ReLUModule from executorch.backends.nxp.tests.nsys_testing import lower_run_compare from executorch.backends.nxp.tests.ops_aliases import ( - AddMm, + AddMM, Convolution, DequantizePerChannel, DequantizePerTensor, @@ -105,7 +105,7 @@ def test_relu_conversion__full_pipeline(self, mocker, model, input_shape): graph_verifier = DetailedGraphVerifier( mocker=mocker, expected_delegated_ops=( - {Convolution: 1, Relu: 1} if is_conv_module else {AddMm: 1, Relu: 1} + {Convolution: 1, Relu: 1} if is_conv_module else {AddMM: 1, Relu: 1} ), expected_non_delegated_ops={}, ops_to_ignore=[ diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 7545dd940f2..f7fc6918554 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -252,24 +252,36 @@ def forward(self, x): class AddmmModule(torch.nn.Module): - def __init__(self, in_channels: int): + def __init__( + self, + in_channels: int, + out_channels: int = 7, + alpha: float | None = None, + beta: float | None = None, + ): super().__init__() - self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels)) - self.bias = torch.nn.Parameter(torch.empty(in_channels)) + self.weight = torch.nn.Parameter(torch.empty(in_channels, out_channels)) + self.bias = torch.nn.Parameter(torch.empty(out_channels)) torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) torch.nn.init.uniform_(self.bias, -bound, bound) self.eval() + self.kwargs = {} + if alpha is not None: + self.kwargs["alpha"] = alpha + if beta is not None: + self.kwargs["beta"] = beta + def forward(self, x): - return torch.addmm(self.bias, x, self.weight) + return torch.addmm(self.bias, x, self.weight, **self.kwargs) class MmModule(torch.nn.Module): - def __init__(self, in_channels: int): + def __init__(self, in_channels: int, out_channels: int = 7): super().__init__() - self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels)) + self.weight = torch.nn.Parameter(torch.empty(in_channels, out_channels)) torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.eval() diff --git a/backends/nxp/tests/ops_aliases.py b/backends/nxp/tests/ops_aliases.py index efb1147c292..6f02c5e0e23 100644 --- a/backends/nxp/tests/ops_aliases.py +++ b/backends/nxp/tests/ops_aliases.py @@ -13,7 +13,7 @@ Abs = exir_ops.edge.aten.abs.default AdaptiveAvgPool2D = exir_ops.edge.aten._adaptive_avg_pool2d.default -AddMm = exir_ops.edge.aten.addmm.default +AddMM = exir_ops.edge.aten.addmm.default AddTensor = exir_ops.edge.aten.add.Tensor AvgPool2D = exir_ops.edge.aten.avg_pool2d.default Bmm = exir_ops.edge.aten.bmm.default @@ -30,13 +30,13 @@ HardTanh = exir_ops.edge.aten.hardtanh.default HardTanh_ = exir_ops.edge.aten.hardtanh_.default LeakyRelu = exir_ops.edge.aten.leaky_relu.default +MM = exir_ops.edge.aten.mm.default MaxPool2DWithIndices = exir_ops.edge.aten.max_pool2d_with_indices.default MeanDim = exir_ops.edge.aten.mean.dim MulTensor = exir_ops.edge.aten.mul.Tensor PermuteCopy = exir_ops.edge.aten.permute_copy.default QuantizePerChannel = exir_ops.edge.quantized_decomposed.quantize_per_channel.default QuantizePerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default -PermuteCopy = exir_ops.edge.aten.permute_copy.default Relu = exir_ops.edge.aten.relu.default Sigmoid = exir_ops.edge.aten.sigmoid.default Slice = exir_ops.edge.aten.slice.Tensor diff --git a/backends/nxp/tests/use_qat.py b/backends/nxp/tests/use_qat.py index 5994d5aa193..54592a3a96c 100644 --- a/backends/nxp/tests/use_qat.py +++ b/backends/nxp/tests/use_qat.py @@ -8,4 +8,9 @@ def use_qat(request): def pytest_generate_tests(metafunc): if "use_qat" in metafunc.fixturenames: - metafunc.parametrize("use_qat", [True, False], indirect=True) + metafunc.parametrize( + "use_qat", + [True, False], + indirect=True, + ids=lambda use_qat: "QAT" if use_qat else "PTQ", + )