From 194bb191a2eadb7723f3b06c4eb594587d754e4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Mon, 8 Jun 2026 16:59:08 +0200 Subject: [PATCH] XNNPACK: Fix cat qparams for quantized ViT token paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug: Cat annotation always used the first input as the shared qparam source. In DeiT Tiny, the first input to token concatenation is the class-token path: cls_token -> expand -> cat The patch-token path is a later cat input: conv -> flatten -> transpose -> cat The conv output has annotated activation qparams, and flatten/transpose are qparam-preserving view/layout ops. They should carry the same qparams with SharedQuantizationSpec. Anchoring cat qparams on the class-token path can leave the patch-token static transpose with different input/output qparams. XNNPACK rejects this during runtime initialization because static transpose only reorders bytes. Fix: Choose the first cat input that traces through qparam-preserving ops to annotated output qparams, falling back to the first input otherwise. Also propagate shared qparams through reshape and transpose so static transpose nodes keep identical input and output qparams. Add a quantized XNNPACK regression covering conv, flatten, transpose, cat, and add. Change-Id: I86fafd584c1cb561bd2d4444ea70c1a1b0650066 Signed-off-by: Måns Nilsson --- .../quantizer/xnnpack_quantizer_utils.py | 65 +++++++++++++++++-- backends/xnnpack/test/ops/test_cat.py | 43 ++++++++++++ 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 90ddfaaf01f..6e19d38cdf9 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -1,3 +1,8 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + # mypy: allow-untyped-defs import itertools from typing import Callable, Optional @@ -995,6 +1000,48 @@ def _annotate_mul( return annotated_partitions +def _has_annotated_qparam_source(node: Node) -> bool: + """Check whether a node traces to annotated output qparams. + + Walk backward through qparam-preserving view/layout ops to find an already + annotated activation producer. + """ + visited: set[Node] = set() + while node not in visited: + visited.add(node) + + quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, None) + if ( + quantization_annotation is not None + and quantization_annotation._annotated + and quantization_annotation.output_qspec is not None + ): + return True + + if node.op != "call_function" or not _is_share_obs_or_fq_op(node.target): + return False + + prev_node = node.args[0] + if not isinstance(prev_node, Node): + return False + node = prev_node + + return False + + +def _get_cat_qparam_source(inputs) -> object: + """Choose the input that should own a cat node's shared qparams. + + Prefer the first input that traces to annotated output qparams. Fall back to + the first input otherwise. + """ + for input_act in inputs: + if isinstance(input_act, Node) and _has_annotated_qparam_source(input_act): + return input_act + + return inputs[0] + + # TODO: remove Optional in return type, fix annotated_partitions logic @register_annotator("cat") def _annotate_cat( @@ -1014,18 +1061,20 @@ def _annotate_cat( input_act_qspec = get_input_act_qspec(quantization_config) inputs = cat_node.args[0] + input_act_qparam_source = _get_cat_qparam_source(inputs) input_qspec_map = {} - input_act0 = inputs[0] # type: ignore[index] - if isinstance(input_act0, Node): - input_qspec_map[input_act0] = input_act_qspec + if isinstance(input_act_qparam_source, Node): + input_qspec_map[input_act_qparam_source] = input_act_qspec - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type] - for input_act in inputs[1:]: # type: ignore[index, union-attr] + shared_with_source_qspec = SharedQuantizationSpec( + (input_act_qparam_source, cat_node) # type: ignore[arg-type] + ) + for input_act in inputs: # type: ignore[union-attr] if input_act not in input_qspec_map: - input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index] + input_qspec_map[input_act] = shared_with_source_qspec # type: ignore[index] - output_act_qspec = shared_with_input0_qspec + output_act_qspec = shared_with_source_qspec cat_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -1045,12 +1094,14 @@ def _is_share_obs_or_fq_op(op: Callable) -> bool: torch.ops.aten.mean.dim, torch.ops.aten.permute.default, torch.ops.aten.permute_copy.default, + torch.ops.aten.transpose.int, torch.ops.aten.squeeze.dim, torch.ops.aten.squeeze_copy.dim, # TODO: remove? torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.view_copy.default, torch.ops.aten.view.default, + torch.ops.aten.reshape.default, torch.ops.aten.slice_copy.Tensor, torch.ops.aten.flatten.using_ints, ] diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 11e246f541a..f3fa7e06276 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -199,6 +200,48 @@ def test_qs8_cat_with_empty_tensor(self): ) self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) + class CatAfterConvAndTranspose(torch.nn.Module): + def __init__(self): + super().__init__() + self.proj = torch.nn.Conv2d( + in_channels=3, + out_channels=8, + kernel_size=(4, 4), + stride=(4, 4), + bias=False, + ) + self.cls_token = torch.nn.Parameter(torch.full((1, 1, 8), 4.0)) + self.pos_embed = torch.nn.Parameter(torch.full((1, 5, 8), 0.125)) + + with torch.no_grad(): + self.proj.weight.fill_(0.025) + + def forward(self, x): + patch_tokens = self.proj(x).flatten(2).transpose(1, 2) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + tokens = torch.cat((cls_token, patch_tokens), dim=1) + return tokens + self.pos_embed + + def test_qs8_cat_uses_annotated_transpose_path_qparams(self): + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(self.CatAfterConvAndTranspose(), inputs) + .quantize() + .export() + .check_count({"torch.ops.aten.cat": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_cat", + "torch.ops.quantized_decomposed", + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(inputs=inputs) + ) + class CatNegativeDim(torch.nn.Module): def __init__(self): super().__init__()