Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
7fd45d4
Support WAN2.2 models W4A16 quantization
lvliang-intel Apr 13, 2026
dc44b4d
Merge branch 'main' of https://github.com/intel/auto-round into lvl/s…
lvliang-intel Apr 13, 2026
b57fde7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2026
4853279
fix ci issue
lvliang-intel Apr 14, 2026
ab388cb
Merge branch 'main' of https://github.com/intel/auto-round into lvl/s…
lvliang-intel Apr 14, 2026
aa8e45c
Merge branch 'lvl/support_wan2.2' of https://github.com/intel/auto-ro…
lvliang-intel Apr 14, 2026
257690b
Potential fix for pull request finding
lvliang-intel Apr 14, 2026
8195bc1
Potential fix for pull request finding
lvliang-intel Apr 14, 2026
be09c1d
Potential fix for pull request finding
lvliang-intel Apr 14, 2026
3874e89
Potential fix for pull request finding
lvliang-intel Apr 14, 2026
05019a8
avoid lint issue
lvliang-intel Apr 14, 2026
92436a8
Merge branch 'lvl/support_wan2.2' of https://github.com/intel/auto-ro…
lvliang-intel Apr 14, 2026
34802fb
Merge origin/main into lvl/support_wan2.2 to resolve conflicts
Copilot Apr 14, 2026
72125d2
Merge origin/main into branch to resolve conflicts
Copilot Apr 14, 2026
3f0b0d0
support quantize transformers_2 in WAN2.2
lvliang-intel Apr 14, 2026
b5a9c71
Merge branch 'lvl/support_wan2.2' of https://github.com/intel/auto-ro…
lvliang-intel Apr 14, 2026
e7e45d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2026
f0ee948
Merge branch 'main' into lvl/support_wan2.2
lvliang-intel Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,9 +1766,10 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
formats = self.formats if hasattr(self, "formats") else None
# It is best to modify the model structure in the quantize function and check the format,
# because it may cause the gguf format to not be exported normally.
self.model = update_module(
self.model, formats=formats, trust_remote_code=self.trust_remote_code, cleanup_original=False
)
if not self.diffusion:
self.model = update_module(
self.model, formats=formats, trust_remote_code=self.trust_remote_code, cleanup_original=False
)

# Temporary names must be assigned after handle_moe_model;
# placing them earlier would cause them to be removed when the module is replaced.
Expand Down Expand Up @@ -3025,7 +3026,7 @@ def _quantize_block(
else:
card_0_in_high_risk, loss_device = False, device

if len(self.device_list) > 1 and auto_offload:
if len(self.device_list) > 1 and auto_offload and not self.diffusion:
for n, m in block.named_modules():
if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"):
continue
Expand Down
6 changes: 4 additions & 2 deletions auto_round/compressors/diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ auto-round \

### Diffusion Support Matrix

For diffusion models, currently we only validate quantizaion on the FLUX.1-dev, which involves quantizing the transformer component of the pipeline.
For diffusion models, currently we validate quantization on a few models, which involves quantizing the transformer component of the pipeline.

| Model | calibration dataset | Model Link |
|---------------|---------------------|--------------|
| black-forest-labs/FLUX.1-dev | COCO2014 | - |
| Tongyi-MAI/Z-Image | COCO2014 | - |
| Tongyi-MAI/Z-Image-Turb | COCO2014 | - |
| stepfun-ai/NextStep-1.1 | COCO2014 | - |

| Wan-AI/Wan2.2-I2V-A14B-Diffusers | COCO2014 | - |
| Wan-AI/Wan2.2-TI2V-5B-Diffusers | COCO2014 | - |
| Wan-AI/Wan2.2-T2V-A14B-Diffusers | COCO2014 | - |


<details>
Expand Down
221 changes: 210 additions & 11 deletions auto_round/compressors/diffusion/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
from collections import defaultdict
from copy import deepcopy
Expand All @@ -33,10 +34,12 @@
config_save_pretrained,
copy_python_files_from_model_cache,
diffusion_load_model,
dispatch_model_block_wise,
dispatch_model_by_all_available_devices,
extract_block_names_to_str,
find_matching_blocks,
get_block_names,
is_auto_device_mapping,
merge_block_output_keys,
rename_weights_files,
wrap_block_forward_positional_to_kwargs,
Expand All @@ -47,6 +50,7 @@
output_configs = {
"FluxTransformerBlock": ["encoder_hidden_states", "hidden_states"],
"FluxSingleTransformerBlock": ["encoder_hidden_states", "hidden_states"],
"WanTransformerBlock": ["hidden_states"],
}


Expand Down Expand Up @@ -116,7 +120,7 @@ def __init__(
pipeline_fn: callable = None,
**kwargs,
):
logger.warning("Diffusion model quantization is experimental and is only validated on Flux models.")
logger.warning("Diffusion model quantization is experimental and is only validated on a few models.")
if dataset == "NeelNanda/pile-10k":
dataset = "coco2014"
logger.warning(
Expand All @@ -137,6 +141,7 @@ def __init__(
pipe, model = diffusion_load_model(model, platform=platform, device=self.device, model_dtype=model_dtype)

self.model = model
self._current_transformer_name = "transformer"
self.pipe = pipe
# Use explicit pipeline_fn; fall back to whatever diffusion_load_model attached to the pipe
self.pipeline_fn = pipeline_fn or getattr(pipe, "_autoround_pipeline_fn", None)
Expand Down Expand Up @@ -185,6 +190,11 @@ def __init__(
self._align_device_and_dtype()

def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, dict]:
if self._uses_single_hidden_state_input():
if q_inputs is not None:
q_inputs = q_inputs.pop("hidden_states", None)
return inputs, q_inputs

# flux transformer model's blocks will update hidden_states and encoder_hidden_states
input_id_str = [key for key in inputs.keys() if "hidden_state" in key]
if q_inputs is not None:
Expand All @@ -194,7 +204,44 @@ def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, dict]:
def _get_block_forward_func(self, name):
return wrap_block_forward_positional_to_kwargs(super()._get_block_forward_func(name))

def _uses_single_hidden_state_input(self) -> bool:
if not self.quant_block_list:
return False
first_block_name = self.quant_block_list[0][0]
first_block = self.model.get_submodule(first_block_name)
return output_configs.get(first_block.__class__.__name__, []) == ["hidden_states"]

Comment thread
lvliang-intel marked this conversation as resolved.
def _requires_calibration_image(self) -> bool:
image_param = inspect.signature(self.pipe.__call__).parameters.get("image")
return image_param is not None and image_param.default is inspect.Parameter.empty

def _get_calibration_image(self, batch_size: int):
from PIL import Image # pylint: disable=E0401

params = inspect.signature(self.pipe.__call__).parameters
width_param = params.get("width")
height_param = params.get("height")
width = (
832
if width_param is None or width_param.default in (inspect.Parameter.empty, None)
else width_param.default
)
height = (
480
if height_param is None or height_param.default in (inspect.Parameter.empty, None)
else height_param.default
)
image = Image.new("RGB", (int(width), int(height)), color=(127, 127, 127))
if batch_size == 1:
return image
return [image.copy() for _ in range(batch_size)]

def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[dict, dict]:
if self._uses_single_hidden_state_input():
input_ids = inputs.pop("hidden_states", None)
input_others = inputs
return input_ids, input_others

input_id_str = [key for key in inputs.keys() if "hidden_state" in key]
input_ids = {k: inputs.pop(k, None) for k in input_id_str}
input_others = inputs
Expand Down Expand Up @@ -295,11 +342,52 @@ def _get_block_outputs(

def _get_current_num_elm(
self,
input_ids: list[torch.Tensor],
input_ids: Union[list[torch.Tensor], dict],
indices: list[int],
) -> int:
current_input_ids = [input_ids["hidden_states"][i] for i in indices]
return sum(id.numel() for id in current_input_ids)
if isinstance(input_ids, dict):
current_input_ids = [input_ids["hidden_states"][i] for i in indices]
else:
current_input_ids = [input_ids[i] for i in indices]
return sum(input_id.numel() for input_id in current_input_ids)

def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_name=None):
"""Dispatch multi-device before caching so accelerate hooks are added before _replace_forward."""
multi_device_diffusion = is_auto_device_mapping(self.device_map) and len(self.device_list) > 1
if multi_device_diffusion:
if not (hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1):
# Place other pipeline components on GPU/XPU *before* dispatching the
# current transformer. dispatch_model_block_wise queries free memory
# so it will leave room for these components automatically.
comp_device = self.device_list[-1]
for comp_name in self.pipe.components:
comp = getattr(self.pipe, comp_name, None)
if comp is None or comp is self.model:
continue
if not hasattr(comp, "to"):
continue
is_other_transformer = (
comp_name.startswith("transformer")
and isinstance(comp, torch.nn.Module)
and next(comp.parameters()).device.type == "cpu"
)
is_other_component = not comp_name.startswith("transformer")
if is_other_transformer or is_other_component:
# Convert dtype on CPU first to reduce GPU/XPU memory for large components
if (
isinstance(comp, torch.nn.Module)
and hasattr(comp, "dtype")
and comp.dtype != self.model.dtype
):
comp.to(dtype=self.model.dtype)
try:
comp.to(comp_device)
except (NotImplementedError, RuntimeError):
# Component may have meta/unmaterialized tensors (already quantized & saved)
continue
self.model = dispatch_model_block_wise(self.model, self.device_map)
setattr(self.pipe, self._current_transformer_name, self.model)
return super().cache_inter_data(block_names, nsamples, layer_names, last_cache_name)

def _run_pipeline(self, prompts: list) -> None:
"""Execute one full diffusion pipeline forward pass for calibration input capture.
Expand Down Expand Up @@ -343,20 +431,25 @@ def nextstep_fn(pipe, prompts, guidance_scale=7.5,
if self.generator_seed is None
else torch.Generator(device=self.pipe.device).manual_seed(self.generator_seed)
)
extra_kwargs = {}
if self._requires_calibration_image():
extra_kwargs["image"] = self._get_calibration_image(len(prompts) if isinstance(prompts, list) else 1)
if self.pipeline_fn is not None:
self.pipeline_fn(
self.pipe,
prompts,
guidance_scale=self.guidance_scale,
num_inference_steps=self.num_inference_steps,
generator=generator,
**extra_kwargs,
)
else:
self.pipe(
prompts,
guidance_scale=self.guidance_scale,
num_inference_steps=self.num_inference_steps,
generator=generator,
**extra_kwargs,
)

def calib(self, nsamples, bs):
Expand Down Expand Up @@ -398,8 +491,8 @@ def calib(self, nsamples, bs):
self._run_pipeline(prompts)
except NotImplementedError:
pass
except Exception as error:
raise error
except Exception:
raise
step = len(prompts)
total_cnt += step
pbar.update(step)
Expand Down Expand Up @@ -435,6 +528,105 @@ def _should_stop_cache_forward(self, name: str) -> bool:
# diffusion model needs to run all steps to collect input
return False

def _find_additional_transformers(self):
"""Find transformer components beyond the primary one (e.g. transformer_2 in WAN)."""
result = []
for comp_name in self.pipe.components:
comp = getattr(self.pipe, comp_name, None)
if (
comp_name.startswith("transformer")
and comp_name != "transformer"
and comp is not None
and isinstance(comp, torch.nn.Module)
):
result.append((comp_name, comp))
return result

def quantize_and_save(
self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace: bool = True, **kwargs
):
"""Quantize all transformer components and save the pipeline."""
additional = self._find_additional_transformers()
if not additional:
return super().quantize_and_save(output_dir, format=format, inplace=inplace, **kwargs)

# Setup (mirrors BaseCompressor.quantize_and_save)
from auto_round.formats import get_formats

self.orig_output_dir = output_dir
format_list = get_formats(format, self)
self.formats = format_list
if len(format_list) > 1:
inplace = False
self.inplace = kwargs.get("inplace", inplace)
kwargs.pop("inplace", None)

# For multi-transformer quantization, disable low_cpu_mem_usage so that
# is_immediate_saving stays False. With immediate saving, each quantized
# block is written to disk and then moved to meta device — making the
# model a hollow shell that cannot run inference. We need the quantized
# primary transformer to remain functional in CPU memory so the pipeline
# can use it during calibration of subsequent transformers.
orig_low_cpu = self.low_cpu_mem_usage
self.low_cpu_mem_usage = False

# Quantize primary transformer
logger.info("Quantizing transformer")
self.quantize()

# Remove the stale device map so cache_inter_data can re-dispatch freely.
if hasattr(self.model, "hf_device_map"):
del self.model.hf_device_map
clear_memory(device_list=self.device_list)

# Store results and save state
primary_model = self.model
primary_layer_config = self.layer_config
primary_quant_block_list = self.quant_block_list
quantized_extras = {}

# Dual-transformer pipelines (e.g. WAN) switch between transformers based on
# a boundary timestep. With num_inference_steps=1 only the primary (high-noise)
# transformer is called, so the secondary never receives calibration data.
# Ensure at least 2 steps so both transformers are exercised.
orig_steps = self.num_inference_steps
if self.num_inference_steps < 2:
logger.warning(
f"num_inference_steps={self.num_inference_steps} is too low for dual-transformer "
f"quantization — increasing to 2 so all transformers receive calibration data."
)
self.num_inference_steps = 2

for comp_name, transformer in additional:
logger.info(f"Quantizing {comp_name}")
self._current_transformer_name = comp_name
self.model = transformer
self.quantized = False
all_blocks = get_block_names(self.model)
self.quant_block_list = find_matching_blocks(self.model, all_blocks, None)
self.layer_config = {}

self.quantize()

quantized_extras[comp_name] = (self.model, dict(self.layer_config))
setattr(self.pipe, comp_name, self.model)

# Restore primary state for save
self._current_transformer_name = "transformer"
self.model = primary_model
self.layer_config = primary_layer_config
self.quant_block_list = primary_quant_block_list
self._quantized_transformers = quantized_extras
self.low_cpu_mem_usage = orig_low_cpu
self.num_inference_steps = orig_steps

# Save everything
self.save_quantized(output_dir, format=self.formats, inplace=inplace, return_folders=True, **kwargs)

from auto_round.utils import memory_monitor

memory_monitor.log_summary()

def _get_save_folder_name(self, format: OutputFormat) -> str:
"""Generates the save folder name based on the provided format string.

Expand Down Expand Up @@ -483,6 +675,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
if output_dir is None:
return super().save_quantized(output_dir, format=format, inplace=inplace, **kwargs)

quantized_transformers = getattr(self, "_quantized_transformers", {})
compressed_model = None
if hasattr(self.model, "config") and getattr(self.model.config, "model_type", None) == "nextstep":
compressed_model = super().save_quantized(
Expand All @@ -500,11 +693,17 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
sub_module_path = (
os.path.join(output_dir, name) if os.path.basename(os.path.normpath(output_dir)) != name else output_dir
)
if (
hasattr(val, "config")
and hasattr(val.config, "_name_or_path")
and val.config._name_or_path == self.model.config._name_or_path
):
if name in quantized_transformers:
saved_model, saved_lc = self.model, self.layer_config
self.model, self.layer_config = quantized_transformers[name]
compressed_model = super().save_quantized(
output_dir=sub_module_path if not self.is_immediate_saving else output_dir,
format=format,
inplace=inplace,
**kwargs,
)
self.model, self.layer_config = saved_model, saved_lc
elif val is self.model:
compressed_model = super().save_quantized(
output_dir=sub_module_path if not self.is_immediate_saving else output_dir,
format=format,
Expand Down
3 changes: 2 additions & 1 deletion auto_round/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,8 @@ def dispatch_model_block_wise(model: torch.nn.Module, device_map: str, max_mem_r
max_memory=new_max_memory,
no_split_module_classes=no_split_modules,
)
model.tie_weights()
if hasattr(model, "tie_weights"):
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=new_max_memory, no_split_module_classes=no_split_modules)
if len(devices) > 1 and "cpu" in device_map.values():
logger.warning(
Expand Down
Loading
Loading