Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8691,14 +8691,6 @@ def aten_scatter_reduce(
):
"""scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor"""

reduce_mode = { # convert torch string name to onnx string name
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
"sum": "add",
"prod": "mul",
"amin": "min",
"amax": "max",
}
onnx_reduce = reduce_mode[reduce]
dtype = src.dtype or self.dtype
assert dtype is not None, "dtype should be not None"

Expand All @@ -8709,13 +8701,59 @@ def aten_scatter_reduce(
index = op.Reshape(index, neg_1)
src = op.Reshape(src, neg_1)

if reduce == "mean":
# ONNX ScatterElements does not support "mean" reduction.
# Implement mean as: mean = sum(src) / count(src), with optional include_self.
zero_val = ir.tensor([0], dtype=dtype)
one_val = ir.tensor([1], dtype=dtype)
# Scatter sum of src values onto a zeros tensor
scatter_sum = op.ScatterElements(
op.ConstantOfShape(op.Shape(self), value=zero_val),
index,
src,
axis=dim,
reduction="add",
)
# Scatter count of src contributions onto a zeros tensor
scatter_count = op.ScatterElements(
op.ConstantOfShape(op.Shape(self), value=zero_val),
index,
op.ConstantOfShape(op.Shape(src), value=one_val),
axis=dim,
reduction="add",
)
if include_self:
# Include self in both sum and count
total_sum = op.Add(self, scatter_sum)
total_count = op.Add(
op.ConstantOfShape(op.Shape(self), value=one_val), scatter_count
)
else:
total_sum = scatter_sum
# Avoid division by zero: where count == 0, sum is also 0, so 0/1 = 0 is correct
total_count = op.Max(
scatter_count,
op.ConstantOfShape(op.Shape(scatter_count), value=one_val),
)
result = op.Div(total_sum, total_count)
if self_is_scalar:
result = op.Squeeze(result)
return result

reduce_mode = { # convert torch string name to onnx string name
"sum": "add",
"prod": "mul",
"amin": "min",
"amax": "max",
}
onnx_reduce = reduce_mode[reduce]

if not include_self:
# onnx standard always assume the value from self is part of the reduction.
# A first step is added to replace the impacted value by another one
# chosen in a way that the results of the reduction is not changed
# whether or not it takes part in it.
# It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul.
# mean is not supported.
if onnx_reduce == "max":
if dtype in {
ir.DataType.FLOAT16,
Expand Down
35 changes: 35 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,41 @@ def forward(self, x, indices, updates):
onnx_program = torch.onnx.export(model, xs, dynamo=True)
_testing.assert_onnx_program(onnx_program)

def test_scatter_reduce_mean_include_self_false(self):
"""Test scatter_reduce with reduce='mean' and include_self=False (GitHub issue)."""

class ScatterMeanModel(torch.nn.Module):
def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
index = batch.unsqueeze(1).repeat(1, h.shape[1])
groups = batch.max().int() + 1
out = torch.zeros(groups, h.shape[1], dtype=h.dtype, device=h.device)
out = out.scatter_reduce_(0, index, h, reduce="mean", include_self=False)
return out

h = torch.tensor(
[[1.0, 10.0], [3.0, 30.0], [5.0, 50.0], [7.0, 70.0], [2.0, 20.0], [4.0, 40.0]],
dtype=torch.float32,
)
batch = torch.tensor([0, 0, 1, 1, 2, 2], dtype=torch.int64)
onnx_program = torch.onnx.export(ScatterMeanModel(), (h, batch), dynamo=True)
_testing.assert_onnx_program(onnx_program)

def test_scatter_reduce_mean_include_self_true(self):
"""Test scatter_reduce with reduce='mean' and include_self=True."""

class ScatterMeanIncludeSelfModel(torch.nn.Module):
def forward(self, x: torch.Tensor, index: torch.Tensor, src: torch.Tensor) -> torch.Tensor:
x = x.clone()
return x.scatter_reduce(0, index, src, reduce="mean", include_self=True)

x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)
index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
src = torch.tensor([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]], dtype=torch.float32)
onnx_program = torch.onnx.export(
ScatterMeanIncludeSelfModel(), (x, index, src), dynamo=True
)
_testing.assert_onnx_program(onnx_program)

def test_pow_tensor_scalar_int_float(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
1 change: 0 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,6 @@ def _where_input_wrangler(
core_ops.aten_scatter_reduce,
input_wrangler=_scatter_reduce_input_wrangler,
)
.xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option")
.xfail(
variant_name="prod",
dtypes=(torch.float16, torch.float64),
Expand Down
Loading