[6058907] Fix ShapeInferenceError in ONNX int8+fp16 quantization of weakly-typed models#1627
[6058907] Fix ShapeInferenceError in ONNX int8+fp16 quantization of weakly-typed models#1627ajrasane wants to merge 1 commit into
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughRun ONNX shape inference in strict mode with a try/except fallback to standalone type inference; add output-shape reconciliation to fix stale output declarations; integrate strict-mode inference into AutoCast conversions; add unit tests and a changelog entry. ChangesShape inference robustness and reconciliation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1627 +/- ##
==========================================
- Coverage 77.39% 77.21% -0.18%
==========================================
Files 482 482
Lines 52960 53101 +141
==========================================
+ Hits 40986 41001 +15
- Misses 11974 12100 +126
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
6d53f4d to
e6576a6
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Targeted bug fix with three new tests. Two concerns worth a human look before merging:
-
Dynamic-dim collateral damage in
_reconcile_stale_output_shapes. The helper runs unconditionally on every model that goes through_preprocess_onnx→clear_stale_value_info(not just weakly-typed exports), and it calls ORT'sSymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True). With those flags, ORT routinely rewrites user-nameddim_params (e.g.batch,seq_len) into internalunk__Nsymbols, or commits a guessed rank when it can't fully propagate. The PR description says "re-inference reproduces their existing output shapes" and "healthy models are unaffected", but the bytes-level_sigcomparison will treat abatch→unk__0rename as a change and overwrite the declared shape, which can break downstream consumers that look up output dims by name (export pipelines, TRT profile shapes, etc.). The added tests only cover fully static shapes ([3,4]/[3,5]/[1,3]), so this case is untested. Worth either (a) skipping the overwrite when onlydim_paramnames differ but ranks/values agree, (b) only invoking the reconciler when the declared shape is actually inconsistent with the topology (the original motivation), or (c) at minimum adding a dynamic-batch regression test. -
strict_mode=Trueis now the default forinfer_typescallers inautocast/convert.py. The fallback to_infer_types_onlyon any exception is a behavior broadening — previously-tolerated weakly-typed graphs now silently route through the standalone path instead. The standalone inferencer is a reasonable fallback, but this is a change in the AutoCast hot path with no new AutoCast-level test (the newtest_infer_types_falls_back_to_standalone_when_onnx_failsexercisesinfer_typesdirectly, notconvert_to_mixed_precision/convert_to_f16). Worth confirming the existing AutoCast test suite still passes against a representative weakly-typed model and not just the synthetic TopK fixture.
Fix itself is well-scoped and the docstrings/comments are clear; flagging mainly so an ONNX-pipeline owner can sign off on the dynamic-shape behavior change.
e6576a6 to
c7dc86b
Compare
|
Thanks for the careful review — addressed both in the latest revision: 1. Dynamic-dim collateral damage. The reconciler now overwrites a declared output shape only when it's genuinely stale (a rank mismatch or a conflicting concrete dim); outputs that merely differ in symbolic 2. |
…weakly-typed models
Weakly-typed ONNX models (e.g. TensorFlow exports) can carry graph.output
entries whose stored rank disagrees with the graph topology -- most commonly a
leftover rank-0 (scalar) annotation on a tensor that is really rank-2+. A stale
rank-0 passes onnx.checker (a scalar is valid) but poisons downstream shape
inference: ORT fails while augmenting the model for INT8 calibration
("axis must be in [-rank, rank-1]. Input rank was 0") and
onnx.shape_inference(strict_mode=True) raises "Inferred shape and existing shape
differ in rank" during fp16 autocast. Such models can also contain ops (e.g.
TopK) that ONNX's static shape inference cannot resolve, leaving downstream
tensors untyped and breaking AutoCast's type lookups.
Fixes:
- modelopt/onnx/utils.py: clear_stale_value_info now reconciles stale
graph.output shapes. It snapshots and clears each output shape, re-derives
shapes from the operator graph via ORT symbolic shape inference (falling back
to the size-aware infer_shapes wrapper), and overwrites a declaration only when
it is genuinely stale -- a rank mismatch or a conflicting concrete dim.
Outputs that merely differ in symbolic dim_param names (e.g. a re-derived
unk__0 vs a declared batch) keep their declaration, so healthy models with
dynamic dims are untouched. A graph output is never left without a shape field
(onnx.checker requires it).
- modelopt/onnx/utils.py: infer_types falls back to the schema-based standalone
type inferencer when ONNX shape inference fails (e.g. raises in strict mode on
an op it cannot resolve), so the returned model is still fully typed.
- modelopt/onnx/autocast/convert.py: convert_to_mixed_precision and
convert_to_f16 now run the initial infer_types in strict mode, so an
unresolvable op surfaces as an exception (triggering the fallback above)
instead of silently leaving tensors untyped.
Healthy models are unaffected: strict shape inference succeeds for them and
output-shape reconciliation preserves their declared (incl. dynamic) shapes.
Adds unit tests for stale rank-0 output reconciliation, preservation of a valid
output shape, preservation of named dynamic dims, the infer_types fallback on a
TopK-overflow model, and an AutoCast-level convert_to_f16 fallback test.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
c7dc86b to
cf2d65a
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_autocast.py (1)
345-357: ⚡ Quick winAssert the strict-inference failure precondition before conversion.
This test currently proves conversion succeeds, but it doesn’t explicitly prove strict ONNX inference fails first (the condition that should trigger fallback). Add a precondition assertion so the test remains a precise fallback regression test.
♻️ Suggested test hardening
def test_convert_to_f16_falls_back_on_unresolvable_op(weakly_typed_topk_model): @@ - converted_model = convert_to_f16(weakly_typed_topk_model, keep_io_types=True) + with pytest.raises(Exception): + onnx.shape_inference.infer_shapes(weakly_typed_topk_model, strict_mode=True) + + converted_model = convert_to_f16(weakly_typed_topk_model, keep_io_types=True)As per coding guidelines, “Checked-in tests should document expected behavior [and] protect against regressions.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/onnx/autocast/test_autocast.py` around lines 345 - 357, Add a precondition assertion that strict ONNX inference fails on the weakly_typed_topk_model before calling convert_to_f16: wrap a call to ONNX strict inference (e.g., onnx.shape_inference.infer_shapes or infer_types with strict=True) in pytest.raises(...) to assert it raises for weakly_typed_topk_model, then proceed to call convert_to_f16(weakly_typed_topk_model, keep_io_types=True) as before; reference the test function name test_convert_to_f16_falls_back_on_unresolvable_op and the fixture weakly_typed_topk_model so the failure-before-conversion condition is explicit.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/unit/onnx/autocast/test_autocast.py`:
- Around line 345-357: Add a precondition assertion that strict ONNX inference
fails on the weakly_typed_topk_model before calling convert_to_f16: wrap a call
to ONNX strict inference (e.g., onnx.shape_inference.infer_shapes or infer_types
with strict=True) in pytest.raises(...) to assert it raises for
weakly_typed_topk_model, then proceed to call
convert_to_f16(weakly_typed_topk_model, keep_io_types=True) as before; reference
the test function name test_convert_to_f16_falls_back_on_unresolvable_op and the
fixture weakly_typed_topk_model so the failure-before-conversion condition is
explicit.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 34432f62-d82c-4657-b227-1e7dfd5a3692
📒 Files selected for processing (5)
CHANGELOG.rstmodelopt/onnx/autocast/convert.pymodelopt/onnx/utils.pytests/unit/onnx/autocast/test_autocast.pytests/unit/onnx/test_onnx_utils.py
✅ Files skipped from review due to trivial changes (1)
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/onnx/autocast/convert.py
- tests/unit/onnx/test_onnx_utils.py
- modelopt/onnx/utils.py
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Both critical concerns from the previous review have been addressed:
-
Dynamic-dim collateral damage (cjluo-nv):
_is_stalenow only flags rank mismatches or conflicting concrete dims; symbolicdim_paramrenames are intentionally ignored.guess_output_rank=Truewas also dropped. The newtest_clear_stale_value_info_preserves_dynamic_dim_namestest declares[my_batch, 4]while inference would derive[batch, 4]and assertsmy_batchis preserved — exactly the regression case raised before. -
strict_mode=TrueAutoCast hot-path: The newtest_convert_to_f16_falls_back_on_unresolvable_opexercises the fullconvert_to_f16entry point (not justinfer_types) on a weakly-typed TopK model, confirming the standalone fallback fires end-to-end.
The local SymbolicShapeInference import inside _reconcile_stale_output_shapes is justified (optional onnxruntime dep with a graceful fallback to infer_shapes). Fix is well-scoped, no design-review concerns (additive fallback paths only, no new subsystem).
Complex PR: spans 5 directories (≥ 5). Looping in a human for approval.
What does this PR do?
Type of change: Bug fix
ONNX INT8 + FP16 quantization (
--quantize_mode int8 --high_precision_dtype fp16) crashed with aShapeInferenceErroron weakly-typed models (e.g. TensorFlow exports). Two root causes, both fixed inmodelopt/onnx/utils.py:Stale rank-0 output shapes. Such models can declare a
graph.outputrank that conflicts with the graph topology — typically a leftover rank-0 (scalar) annotation on a tensor that is really rank-2+. A stale rank-0 passesonnx.checker(a scalar is valid) but poisons downstream shape inference: ORT fails while augmenting the model for INT8 calibration (axis must be in [-rank, rank-1]. Input rank was 0), andonnx.shape_inference(strict_mode=True)raisesInferred shape and existing shape differ in rankduring FP16 autocast.clear_stale_value_infonow reconciles stale output shapes — it clears and re-derives them from the operator graph via ORT symbolic shape inference (falling back to the size-awareinfer_shapeswrapper) and adopts the inferred shape. A graph output is never left without a shape field (whichonnx.checkerrequires); an output whose shape cannot be re-derived keeps its original declaration.Ops ONNX static shape inference can't resolve. The same models can contain ops (e.g.
TopK) that ONNX's static shape inference gives up on, leaving downstream tensors untyped and breaking AutoCast's type lookups.infer_typesnow falls back to the schema-based standalone type inferencer when ONNX shape inference raises or leaves tensors untyped, running the fallback on the shape-inferred model so any shapes ONNX did derive are preserved.Healthy models are unaffected: re-inference reproduces their existing output shapes, and a fully typed graph skips the fallback.
Usage
Testing
tests/unit/onnx/test_onnx_utils.py: stale rank-0 output reconciliation, preservation of a valid output shape, and the type-inference fallback on aTopK-overflow model. Run:CUDA_VISIBLE_DEVICES="" pytest tests/unit/onnx/test_onnx_utils.py.onnx.checkerpasses, ORT loads it, QuantizeLinear/DequantizeLinear nodes inserted, FP16 initializers present, all graph-output ranks correct).tests/unit/onnxsuite: no new failures versus the base branch.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/ASummary by CodeRabbit
Bug Fixes
Tests