diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e32d0dae8..7c02fa26f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -86,6 +86,11 @@ "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, + "int4_gptq": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG, + "nvfp4_static_wo_gptq": mtq.NVFP4_STATIC_WO_GPTQ_CFG, + "nvfp4_static_wo": mtq.NVFP4_STATIC_WO_CFG, + "nvfp4_static_wo_gptq_lite": mtq.NVFP4_STATIC_WO_GPTQ_LITE_CFG, + "nvfp4_dynamic_wo_gptq": mtq.NVFP4_DYNAMIC_WO_CFG, } KV_QUANT_CFG_CHOICES = { @@ -250,6 +255,11 @@ def auto_quantize( "w4a8_mxfp4_fp8", "nvfp4_mlp_only", "mxfp8", + "int4_gptq", + "nvfp4_dynamic_wo_gptq", + "nvfp4_static_wo_gptq", + "nvfp4_static_wo", + "nvfp4_static_wo_gptq_lite", ] for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" @@ -624,11 +634,89 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - - export_hf_checkpoint( - full_model, - export_dir=export_path, - ) + if args.export_qdq_weights: + # Disable quantizers + if "gptq" not in args.qformat: + mtq.fold_weight(full_model) + print("Folded weights") + + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + + if True: + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) + else: + export_hf_checkpoint( + full_model, + export_dir=export_path, + ) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) @@ -865,6 +953,11 @@ def quantize_main( "w4a8_mxfp4_fp8", "nvfp4_mlp_only", "mxfp8", + "int4_gptq", + "nvfp4_static_wo_gptq", + "nvfp4_static_wo", + "nvfp4_static_wo_gptq_lite", + "nvfp4_dynamic_wo_gptq", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES ), f"Plain quantization format {args.qformat} not supported for HF export path" @@ -1029,6 +1122,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--export_qdq_weights", + help=("Used for GPTQ weights as is without compressed weights for deployment."), + default=False, + action="store_true", + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e1b48ee60..248f62fd6 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -234,6 +234,112 @@ "algorithm": "max", } +INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": { @@ -992,6 +1098,15 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.", ) + use_sequential: bool = ModeloptField( + default=False, + title="Enable sequential layer-by-layer calibration.", + description=( + "If True, the calibration algorithm is applied sequentially to each decoder block. " + "Outputs from one layer become inputs to the next, reducing memory usage for large models." + ), + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. @@ -1228,6 +1343,44 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): ) +class GPTQConfig(QuantizeAlgorithmConfig): + """The config for GPTQ lite. + + GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + + GPTQ lite does not perform sequential quantization of layers. This means that the updated + activations are not used to process the next layer. + + The default values are taken from the official GPTQ implementation: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 + + Note: This feature is currently experimental and may not translate to improved accuracy as expected. + + + """ + + method: Literal["gptq"] = ModeloptField("gptq") + percdamp: float | None = ModeloptField( + default=0.01, + gt=0.0, + le=1.0, + title="Percentage damping factor.", + description="The percentage of average Hessian diagonal used for damping.", + ) + block_size: int | None = ModeloptField( + default=128, + title="Block size for GPTQ weight update.", + description="""The block size for GPTQ weight update, which must be a multiple of the + group_size used in the quantization.""", + ) + hessian_state_path: str | None = ModeloptField( + default=None, + title="Path to the Hessian state file.", + description="""The path to the Hessian state file. If hessian path exists, we load from + hessian file instead of recomputing them.""", + ) + + QuantizeQuantCfgType = dict[ str | Callable, QuantizerAttributeConfig diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index bfcdb64da..b3df12785 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + GPTQConfig, GPTQLiteConfig, MaxCalibConfig, MseCalibConfig, @@ -56,7 +57,16 @@ restore_svdquant_model, update_quantize_metadata, ) -from .model_calib import awq, gptq_lite, max_calibrate, mse_calibrate, smoothquant, svdquant +from .model_calib import ( + awq, + gptq, + gptq_lite, + max_calibrate, + mse_calibrate, + sequential_calibrate, + smoothquant, + svdquant, +) __all__ = ["BaseCalibrateModeDescriptor"] @@ -212,13 +222,24 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") + sequential = kwargs.pop("use_sequential", False) + if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method if func is not None: - # Call the function with forward_loop as a separate argument - func(model, forward_loop=forward_loop, **kwargs) + if sequential: + # Wrap with sequential processing - just pass func as calib_func! + sequential_calibrate( + model, + forward_loop=forward_loop, + calib_func=func, # <-- Pass func directly! + **kwargs, + ) + else: + # Direct calibration (existing behavior) + func(model, forward_loop=forward_loop, **kwargs) # Lets get the latest metadata for the quantizer states metadata = {} @@ -452,3 +473,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return GPTQLiteConfig _calib_func = gptq_lite + + +@CalibrateModeRegistry.register_mode +class GPTQModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for GPTQ calibration algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return GPTQConfig + + _calib_func = gptq diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 591de3240..a727dd29c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -18,6 +18,7 @@ import math import os import warnings +from collections.abc import Callable from functools import partial import torch @@ -27,9 +28,14 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.quantization.utils import LayerActivationGettr from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.network import ( + bind_forward_method, + get_decoder_layers, + unpatch_forward_method, +) from modelopt.torch.utils.perf import get_used_gpu_mem_fraction from .calib import MseCalibrator @@ -1478,3 +1484,123 @@ def hessian_hook(module, input, output): torch.cuda.empty_cache() print_rank_0("GPTQ-lite quantization completed successfully") + + +@torch.no_grad() +def sequential_calibrate( + model: nn.Module, + forward_loop: ForwardLoop, + calib_func: Callable, + **calib_kwargs, +): + """Sequential calibration - a sequential layer-by-layer calibration algorithm.""" + max_calibrate(model) + breakpoint() + transformer_layers = get_decoder_layers(model) + if transformer_layers is None: + raise ValueError( + "Could not find transformer layers in model'. " + "Sequential calibration requires a model with identifiable transformer layers." + ) + + print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + gettr = LayerActivationGettr(model) + inputs = gettr.get_input_activations(transformer_layers[0], forward_loop) + + for layer in transformer_layers: + inputs = gettr.get_input_activations(layer, forward_loop) + # Call GPTQ + calib_func(layer, inputs, **calib_kwargs) + del inputs + torch.cuda.empty_cache() + # Get outputs + # outputs = gettr.get_output_activations(layer, inputs) + # Update inputs + # inputs = [((out, *inp[0][1:]), inp[1]) for inp, out in zip(inputs, outputs)] + + print_rank_0("Sequential calibration completed successfully") + + +@torch.no_grad() +def gptq( + layer: nn.Module, + inputs: list[tuple[tuple, dict]], + percdamp: float = 0.01, + block_size: int = 128, + **kwargs, +): + """GPTQ quantization - a GPTQ variant.""" + import time + + total_start = time.time() + + # Dictionary to store hessian matrices for all linear layers in this decoder + hessian_state = {} + + # Phase 1: Build tensor mapping for all quantized linear layers in this decoder layer + tensor_mapping = {} + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + in_features = module.weight.shape[-1] + tensor_mapping[name] = ((in_features, in_features), module.weight.device) + module.name = name # Attach name for easy access in hooks + + if not tensor_mapping: + print_rank_0("No quantized linear layers found in decoder layer, skipping GPTQ") + return + + # Initialize hessian state with zeros + for name, (shape, device) in tensor_mapping.items(): + hessian_state[name] = { + "hessian": torch.zeros(shape, dtype=torch.float32, device=device), + "n_samples": 0, + } + + # Phase 2: Register hooks to collect Hessians during forward passes + def hessian_hook(module, input, output): + """Hook to intercept activations and update hessian matrix.""" + state = hessian_state[module.name] + hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + + handles = [] + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + handles.append(module.register_forward_hook(hessian_hook)) + + # Run forward passes with the provided inputs to collect Hessians + hessian_start = time.time() + print_rank_0( + f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." + ) + for args, kwargs_input in inputs: + layer(*args, **kwargs_input) + + # Remove hooks after collecting Hessians + for handle in handles: + handle.remove() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() + print_rank_0("Updating weights using GPTQ algorithm...") + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + state = hessian_state[module.name] + hessian = state["hessian"].to(module.weight.device) + blockwise_weight_update(module, hessian, block_size, percdamp) + # Free memory + del hessian_state[module.name] + torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index b663ef5f2..b794251d2 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -29,10 +29,13 @@ from torch.distributed.tensor import Replicate from modelopt.torch.utils import get_unwrapped_name, print_rank_0 +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method if TYPE_CHECKING: from collections.abc import Generator + from modelopt.torch.opt.searcher import ForwardLoop + __all__ = [ "EXPORT_MODE", "convert_quantization_axis_to_reduce_axis", @@ -810,3 +813,82 @@ def update_quant_cfg_with_kv_cache_quant( quant_cfg["algorithm"] = "max" print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") return quant_cfg + + +class LayerActivationGettr: + """Helper class for collecting layer activations during forward passes. + + This class allows for sequential layer calibration by + patching layers to capture inputs/outputs during forward passes + """ + + def __init__(self, model: nn.Module): + self.model = model + + @staticmethod + def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): + """Patch a layer to collect inputs and outputs during forward passes.""" + + def _forward_w_data_collection(self, *args, **kwargs): + """Custom forward that collects inputs and outputs. + + Note: 'self' refers to the patched layer. + """ + assert len(args) >= 1 + self.inputs.append((args, kwargs)) + output = self._original_forward(*args, **kwargs) + self.outputs.append(output) + if getattr(self, "_stop_after_collection", False): + raise StopIteration() + return output + + bind_forward_method(layer, _forward_w_data_collection, "_original_forward") + layer.inputs = [] + layer.outputs = [] + layer._stop_after_collection = stop_after_collection + + @staticmethod + def _unpatch_and_cleanup_layer(layer: torch.nn.Module): + """Restore a layer's original forward method and clean up.""" + unpatch_forward_method(layer, "_original_forward") + del layer.inputs + del layer.outputs + if hasattr(layer, "_stop_after_collection"): + del layer._stop_after_collection + + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + """Collect input activations for a layer by running the forward loop. + + Propagation stops at the patched layer for each batch (saves compute by not running deeper layers), + but the forward_loop continues to process all batches. + + This function is typically used to collect input activations for the first decoder layer of the model. + """ + + # Wrap model forward to catch StopIteration per-batch + def _early_stop_forward(self, *args, **kwargs): + try: + return self._original_forward(*args, **kwargs) + except StopIteration: + return None # Stop propagation but allow next batch + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + self._patch_and_initialize_layer(layer, stop_after_collection=True) + try: + forward_loop(self.model) + inputs = layer.inputs.copy() + finally: + self._unpatch_and_cleanup_layer(layer) + unpatch_forward_method(self.model, "_original_forward") + return inputs + + def get_output_activations(self, layer: torch.nn.Module, inputs: list) -> list: + """Run inputs through layer and collect outputs.""" + self._patch_and_initialize_layer(layer, stop_after_collection=False) + try: + for args, kwargs in inputs: + layer(*args, **kwargs) + outputs = layer.outputs.copy() + finally: + self._unpatch_and_cleanup_layer(layer) + return outputs diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375..1dd3f15d8 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,6 +46,7 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", + "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", @@ -634,3 +635,35 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str): with temporarily_remove_accelerate_hook(module): setattr(module, "forward", getattr(module, orig_forward_cache_name)) delattr(module, orig_forward_cache_name) + + +def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None: + """Get the decoder layers from a model for sequential calibration. + + Args: + model: The model to extract decoder layers from. + granularity: The type of layers to extract. Currently only "decoder" is supported. + + Returns: + A ModuleList of decoder layers, or None if not found. + """ + if granularity != "decoder": + raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.") + + # HuggingFace transformers pattern: model.model.layers + if hasattr(model, "model") and hasattr(model.model, "layers"): + return model.model.layers + + # Megatron/MCore pattern: model.decoder.layers + if hasattr(model, "decoder") and hasattr(model.decoder, "layers"): + return model.decoder.layers + + # Direct layers attribute (some models) + if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList): + return model.layers + + # GPT-style: model.transformer.h + if hasattr(model, "transformer") and hasattr(model.transformer, "h"): + return model.transformer.h + + return None