Skip to content

Fix scatter_reduce(reduce="mean") producing incorrect ONNX export results#2829

Draft
Copilot wants to merge 2 commits intomainfrom
copilot/fix-scatter-reduce-onnx-export
Draft

Fix scatter_reduce(reduce="mean") producing incorrect ONNX export results#2829
Copilot wants to merge 2 commits intomainfrom
copilot/fix-scatter-reduce-onnx-export

Conversation

Copy link
Contributor

Copilot AI commented Feb 25, 2026

aten::scatter_reduce.two with reduce="mean" silently produced wrong results after ONNX export because the implementation mapped "mean" → ONNX reduction="none" (a plain overwrite), with a comment acknowledging it was unsupported but no correct fallback.

Root cause

ScatterElements in ONNX has no native mean reduction. The previous code fell through to reduction="none", which simply overwrites each output position with the last scattered value instead of averaging.

Fix

Implements mean explicitly as sum / count using two ScatterElements(reduction="add") passes:

  • scatter_sum: accumulates src values at each output index
  • scatter_count: accumulates ones to get the number of contributions per position
  • include_self=True: adds self to both sum (+self) and count (+1)
  • include_self=False: uses max(count, 1) to avoid 0/0 — positions with no incoming values have sum=0 so 0/1=0 matches PyTorch semantics
# Before fix
h = [[1,10],[3,30],[5,50],[7,70],[2,20],[4,40]]
batch = [0,0,1,1,2,2]
# PyTorch:      [[2,20],[6,60],[3,30]]
# ONNX (wrong): [[3,30],[7,70],[4,40]]  ← last-write wins

# After fix
# ONNX (fixed): [[2,20],[6,60],[3,30]]  ← correct mean

Test changes

  • Removed xfail(variant_name="mean") from ops_test_data.py — this reduction mode is now correctly supported
  • Added two e2e tests covering include_self=False and include_self=True
Original prompt

This section details on the original issue you should resolve

<issue_title>torch.export + torch.onnx.export(dynamo=True) gives incorrect results for scatter_reduce_(reduce="mean")</issue_title>
<issue_description>### 🐛 Describe the bug

Observed behaviour:

  • scatter_reduce(mean) shows a large mismatch after export:
    • max_abs_diff: 10.0
    • mean_abs_diff: 5.5
  • Equivalent sum/count control matches exactly:
    • max_abs_diff: 0.0
    • mean_abs_diff: 0.0

Expected behaviour:

  • ONNX output should match eager PyTorch semantics (within normal floating-point tolerance) for:
    • scatter_reduce_(reduce="mean", include_self=False)

Impact:

  • Silent numerical correctness issue (wrong predictions without a crash).

Environment:

  • torch: 2.7.1
  • onnxruntime: 1.24.1
  • python: 3.12
  • os: macOS (darwin 24.6.0)

Code example:

import numpy as np
import onnxruntime as ort
import torch


class ScatterMeanModel(torch.nn.Module):
    def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        # h: [N, F], batch: [N] with group IDs in [0, G-1]
        index = batch.unsqueeze(1).repeat(1, h.shape[1])
        groups = batch.max().int() + 1
        out = torch.zeros(groups, h.shape[1], dtype=h.dtype, device=h.device)
        out = out.scatter_reduce_(0, index, h, reduce="mean", include_self=False)
        return out


class ScatterSumDivCountModel(torch.nn.Module):
    # Mathematically equivalent grouped mean = sum / count.
    def forward(self, h: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        index = batch.unsqueeze(1).repeat(1, h.shape[1])
        groups = batch.max().int() + 1

        sums = torch.zeros(groups, h.shape[1], dtype=h.dtype, device=h.device)
        sums = sums.scatter_reduce_(0, index, h, reduce="sum", include_self=False)

        ones = torch.ones(h.shape[0], 1, dtype=h.dtype, device=h.device)
        counts = torch.zeros(groups, 1, dtype=h.dtype, device=h.device)
        counts = counts.scatter_reduce_(
            0, batch.unsqueeze(1), ones, reduce="sum", include_self=False
        )
        return sums / counts


def run(model: torch.nn.Module) -> tuple[np.ndarray, np.ndarray, float, float]:
    model.eval()
    h = torch.tensor(
        [
            [1.0, 10.0],
            [3.0, 30.0],
            [5.0, 50.0],
            [7.0, 70.0],
            [2.0, 20.0],
            [4.0, 40.0],
        ],
        dtype=torch.float32,
    )
    batch = torch.tensor([0, 0, 1, 1, 2, 2], dtype=torch.int64)

    with torch.inference_mode():
        pt = model(h, batch).cpu().numpy()

    exported = torch.export.export(model, (h, batch), strict=False)
    onnx_program = torch.onnx.export(exported, f=None, dynamo=True)

    sess = ort.InferenceSession(
        onnx_program.model_proto.SerializeToString(),
        providers=["CPUExecutionProvider"],
    )
    input_names = [i.name for i in sess.get_inputs()]
    ort_out = sess.run(
        None, {input_names[0]: h.numpy(), input_names[1]: batch.numpy()}
    )[0]

    diff = np.abs(pt - ort_out)
    return pt, ort_out, float(diff.max()), float(diff.mean())


print("torch:", torch.__version__)
print("onnxruntime:", ort.__version__)

pt, ort_out, max_abs, mean_abs = run(ScatterMeanModel())
print("\n=== scatter_reduce(mean) ===")
print("PyTorch output:\n", pt)
print("ONNX Runtime output:\n", ort_out)
print("max_abs_diff:", max_abs)
print("mean_abs_diff:", mean_abs)

pt2, ort_out2, max_abs2, mean_abs2 = run(ScatterSumDivCountModel())
print("\n=== sum/count control ===")
print("PyTorch output:\n", pt2)
print("ONNX Runtime output:\n", ort_out2)
print("max_abs_diff:", max_abs2)
print("mean_abs_diff:", mean_abs2)

Example output:

python tmp/repro_pytorch_scatter_reduce_mean_onnx.py
torch: 2.7.1
onnxruntime: 1.24.1
W0225 11:07:22.309000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::nms
W0225 11:07:22.310000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_align
W0225 11:07:22.310000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_pool
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅

=== scatter_reduce(mean) ===
PyTorch output:
 [[ 2. 20.]
 [ 6. 60.]
 [ 3. 30.]]
ONNX Runtime output:
 [[ 3. 30.]
 [ 7. 70.]
 [ 4. 40.]]
max_abs_diff: 10.0
mean_abs_diff: 5.5
W0225 11:07:22.633000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::nms
W0225 11:07:22.634000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_align
W0225 11:07:22.634000 94867 torch/onnx/_internal/exporter/_registration.py:103] torchvision is not installed. Skipping torchvision::roi_pool
[tor...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes pytorch/pytorch#175738

<!-- START COPILOT CODING AGENT TIPS -->
---

🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. [Learn more about Advanced Security.](https://gh.io/cca-advanced-security)

ONNX ScatterElements doesn't support 'mean' reduction. Previously the
code mapped 'mean' to 'none' (no reduction), which just overwrote values
instead of computing the mean.

The fix implements mean as sum/count:
- scatter_sum: ScatterElements with reduction='add' onto zeros
- scatter_count: ScatterElements of ones with reduction='add' onto zeros
- For include_self=True: add self to sum and 1 to count
- For include_self=False: use max(count, 1) to avoid div-by-zero
  (positions with count=0 also have sum=0, so 0/1=0 is correct)

Also removes the xfail for scatter_reduce mean in ops_test_data.py and
adds e2e tests for both include_self=True and include_self=False cases.

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix incorrect results for scatter_reduce with ONNX export Fix scatter_reduce(reduce="mean") producing incorrect ONNX export results Feb 25, 2026
Copilot AI requested a review from justinchuby February 25, 2026 15:17
@codecov
Copy link

codecov bot commented Feb 25, 2026

❌ 4 Tests Failed:

Tests completed Failed Passed Skipped
12741 4 12737 939
View the top 3 failed test(s) by shortest run time
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__scatter_reduce_mean_cpu_float32
Stack Traces | 0.356s run time
.../function_libs/torch_lib/ops_test.py:243: in run_test_output_match
    torch.testing.assert_close(
E   AssertionError: Tensor-likes are not close!
E   
E   Mismatched elements: 30 / 50 (60.0%)
E   Greatest absolute difference: 8.478373527526855 at index (0, 0) (up to 1e-05 allowed)
E   Greatest relative difference: 1.0 at index (0, 0) (up to 1.3e-06 allowed)
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__scatter_reduce_mean_cpu_int32
Stack Traces | 0.826s run time
.../function_libs/torch_lib/ops_test.py:243: in run_test_output_match
    torch.testing.assert_close(
E   AssertionError: Tensor-likes are not equal!
E   
E   Mismatched elements: 30 / 50 (60.0%)
E   Greatest absolute difference: 9 at index (0, 2)
E   Greatest relative difference: 1.0 at index (0, 0)
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__scatter_reduce_mean_cpu_float16
Stack Traces | 2.28s run time
.../function_libs/torch_lib/ops_test_common.py:593: in _capture_graph_and_evaluate_torch_script_evaluator
    return _ort_session_run(model_proto.SerializeToString(), ort_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../function_libs/torch_lib/ops_test_common.py:349: in _ort_session_run
    return session.run(None, ort_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
..../test_onnx_ir_git/lib/python3.11.../onnxruntime/capi/onnxruntime_inference_collection.py:275: in run
    return self._sess.run(output_names, input_feed, run_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running ScatterElements node. Name:'node_ScatterElements_11' Status Message: CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'add'.

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:206: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:607: in _capture_graph_and_evaluate_torch_script_evaluator
    raise RuntimeError(
E   RuntimeError: ONNX Runtime failed to evaluate:
E   Inputs:
E   {'input_0': array(-1.837, dtype=float16),
E    'input_2': array(0),
E    'input_3': array(-2.645, dtype=float16)}
E   Model:
E   <
E      ir_version: 10,
E      opset_import: ["" : 18, "pkg.torch.onnx" : 1, "pkg.onnxscript.torch_lib.common" : 1, "pkg.onnxscript.torch_lib" : 1],
E      producer_name: "torch_test"
E   >
E   main_graph (float16 input_0, int64 input_2, float16 input_3) => (float16 val_16) 
E      <int64[1] val_0, float16[1] val_1, int64[1] val_2, float16[1] val_3, int64[1] val_4, float16[1] val_5, float16[1] val_6, int64[1] val_7, float16[1] val_8, int64[1] val_9, float16[1] val_10, float16[1] val_11, int64[1] val_12, float16[1] val_13, float16[1] val_14, float16[1] val_15>
E   {
E      val_0 = Constant <value_ints: ints = [-1]> ()
E      val_1 = Reshape <allowzero: int = 0> (input_0, val_0)
E      val_2 = Reshape <allowzero: int = 0> (input_2, val_0)
E      val_3 = Reshape <allowzero: int = 0> (input_3, val_0)
E      val_4 = Shape <start: int = 0> (val_1)
E      val_5 = ConstantOfShape <value: tensor = float16[1]...> (val_4)
E      val_6 = ScatterElements <axis: int = 0, reduction: string = "add"> (val_5, val_2, val_3)
E      val_7 = Shape <start: int = 0> (val_1)
E      val_8 = ConstantOfShape <value: tensor = float16[1]...> (val_7)
E      val_9 = Shape <start: int = 0> (val_3)
E      val_10 = ConstantOfShape <value: tensor = float16[1]...> (val_9)
E      val_11 = ScatterElements <axis: int = 0, reduction: string = "add"> (val_8, val_2, val_10)
E      val_12 = Shape <start: int = 0> (val_11)
E      val_13 = ConstantOfShape <value: tensor = float16[1]...> (val_12)
E      val_14 = Max (val_11, val_13)
E      val_15 = Div (val_6, val_14)
E      val_16 = Squeeze (val_15)
E   }
View the full list of 1 ❄️ flaky test(s)
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__scatter_reduce_mean_cpu_int64

Flake rate in main: 8.33% (Passed 11 times, Failed 1 times)

Stack Traces | 0.495s run time
.../function_libs/torch_lib/ops_test.py:243: in run_test_output_match
    torch.testing.assert_close(
E   AssertionError: Tensor-likes are not equal!
E   
E   Mismatched elements: 30 / 50 (60.0%)
E   Greatest absolute difference: 9 at index (0, 2)
E   Greatest relative difference: 1.0 at index (0, 0)

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

2 participants