Skip to content

[3/4] Diffusion Quantized ckpt export#834

Open
jingyu-ml wants to merge 52 commits intomainfrom
jingyux/3-4-diffusion
Open

[3/4] Diffusion Quantized ckpt export#834
jingyu-ml wants to merge 52 commits intomainfrom
jingyux/3-4-diffusion

Conversation

@jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Jan 31, 2026

What does this PR do?

Type of change: documentation

Overview:

Updated diffusers quantization docs to cover new models and HF/torch checkpoint export, moved ONNX/TensorRT workflow into a dedicated quantization/ONNX.md, and added a specific LTX‑2 FP4 example command. Also refreshed core docs to state that the unified Hugging Face export API supports diffusers models (overview, deployment, and root README).

We will add more documentation on vLLM, SGLang, and TRTLLM/ComfyUI deployment once they are ready. @Edwardf0t1

Plans

  • [1/4] Add the basic functionalities to support limited image models with NVFP4 + FP8, with some refactoring on the previous LLM code and the diffusers example. PIC: @jingyu-ml
  • [2/4] Add support to more video gen models. PIC: @jingyu-ml
  • [3/4] Add test cases, refactor on the doc, and all related README. PIC: @jingyu-ml
  • [4/4] Add the final support to ComfyUI. PIC @jingyu-ml

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: No
  • Did you write any new necessary tests?:No
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

  • New Features

    • Unified Hugging Face export API now supports diffusers models in addition to transformers.
    • Added support for LTX-2 and expanded diffusion model variants (FLUX, SD3, LTX-Video).
    • Introduced ONNX export and TensorRT engine building workflows for quantized diffusion models.
  • Documentation

    • Updated deployment and getting started guides to document the unified export capabilities and new quantization workflows.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
jingyu-ml and others added 12 commits January 26, 2026 15:04
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners January 31, 2026 00:55
@jingyu-ml jingyu-ml self-assigned this Jan 31, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 31, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review
📝 Walkthrough

Walkthrough

This pull request introduces comprehensive infrastructure improvements for diffusion model quantization and export, including new pipeline management and calibration abstractions, LTX-2 quantization support, enhanced diffusers export capabilities, and reorganization of plugin modules under a diffusion subpackage. Documentation is updated to reflect unified Hugging Face export support.

Changes

Cohort / File(s) Summary
Documentation Updates
README.md, docs/source/deployment/3_unified_hf.rst, docs/source/getting_started/1_overview.rst, examples/diffusers/README.md, examples/diffusers/quantization/ONNX.md
Updates to various documentation describing unified Hugging Face export API support for transformers and diffusers models; major restructuring of examples README to cross-reference ONNX/TensorRT workflows and include LoRA/knowledge distillation guidance; new comprehensive ONNX.md guide for export and TensorRT engine building.
Configuration Extraction
examples/diffusers/quantization/quantize_config.py, examples/diffusers/quantization/quantize.py
New quantize_config.py module defines enums (DataType, QuantFormat, QuantAlgo, CollectMethod) and dataclasses (QuantizationConfig, CalibrationConfig, ModelConfig, ExportConfig) with validation logic; quantize.py updated to import these classes and remove local definitions, reducing complexity.
Calibration & Pipeline Infrastructure
examples/diffusers/quantization/calibration.py, examples/diffusers/quantization/pipeline_manager.py, examples/diffusers/quantization/models_utils.py
New Calibrator class manages model calibration with support for specialized handlers (LTX2, LTX_VIDEO_DEV, WAN22_T2V); new PipelineManager class orchestrates diffusion pipeline creation with device setup; models_utils.py adds LTX2 model type support and extra parameter parsing utilities.
Diffusers Export Enhancement
modelopt/torch/export/diffusers_utils.py, modelopt/torch/export/unified_export_hf.py
Enhanced type handling for diffusers and LTX-2 pipelines; new is_diffusers_object() and generate_diffusion_dummy_forward_fn() functions; refactored component extraction with get_diffusion_components() supporting LTX-2 duck-typed export; export pipeline now handles components without save_pretrained via safetensors export.
LTX-2 Quantization Support
modelopt/torch/quantization/plugins/diffusion/ltx2.py, examples/diffusers/quantization/utils.py
New _QuantLTX2Linear class with FP8 weight upcasting; register_ltx2_quant_linear() registration hook; _upcast_fp8_weight() utility for dtype conversion.
Plugin Module Restructuring
modelopt/torch/quantization/plugins/__init__.py, modelopt/torch/quantization/plugins/diffusion/diffusers.py, modelopt/torch/quantization/plugins/diffusion/fastvideo.py
Import path reorganization: diffusers and fastvideo plugins moved under diffusion submodule; relative import paths updated in diffusers.py and fastvideo.py to reflect deeper package structure.
Forward Patching & Dynamic Modules
modelopt/torch/opt/dynamic.py, modelopt/torch/quantization/nn/modules/quant_module.py, modelopt/torch/quantization/utils.py
Dynamic module conversion now preserves pre-conversion forward via _forward_pre_dm attribute to prevent leakage into exported modules; QuantInputBase.forward enhanced to conditionally delegate to pre-forward callback; weight_attr_names() relaxed nn.Parameter requirement.
Test Coverage
tests/unit/torch/opt/test_chaining.py, tests/unit/torch/quantization/test_forward_patching.py
New tests for chained modes preserving forward patching during quantization; new tests verifying forward patching behavior in QuantInputBase with _forward_pre_dm attribute handling.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PipelineManager
    participant Calibrator
    participant DiffusionPipeline
    participant QuantModules

    User->>PipelineManager: create_pipeline(model_type, torch_dtype)
    PipelineManager->>DiffusionPipeline: instantiate from model_type
    DiffusionPipeline-->>PipelineManager: pipeline instance
    PipelineManager->>PipelineManager: setup_device()
    PipelineManager-->>User: ready pipeline

    User->>Calibrator: load_and_batch_prompts()
    Calibrator-->>User: batched_prompts

    User->>Calibrator: run_calibration(batched_prompts)
    loop for each batch
        Calibrator->>PipelineManager: get_backbone()
        PipelineManager-->>Calibrator: backbone module
        Calibrator->>DiffusionPipeline: forward(prompt, kwargs)
        DiffusionPipeline->>QuantModules: quantize inputs/weights
        QuantModules-->>DiffusionPipeline: quantized outputs
        DiffusionPipeline-->>Calibrator: calibration data collected
    end
    Calibrator-->>User: calibration complete
Loading
sequenceDiagram
    participant ExportAPI
    participant DiffusersUtils
    participant UnifiedExport
    participant ComponentExport
    participant Filesystem

    ExportAPI->>DiffusersUtils: is_diffusers_object(model)
    DiffusersUtils-->>ExportAPI: true/false

    ExportAPI->>UnifiedExport: export_hf_checkpoint(model)
    UnifiedExport->>DiffusersUtils: get_diffusion_components(model)
    DiffusersUtils-->>UnifiedExport: {component_name: module}

    loop for each component
        UnifiedExport->>ComponentExport: _fuse_qkv_linears_diffusion(component)
        ComponentExport->>DiffusersUtils: generate_diffusion_dummy_forward_fn()
        DiffusersUtils-->>ComponentExport: dummy_forward callable
        ComponentExport->>ComponentExport: run dummy forward, fuse QKV
        ComponentExport-->>UnifiedExport: fused component
    end

    loop for each component
        UnifiedExport->>ComponentExport: export component (safetensors or save_pretrained)
        ComponentExport->>Filesystem: write component weights & config
        Filesystem-->>ComponentExport: success
    end

    UnifiedExport->>Filesystem: write model_index.json
    Filesystem-->>UnifiedExport: success
    UnifiedExport-->>ExportAPI: export complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.55% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[3/4] Diffusion Quantized ckpt export' accurately describes a core change in the PR—adding diffusion quantized checkpoint export functionality. It is concise and specific enough for scanning history, though it focuses on implementation rather than the broader documentation and API updates also present.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/3-4-diffusion

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/diffusers/README.md (1)

15-15: ⚠️ Potential issue | 🟡 Minor

Typo: "cahce" should be "cache".

-| Support Matrix | View the support matrix to see quantization/cahce diffusion compatibility and feature availability across different models | \[[Link](`#support-matrix`)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
+| Support Matrix | View the support matrix to see quantization/cache diffusion compatibility and feature availability across different models | \[[Link](`#support-matrix`)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
examples/diffusers/quantization/quantize.py (1)

539-542: ⚠️ Potential issue | 🟠 Major

Bug: samefile compares path to itself, condition is always False.

The condition compares export_config.restore_from with itself, which always returns True for samefile(), making the entire condition always False (due to not).

This likely intended to check if the restore path differs from the save path:

🐛 Proposed fix
-            if export_config.quantized_torch_ckpt_path and not export_config.restore_from.samefile(
-                export_config.restore_from
+            if export_config.quantized_torch_ckpt_path and not export_config.quantized_torch_ckpt_path.samefile(
+                export_config.restore_from
             ):
                 export_manager.save_checkpoint(backbone)
🤖 Fix all issues with AI agents
In `@examples/diffusers/quantization/calibration.py`:
- Around line 60-72: The code path that assumes dict-based prompts_dataset may
raise a generic KeyError; in calibration.py before calling load_calib_prompts
with self.config.prompts_dataset["name"], ["split"], ["column"] validate that
prompts_dataset is a dict and contains the keys "name", "split", and "column"
(or provide sensible defaults) and if not raise a clear ValueError indicating
the missing keys; update the branch around the return that calls
load_calib_prompts to check the keys (referencing self.config.prompts_dataset
and the load_calib_prompts call) and surface a descriptive error instead of
letting a KeyError bubble up.
- Around line 107-118: The method _run_wan_video_calibration currently indexes
extra_args directly and can raise obscure KeyError; instead validate required
keys up front (e.g., required =
["negative_prompt","height","width","num_frames","guidance_scale"]) and compute
missing = [k for k in required if k not in extra_args]; if any are missing raise
a clear ValueError (or custom exception) that names the missing keys and that
they are required for WAN calibration; then populate kwargs from extra_args
(include guidance_scale_2 only if present) and set kwargs["num_inference_steps"]
= self.config.n_steps as before.
- Around line 122-126: The _run_ltx2_calibration function currently silently
takes only the first entry from prompt_batch; add an explicit validation at the
start of _run_ltx2_calibration to ensure a single prompt is provided (e.g.,
check len(prompt_batch) == 1) and raise a clear exception (or error log +
exception) if more than one prompt is passed, referencing prompt_batch and the
function name _run_ltx2_calibration so callers know the LTX2 calibration
requires a single prompt; keep the rest of the logic (using prompt =
prompt_batch[0]) unchanged after the validation.

In `@examples/diffusers/quantization/quantize_config.py`:
- Around line 36-40: The _dtype_map currently uses enum members as keys but the
torch_dtype property indexes it with self.value (a string), causing KeyError;
update the code so the lookup and keys match — either change _dtype_map keys to
the enum .value strings (e.g., "Half") or change the torch_dtype property to use
the enum member (use self or DataType(self.value)) when indexing; locate
DataType, _dtype_map, and the torch_dtype property and make them consistent so
the dictionary lookup succeeds.

In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 359-360: The two calls using next(model.parameters()) can raise
StopIteration for parameter-less models; change them to use a safe retrieval
like first_param = next(model.parameters(), None) and then set device =
first_param.device if first_param is not None else torch.device("cpu"), and
default_dtype = first_param.dtype if first_param is not None else
torch.get_default_dtype() (or torch.float32) — ensure torch is imported and
apply the same defensive pattern where device/default_dtype are computed (the
occurrences that produce device and default_dtype around the
next(model.parameters()) calls and the similar ones at lines ~440-441).
- Around line 52-68: The function is_diffusers_object currently bails out early
when _HAS_DIFFUSERS is False so LTX-2 pipelines (TI2VidTwoStagesPipeline) are
never detected; remove the early return and instead always build diffusers_types
by checking DiffusionPipeline, ModelMixin, and TI2VidTwoStagesPipeline
individually (i.e., keep the conditional additions for DiffusionPipeline and
ModelMixin but allow TI2VidTwoStagesPipeline to be added even if _HAS_DIFFUSERS
is False), then return isinstance(model, diffusers_types) only after verifying
diffusers_types is non-empty.

In `@modelopt/torch/quantization/nn/modules/quant_module.py`:
- Around line 113-132: Remove the dead identity comparison and rely solely on
the MRO-function check: drop the `pre_fwd is getattr(self, "forward")` branch
and change the conditional to call `_is_forward_in_mro(pre_fwd)` (or its
negation as appropriate) when deciding whether to call `super().forward(...)`
versus `pre_fwd(...)`; keep the helper `_is_forward_in_mro` and the `pre_fwd =
getattr(self, "_forward_pre_dm")` logic unchanged so recursion is detected by
comparing the underlying function in the MRO rather than by bound-method
identity.
🧹 Nitpick comments (8)
docs/source/deployment/3_unified_hf.rst (1)

35-38: Consider adding a cross-reference link to the diffusers examples.

The note mentions "diffusers quantization examples" but doesn't provide a direct reference. Adding a link (e.g., to examples/diffusers/quantization/ or the README) would help users navigate to the relevant documentation.

examples/diffusers/quantization/ONNX.md (1)

147-149: Fix table column alignment for consistent formatting.

The table pipes are not aligned with headers. This is flagged by markdownlint (MD060).

Aligned table formatting
-| SDXL FP16 | SDXL INT8 |
-|:---------:|:---------:|
-| ![FP16](./assets/xl_base-fp16.png) | ![INT8](./assets/xl_base-int8.png) |
+| SDXL FP16                          | SDXL INT8                          |
+| :--------------------------------: | :--------------------------------: |
+| ![FP16](./assets/xl_base-fp16.png) | ![INT8](./assets/xl_base-int8.png) |
examples/diffusers/quantization/quantize_config.py (2)

110-113: Consider documenting or warning about truncation behavior.

The num_batches property uses integer division, meaning if calib_size is not evenly divisible by batch_size, some samples will be silently skipped. For example, with calib_size=100 and batch_size=32, only 96 samples (3 batches) will be used.

Consider either:

  1. Adding a warning in validate() when truncation occurs
  2. Documenting this behavior in the docstring

146-160: Validation method has side effects (directory creation).

The validate() method creates directories (lines 152-160), which is unusual for validation logic. This could cause unexpected behavior if called multiple times or in read-only contexts.

Consider separating validation from directory creation:

♻️ Suggested refactor
     def validate(self) -> None:
         """Validate export configuration."""
         if self.restore_from and not self.restore_from.exists():
             raise FileNotFoundError(f"Restore checkpoint not found: {self.restore_from}")
 
+    def ensure_directories(self) -> None:
+        """Create output directories if they don't exist."""
         if self.quantized_torch_ckpt_path:
             parent_dir = self.quantized_torch_ckpt_path.parent
             if not parent_dir.exists():
                 parent_dir.mkdir(parents=True, exist_ok=True)
 
         if self.onnx_dir and not self.onnx_dir.exists():
             self.onnx_dir.mkdir(parents=True, exist_ok=True)
 
         if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists():
             self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True)
examples/diffusers/quantization/quantize.py (1)

19-19: Remove redundant import alias.

import time as time is redundant—just use import time.

♻️ Proposed fix
-import time as time
+import time
examples/diffusers/quantization/pipeline_manager.py (2)

72-73: Remove unnecessary try/except that only re-raises.

The try/except block doesn't add error handling—it just re-raises the exception. Remove it for cleaner code.

♻️ Proposed fix
-        try:
-            pipeline_cls = MODEL_PIPELINE[model_type]
-            if pipeline_cls is None:
-                raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
-            model_id = (
-                MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
-            )
-            pipe = pipeline_cls.from_pretrained(
-                model_id,
-                torch_dtype=torch_dtype,
-                use_safetensors=True,
-                **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}),
-            )
-            pipe.set_progress_bar_config(disable=True)
-            return pipe
-        except Exception as e:
-            raise e
+        pipeline_cls = MODEL_PIPELINE[model_type]
+        if pipeline_cls is None:
+            raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
+        model_id = (
+            MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
+        )
+        pipe = pipeline_cls.from_pretrained(
+            model_id,
+            torch_dtype=torch_dtype,
+            use_safetensors=True,
+            **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}),
+        )
+        pipe.set_progress_bar_config(disable=True)
+        return pipe

156-169: Type hint mismatch: get_backbone can return None for LTX2.

The return type is declared as torch.nn.Module, but for LTX2, self._transformer could theoretically be None if _ensure_ltx2_transformer_cached fails silently. While the current implementation sets self._transformer correctly, the type annotation on _transformer (line 40) is torch.nn.Module | None.

Consider adding an assertion or raising an explicit error if the transformer isn't available:

♻️ Proposed fix
     def get_backbone(self) -> torch.nn.Module:
         ...
         if self.config.model_type == ModelType.LTX2:
             self._ensure_ltx2_transformer_cached()
+            if self._transformer is None:
+                raise RuntimeError("Failed to retrieve LTX2 transformer")
             return self._transformer
         return getattr(self.pipe, self.config.backbone)
examples/diffusers/quantization/calibration.py (1)

190-195: Comment is inconsistent with behavior.

The comment says the upscale step is omitted, but the upsampler is still executed. Either remove the call or update the comment to match intent.

@codecov
Copy link

codecov bot commented Jan 31, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.72%. Comparing base (452c5a0) to head (920fd4f).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #834   +/-   ##
=======================================
  Coverage   73.72%   73.72%           
=======================================
  Files         196      196           
  Lines       20457    20457           
=======================================
  Hits        15082    15082           
  Misses       5375     5375           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

…lopt

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant