Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 69 additions & 90 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import utils as quant_utils
from .calibrator import _Calibrator

__all__ = ["MseCalibrator"]
__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]


class MseCalibrator(_Calibrator):
Expand All @@ -39,7 +39,6 @@ def __init__(
stop_multiplier: float = 4.0,
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
fp8_scale_sweep: bool = False,
):
"""Initialize MSE calibrator.

Expand All @@ -54,9 +53,6 @@ def __init__(
Should have signature: quant_func(x, amax) -> quantized_x.
error_func: Function to compute error between x and xq.
Default is F.mse_loss(x, xq, reduction='none').
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
instead of using multipliers. This is specifically for NVFP4
per-block quantization where scales are stored in FP8 format.
"""
super().__init__(num_bits=None, axis=axis, unsigned=None)
self._initial_amax = amax
Expand All @@ -67,17 +63,21 @@ def __init__(

self._quant_func = quant_func
self._error_func = error_func
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
self._fp8_scale_sweep = fp8_scale_sweep
if fp8_scale_sweep:
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
self._num_steps = 126
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps

self._amax = None
self._losses_sum: list[torch.Tensor | None] | None = None
self._candidates: torch.Tensor | None = None
self._amax: torch.Tensor | None = None

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate candidate multipliers. Override in subclasses for different candidate sets."""
return torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
"""Compute amax from candidates. Override in subclasses for different amax computation."""
if candidates.ndim != 0: # Called during final compute amax
candidates = candidates.view_as(self._initial_amax)
return self._initial_amax * candidates

@torch.no_grad()
def collect(self, x: torch.Tensor):
Expand All @@ -87,39 +87,22 @@ def collect(self, x: torch.Tensor):
x: Input tensor.
"""
if self._quant_func is None:
raise RuntimeError(
"Quantization function not set. Msecalibrator requires a quant_func to be provided."
)
raise RuntimeError("Quantization function not set.")

x = x.detach().to(dtype=torch.float32)

device = x.device

if self._fp8_scale_sweep:
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)

# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()

# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values_valid = fp8_values[valid_mask]
candidates = self._generate_candidates(device)
if self._candidates is None:
self._candidates = candidates
self._num_steps = len(candidates)
self._losses_sum = [None] * self._num_steps

candidates = fp8_values_valid / 448.0
else:
candidates = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)
# Get reduce axis for per-channel quantization
assert self._losses_sum is not None
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)

for step, candidate in enumerate(candidates):
if self._fp8_scale_sweep:
candidate_amax = (global_amax * candidate) * torch.ones_like(self._initial_amax)
else:
candidate_amax = self._initial_amax * candidate
candidate_amax = self._compute_candidate_amax(candidate)
xq = self._quant_func(x, candidate_amax)

if self._error_func is not None:
Expand All @@ -129,28 +112,16 @@ def collect(self, x: torch.Tensor):

loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False)

if self._candidate_amaxs[step] is None:
self._candidate_amaxs[step] = candidate_amax

if self._losses_sum[step] is None:
self._losses_sum[step] = loss.clone()
else:
self._losses_sum[step] += loss

def reset(self):
"""Reset the stored losses and amax value."""
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
self._losses_sum = None
self._candidates = None
self._amax = None

def clear(self):
"""Clear all cached data to free GPU memory.

Call this after compute_amax() and load_calib_amax() are done.
"""
self._losses_sum = []
self._candidate_amaxs = []

if self._initial_amax is not None:
del self._initial_amax
self._initial_amax = None
Expand All @@ -162,49 +133,28 @@ def compute_amax(self, verbose: bool = False):
Args:
verbose: If True, print the ratio of best_amax to initial_amax.
"""
if not any(loss_sum is not None for loss_sum in self._losses_sum):
if self._losses_sum is None or not any(loss is not None for loss in self._losses_sum):
return None

# Check if this is per-tensor or per-channel based on the first loss
first_loss_sum = None
for loss_sum in self._losses_sum:
if loss_sum is not None:
first_loss_sum = loss_sum
break

if first_loss_sum is None:
first_loss = next((loss for loss in self._losses_sum if loss is not None), None)
if first_loss is None:
return None

# Collect losses for all steps
losses_per_step = []
# Stack losses: [num_steps] or [num_steps, num_channels]
losses = []
for step in range(self._num_steps):
if self._losses_sum[step] is not None:
losses_per_step.append(self._losses_sum[step])
# No data for this step, use inf
elif first_loss_sum.ndim == 0:
losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device))
losses.append(self._losses_sum[step])
elif first_loss.ndim == 0:
losses.append(torch.tensor(float("inf"), device=first_loss.device))
else:
losses_per_step.append(torch.full_like(first_loss_sum, float("inf")))

# Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel
losses_per_step = torch.stack(losses_per_step)
losses.append(torch.full_like(first_loss, float("inf")))

# Find best step(s): scalar for per-tensor, [num_channels] for per-channel
best_steps = torch.argmin(losses_per_step, dim=0)

# Stack candidate amaxs and select based on best_steps
candidate_amaxs = torch.stack(self._candidate_amaxs)

if first_loss_sum.ndim == 0:
# Per-tensor case: best_steps is a scalar
self._amax = self._candidate_amaxs[best_steps.item()]
else:
# Per-channel case: best_steps is a tensor
num_channels = best_steps.shape[0]
self._amax = candidate_amaxs[
best_steps, torch.arange(num_channels, device=best_steps.device)
]
self._amax = self._amax.reshape(self._initial_amax.shape)
losses = torch.stack(losses)
best_indices = torch.argmin(losses, dim=0)
assert self._candidates is not None
best_candidates = self._candidates[best_indices]
self._amax = self._compute_candidate_amax(best_candidates)

if verbose:
ratio = self._amax / self._initial_amax
Expand All @@ -219,3 +169,32 @@ def compute_amax(self, verbose: bool = False):
)

return self._amax


class NVFP4MSECalibrator(MseCalibrator):
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization."""

def __init__(
self,
amax: torch.Tensor, # per_block_amax shape [num_blocks]
global_amax: torch.Tensor, # scalar
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
):
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
self._global_amax = global_amax

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
if candidates.ndim != 0: # Called during final compute amax
candidates = candidates.view_as(self._initial_amax)
return torch.ones_like(self._initial_amax) * self._global_amax * candidates

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate 126 valid FP8 E4M3 scale candidates."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
return fp8_values / 448.0
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_QuantizeExportConfig,
)
from .nn import (
NVFP4StaticQuantizer,
QuantModule,
QuantModuleRegistry,
SequentialQuantizer,
Expand Down Expand Up @@ -125,6 +126,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer):
name = get_unwrapped_name(name, model)
state = quantizer_state_dict[name]
# TODO: Add a registry for TensorQuantizers and avoid this manual conversion.
if state.get("_is_nvfp4_static_quantizer") and not isinstance(
module, NVFP4StaticQuantizer
):
NVFP4StaticQuantizer.from_tensor_quantizer(module)
module.set_from_modelopt_state(quantizer_state_dict[name])

for name, module in model.named_modules():
Expand Down
74 changes: 30 additions & 44 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction

from .calib import MseCalibrator
from .calib import MseCalibrator, NVFP4MSECalibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import QuantModule, SequentialQuantizer, TensorQuantizer
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
disable_calib,
enable_fake_quant,
Expand All @@ -44,6 +44,7 @@
is_quantized_linear,
is_quantized_row_parallel_linear,
quantizer_attr_names,
reduce_amax,
weight_attr_names,
)

Expand Down Expand Up @@ -264,7 +265,7 @@ def mse_calibrate(
weight_quantizers = []
seen_modules = set()

for name, module in model.named_modules():
for name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
Expand All @@ -276,6 +277,24 @@ def mse_calibrate(
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)

if is_nvfp4_static:
# Compute and set global_amax
global_amax = reduce_amax(initial_amax, axis=None)

# Convert to NVFP4StaticQuantizer in-place
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
global_amax=module.global_amax,
quant_func=partial(_mse_quant_func, quantizer=module),
)
continue

if fp8_scale_sweep and not is_nvfp4_static:
warnings.warn(
f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static "
Expand All @@ -290,7 +309,6 @@ def mse_calibrate(
start_multiplier=start_multiplier,
stop_multiplier=stop_multiplier,
quant_func=partial(_mse_quant_func, quantizer=module),
fp8_scale_sweep=fp8_scale_sweep and is_nvfp4_static,
)

# Identify weight quantizers by checking if they have corresponding weight parameters
Expand All @@ -309,40 +327,12 @@ def mse_calibrate(
# This ensures weights are only calibrated once, not during every forward pass
for parent_module, weight_name, weight_quantizer in weight_quantizers:
# Enable calibration mode for the weight quantizer
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()

enable_stats_collection(parent_module)
with enable_weight_access_and_writeback(parent_module, model):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)

# Step 4: Disable weight quantizers during forward loop
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.disable()

# Step 5: Collect data with MSE calibrators for activation quantizers only
enable_stats_collection(model)
if forward_loop is None:
# If no forward loop, nothing else to do since weights are already calibrated
pass
else:
# Run forward loop - only activation quantizers will collect data
forward_loop(model)

# Step 6: Re-enable weight quantizers before finalizing calibration
# This ensures finish_stats_collection processes them correctly
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.enable()

# Step 7: Compute optimal amax and load it for all quantizers (weights + activations)
finish_stats_collection(model, method="mse")

# Step 8: Free GPU memory by clearing calibrator data
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
if hasattr(module, "_calibrator") and getattr(module, "_calibrator", None) is not None:
if hasattr(module._calibrator, "clear"):
module._calibrator.clear()
finish_stats_collection(parent_module, method="mse")
weight_quantizer._calibrator.reset()

# TODO: Sync amax across distributed processes

Expand All @@ -358,23 +348,19 @@ def enable_stats_collection(model: nn.Module):
module.disable()


def finish_stats_collection(model: nn.Module, method: str | None = None):
def finish_stats_collection(model: nn.Module, method: str | None = None, **kwargs):
"""Finish stats collection for all quantizers in the model."""
for _, module in model.named_modules():
if not isinstance(module, TensorQuantizer) or module._disabled:
continue

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"mse", "entropy"}:
if method in {"entropy"}:
if cal.compute_amax(method) is not None:
if method == "entropy":
module.load_calib_amax("entropy")
else:
module.load_calib_amax()
elif cal.compute_amax() is not None:
# Max calibrator
module.load_calib_amax()
module.load_calib_amax("entropy", **kwargs)
elif cal.compute_amax(**kwargs) is not None:
module.load_calib_amax(**kwargs)

if module.bias_calibrator is not None and module.bias_type == "static":
module.load_calib_bias()
Expand Down
Loading