Skip to content

[Bug] ONNX Softmax/LogSoftmax/Hardmax frontend ignores opset <13 semantics #19420

@wuyii8941

Description

@wuyii8941

Description

The ONNX frontend converter for Softmax, LogSoftmax, and Hardmax only implements _impl_v13. It does not implement _impl_v1 or _impl_v11, which have different semantics from v13.

This causes silently wrong results when importing any ONNX model with opset version ≤12.

ONNX spec difference

  • Opset < 13: The input is coerced to 2D by flattening at axis (default 1). Softmax is computed on the last dimension. The result is reshaped back to the original shape.
  • Opset ≥ 13: Softmax is computed directly along axis (default -1). No flattening.

For a 3D input (2,3,4) with default axis, opset 11 computes softmax over the flattened last 2 dims (12 elements), while opset 13 computes softmax over the last dim (4 elements). These produce very different results.

Reproduction

import numpy as np
import onnx
from onnx import helper, TensorProto
import onnxruntime as ort

shape = (2, 3, 4)
X = np.random.randn(*shape).astype(np.float32)

X_info = helper.make_tensor_value_info("X", TensorProto.FLOAT, list(shape))
Y_info = helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(shape))
node = helper.make_node("Softmax", ["X"], ["Y"])
graph = helper.make_graph([node], "test", [X_info], [Y_info])

model11 = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)])

# ORT handles opset 11 correctly
sess = ort.InferenceSession(model11.SerializeToString())
ort_out = sess.run(None, {"X": X})[0]

# TVM treats it as opset 13
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx

mod = from_onnx(model11)
pipeline = tvm.ir.transform.Sequential([relax.transform.LegalizeOps()])
exe = tvm.relax.build(pipeline(mod), target="llvm")
vm = tvm.relax.VirtualMachine(exe, device=tvm.cpu())
tvm_out = vm["main"](tvm.runtime.tensor(X, device=tvm.cpu())).numpy()

print(f"max_diff = {np.max(np.abs(tvm_out - ort_out))}")  # ~0.47

Root cause

In onnx_frontend.py, the Softmax class only has _impl_v13:

class Softmax(OnnxOpConverter):
    @classmethod
    def _impl_v13(cls, bb, inputs, attr, params):
        axis = attr.get("axis", -1)
        return relax.op.nn.softmax(inputs[0], axis=axis)

Missing implementations:

  • _impl_v1: axis default is 1, flatten semantics
  • _impl_v11: axis default is 1, flatten semantics

Same issue affects LogSoftmax and Hardmax.

Environment

  • TVM: v0.23.0 (also unfixed on main branch)
  • Python: 3.11

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions