From 8ebd9245d1d13e013554b549247f5f324d556b21 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 1/4] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 3 + modelopt/torch/quantization/config.py | 59 ++++++++++ modelopt/torch/quantization/mode.py | 39 ++++++- modelopt/torch/quantization/model_calib.py | 123 ++++++++++++++++++++- modelopt/torch/quantization/utils.py | 82 ++++++++++++++ modelopt/torch/utils/network.py | 33 ++++++ 6 files changed, 335 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e32d0dae8..0d2a5aae9 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -86,6 +86,7 @@ "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, + "int4_gptq": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG, } KV_QUANT_CFG_CHOICES = { @@ -250,6 +251,7 @@ def auto_quantize( "w4a8_mxfp4_fp8", "nvfp4_mlp_only", "mxfp8", + "int4_gptq", ] for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" @@ -865,6 +867,7 @@ def quantize_main( "w4a8_mxfp4_fp8", "nvfp4_mlp_only", "mxfp8", + "int4_gptq", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES ), f"Plain quantization format {args.qformat} not supported for HF export path" diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e1b48ee60..73659a465 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -234,6 +234,18 @@ "algorithm": "max", } +INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + INT4_AWQ_CFG = { "quant_cfg": { @@ -992,6 +1004,15 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.", ) + use_sequential: bool = ModeloptField( + default=False, + title="Enable sequential layer-by-layer calibration.", + description=( + "If True, the calibration algorithm is applied sequentially to each decoder block. " + "Outputs from one layer become inputs to the next, reducing memory usage for large models." + ), + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. @@ -1228,6 +1249,44 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): ) +class GPTQConfig(QuantizeAlgorithmConfig): + """The config for GPTQ lite. + + GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + + GPTQ lite does not perform sequential quantization of layers. This means that the updated + activations are not used to process the next layer. + + The default values are taken from the official GPTQ implementation: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 + + Note: This feature is currently experimental and may not translate to improved accuracy as expected. + + + """ + + method: Literal["gptq"] = ModeloptField("gptq") + percdamp: float | None = ModeloptField( + default=0.01, + gt=0.0, + le=1.0, + title="Percentage damping factor.", + description="The percentage of average Hessian diagonal used for damping.", + ) + block_size: int | None = ModeloptField( + default=128, + title="Block size for GPTQ weight update.", + description="""The block size for GPTQ weight update, which must be a multiple of the + group_size used in the quantization.""", + ) + hessian_state_path: str | None = ModeloptField( + default=None, + title="Path to the Hessian state file.", + description="""The path to the Hessian state file. If hessian path exists, we load from + hessian file instead of recomputing them.""", + ) + + QuantizeQuantCfgType = dict[ str | Callable, QuantizerAttributeConfig diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index bfcdb64da..b3df12785 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + GPTQConfig, GPTQLiteConfig, MaxCalibConfig, MseCalibConfig, @@ -56,7 +57,16 @@ restore_svdquant_model, update_quantize_metadata, ) -from .model_calib import awq, gptq_lite, max_calibrate, mse_calibrate, smoothquant, svdquant +from .model_calib import ( + awq, + gptq, + gptq_lite, + max_calibrate, + mse_calibrate, + sequential_calibrate, + smoothquant, + svdquant, +) __all__ = ["BaseCalibrateModeDescriptor"] @@ -212,13 +222,24 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") + sequential = kwargs.pop("use_sequential", False) + if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method if func is not None: - # Call the function with forward_loop as a separate argument - func(model, forward_loop=forward_loop, **kwargs) + if sequential: + # Wrap with sequential processing - just pass func as calib_func! + sequential_calibrate( + model, + forward_loop=forward_loop, + calib_func=func, # <-- Pass func directly! + **kwargs, + ) + else: + # Direct calibration (existing behavior) + func(model, forward_loop=forward_loop, **kwargs) # Lets get the latest metadata for the quantizer states metadata = {} @@ -452,3 +473,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return GPTQLiteConfig _calib_func = gptq_lite + + +@CalibrateModeRegistry.register_mode +class GPTQModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for GPTQ calibration algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return GPTQConfig + + _calib_func = gptq diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 591de3240..a425128b2 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -18,6 +18,7 @@ import math import os import warnings +from collections.abc import Callable from functools import partial import torch @@ -27,9 +28,14 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.quantization.utils import LayerActivationGettr from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.network import ( + bind_forward_method, + get_decoder_layers, + unpatch_forward_method, +) from modelopt.torch.utils.perf import get_used_gpu_mem_fraction from .calib import MseCalibrator @@ -1478,3 +1484,118 @@ def hessian_hook(module, input, output): torch.cuda.empty_cache() print_rank_0("GPTQ-lite quantization completed successfully") + + +@torch.no_grad() +def sequential_calibrate( + model: nn.Module, + forward_loop: ForwardLoop, + calib_func: Callable, + **calib_kwargs, +): + """Sequential calibration - a sequential layer-by-layer calibration algorithm.""" + transformer_layers = get_decoder_layers(model) + if transformer_layers is None: + raise ValueError( + "Could not find transformer layers in model'. " + "Sequential calibration requires a model with identifiable transformer layers." + ) + + print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + gettr = LayerActivationGettr(model) + inputs = gettr.get_input_activations(transformer_layers[0], forward_loop) + + for layer in transformer_layers: + # Call GPTQ + calib_func(layer, inputs, **calib_kwargs) + # Get outputs + outputs = gettr.get_output_activations(layer, inputs) + # Update inputs + inputs = [(out, *inp[0][1:]) for inp, out in zip(inputs, outputs)] + + print_rank_0("Sequential calibration completed successfully") + + +@torch.no_grad() +def gptq( + layer: nn.Module, + inputs: list[tuple[tuple, dict]], + percdamp: float = 0.01, + block_size: int = 128, + **kwargs, +): + """GPTQ quantization - a GPTQ variant.""" + import time + + total_start = time.time() + + # Dictionary to store hessian matrices for all linear layers in this decoder + hessian_state = {} + + # Phase 1: Build tensor mapping for all quantized linear layers in this decoder layer + tensor_mapping = {} + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + in_features = module.weight.shape[-1] + tensor_mapping[name] = ((in_features, in_features), module.weight.device) + module.name = name # Attach name for easy access in hooks + + if not tensor_mapping: + print_rank_0("No quantized linear layers found in decoder layer, skipping GPTQ") + return + + # Initialize hessian state with zeros + for name, (shape, device) in tensor_mapping.items(): + hessian_state[name] = { + "hessian": torch.zeros(shape, dtype=torch.float32, device=device), + "n_samples": 0, + } + + # Phase 2: Register hooks to collect Hessians during forward passes + def hessian_hook(module, input, output): + """Hook to intercept activations and update hessian matrix.""" + state = hessian_state[module.name] + hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + + handles = [] + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + handles.append(module.register_forward_hook(hessian_hook)) + + # Run forward passes with the provided inputs to collect Hessians + hessian_start = time.time() + print_rank_0( + f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." + ) + for args, kwargs_input in inputs: + layer(*args, **kwargs_input) + + # Remove hooks after collecting Hessians + for handle in handles: + handle.remove() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() + print_rank_0("Updating weights using GPTQ algorithm...") + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + state = hessian_state[module.name] + hessian = state["hessian"].to(module.weight.device) + blockwise_weight_update(module, hessian, block_size, percdamp) + # Free memory + del hessian_state[module.name] + torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index b663ef5f2..b794251d2 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -29,10 +29,13 @@ from torch.distributed.tensor import Replicate from modelopt.torch.utils import get_unwrapped_name, print_rank_0 +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method if TYPE_CHECKING: from collections.abc import Generator + from modelopt.torch.opt.searcher import ForwardLoop + __all__ = [ "EXPORT_MODE", "convert_quantization_axis_to_reduce_axis", @@ -810,3 +813,82 @@ def update_quant_cfg_with_kv_cache_quant( quant_cfg["algorithm"] = "max" print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") return quant_cfg + + +class LayerActivationGettr: + """Helper class for collecting layer activations during forward passes. + + This class allows for sequential layer calibration by + patching layers to capture inputs/outputs during forward passes + """ + + def __init__(self, model: nn.Module): + self.model = model + + @staticmethod + def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): + """Patch a layer to collect inputs and outputs during forward passes.""" + + def _forward_w_data_collection(self, *args, **kwargs): + """Custom forward that collects inputs and outputs. + + Note: 'self' refers to the patched layer. + """ + assert len(args) >= 1 + self.inputs.append((args, kwargs)) + output = self._original_forward(*args, **kwargs) + self.outputs.append(output) + if getattr(self, "_stop_after_collection", False): + raise StopIteration() + return output + + bind_forward_method(layer, _forward_w_data_collection, "_original_forward") + layer.inputs = [] + layer.outputs = [] + layer._stop_after_collection = stop_after_collection + + @staticmethod + def _unpatch_and_cleanup_layer(layer: torch.nn.Module): + """Restore a layer's original forward method and clean up.""" + unpatch_forward_method(layer, "_original_forward") + del layer.inputs + del layer.outputs + if hasattr(layer, "_stop_after_collection"): + del layer._stop_after_collection + + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + """Collect input activations for a layer by running the forward loop. + + Propagation stops at the patched layer for each batch (saves compute by not running deeper layers), + but the forward_loop continues to process all batches. + + This function is typically used to collect input activations for the first decoder layer of the model. + """ + + # Wrap model forward to catch StopIteration per-batch + def _early_stop_forward(self, *args, **kwargs): + try: + return self._original_forward(*args, **kwargs) + except StopIteration: + return None # Stop propagation but allow next batch + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + self._patch_and_initialize_layer(layer, stop_after_collection=True) + try: + forward_loop(self.model) + inputs = layer.inputs.copy() + finally: + self._unpatch_and_cleanup_layer(layer) + unpatch_forward_method(self.model, "_original_forward") + return inputs + + def get_output_activations(self, layer: torch.nn.Module, inputs: list) -> list: + """Run inputs through layer and collect outputs.""" + self._patch_and_initialize_layer(layer, stop_after_collection=False) + try: + for args, kwargs in inputs: + layer(*args, **kwargs) + outputs = layer.outputs.copy() + finally: + self._unpatch_and_cleanup_layer(layer) + return outputs diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375..1dd3f15d8 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,6 +46,7 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", + "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", @@ -634,3 +635,35 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str): with temporarily_remove_accelerate_hook(module): setattr(module, "forward", getattr(module, orig_forward_cache_name)) delattr(module, orig_forward_cache_name) + + +def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None: + """Get the decoder layers from a model for sequential calibration. + + Args: + model: The model to extract decoder layers from. + granularity: The type of layers to extract. Currently only "decoder" is supported. + + Returns: + A ModuleList of decoder layers, or None if not found. + """ + if granularity != "decoder": + raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.") + + # HuggingFace transformers pattern: model.model.layers + if hasattr(model, "model") and hasattr(model.model, "layers"): + return model.model.layers + + # Megatron/MCore pattern: model.decoder.layers + if hasattr(model, "decoder") and hasattr(model.decoder, "layers"): + return model.decoder.layers + + # Direct layers attribute (some models) + if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList): + return model.layers + + # GPT-style: model.transformer.h + if hasattr(model, "transformer") and hasattr(model.transformer, "h"): + return model.transformer.h + + return None From ed2d46950b0115ca47f658a3757f220a73f48c95 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:46:47 +0000 Subject: [PATCH 2/4] tested, revert later --- examples/llm_ptq/hf_ptq.py | 76 ++++++++++++++++++++++ modelopt/torch/quantization/model_calib.py | 9 ++- 2 files changed, 80 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 0d2a5aae9..82a9a0560 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -626,6 +626,82 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + if True: + # Disable quantizers + # mtq.fold_weight(full_model) + # print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + if True: + # mtq.fold_weight(full_model) + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + breakpoint() export_hf_checkpoint( full_model, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index a425128b2..4cd15f840 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1494,6 +1494,8 @@ def sequential_calibrate( **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm.""" + max_calibrate(model) + breakpoint() transformer_layers = get_decoder_layers(model) if transformer_layers is None: raise ValueError( @@ -1503,15 +1505,12 @@ def sequential_calibrate( print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") gettr = LayerActivationGettr(model) - inputs = gettr.get_input_activations(transformer_layers[0], forward_loop) for layer in transformer_layers: + # Get outputs + inputs = gettr.get_input_activations(layer, forward_loop) # Call GPTQ calib_func(layer, inputs, **calib_kwargs) - # Get outputs - outputs = gettr.get_output_activations(layer, inputs) - # Update inputs - inputs = [(out, *inp[0][1:]) for inp, out in zip(inputs, outputs)] print_rank_0("Sequential calibration completed successfully") From 67ca94fa1be529f28b30b1a44da8b30dd294ece3 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:41:46 +0000 Subject: [PATCH 3/4] tested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 42 +++++++--- modelopt/torch/quantization/config.py | 94 ++++++++++++++++++++++ modelopt/torch/quantization/model_calib.py | 8 +- 3 files changed, 132 insertions(+), 12 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 82a9a0560..7c02fa26f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -87,6 +87,10 @@ "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "int4_gptq": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG, + "nvfp4_static_wo_gptq": mtq.NVFP4_STATIC_WO_GPTQ_CFG, + "nvfp4_static_wo": mtq.NVFP4_STATIC_WO_CFG, + "nvfp4_static_wo_gptq_lite": mtq.NVFP4_STATIC_WO_GPTQ_LITE_CFG, + "nvfp4_dynamic_wo_gptq": mtq.NVFP4_DYNAMIC_WO_CFG, } KV_QUANT_CFG_CHOICES = { @@ -252,6 +256,10 @@ def auto_quantize( "nvfp4_mlp_only", "mxfp8", "int4_gptq", + "nvfp4_dynamic_wo_gptq", + "nvfp4_static_wo_gptq", + "nvfp4_static_wo", + "nvfp4_static_wo_gptq_lite", ] for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" @@ -626,14 +634,16 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - if True: + if args.export_qdq_weights: # Disable quantizers - # mtq.fold_weight(full_model) - # print("Folded weights") + if "gptq" not in args.qformat: + mtq.fold_weight(full_model) + print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") mtq.disable_quantizer(full_model, "*") + if True: - # mtq.fold_weight(full_model) import os import torch.nn.functional as F @@ -700,13 +710,13 @@ def _compute_perplexity(model, data, batch_size: int = 1): eval_data = _get_wikitext2(tokenizer, 2048) ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - - breakpoint() - - export_hf_checkpoint( - full_model, - export_dir=export_path, - ) + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) + else: + export_hf_checkpoint( + full_model, + export_dir=export_path, + ) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) @@ -944,6 +954,10 @@ def quantize_main( "nvfp4_mlp_only", "mxfp8", "int4_gptq", + "nvfp4_static_wo_gptq", + "nvfp4_static_wo", + "nvfp4_static_wo_gptq_lite", + "nvfp4_dynamic_wo_gptq", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES ), f"Plain quantization format {args.qformat} not supported for HF export path" @@ -1108,6 +1122,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--export_qdq_weights", + help=("Used for GPTQ weights as is without compressed weights for deployment."), + default=False, + action="store_true", + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 73659a465..248f62fd6 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -246,6 +246,100 @@ }, } +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": { diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 4cd15f840..a727dd29c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1505,12 +1505,18 @@ def sequential_calibrate( print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") gettr = LayerActivationGettr(model) + inputs = gettr.get_input_activations(transformer_layers[0], forward_loop) for layer in transformer_layers: - # Get outputs inputs = gettr.get_input_activations(layer, forward_loop) # Call GPTQ calib_func(layer, inputs, **calib_kwargs) + del inputs + torch.cuda.empty_cache() + # Get outputs + # outputs = gettr.get_output_activations(layer, inputs) + # Update inputs + # inputs = [((out, *inp[0][1:]), inp[1]) for inp, out in zip(inputs, outputs)] print_rank_0("Sequential calibration completed successfully") From 65300404851b996d647a899428f14082cd928063 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 11 Feb 2026 07:43:06 +0000 Subject: [PATCH 4/4] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 80 +++++----------------- 1 file changed, 16 insertions(+), 64 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index a727dd29c..8a1a1fa6c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1250,56 +1250,6 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv -def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): - """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. - - Args: - full_weight: The full weight tensor (needed for INT4 quantization) - block_start: Starting column index of the block - block_end: Ending column index of the block - h_inv: Hessian inverse - quantizer: The quantizer to apply - Returns: - quantized_block: Quantized weights for this block - losses: Quantization losses per element - errors: Accumulated errors for propagation - """ - # Extract the block we're working on - block_weight = full_weight[:, block_start:block_end] - block_hinv = h_inv[block_start:block_end, block_start:block_end] - block_size = block_end - block_start - - quantized_block = torch.zeros_like(block_weight) - losses = torch.zeros_like(block_weight) - errors = torch.zeros_like(block_weight) - - # We perform column-wise update for GPTQ within the block - group_size = 1 - - for group_start in range(0, block_size, group_size): - group_end = min(group_start + group_size, block_size) - group_cols = slice(group_start, group_end) - # Get current column and its Hessian inverse diagonal - weight_col = block_weight[:, group_cols] - hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) - - # Quantize using the full weight, then extract the columns we need - quantized_full = quantizer(full_weight) - quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] - quantized_block[:, group_cols] = quantized_cols - - # Compute quantization error and loss - error = (weight_col - quantized_cols) / hinv_diag - losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 - errors[:, group_cols] = error - - # Propagate error to remaining columns in block - block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] - full_weight[:, block_start:block_end] = block_weight - - return quantized_block, losses, errors - - def blockwise_weight_update(module, h, block_size, percdamp): """Update module weights using GPTQ-style blockwise quantization. @@ -1315,28 +1265,30 @@ def blockwise_weight_update(module, h, block_size, percdamp): # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) - # Initialize output tensors - quantized_weight = torch.zeros_like(weight) - losses = torch.zeros_like(weight) - # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) - - quantized_block, block_losses, block_errors = quantize_block( - weight, block_start, block_end, h_inv, module.weight_quantizer - ) - # Store results - quantized_weight[:, block_start:block_end] = quantized_block - losses[:, block_start:block_end] = block_losses + n_cols = block_end - block_start + wblk = weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] + + for i in range(n_cols): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = module.weight_quantizer(wblk) + weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err # Propagate errors to remaining weights - weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] + weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) # Print relative mse error - _print_relative_mse_error(quantized_weight, module.weight.float(), h, module.name) + _print_relative_mse_error(weight, module.weight.float(), h, module.name) # Update module weights - module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) + module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) def gptq_lite(