Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# mypy: allow-untyped-defs
import itertools
from typing import Callable, Optional
Expand Down Expand Up @@ -995,6 +1000,48 @@ def _annotate_mul(
return annotated_partitions


def _has_annotated_qparam_source(node: Node) -> bool:
"""Check whether a node traces to annotated output qparams.

Walk backward through qparam-preserving view/layout ops to find an already
annotated activation producer.
"""
visited: set[Node] = set()
while node not in visited:
visited.add(node)

quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, None)
if (
quantization_annotation is not None
and quantization_annotation._annotated
and quantization_annotation.output_qspec is not None
):
return True

if node.op != "call_function" or not _is_share_obs_or_fq_op(node.target):
return False

prev_node = node.args[0]
if not isinstance(prev_node, Node):
return False
node = prev_node

return False


def _get_cat_qparam_source(inputs) -> object:
"""Choose the input that should own a cat node's shared qparams.

Prefer the first input that traces to annotated output qparams. Fall back to
the first input otherwise.
"""
for input_act in inputs:
if isinstance(input_act, Node) and _has_annotated_qparam_source(input_act):
return input_act

return inputs[0]


# TODO: remove Optional in return type, fix annotated_partitions logic
@register_annotator("cat")
def _annotate_cat(
Expand All @@ -1014,18 +1061,20 @@ def _annotate_cat(

input_act_qspec = get_input_act_qspec(quantization_config)
inputs = cat_node.args[0]
input_act_qparam_source = _get_cat_qparam_source(inputs)

input_qspec_map = {}
input_act0 = inputs[0] # type: ignore[index]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
if isinstance(input_act_qparam_source, Node):
input_qspec_map[input_act_qparam_source] = input_act_qspec

shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type]
for input_act in inputs[1:]: # type: ignore[index, union-attr]
shared_with_source_qspec = SharedQuantizationSpec(
(input_act_qparam_source, cat_node) # type: ignore[arg-type]
)
for input_act in inputs: # type: ignore[union-attr]
if input_act not in input_qspec_map:
input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index]
input_qspec_map[input_act] = shared_with_source_qspec # type: ignore[index]

output_act_qspec = shared_with_input0_qspec
output_act_qspec = shared_with_source_qspec

cat_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
Expand All @@ -1045,12 +1094,14 @@ def _is_share_obs_or_fq_op(op: Callable) -> bool:
torch.ops.aten.mean.dim,
torch.ops.aten.permute.default,
torch.ops.aten.permute_copy.default,
torch.ops.aten.transpose.int,
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze_copy.dim,
# TODO: remove?
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.view_copy.default,
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
torch.ops.aten.slice_copy.Tensor,
torch.ops.aten.flatten.using_ints,
]
Expand Down
43 changes: 43 additions & 0 deletions backends/xnnpack/test/ops/test_cat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -199,6 +200,48 @@ def test_qs8_cat_with_empty_tensor(self):
)
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)

class CatAfterConvAndTranspose(torch.nn.Module):
def __init__(self):
super().__init__()
self.proj = torch.nn.Conv2d(
in_channels=3,
out_channels=8,
kernel_size=(4, 4),
stride=(4, 4),
bias=False,
)
self.cls_token = torch.nn.Parameter(torch.full((1, 1, 8), 4.0))
self.pos_embed = torch.nn.Parameter(torch.full((1, 5, 8), 0.125))

with torch.no_grad():
self.proj.weight.fill_(0.025)

def forward(self, x):
patch_tokens = self.proj(x).flatten(2).transpose(1, 2)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
tokens = torch.cat((cls_token, patch_tokens), dim=1)
return tokens + self.pos_embed

def test_qs8_cat_uses_annotated_transpose_path_qparams(self):
inputs = (torch.randn(1, 3, 8, 8),)
(
Tester(self.CatAfterConvAndTranspose(), inputs)
.quantize()
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_cat",
"torch.ops.quantized_decomposed",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs(inputs=inputs)
)

class CatNegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading