Skip to content

[6058907] Fix ShapeInferenceError in ONNX int8+fp16 quantization of weakly-typed models#1627

Open
ajrasane wants to merge 1 commit into
mainfrom
ajrasane/onnx-stale-shape-typeinfer-fix
Open

[6058907] Fix ShapeInferenceError in ONNX int8+fp16 quantization of weakly-typed models#1627
ajrasane wants to merge 1 commit into
mainfrom
ajrasane/onnx-stale-shape-typeinfer-fix

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented Jun 4, 2026

What does this PR do?

Type of change: Bug fix

ONNX INT8 + FP16 quantization (--quantize_mode int8 --high_precision_dtype fp16) crashed with a ShapeInferenceError on weakly-typed models (e.g. TensorFlow exports). Two root causes, both fixed in modelopt/onnx/utils.py:

  • Stale rank-0 output shapes. Such models can declare a graph.output rank that conflicts with the graph topology — typically a leftover rank-0 (scalar) annotation on a tensor that is really rank-2+. A stale rank-0 passes onnx.checker (a scalar is valid) but poisons downstream shape inference: ORT fails while augmenting the model for INT8 calibration (axis must be in [-rank, rank-1]. Input rank was 0), and onnx.shape_inference(strict_mode=True) raises Inferred shape and existing shape differ in rank during FP16 autocast. clear_stale_value_info now reconciles stale output shapes — it clears and re-derives them from the operator graph via ORT symbolic shape inference (falling back to the size-aware infer_shapes wrapper) and adopts the inferred shape. A graph output is never left without a shape field (which onnx.checker requires); an output whose shape cannot be re-derived keeps its original declaration.

  • Ops ONNX static shape inference can't resolve. The same models can contain ops (e.g. TopK) that ONNX's static shape inference gives up on, leaving downstream tensors untyped and breaking AutoCast's type lookups. infer_types now falls back to the schema-based standalone type inferencer when ONNX shape inference raises or leaves tensors untyped, running the fallback on the shape-inferred model so any shapes ONNX did derive are preserved.

Healthy models are unaffected: re-inference reproduces their existing output shapes, and a fully typed graph skips the fallback.

Usage

python -m modelopt.onnx.quantization \
  --quantize_mode int8 --high_precision_dtype fp16 \
  --onnx_path model.onnx --output_path model_int8_fp16.onnx

Testing

  • Added CPU-only unit tests in tests/unit/onnx/test_onnx_utils.py: stale rank-0 output reconciliation, preservation of a valid output shape, and the type-inference fallback on a TopK-overflow model. Run: CUDA_VISIBLE_DEVICES="" pytest tests/unit/onnx/test_onnx_utils.py.
  • Verified end-to-end on a weakly-typed object-detection model that previously failed: the command now completes and produces a valid quantized FP16 model (onnx.checker passes, ORT loads it, QuantizeLinear/DequantizeLinear nodes inserted, FP16 initializers present, all graph-output ranks correct).
  • Full tests/unit/onnx suite: no new failures versus the base branch.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Summary by CodeRabbit

  • Bug Fixes

    • Resolved cases where mixed-precision conversion left tensors untyped by reconciling stale output shapes and using stricter inference with a safe fallback, improving model validity for INT8+FP16 conversions.
  • Tests

    • Added unit tests validating output-shape reconciliation and the strict-mode inference fallback, including scenarios with unresolved ops to ensure conversion robustness.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 4, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 4, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

Run ONNX shape inference in strict mode with a try/except fallback to standalone type inference; add output-shape reconciliation to fix stale output declarations; integrate strict-mode inference into AutoCast conversions; add unit tests and a changelog entry.

Changes

Shape inference robustness and reconciliation

Layer / File(s) Summary
Core shape/type inference refactoring
modelopt/onnx/utils.py
infer_types() now runs ONNX inference inside try/except and falls back to _infer_types_only(model) on errors; added _reconcile_stale_output_shapes() to re-derive and patch stale graph.output shapes; clear_stale_value_info() includes fixed shape count in its return.
AutoCast strict-mode integration
modelopt/onnx/autocast/convert.py
convert_to_mixed_precision and convert_to_f16 call infer_types(..., strict_mode=True) to ensure strict ONNX inference with fallback.
Test coverage for inference and reconciliation
tests/unit/onnx/test_onnx_utils.py, tests/unit/onnx/autocast/test_autocast.py
Added tests for reconciling stale rank-0 outputs to correct ranks, preserving valid shapes and dim_param names, and verifying fallback to standalone inference when ONNX shape inference fails (TopK overflow). Also added an autocast test exercising the fallback path.
Changelog documentation
CHANGELOG.rst
Adds a bugfix note describing the ShapeInferenceError fix for ONNX INT8+FP16 quantization on weakly-typed models and the reconciliation/strict-mode inference changes.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • cjluo-nv
  • galagam
  • gcunhase
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main purpose of the PR: fixing ShapeInferenceError crashes during ONNX INT8 + FP16 quantization on weakly-typed models by addressing stale rank-0 output shapes and type inference fallbacks.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. Code avoids unsafe deserialization, hardcoded trust settings, and eval/exec patterns. All imports are from approved dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ajrasane/onnx-stale-shape-typeinfer-fix

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1627/

Built to branch gh-pages at 2026-06-04 21:58 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 4, 2026

Codecov Report

❌ Patch coverage is 79.24528% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.21%. Comparing base (ca7eb64) to head (cf2d65a).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/utils.py 78.43% 11 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1627      +/-   ##
==========================================
- Coverage   77.39%   77.21%   -0.18%     
==========================================
  Files         482      482              
  Lines       52960    53101     +141     
==========================================
+ Hits        40986    41001      +15     
- Misses      11974    12100     +126     
Flag Coverage Δ
examples 42.86% <5.66%> (+0.86%) ⬆️
gpu 59.43% <67.92%> (-0.58%) ⬇️
unit 53.95% <79.24%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ajrasane ajrasane force-pushed the ajrasane/onnx-stale-shape-typeinfer-fix branch 2 times, most recently from 6d53f4d to e6576a6 Compare June 4, 2026 19:40
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Targeted bug fix with three new tests. Two concerns worth a human look before merging:

  1. Dynamic-dim collateral damage in _reconcile_stale_output_shapes. The helper runs unconditionally on every model that goes through _preprocess_onnxclear_stale_value_info (not just weakly-typed exports), and it calls ORT's SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True). With those flags, ORT routinely rewrites user-named dim_params (e.g. batch, seq_len) into internal unk__N symbols, or commits a guessed rank when it can't fully propagate. The PR description says "re-inference reproduces their existing output shapes" and "healthy models are unaffected", but the bytes-level _sig comparison will treat a batchunk__0 rename as a change and overwrite the declared shape, which can break downstream consumers that look up output dims by name (export pipelines, TRT profile shapes, etc.). The added tests only cover fully static shapes ([3,4] / [3,5] / [1,3]), so this case is untested. Worth either (a) skipping the overwrite when only dim_param names differ but ranks/values agree, (b) only invoking the reconciler when the declared shape is actually inconsistent with the topology (the original motivation), or (c) at minimum adding a dynamic-batch regression test.

  2. strict_mode=True is now the default for infer_types callers in autocast/convert.py. The fallback to _infer_types_only on any exception is a behavior broadening — previously-tolerated weakly-typed graphs now silently route through the standalone path instead. The standalone inferencer is a reasonable fallback, but this is a change in the AutoCast hot path with no new AutoCast-level test (the new test_infer_types_falls_back_to_standalone_when_onnx_fails exercises infer_types directly, not convert_to_mixed_precision / convert_to_f16). Worth confirming the existing AutoCast test suite still passes against a representative weakly-typed model and not just the synthetic TopK fixture.

Fix itself is well-scoped and the docstrings/comments are clear; flagging mainly so an ONNX-pipeline owner can sign off on the dynamic-shape behavior change.

@ajrasane ajrasane marked this pull request as ready for review June 4, 2026 21:21
@ajrasane ajrasane requested review from a team as code owners June 4, 2026 21:22
@ajrasane ajrasane requested a review from galagam June 4, 2026 21:22
@ajrasane ajrasane force-pushed the ajrasane/onnx-stale-shape-typeinfer-fix branch from e6576a6 to c7dc86b Compare June 4, 2026 21:44
@ajrasane
Copy link
Copy Markdown
Contributor Author

ajrasane commented Jun 4, 2026

Thanks for the careful review — addressed both in the latest revision:

1. Dynamic-dim collateral damage. The reconciler now overwrites a declared output shape only when it's genuinely stale (a rank mismatch or a conflicting concrete dim); outputs that merely differ in symbolic dim_param names (e.g. a re-derived unk__0 vs a declared batch) keep their declaration, so healthy models with dynamic dims are left untouched. I also dropped guess_output_rank=True — it's unnecessary for the fix (symbolic inference still derives the affected output ranks without it) and was the source of the guessed-rank risk you noted. Added test_clear_stale_value_info_preserves_dynamic_dim_names as a regression test (declares ["my_batch", 4] where inference derives ["batch", 4] and asserts the declared dim_param survives).

2. strict_mode=True / AutoCast coverage. Added test_convert_to_f16_falls_back_on_unresolvable_op, which exercises the full convert_to_f16 path (used by INT8 + --high_precision_dtype fp16) on a weakly-typed model whose TopK defeats strict ONNX shape inference, asserting it converts via the standalone type-inference fallback rather than crashing. The existing AutoCast suite (tests/unit/onnx/autocast/) still passes, and the original weakly-typed model is verified end-to-end.

…weakly-typed models

Weakly-typed ONNX models (e.g. TensorFlow exports) can carry graph.output
entries whose stored rank disagrees with the graph topology -- most commonly a
leftover rank-0 (scalar) annotation on a tensor that is really rank-2+. A stale
rank-0 passes onnx.checker (a scalar is valid) but poisons downstream shape
inference: ORT fails while augmenting the model for INT8 calibration
("axis must be in [-rank, rank-1]. Input rank was 0") and
onnx.shape_inference(strict_mode=True) raises "Inferred shape and existing shape
differ in rank" during fp16 autocast. Such models can also contain ops (e.g.
TopK) that ONNX's static shape inference cannot resolve, leaving downstream
tensors untyped and breaking AutoCast's type lookups.

Fixes:

- modelopt/onnx/utils.py: clear_stale_value_info now reconciles stale
  graph.output shapes. It snapshots and clears each output shape, re-derives
  shapes from the operator graph via ORT symbolic shape inference (falling back
  to the size-aware infer_shapes wrapper), and overwrites a declaration only when
  it is genuinely stale -- a rank mismatch or a conflicting concrete dim.
  Outputs that merely differ in symbolic dim_param names (e.g. a re-derived
  unk__0 vs a declared batch) keep their declaration, so healthy models with
  dynamic dims are untouched. A graph output is never left without a shape field
  (onnx.checker requires it).

- modelopt/onnx/utils.py: infer_types falls back to the schema-based standalone
  type inferencer when ONNX shape inference fails (e.g. raises in strict mode on
  an op it cannot resolve), so the returned model is still fully typed.

- modelopt/onnx/autocast/convert.py: convert_to_mixed_precision and
  convert_to_f16 now run the initial infer_types in strict mode, so an
  unresolvable op surfaces as an exception (triggering the fallback above)
  instead of silently leaving tensors untyped.

Healthy models are unaffected: strict shape inference succeeds for them and
output-shape reconciliation preserves their declared (incl. dynamic) shapes.

Adds unit tests for stale rank-0 output reconciliation, preservation of a valid
output shape, preservation of named dynamic dims, the infer_types fallback on a
TopK-overflow model, and an AutoCast-level convert_to_f16 fallback test.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane force-pushed the ajrasane/onnx-stale-shape-typeinfer-fix branch from c7dc86b to cf2d65a Compare June 4, 2026 21:52
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_autocast.py (1)

345-357: ⚡ Quick win

Assert the strict-inference failure precondition before conversion.

This test currently proves conversion succeeds, but it doesn’t explicitly prove strict ONNX inference fails first (the condition that should trigger fallback). Add a precondition assertion so the test remains a precise fallback regression test.

♻️ Suggested test hardening
 def test_convert_to_f16_falls_back_on_unresolvable_op(weakly_typed_topk_model):
@@
-    converted_model = convert_to_f16(weakly_typed_topk_model, keep_io_types=True)
+    with pytest.raises(Exception):
+        onnx.shape_inference.infer_shapes(weakly_typed_topk_model, strict_mode=True)
+
+    converted_model = convert_to_f16(weakly_typed_topk_model, keep_io_types=True)

As per coding guidelines, “Checked-in tests should document expected behavior [and] protect against regressions.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/onnx/autocast/test_autocast.py` around lines 345 - 357, Add a
precondition assertion that strict ONNX inference fails on the
weakly_typed_topk_model before calling convert_to_f16: wrap a call to ONNX
strict inference (e.g., onnx.shape_inference.infer_shapes or infer_types with
strict=True) in pytest.raises(...) to assert it raises for
weakly_typed_topk_model, then proceed to call
convert_to_f16(weakly_typed_topk_model, keep_io_types=True) as before; reference
the test function name test_convert_to_f16_falls_back_on_unresolvable_op and the
fixture weakly_typed_topk_model so the failure-before-conversion condition is
explicit.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/unit/onnx/autocast/test_autocast.py`:
- Around line 345-357: Add a precondition assertion that strict ONNX inference
fails on the weakly_typed_topk_model before calling convert_to_f16: wrap a call
to ONNX strict inference (e.g., onnx.shape_inference.infer_shapes or infer_types
with strict=True) in pytest.raises(...) to assert it raises for
weakly_typed_topk_model, then proceed to call
convert_to_f16(weakly_typed_topk_model, keep_io_types=True) as before; reference
the test function name test_convert_to_f16_falls_back_on_unresolvable_op and the
fixture weakly_typed_topk_model so the failure-before-conversion condition is
explicit.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 34432f62-d82c-4657-b227-1e7dfd5a3692

📥 Commits

Reviewing files that changed from the base of the PR and between c7dc86b and cf2d65a.

📒 Files selected for processing (5)
  • CHANGELOG.rst
  • modelopt/onnx/autocast/convert.py
  • modelopt/onnx/utils.py
  • tests/unit/onnx/autocast/test_autocast.py
  • tests/unit/onnx/test_onnx_utils.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/onnx/autocast/convert.py
  • tests/unit/onnx/test_onnx_utils.py
  • modelopt/onnx/utils.py

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Both critical concerns from the previous review have been addressed:

  1. Dynamic-dim collateral damage (cjluo-nv): _is_stale now only flags rank mismatches or conflicting concrete dims; symbolic dim_param renames are intentionally ignored. guess_output_rank=True was also dropped. The new test_clear_stale_value_info_preserves_dynamic_dim_names test declares [my_batch, 4] while inference would derive [batch, 4] and asserts my_batch is preserved — exactly the regression case raised before.

  2. strict_mode=True AutoCast hot-path: The new test_convert_to_f16_falls_back_on_unresolvable_op exercises the full convert_to_f16 entry point (not just infer_types) on a weakly-typed TopK model, confirming the standalone fallback fires end-to-end.

The local SymbolicShapeInference import inside _reconcile_stale_output_shapes is justified (optional onnxruntime dep with a graceful fallback to infer_shapes). Fix is well-scoped, no design-review concerns (additive fallback paths only, no new subsystem).

Complex PR: spans 5 directories (≥ 5). Looping in a human for approval.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants