diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..6b33cace86 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8691,14 +8691,6 @@ def aten_scatter_reduce( ): """scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor""" - reduce_mode = { # convert torch string name to onnx string name - "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition - "sum": "add", - "prod": "mul", - "amin": "min", - "amax": "max", - } - onnx_reduce = reduce_mode[reduce] dtype = src.dtype or self.dtype assert dtype is not None, "dtype should be not None" @@ -8709,13 +8701,59 @@ def aten_scatter_reduce( index = op.Reshape(index, neg_1) src = op.Reshape(src, neg_1) + if reduce == "mean": + # ONNX ScatterElements does not support "mean" reduction. + # Implement mean as: mean = sum(src) / count(src), with optional include_self. + zero_val = ir.tensor([0], dtype=dtype) + one_val = ir.tensor([1], dtype=dtype) + # Scatter sum of src values onto a zeros tensor + scatter_sum = op.ScatterElements( + op.ConstantOfShape(op.Shape(self), value=zero_val), + index, + src, + axis=dim, + reduction="add", + ) + # Scatter count of src contributions onto a zeros tensor + scatter_count = op.ScatterElements( + op.ConstantOfShape(op.Shape(self), value=zero_val), + index, + op.ConstantOfShape(op.Shape(src), value=one_val), + axis=dim, + reduction="add", + ) + if include_self: + # Include self in both sum and count + total_sum = op.Add(self, scatter_sum) + total_count = op.Add( + op.ConstantOfShape(op.Shape(self), value=one_val), scatter_count + ) + else: + total_sum = scatter_sum + # Avoid division by zero: where count == 0, sum is also 0, so 0/1 = 0 is correct + total_count = op.Max( + scatter_count, + op.ConstantOfShape(op.Shape(scatter_count), value=one_val), + ) + result = op.Div(total_sum, total_count) + if self_is_scalar: + result = op.Squeeze(result) + return result + + reduce_mode = { # convert torch string name to onnx string name + "sum": "add", + "prod": "mul", + "amin": "min", + "amax": "max", + } + onnx_reduce = reduce_mode[reduce] + if not include_self: # onnx standard always assume the value from self is part of the reduction. # A first step is added to replace the impacted value by another one # chosen in a way that the results of the reduction is not changed # whether or not it takes part in it. # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. - # mean is not supported. if onnx_reduce == "max": if dtype in { ir.DataType.FLOAT16, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 019e6f7fe5..eae1a631e8 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -39,6 +39,41 @@ def forward(self, x, indices, updates): onnx_program = torch.onnx.export(model, xs, dynamo=True) _testing.assert_onnx_program(onnx_program) + def test_scatter_reduce_mean_include_self_false(self): + """Test scatter_reduce with reduce='mean' and include_self=False (GitHub issue).""" + + class ScatterMeanModel(torch.nn.Module): + def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: + index = batch.unsqueeze(1).repeat(1, h.shape[1]) + groups = batch.max().int() + 1 + out = torch.zeros(groups, h.shape[1], dtype=h.dtype, device=h.device) + out = out.scatter_reduce_(0, index, h, reduce="mean", include_self=False) + return out + + h = torch.tensor( + [[1.0, 10.0], [3.0, 30.0], [5.0, 50.0], [7.0, 70.0], [2.0, 20.0], [4.0, 40.0]], + dtype=torch.float32, + ) + batch = torch.tensor([0, 0, 1, 1, 2, 2], dtype=torch.int64) + onnx_program = torch.onnx.export(ScatterMeanModel(), (h, batch), dynamo=True) + _testing.assert_onnx_program(onnx_program) + + def test_scatter_reduce_mean_include_self_true(self): + """Test scatter_reduce with reduce='mean' and include_self=True.""" + + class ScatterMeanIncludeSelfModel(torch.nn.Module): + def forward(self, x: torch.Tensor, index: torch.Tensor, src: torch.Tensor) -> torch.Tensor: + x = x.clone() + return x.scatter_reduce(0, index, src, reduce="mean", include_self=True) + + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32) + index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64) + src = torch.tensor([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]], dtype=torch.float32) + onnx_program = torch.onnx.export( + ScatterMeanIncludeSelfModel(), (x, index, src), dynamo=True + ) + _testing.assert_onnx_program(onnx_program) + def test_pow_tensor_scalar_int_float(self): class PowModel(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..b67297e33a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1841,7 +1841,6 @@ def _where_input_wrangler( core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, ) - .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64),