Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 104 additions & 5 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"int4_gptq": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG,
"nvfp4_static_wo_gptq": mtq.NVFP4_STATIC_WO_GPTQ_CFG,
"nvfp4_static_wo": mtq.NVFP4_STATIC_WO_CFG,
"nvfp4_static_wo_gptq_lite": mtq.NVFP4_STATIC_WO_GPTQ_LITE_CFG,
"nvfp4_dynamic_wo_gptq": mtq.NVFP4_DYNAMIC_WO_CFG,
}

KV_QUANT_CFG_CHOICES = {
Expand Down Expand Up @@ -250,6 +255,11 @@ def auto_quantize(
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"mxfp8",
"int4_gptq",
"nvfp4_dynamic_wo_gptq",
"nvfp4_static_wo_gptq",
"nvfp4_static_wo",
"nvfp4_static_wo_gptq_lite",
]
for args.qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"
Expand Down Expand Up @@ -624,11 +634,89 @@ def export_quantized(
"Unified HF export format does not specify inference tensor parallel or pipeline parallel. "
"They will be set at deployment time."
)

export_hf_checkpoint(
full_model,
export_dir=export_path,
)
if args.export_qdq_weights:
# Disable quantizers
if "gptq" not in args.qformat:
mtq.fold_weight(full_model)
print("Folded weights")

print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)")
mtq.disable_quantizer(full_model, "*")

if True:
import os

import torch.nn.functional as F
from datasets import load_dataset
from tqdm import trange
from transformers import AutoTokenizer

# Set cache directory to work directory to avoid disk space issues
cache_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), ".hf_cache"
)
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = cache_dir
print(f"Using HuggingFace datasets cache: {cache_dir}")

def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int):
test_dataset_raw = load_dataset(
"wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir
)
test_dataset_tok = tokenizer(
"\n\n".join(test_dataset_raw["text"]), return_tensors="pt"
).input_ids
num_test_sequences = test_dataset_tok.numel() // sequence_length
test_loader = [
test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length]
for i in range(num_test_sequences)
]
return test_loader

@torch.no_grad()
def _compute_perplexity(model, data, batch_size: int = 1):
num_samples = len(data)
device = next(model.parameters()).device
# Running estimate of negative log-likelihood
nll_running = 0
# Number of tokens processed to far
tokens_processed = 0
# Loop through each batch
for i in trange(
0, num_samples, batch_size, desc="Computing perplexity", leave=False
):
j = min(i + batch_size, num_samples)
inputs = torch.cat(data[i:j]).to(device)
# Forward pass through the model
lm_logits = model(inputs).logits
# Shift logits and labels for next token prediction
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inputs[:, 1:]
# Compute loss
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.reshape(-1),
)
# Calculate negative log likelihood
a = shift_labels.numel() / (tokens_processed + shift_labels.numel())
b = tokens_processed / (tokens_processed + shift_labels.numel())
nll_running = a * loss + b * nll_running
# Update number of processed tokens
tokens_processed += shift_labels.numel()
# Compute perplexity
ppl = nll_running.exp().item()
return ppl

eval_data = _get_wikitext2(tokenizer, 2048)
ppl = _compute_perplexity(full_model, eval_data)
print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}")
print(f"Saving model to {args.export_path}")
full_model.save_pretrained(args.export_path)
else:
export_hf_checkpoint(
full_model,
export_dir=export_path,
)

# Copy custom model files (Python files and JSON configs) if trust_remote_code is used
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)
Expand Down Expand Up @@ -865,6 +953,11 @@ def quantize_main(
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"mxfp8",
"int4_gptq",
"nvfp4_static_wo_gptq",
"nvfp4_static_wo",
"nvfp4_static_wo_gptq_lite",
"nvfp4_dynamic_wo_gptq",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Plain quantization format {args.qformat} not supported for HF export path"
Expand Down Expand Up @@ -1029,6 +1122,12 @@ def parse_args() -> argparse.Namespace:
default=False,
action="store_true",
)
parser.add_argument(
"--export_qdq_weights",
help=("Used for GPTQ weights as is without compressed weights for deployment."),
default=False,
action="store_true",
)
parser.add_argument(
"--verbose",
help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.",
Expand Down
153 changes: 153 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,112 @@
"algorithm": "max",
}

INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "gptq",
"use_sequential": True,
},
}

NVFP4_STATIC_WO_GPTQ_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "gptq",
"use_sequential": True,
},
}

NVFP4_STATIC_WO_GPTQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "gptq_lite",
"use_sequential": False,
},
}

NVFP4_STATIC_WO_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "max",
"use_sequential": False,
},
}

NVFP4_STATIC_WO_GPTQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "gptq_lite",
"use_sequential": False,
},
}

NVFP4_DYNAMIC_WO_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "gptq_lite",
"use_sequential": False,
},
}

INT4_AWQ_CFG = {
"quant_cfg": {
Expand Down Expand Up @@ -992,6 +1098,15 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.",
)

use_sequential: bool = ModeloptField(
default=False,
title="Enable sequential layer-by-layer calibration.",
description=(
"If True, the calibration algorithm is applied sequentially to each decoder block. "
"Outputs from one layer become inputs to the next, reducing memory usage for large models."
),
)


class MaxCalibConfig(QuantizeAlgorithmConfig):
"""The config for max calibration algorithm.
Expand Down Expand Up @@ -1228,6 +1343,44 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig):
)


class GPTQConfig(QuantizeAlgorithmConfig):
"""The config for GPTQ lite.

GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation.

GPTQ lite does not perform sequential quantization of layers. This means that the updated
activations are not used to process the next layer.

The default values are taken from the official GPTQ implementation:
https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35

Note: This feature is currently experimental and may not translate to improved accuracy as expected.


"""

method: Literal["gptq"] = ModeloptField("gptq")
percdamp: float | None = ModeloptField(
default=0.01,
gt=0.0,
le=1.0,
title="Percentage damping factor.",
description="The percentage of average Hessian diagonal used for damping.",
)
block_size: int | None = ModeloptField(
default=128,
title="Block size for GPTQ weight update.",
description="""The block size for GPTQ weight update, which must be a multiple of the
group_size used in the quantization.""",
)
hessian_state_path: str | None = ModeloptField(
default=None,
title="Path to the Hessian state file.",
description="""The path to the Hessian state file. If hessian path exists, we load from
hessian file instead of recomputing them.""",
)


QuantizeQuantCfgType = dict[
str | Callable,
QuantizerAttributeConfig
Expand Down
39 changes: 36 additions & 3 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AWQFullCalibConfig,
AWQLiteCalibConfig,
CompressConfig,
GPTQConfig,
GPTQLiteConfig,
MaxCalibConfig,
MseCalibConfig,
Expand All @@ -56,7 +57,16 @@
restore_svdquant_model,
update_quantize_metadata,
)
from .model_calib import awq, gptq_lite, max_calibrate, mse_calibrate, smoothquant, svdquant
from .model_calib import (
awq,
gptq,
gptq_lite,
max_calibrate,
mse_calibrate,
sequential_calibrate,
smoothquant,
svdquant,
)

__all__ = ["BaseCalibrateModeDescriptor"]

Expand Down Expand Up @@ -212,13 +222,24 @@ def wrapped_calib_func(
"""
kwargs = config.model_dump()
method = kwargs.pop("method")
sequential = kwargs.pop("use_sequential", False)

if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method

if func is not None:
# Call the function with forward_loop as a separate argument
func(model, forward_loop=forward_loop, **kwargs)
if sequential:
# Wrap with sequential processing - just pass func as calib_func!
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func, # <-- Pass func directly!
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)

# Lets get the latest metadata for the quantizer states
metadata = {}
Expand Down Expand Up @@ -452,3 +473,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
return GPTQLiteConfig

_calib_func = gptq_lite


@CalibrateModeRegistry.register_mode
class GPTQModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for GPTQ calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
"""Specifies the config class for the mode."""
return GPTQConfig

_calib_func = gptq
Loading