diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 90ddfaaf01f..6e19d38cdf9 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -1,3 +1,8 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + # mypy: allow-untyped-defs import itertools from typing import Callable, Optional @@ -995,6 +1000,48 @@ def _annotate_mul( return annotated_partitions +def _has_annotated_qparam_source(node: Node) -> bool: + """Check whether a node traces to annotated output qparams. + + Walk backward through qparam-preserving view/layout ops to find an already + annotated activation producer. + """ + visited: set[Node] = set() + while node not in visited: + visited.add(node) + + quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, None) + if ( + quantization_annotation is not None + and quantization_annotation._annotated + and quantization_annotation.output_qspec is not None + ): + return True + + if node.op != "call_function" or not _is_share_obs_or_fq_op(node.target): + return False + + prev_node = node.args[0] + if not isinstance(prev_node, Node): + return False + node = prev_node + + return False + + +def _get_cat_qparam_source(inputs) -> object: + """Choose the input that should own a cat node's shared qparams. + + Prefer the first input that traces to annotated output qparams. Fall back to + the first input otherwise. + """ + for input_act in inputs: + if isinstance(input_act, Node) and _has_annotated_qparam_source(input_act): + return input_act + + return inputs[0] + + # TODO: remove Optional in return type, fix annotated_partitions logic @register_annotator("cat") def _annotate_cat( @@ -1014,18 +1061,20 @@ def _annotate_cat( input_act_qspec = get_input_act_qspec(quantization_config) inputs = cat_node.args[0] + input_act_qparam_source = _get_cat_qparam_source(inputs) input_qspec_map = {} - input_act0 = inputs[0] # type: ignore[index] - if isinstance(input_act0, Node): - input_qspec_map[input_act0] = input_act_qspec + if isinstance(input_act_qparam_source, Node): + input_qspec_map[input_act_qparam_source] = input_act_qspec - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type] - for input_act in inputs[1:]: # type: ignore[index, union-attr] + shared_with_source_qspec = SharedQuantizationSpec( + (input_act_qparam_source, cat_node) # type: ignore[arg-type] + ) + for input_act in inputs: # type: ignore[union-attr] if input_act not in input_qspec_map: - input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index] + input_qspec_map[input_act] = shared_with_source_qspec # type: ignore[index] - output_act_qspec = shared_with_input0_qspec + output_act_qspec = shared_with_source_qspec cat_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -1045,12 +1094,14 @@ def _is_share_obs_or_fq_op(op: Callable) -> bool: torch.ops.aten.mean.dim, torch.ops.aten.permute.default, torch.ops.aten.permute_copy.default, + torch.ops.aten.transpose.int, torch.ops.aten.squeeze.dim, torch.ops.aten.squeeze_copy.dim, # TODO: remove? torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.view_copy.default, torch.ops.aten.view.default, + torch.ops.aten.reshape.default, torch.ops.aten.slice_copy.Tensor, torch.ops.aten.flatten.using_ints, ] diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 11e246f541a..f3fa7e06276 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -199,6 +200,48 @@ def test_qs8_cat_with_empty_tensor(self): ) self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) + class CatAfterConvAndTranspose(torch.nn.Module): + def __init__(self): + super().__init__() + self.proj = torch.nn.Conv2d( + in_channels=3, + out_channels=8, + kernel_size=(4, 4), + stride=(4, 4), + bias=False, + ) + self.cls_token = torch.nn.Parameter(torch.full((1, 1, 8), 4.0)) + self.pos_embed = torch.nn.Parameter(torch.full((1, 5, 8), 0.125)) + + with torch.no_grad(): + self.proj.weight.fill_(0.025) + + def forward(self, x): + patch_tokens = self.proj(x).flatten(2).transpose(1, 2) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + tokens = torch.cat((cls_token, patch_tokens), dim=1) + return tokens + self.pos_embed + + def test_qs8_cat_uses_annotated_transpose_path_qparams(self): + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(self.CatAfterConvAndTranspose(), inputs) + .quantize() + .export() + .check_count({"torch.ops.aten.cat": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_cat", + "torch.ops.quantized_decomposed", + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(inputs=inputs) + ) + class CatNegativeDim(torch.nn.Module): def __init__(self): super().__init__()