From d06094868d8dcc7c7206393844b0eb0c4ac8dfe6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 15:09:11 +0000 Subject: [PATCH 1/2] Initial plan From 588f9b02eff96131fe12911a62493acedce5b0b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 15:17:13 +0000 Subject: [PATCH 2/2] fix: implement scatter_reduce mean reduction correctly for ONNX export ONNX ScatterElements doesn't support 'mean' reduction. Previously the code mapped 'mean' to 'none' (no reduction), which just overwrote values instead of computing the mean. The fix implements mean as sum/count: - scatter_sum: ScatterElements with reduction='add' onto zeros - scatter_count: ScatterElements of ones with reduction='add' onto zeros - For include_self=True: add self to sum and 1 to count - For include_self=False: use max(count, 1) to avoid div-by-zero (positions with count=0 also have sum=0, so 0/1=0 is correct) Also removes the xfail for scatter_reduce mean in ops_test_data.py and adds e2e tests for both include_self=True and include_self=False cases. Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- .../function_libs/torch_lib/ops/core.py | 56 ++++++++++++++++--- .../function_libs/torch_lib/e2e_ops_tests.py | 35 ++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 1 - 3 files changed, 82 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..6b33cace86 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8691,14 +8691,6 @@ def aten_scatter_reduce( ): """scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor""" - reduce_mode = { # convert torch string name to onnx string name - "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition - "sum": "add", - "prod": "mul", - "amin": "min", - "amax": "max", - } - onnx_reduce = reduce_mode[reduce] dtype = src.dtype or self.dtype assert dtype is not None, "dtype should be not None" @@ -8709,13 +8701,59 @@ def aten_scatter_reduce( index = op.Reshape(index, neg_1) src = op.Reshape(src, neg_1) + if reduce == "mean": + # ONNX ScatterElements does not support "mean" reduction. + # Implement mean as: mean = sum(src) / count(src), with optional include_self. + zero_val = ir.tensor([0], dtype=dtype) + one_val = ir.tensor([1], dtype=dtype) + # Scatter sum of src values onto a zeros tensor + scatter_sum = op.ScatterElements( + op.ConstantOfShape(op.Shape(self), value=zero_val), + index, + src, + axis=dim, + reduction="add", + ) + # Scatter count of src contributions onto a zeros tensor + scatter_count = op.ScatterElements( + op.ConstantOfShape(op.Shape(self), value=zero_val), + index, + op.ConstantOfShape(op.Shape(src), value=one_val), + axis=dim, + reduction="add", + ) + if include_self: + # Include self in both sum and count + total_sum = op.Add(self, scatter_sum) + total_count = op.Add( + op.ConstantOfShape(op.Shape(self), value=one_val), scatter_count + ) + else: + total_sum = scatter_sum + # Avoid division by zero: where count == 0, sum is also 0, so 0/1 = 0 is correct + total_count = op.Max( + scatter_count, + op.ConstantOfShape(op.Shape(scatter_count), value=one_val), + ) + result = op.Div(total_sum, total_count) + if self_is_scalar: + result = op.Squeeze(result) + return result + + reduce_mode = { # convert torch string name to onnx string name + "sum": "add", + "prod": "mul", + "amin": "min", + "amax": "max", + } + onnx_reduce = reduce_mode[reduce] + if not include_self: # onnx standard always assume the value from self is part of the reduction. # A first step is added to replace the impacted value by another one # chosen in a way that the results of the reduction is not changed # whether or not it takes part in it. # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. - # mean is not supported. if onnx_reduce == "max": if dtype in { ir.DataType.FLOAT16, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 019e6f7fe5..eae1a631e8 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -39,6 +39,41 @@ def forward(self, x, indices, updates): onnx_program = torch.onnx.export(model, xs, dynamo=True) _testing.assert_onnx_program(onnx_program) + def test_scatter_reduce_mean_include_self_false(self): + """Test scatter_reduce with reduce='mean' and include_self=False (GitHub issue).""" + + class ScatterMeanModel(torch.nn.Module): + def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: + index = batch.unsqueeze(1).repeat(1, h.shape[1]) + groups = batch.max().int() + 1 + out = torch.zeros(groups, h.shape[1], dtype=h.dtype, device=h.device) + out = out.scatter_reduce_(0, index, h, reduce="mean", include_self=False) + return out + + h = torch.tensor( + [[1.0, 10.0], [3.0, 30.0], [5.0, 50.0], [7.0, 70.0], [2.0, 20.0], [4.0, 40.0]], + dtype=torch.float32, + ) + batch = torch.tensor([0, 0, 1, 1, 2, 2], dtype=torch.int64) + onnx_program = torch.onnx.export(ScatterMeanModel(), (h, batch), dynamo=True) + _testing.assert_onnx_program(onnx_program) + + def test_scatter_reduce_mean_include_self_true(self): + """Test scatter_reduce with reduce='mean' and include_self=True.""" + + class ScatterMeanIncludeSelfModel(torch.nn.Module): + def forward(self, x: torch.Tensor, index: torch.Tensor, src: torch.Tensor) -> torch.Tensor: + x = x.clone() + return x.scatter_reduce(0, index, src, reduce="mean", include_self=True) + + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32) + index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64) + src = torch.tensor([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]], dtype=torch.float32) + onnx_program = torch.onnx.export( + ScatterMeanIncludeSelfModel(), (x, index, src), dynamo=True + ) + _testing.assert_onnx_program(onnx_program) + def test_pow_tensor_scalar_int_float(self): class PowModel(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..b67297e33a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1841,7 +1841,6 @@ def _where_input_wrangler( core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, ) - .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64),