diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 4203473e2..1f439a7e7 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -24,7 +24,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -39,7 +39,6 @@ def __init__( stop_multiplier: float = 4.0, quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - fp8_scale_sweep: bool = False, ): """Initialize MSE calibrator. @@ -54,9 +53,6 @@ def __init__( Should have signature: quant_func(x, amax) -> quantized_x. error_func: Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction='none'). - fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values - instead of using multipliers. This is specifically for NVFP4 - per-block quantization where scales are stored in FP8 format. """ super().__init__(num_bits=None, axis=axis, unsigned=None) self._initial_amax = amax @@ -67,17 +63,21 @@ def __init__( self._quant_func = quant_func self._error_func = error_func - self._losses_sum = [None] * self._num_steps - self._candidate_amaxs = [None] * self._num_steps - self._fp8_scale_sweep = fp8_scale_sweep - if fp8_scale_sweep: - # For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values - # (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN) - self._num_steps = 126 - self._losses_sum = [None] * self._num_steps - self._candidate_amaxs = [None] * self._num_steps - - self._amax = None + self._losses_sum: list[torch.Tensor | None] | None = None + self._candidates: torch.Tensor | None = None + self._amax: torch.Tensor | None = None + + def _generate_candidates(self, device: torch.device) -> torch.Tensor: + """Generate candidate multipliers. Override in subclasses for different candidate sets.""" + return torch.linspace( + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + ) + + def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: + """Compute amax from candidates. Override in subclasses for different amax computation.""" + if candidates.ndim != 0: # Called during final compute amax + candidates = candidates.view_as(self._initial_amax) + return self._initial_amax * candidates @torch.no_grad() def collect(self, x: torch.Tensor): @@ -87,39 +87,22 @@ def collect(self, x: torch.Tensor): x: Input tensor. """ if self._quant_func is None: - raise RuntimeError( - "Quantization function not set. Msecalibrator requires a quant_func to be provided." - ) + raise RuntimeError("Quantization function not set.") x = x.detach().to(dtype=torch.float32) - device = x.device - if self._fp8_scale_sweep: - global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) - - # Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn) - # Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32 - uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) - fp8_values = uint8_values.view(torch.float8_e4m3fn).float() - - # Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers - valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) - fp8_values_valid = fp8_values[valid_mask] + candidates = self._generate_candidates(device) + if self._candidates is None: + self._candidates = candidates + self._num_steps = len(candidates) + self._losses_sum = [None] * self._num_steps - candidates = fp8_values_valid / 448.0 - else: - candidates = torch.linspace( - self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device - ) - # Get reduce axis for per-channel quantization + assert self._losses_sum is not None reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) for step, candidate in enumerate(candidates): - if self._fp8_scale_sweep: - candidate_amax = (global_amax * candidate) * torch.ones_like(self._initial_amax) - else: - candidate_amax = self._initial_amax * candidate + candidate_amax = self._compute_candidate_amax(candidate) xq = self._quant_func(x, candidate_amax) if self._error_func is not None: @@ -129,9 +112,6 @@ def collect(self, x: torch.Tensor): loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False) - if self._candidate_amaxs[step] is None: - self._candidate_amaxs[step] = candidate_amax - if self._losses_sum[step] is None: self._losses_sum[step] = loss.clone() else: @@ -139,18 +119,9 @@ def collect(self, x: torch.Tensor): def reset(self): """Reset the stored losses and amax value.""" - self._losses_sum = [None] * self._num_steps - self._candidate_amaxs = [None] * self._num_steps + self._losses_sum = None + self._candidates = None self._amax = None - - def clear(self): - """Clear all cached data to free GPU memory. - - Call this after compute_amax() and load_calib_amax() are done. - """ - self._losses_sum = [] - self._candidate_amaxs = [] - if self._initial_amax is not None: del self._initial_amax self._initial_amax = None @@ -162,49 +133,28 @@ def compute_amax(self, verbose: bool = False): Args: verbose: If True, print the ratio of best_amax to initial_amax. """ - if not any(loss_sum is not None for loss_sum in self._losses_sum): + if self._losses_sum is None or not any(loss is not None for loss in self._losses_sum): return None - # Check if this is per-tensor or per-channel based on the first loss - first_loss_sum = None - for loss_sum in self._losses_sum: - if loss_sum is not None: - first_loss_sum = loss_sum - break - - if first_loss_sum is None: + first_loss = next((loss for loss in self._losses_sum if loss is not None), None) + if first_loss is None: return None - # Collect losses for all steps - losses_per_step = [] + # Stack losses: [num_steps] or [num_steps, num_channels] + losses = [] for step in range(self._num_steps): if self._losses_sum[step] is not None: - losses_per_step.append(self._losses_sum[step]) - # No data for this step, use inf - elif first_loss_sum.ndim == 0: - losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device)) + losses.append(self._losses_sum[step]) + elif first_loss.ndim == 0: + losses.append(torch.tensor(float("inf"), device=first_loss.device)) else: - losses_per_step.append(torch.full_like(first_loss_sum, float("inf"))) - - # Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel - losses_per_step = torch.stack(losses_per_step) + losses.append(torch.full_like(first_loss, float("inf"))) - # Find best step(s): scalar for per-tensor, [num_channels] for per-channel - best_steps = torch.argmin(losses_per_step, dim=0) - - # Stack candidate amaxs and select based on best_steps - candidate_amaxs = torch.stack(self._candidate_amaxs) - - if first_loss_sum.ndim == 0: - # Per-tensor case: best_steps is a scalar - self._amax = self._candidate_amaxs[best_steps.item()] - else: - # Per-channel case: best_steps is a tensor - num_channels = best_steps.shape[0] - self._amax = candidate_amaxs[ - best_steps, torch.arange(num_channels, device=best_steps.device) - ] - self._amax = self._amax.reshape(self._initial_amax.shape) + losses = torch.stack(losses) + best_indices = torch.argmin(losses, dim=0) + assert self._candidates is not None + best_candidates = self._candidates[best_indices] + self._amax = self._compute_candidate_amax(best_candidates) if verbose: ratio = self._amax / self._initial_amax @@ -219,3 +169,32 @@ def compute_amax(self, verbose: bool = False): ) return self._amax + + +class NVFP4MSECalibrator(MseCalibrator): + """Per-block FP8 scale sweep calibrator for NVFP4 static quantization.""" + + def __init__( + self, + amax: torch.Tensor, # per_block_amax shape [num_blocks] + global_amax: torch.Tensor, # scalar + axis: int | tuple | list | None = None, + quant_func: Callable | None = None, + error_func: Callable | None = None, + ): + """Initialize NVFP4 MSE calibrator with per-block and global amax.""" + super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func) + self._global_amax = global_amax + + def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: + if candidates.ndim != 0: # Called during final compute amax + candidates = candidates.view_as(self._initial_amax) + return torch.ones_like(self._initial_amax) * self._global_amax * candidates + + def _generate_candidates(self, device: torch.device) -> torch.Tensor: + """Generate 126 valid FP8 E4M3 scale candidates.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + fp8_values = fp8_values[valid_mask] + return fp8_values / 448.0 diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index c93ea546f..f7ef704ee 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -35,6 +35,7 @@ _QuantizeExportConfig, ) from .nn import ( + NVFP4StaticQuantizer, QuantModule, QuantModuleRegistry, SequentialQuantizer, @@ -125,6 +126,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: for name, module in model.named_modules(): if isinstance(module, TensorQuantizer): name = get_unwrapped_name(name, model) + state = quantizer_state_dict[name] + # TODO: Add a registry for TensorQuantizers and avoid this manual conversion. + if state.get("_is_nvfp4_static_quantizer") and not isinstance( + module, NVFP4StaticQuantizer + ): + NVFP4StaticQuantizer.from_tensor_quantizer(module) module.set_from_modelopt_state(quantizer_state_dict[name]) for name, module in model.named_modules(): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 591de3240..7e5414b9f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -32,9 +32,9 @@ from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from modelopt.torch.utils.perf import get_used_gpu_mem_fraction -from .calib import MseCalibrator +from .calib import MseCalibrator, NVFP4MSECalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context -from .nn import QuantModule, SequentialQuantizer, TensorQuantizer +from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( disable_calib, enable_fake_quant, @@ -44,6 +44,7 @@ is_quantized_linear, is_quantized_row_parallel_linear, quantizer_attr_names, + reduce_amax, weight_attr_names, ) @@ -264,7 +265,7 @@ def mse_calibrate( weight_quantizers = [] seen_modules = set() - for name, module in model.named_modules(): + for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): # Get the initial amax from max calibration @@ -276,6 +277,24 @@ def mse_calibrate( and module._block_sizes is not None and module._block_sizes.get("scale_bits") == (4, 3) ) + + if is_nvfp4_static: + # Compute and set global_amax + global_amax = reduce_amax(initial_amax, axis=None) + + # Convert to NVFP4StaticQuantizer in-place + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + + if fp8_scale_sweep and is_nvfp4_static: + # Replace calibrator with NVFP4MSECalibrator + module._calibrator = NVFP4MSECalibrator( + amax=initial_amax, + axis=module._calibrator._axis, + global_amax=module.global_amax, + quant_func=partial(_mse_quant_func, quantizer=module), + ) + continue + if fp8_scale_sweep and not is_nvfp4_static: warnings.warn( f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static " @@ -290,7 +309,6 @@ def mse_calibrate( start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, quant_func=partial(_mse_quant_func, quantizer=module), - fp8_scale_sweep=fp8_scale_sweep and is_nvfp4_static, ) # Identify weight quantizers by checking if they have corresponding weight parameters @@ -309,40 +327,12 @@ def mse_calibrate( # This ensures weights are only calibrated once, not during every forward pass for parent_module, weight_name, weight_quantizer in weight_quantizers: # Enable calibration mode for the weight quantizer - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() - + enable_stats_collection(parent_module) with enable_weight_access_and_writeback(parent_module, model): weight = getattr(parent_module, weight_name) weight_quantizer(weight) - - # Step 4: Disable weight quantizers during forward loop - for _, _, weight_quantizer in weight_quantizers: - weight_quantizer.disable() - - # Step 5: Collect data with MSE calibrators for activation quantizers only - enable_stats_collection(model) - if forward_loop is None: - # If no forward loop, nothing else to do since weights are already calibrated - pass - else: - # Run forward loop - only activation quantizers will collect data - forward_loop(model) - - # Step 6: Re-enable weight quantizers before finalizing calibration - # This ensures finish_stats_collection processes them correctly - for _, _, weight_quantizer in weight_quantizers: - weight_quantizer.enable() - - # Step 7: Compute optimal amax and load it for all quantizers (weights + activations) - finish_stats_collection(model, method="mse") - - # Step 8: Free GPU memory by clearing calibrator data - for name, module in model.named_modules(): - if isinstance(module, TensorQuantizer) and not module._disabled: - if hasattr(module, "_calibrator") and getattr(module, "_calibrator", None) is not None: - if hasattr(module._calibrator, "clear"): - module._calibrator.clear() + finish_stats_collection(parent_module, method="mse") + weight_quantizer._calibrator.reset() # TODO: Sync amax across distributed processes @@ -358,7 +348,7 @@ def enable_stats_collection(model: nn.Module): module.disable() -def finish_stats_collection(model: nn.Module, method: str | None = None): +def finish_stats_collection(model: nn.Module, method: str | None = None, **kwargs): """Finish stats collection for all quantizers in the model.""" for _, module in model.named_modules(): if not isinstance(module, TensorQuantizer) or module._disabled: @@ -366,15 +356,11 @@ def finish_stats_collection(model: nn.Module, method: str | None = None): cal = getattr(module, "_calibrator", None) if cal and not getattr(module, "_dynamic", False): - if method in {"mse", "entropy"}: + if method in {"entropy"}: if cal.compute_amax(method) is not None: - if method == "entropy": - module.load_calib_amax("entropy") - else: - module.load_calib_amax() - elif cal.compute_amax() is not None: - # Max calibrator - module.load_calib_amax() + module.load_calib_amax("entropy", **kwargs) + elif cal.compute_amax(**kwargs) is not None: + module.load_calib_amax(**kwargs) if module.bias_calibrator is not None and module.bias_type == "static": module.load_calib_bias() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index e004cf0e7..3852d1144 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -64,6 +64,7 @@ from ..functional import normalized_hadamard_transform __all__ = [ + "NVFP4StaticQuantizer", "SequentialQuantizer", "TensorQuantizer", "TensorQuantizerCache", @@ -763,16 +764,6 @@ def _fake_quantize(self, inputs): getattr(self, "_onnx_quantizer_type", None), self._pass_through_bwd, ) - elif self._num_bits == (2, 1) and self.is_static_block_quant: - outputs = static_blockwise_fp4_fake_quant( - inputs, - None, # scale - None, # scale_fp8_quant_amax - False, # skip_scale_quant - inputs.dtype, # out_dtype - self._pass_through_bwd, # pass_through_bwd - amax, # amax - ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 @@ -1249,6 +1240,66 @@ def _set_buffer(self, key, value): self.register_buffer(key, value) +class NVFP4StaticQuantizer(TensorQuantizer): + """TensorQuantizer for NVFP4 static block quantization with two-level scaling. + + Uses _global_amax and inherited _amax for per-block amax values. + """ + + @classmethod + def from_tensor_quantizer( + cls, tq: TensorQuantizer, global_amax: torch.Tensor | None = None + ) -> "NVFP4StaticQuantizer": + """Convert a TensorQuantizer to NVFP4StaticQuantizer in-place. + + Args: + tq: The TensorQuantizer to convert. + global_amax: Optional global amax value to set on the quantizer. + """ + if isinstance(tq, cls): + if global_amax is not None: + tq.global_amax = global_amax + return tq + tq.__class__ = cls + tq._is_nvfp4_static_quantizer = True + if global_amax is not None: + tq.global_amax = global_amax + return tq + + @property + def global_amax(self): + """Return global_amax for quantization.""" + if not hasattr(self, "_global_amax"): + return None + return self._global_amax + + @global_amax.setter + def global_amax(self, value): + if value is None: + if hasattr(self, "_global_amax"): + self._global_amax = None + return + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + if not hasattr(self, "_global_amax") or self._global_amax is None: + self.register_buffer("_global_amax", value.clone().detach()) + else: + self._global_amax.data.copy_(value.clone().detach().to(self._global_amax.device)) + + def _fake_quantize(self, inputs): + """Fake quantization using two-level scaling with _amax and _global_amax.""" + if self.amax is not None: + return static_blockwise_fp4_fake_quant( + inputs, + self.amax, + self.global_amax, # Can be None, will be computed internally + True, # quantize_block_scales + inputs.dtype, + self._pass_through_bwd, + ) + return super()._fake_quantize(inputs) + + class SequentialQuantizer(nn.Sequential): """A sequential container for :class:`TensorQuantizer` modules. diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 0a95d9916..d9b583971 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -171,6 +171,7 @@ def _dynamic_block_quantize_impl( num_bits == (2, 1) # type: ignore[comparison-overlap] and scale_bits == (4, 3) and triton_kernel.IS_AVAILABLE + and hasattr(triton_kernel, "fp4_fake_quant_block") # requires compute >= 8.9 and not DISABLE_TRITON_KERNEL and amax is not None ): @@ -569,28 +570,31 @@ class StaticBlockwiseFP4FakeQuantFunction(Function): def forward( ctx, x, - scale, - scale_fp8_quant_amax, - skip_scale_quant, - out_dtype, + amax, + global_amax=None, + quantize_block_scales=True, + out_dtype=None, pass_through_bwd=False, - amax=None, ): """Forward method.""" - _save_for_backward_if_needed(ctx, pass_through_bwd, x, scale if scale is not None else amax) + if not triton_kernel.IS_AVAILABLE: + raise RuntimeError( + "static_blockwise_fp4_fake_quant requires triton. " + "Install with `pip install triton`." + ) + _save_for_backward_if_needed(ctx, pass_through_bwd, x, amax) return triton_kernel.static_blockwise_fp4_fake_quant( x, - scale, - scale_fp8_quant_amax, - skip_scale_quant, - out_dtype, amax, + global_amax, + quantize_block_scales, + out_dtype, ) @staticmethod def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=7) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=6) def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index c513a4b11..0af34b21f 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -22,8 +22,7 @@ IS_AVAILABLE = False -# triton fp8 requires compute_cap >= 89 -if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9): +if torch.cuda.is_available(): with import_plugin( "triton", msg_if_missing=( @@ -31,6 +30,11 @@ "quantization simulations. Try to install triton with `pip install triton`." ), ): + # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) + if torch.cuda.get_device_capability() >= (8, 9): + from .fp4_kernel_hopper import * + IS_AVAILABLE = True diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 4735fb5ee..63a8b3dcb 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -24,7 +24,7 @@ import triton import triton.language as tl -__all__ = ["fp4_fake_quant_block", "static_blockwise_fp4_fake_quant"] +__all__ = ["fp4_dequantize", "static_blockwise_fp4_fake_quant"] _TORCH_TO_TL_DTYPE = { @@ -42,172 +42,6 @@ def _torch_dtype_to_tl(dtype: torch.dtype): return _TORCH_TO_TL_DTYPE[dtype] -@triton.jit -def fp4_fake_quant_kernel( - x_ptr, - y_ptr, - M, - N, - global_scale_ptr, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - BLOCK_SIZE: tl.constexpr, - TILE_M: tl.constexpr, - TILE_N: tl.constexpr, - NUM_FP4_BLOCKS: tl.constexpr, - OUT_DTYPE: tl.constexpr, -): - """Applies FP4 fake quantization using block pointers for memory addressing.""" - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - - row_start = pid_m * TILE_M - col_start = pid_n * TILE_N - - x_block_ptr = tl.make_block_ptr( - base=x_ptr, - shape=(M, N), - strides=(stride_xm, stride_xn), - offsets=(row_start, col_start), - block_shape=(TILE_M, TILE_N), - order=(1, 0), - ) - y_block_ptr = tl.make_block_ptr( - base=y_ptr, - shape=(M, N), - strides=(stride_ym, stride_yn), - offsets=(row_start, col_start), - block_shape=(TILE_M, TILE_N), - order=(1, 0), - ) - - global_scale = tl.load(global_scale_ptr).to(tl.float32) - global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) - - tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) - - tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)) - x_abs = tl.abs(tile_reshaped) - - block_max = tl.max(x_abs, axis=2, keep_dims=True) - - block_max_scaled = block_max / (6.0 * global_scale_safe) - block_max_scaled = tl.minimum(block_max_scaled, 448.0) - block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale - block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0) - - block_max_quant_broadcast = tl.broadcast_to( - block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE) - ) - - abs_scaled = x_abs / block_max_quant_broadcast - - q_val = tl.where( - abs_scaled <= 0.25, - 0.0, - tl.where( - abs_scaled < 0.75, - 0.5, - tl.where( - abs_scaled <= 1.25, - 1.0, - tl.where( - abs_scaled < 1.75, - 1.5, - tl.where( - abs_scaled <= 2.5, - 2.0, - tl.where( - abs_scaled < 3.5, - 3.0, - tl.where(abs_scaled <= 5.0, 4.0, 6.0), - ), - ), - ), - ), - ), - ) - - x_rescaled = q_val * block_max_quant_broadcast - x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled) - - tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N)) - - tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1)) - - -def fp4_fake_quant_block( - x: torch.Tensor, - global_amax: torch.Tensor, - block_size: int = 16, - tile_rows: int = 16, - tile_cols: int = 64, - num_warps: int | None = None, - num_stages: int | None = None, -) -> torch.Tensor: - """FP4 fake quantization implementation using block-pointer tiling. - - Args: - x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher. - global_amax (torch.Tensor): Global maximum value tensor for scaling. - block_size (int): Number of elements per FP4 block. - tile_rows (int, optional): Row tile size. Defaults to 64. - tile_cols (int, optional): Column tile size. Defaults to 128. Rounded up to - the nearest multiple of ``block_size`` internally. - num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``. - num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``. - - Returns: - torch.Tensor: Fake-quantized tensor matching the input shape and dtype. - """ - x_shape = x.shape - x_dtype = x.dtype - x = x.reshape(-1, x_shape[-1]).contiguous() - - M, N = x.shape - y = torch.empty_like(x) - - stride_xm, stride_xn = x.stride() - stride_ym, stride_yn = y.stride() - - tile_cols = max(tile_cols, block_size) - tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size - num_fp4_blocks = tile_cols_aligned // block_size - - global_scale = global_amax.float() / (6.0 * 448.0) - - grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) - - launch_kwargs = { - "BLOCK_SIZE": block_size, - "TILE_M": tile_rows, - "TILE_N": tile_cols_aligned, - "NUM_FP4_BLOCKS": num_fp4_blocks, - "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), - } - if num_warps is not None: - launch_kwargs["num_warps"] = num_warps - if num_stages is not None: - launch_kwargs["num_stages"] = num_stages - fp4_fake_quant_kernel[grid]( - x, - y, - M, - N, - global_scale, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - **launch_kwargs, - ) - - y = y.view(*x_shape) - return y - - @triton.jit def fp4_dequantize_kernel( packed_ptr, @@ -368,7 +202,13 @@ def static_blockwise_fp4_fake_quant_kernel( x = tl.load(x_ptr + idx).to(tl.float32) x_abs = tl.abs(x) - scale_safe = tl.where(scale >= 1e-5, scale, 1.0) + # If scale is 0, inf, or nan, use 1.0 (matching CUDA kernel behavior) + # Note: (x != x) checks if x is NaN per IEEE 754 + scale_safe = tl.where( + (scale == 0) | (scale != scale) | (tl.abs(scale) == float("inf")), # noqa: PLR0124 + 1.0, + scale, + ) abs_scaled = x_abs / scale_safe # FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 @@ -406,47 +246,38 @@ def static_blockwise_fp4_fake_quant_kernel( def static_blockwise_fp4_fake_quant( x: torch.Tensor, - scale: torch.Tensor | None = None, - scale_fp8_quant_amax: torch.Tensor | None = None, - skip_scale_quant: bool = False, + amax: torch.Tensor, + global_amax: torch.Tensor | None = None, + quantize_block_scales: bool = True, out_dtype: torch.dtype | None = None, - amax: torch.Tensor | None = None, ): """Static blockwise FP4 fake quantization using Triton kernel. Args: x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. - scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. Mutually exclusive with amax. - scale_fp8_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale. - skip_scale_quant: If True, skip FP8 quantization of scale. + amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] per-block amax values. + global_amax: FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax. + quantize_block_scales: If True, quantize block scales to FP8. out_dtype: Output dtype. Defaults to x.dtype if None. - amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. If provided, scale = amax / 6.0. - Mutually exclusive with scale. """ - if scale is None and amax is None: - raise ValueError("Either scale or amax must be provided") - if scale is not None and amax is not None: - raise ValueError("Cannot provide both scale and amax") - - if amax is not None: - scale = amax / 6.0 # FP4 max representable value is 6.0 - - assert scale is not None # Guaranteed by validation above assert x.ndim == 2 NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape if out_dtype is None: out_dtype = x.dtype - if not skip_scale_quant: + amax = amax.float() # Requires to be in float32 + scale = amax / 6.0 # FP4 max representable value is 6.0 + + if quantize_block_scales: from modelopt.torch.quantization.tensor_quant import scaled_e4m3_impl from modelopt.torch.quantization.utils import reduce_amax - if scale_fp8_quant_amax is None: - scale_fp8_quant_amax = reduce_amax( - scale, axis=None, keepdims=False, squeeze_scalar=True - ) + if global_amax is None: + global_amax = reduce_amax(amax, axis=None, keepdims=False, squeeze_scalar=True) + global_amax = global_amax.float() + scale_fp8_quant_amax = global_amax / 6.0 scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax) x_flat = x.contiguous().view(-1) @@ -457,7 +288,6 @@ def static_blockwise_fp4_fake_quant( grid = (NUM_FP4_BLOCKS,) - # Ensure we're running on the correct CUDA device with torch.cuda.device(x.device): static_blockwise_fp4_fake_quant_kernel[grid]( x_flat, diff --git a/modelopt/torch/quantization/triton/fp4_kernel_hopper.py b/modelopt/torch/quantization/triton/fp4_kernel_hopper.py new file mode 100644 index 000000000..2ec31863e --- /dev/null +++ b/modelopt/torch/quantization/triton/fp4_kernel_hopper.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NVFP4 Fake Quantization Triton kernels requiring compute capability >= 8.9 (Hopper+). + +These kernels use tl.float8e4nv which requires native FP8 hardware support. +""" + +import torch +import triton +import triton.language as tl + +from .fp4_kernel import _torch_dtype_to_tl + +__all__ = ["fp4_fake_quant_block"] + + +@triton.jit +def fp4_fake_quant_kernel( + x_ptr, + y_ptr, + M, + N, + global_scale_ptr, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + BLOCK_SIZE: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + NUM_FP4_BLOCKS: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + """Applies FP4 fake quantization using block pointers for memory addressing.""" + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + row_start = pid_m * TILE_M + col_start = pid_n * TILE_N + + x_block_ptr = tl.make_block_ptr( + base=x_ptr, + shape=(M, N), + strides=(stride_xm, stride_xn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + y_block_ptr = tl.make_block_ptr( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + + global_scale = tl.load(global_scale_ptr).to(tl.float32) + global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) + + tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)) + x_abs = tl.abs(tile_reshaped) + + block_max = tl.max(x_abs, axis=2, keep_dims=True) + + block_max_scaled = block_max / (6.0 * global_scale_safe) + block_max_scaled = tl.minimum(block_max_scaled, 448.0) + block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale + block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0) + + block_max_quant_broadcast = tl.broadcast_to( + block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE) + ) + + abs_scaled = x_abs / block_max_quant_broadcast + + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where( + abs_scaled < 3.5, + 3.0, + tl.where(abs_scaled <= 5.0, 4.0, 6.0), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * block_max_quant_broadcast + x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled) + + tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N)) + + tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1)) + + +def fp4_fake_quant_block( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, + tile_rows: int = 16, + tile_cols: int = 64, + num_warps: int | None = None, + num_stages: int | None = None, +) -> torch.Tensor: + """FP4 fake quantization implementation using block-pointer tiling. + + Args: + x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher. + global_amax (torch.Tensor): Global maximum value tensor for scaling. + block_size (int): Number of elements per FP4 block. + tile_rows (int, optional): Row tile size. Defaults to 16. + tile_cols (int, optional): Column tile size. Defaults to 64. Rounded up to + the nearest multiple of ``block_size`` internally. + num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``. + num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``. + + Returns: + torch.Tensor: Fake-quantized tensor matching the input shape and dtype. + """ + x_shape = x.shape + x_dtype = x.dtype + x = x.reshape(-1, x_shape[-1]).contiguous() + + M, N = x.shape + y = torch.empty_like(x) + + stride_xm, stride_xn = x.stride() + stride_ym, stride_yn = y.stride() + + tile_cols = max(tile_cols, block_size) + tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size + num_fp4_blocks = tile_cols_aligned // block_size + + global_scale = (global_amax.float() / (6.0 * 448.0)).to(x.device) + + grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) + + launch_kwargs = { + "BLOCK_SIZE": block_size, + "TILE_M": tile_rows, + "TILE_N": tile_cols_aligned, + "NUM_FP4_BLOCKS": num_fp4_blocks, + "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), + } + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages + with torch.cuda.device(x.device): + fp4_fake_quant_kernel[grid]( + x, + y, + M, + N, + global_scale, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + **launch_kwargs, + ) + + y = y.view(*x_shape) + return y diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index f62d2d991..2b2e43dcf 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -75,7 +75,14 @@ def forward_loop(model, run_backward=False): forward_loop(model, run_backward=True) -def save_restore_test(model_cls, device, quant_config, compress=False, version=None): +def save_restore_test( + model_cls, + device, + quant_config, + compress=False, + version=None, + test_cpu_restore: bool = False, +): # test restoring to an unquantized model model_quant = model_cls().to(device) model_ref = model_cls().to(device) @@ -89,11 +96,20 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N model_ref.load_state_dict(model_quant.state_dict()) assert torch.allclose(model_quant(calib_data[0]), model_ref(calib_data[0])) + # Verify that TensorQuantizer subclass types are preserved after restore + for name_q, mod_q in model_quant.named_modules(): + if name_q.endswith("quantizer"): + mod_r = dict(model_ref.named_modules())[name_q] + assert type(mod_q) is type(mod_r), ( + f"Quantizer class mismatch for '{name_q}': " + f"expected {type(mod_q).__name__}, got {type(mod_r).__name__}" + ) + if version is not None and Version(version) < Version("0.29"): # Rest of the tests are not needed for version < 0.29 return - if not compress: + if test_cpu_restore: # gpu: test restoring to a model on cpu. If the quantizer states are not initialized correctly, # the buffers will be created on cuda and this test will fail model_ref = model_cls().to("cpu") diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py new file mode 100644 index 000000000..b1b3691a7 --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NVFP4StaticQuantizer and NVFP4MSECalibrator.""" + +import pytest +import torch + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.tensor_quant import ( + scaled_e4m3_impl, + static_blockwise_fp4_fake_quant, +) + + +@pytest.mark.parametrize("device", ["cuda"]) +class TestNVFP4StaticQuantizer: + def test_from_tensor_quantizer(self, device): + """Test creating NVFP4StaticQuantizer from TensorQuantizer.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + tq = TensorQuantizer(quant_attribute_cfg=cfg).to(device) + tq.amax = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) + + nvfp4_quantizer = NVFP4StaticQuantizer.from_tensor_quantizer(tq) + + assert nvfp4_quantizer.global_amax is None + assert nvfp4_quantizer._num_bits == (2, 1) + assert torch.allclose(nvfp4_quantizer._amax, tq._amax) + + def test_global_amax_property(self, device): + """Test global_amax property getter/setter.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + + assert quantizer.global_amax is None + + quantizer.global_amax = torch.tensor(5.0, device=device) + assert quantizer.global_amax is not None + assert torch.isclose(quantizer.global_amax, torch.tensor(5.0, device=device)) + + quantizer.global_amax = 10.0 + assert torch.isclose(quantizer.global_amax, torch.tensor(10.0, device=device)) + + quantizer.global_amax = None + assert quantizer.global_amax is None + + def test_fake_quantize_with_both_amaxs(self, device): + """Test _fake_quantize uses both _amax and _global_amax.""" + num_blocks = 4 + block_size = 16 + + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: block_size, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + + x = torch.randn(num_blocks, block_size, device=device) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + quantizer.amax = per_block_amax + quantizer.global_amax = global_amax + + output = quantizer._fake_quantize(x) + + expected = static_blockwise_fp4_fake_quant( + x, + per_block_amax, + global_amax, + ) + + assert torch.allclose(output, expected) + + +@pytest.mark.parametrize("device", ["cuda"]) +class TestNVFP4MSECalibrator: + def test_basic_initialization(self, device): + """Test NVFP4MSECalibrator initialization.""" + num_blocks = 4 + amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(10.0, device=device) + cal = NVFP4MSECalibrator(amax=amax, global_amax=global_amax) + + assert cal._losses_sum is None + assert cal._amax is None + + def test_fp8_candidates_generation(self, device): + """Test that 126 valid FP8 candidates are generated.""" + num_blocks = 4 + amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(10.0, device=device) + cal = NVFP4MSECalibrator(amax=amax, global_amax=global_amax) + + candidates = cal._generate_candidates(device) + + assert candidates.shape[0] == 126 + assert torch.all(torch.isfinite(candidates)) + assert torch.all(candidates > 0) + + def test_collect_and_compute_amax(self, device): + """Test collect and compute_amax workflow.""" + num_blocks = 8 + block_size = 16 + per_block_amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(10.0, device=device) + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + global_amax=global_amax, + quant_func=quant_func, + ) + + x = torch.randn(num_blocks, block_size, device=device) + cal.collect(x) + + assert cal._losses_sum is not None + assert len(cal._losses_sum) == 126 + + amax = cal.compute_amax() + + assert amax is not None + assert amax.shape[0] == num_blocks + assert torch.all(torch.isfinite(amax)) + assert torch.all(amax > 0) + + def test_multiple_collections(self, device): + """Test that multiple collections accumulate correctly.""" + num_blocks = 4 + block_size = 16 + per_block_amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(5.0, device=device) + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + global_amax=global_amax, + quant_func=quant_func, + ) + + x1 = torch.randn(num_blocks, block_size, device=device) + x2 = torch.randn(num_blocks, block_size, device=device) + + cal.collect(x1) + losses_after_first = [loss.clone() for loss in cal._losses_sum] + + cal.collect(x2) + losses_after_second = cal._losses_sum + + for loss1, loss2 in zip(losses_after_first, losses_after_second): + assert torch.all(loss2 >= loss1 - 1e-6) + + def test_per_block_independent_optimization(self, device): + """Test that each block is optimized independently. + + Uses constant values per block to ensure deterministic behavior. + """ + num_blocks = 4 + block_size = 16 + + # Create blocks with constant values (all elements in a block are the same) + # This ensures deterministic behavior for the test + x = torch.zeros(num_blocks, block_size, device=device) + x[0, :] = 0.5 + x[1, :] = 2.0 + x[2, :] = 5.0 + x[3, :] = 10.0 + + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, # reduce_axis = -1 + global_amax=global_amax, + quant_func=quant_func, + ) + + cal.collect(x) + amax = cal.compute_amax() + + # With constant values per block, the optimal amax should scale with the block values + assert amax[1] > amax[0] + assert amax[2] > amax[1] + assert amax[3] > amax[2] + + def test_fp8_sweep_generates_quantized_scales(self, device): + """Test that the fp8 sweep produces scales that are already FP8-quantized.""" + num_blocks = 8 + block_size = 16 + + x = torch.randn(num_blocks, block_size, device=device) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + global_amax=global_amax, + quant_func=quant_func, + ) + + cal.collect(x) + amax = cal.compute_amax() + + # The calibrator sweeps over FP8 candidates, so the resulting scales + # should already be representable in FP8 (i.e., quantize-dequantize is a no-op). + scale = amax.float() / 6.0 + scale_fp8_quant_amax = global_amax.float() / 6.0 + scale_qdq = scaled_e4m3_impl(scale, scale_fp8_quant_amax) + assert torch.allclose(scale_qdq, scale) diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 3d1de84d3..b5aca034a 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -133,7 +133,10 @@ def test_quantize(model_cls, config): (SimpleLinear, mtq.INT8_SMOOTHQUANT_CFG), (SimpleLinear, mtq.W4A8_AWQ_BETA_CFG), (SimpleConvLinear, mtq.INT8_DEFAULT_CFG), + (SimpleLinear, NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG), + (SimpleLinear, NVFP4_WEIGHT_ACT_MSE_CFG), ], ) def test_save_restore(model_cls, quant_config): - save_restore_test(model_cls, "cuda", quant_config) + test_cpu_restore = quant_config == mtq.INT8_SMOOTHQUANT_CFG + save_restore_test(model_cls, "cuda", quant_config, test_cpu_restore=test_cpu_restore) diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index 509e1b4e7..e84b1a49a 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -172,7 +172,7 @@ def _test_fp4_kernel(test_in, test_out, skip_triton=False): inputs.abs().amax(), ) assert torch.allclose(quantized_outputs, expected_outputs) - if triton_kernel.IS_AVAILABLE and not skip_triton: + if hasattr(triton_kernel, "fp4_fake_quant_block") and not skip_triton: quantized_outputs_triton = triton_kernel.fp4_fake_quant_block( inputs, inputs.abs().amax() ) @@ -220,13 +220,13 @@ def _get_test_inputs_outputs(test_in, test_out, num_blocks=4): num_blocks, 1 ), torch.concat((test_out,) * (block_size // 8), dim=-1).repeat(num_blocks, 1) - def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): + def _test_static_fp4_kernel(test_in, test_out, amax_value=6.0): inputs, expected_outputs = _get_test_inputs_outputs(test_in, test_out) num_blocks = inputs.shape[0] - scales = torch.full((num_blocks,), scale_value, device=inputs.device) + amax = torch.full((num_blocks,), amax_value, device=inputs.device) quantized_outputs_triton = triton_kernel.static_blockwise_fp4_fake_quant( - inputs, scale=scales, skip_scale_quant=skip_scale_quant + inputs, amax=amax, quantize_block_scales=not skip_scale_quant ) # Only check exact values when skip_scale_quant=True @@ -257,49 +257,33 @@ def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign _test_static_fp4_kernel(test_in, test_out) - @pytest.mark.skipif(not triton_kernel.IS_AVAILABLE, reason="triton kernel is not available") + @pytest.mark.skipif( + not hasattr(triton_kernel, "fp4_fake_quant_block"), + reason="fp4_fake_quant_block requires compute >= 8.9", + ) @pytest.mark.parametrize( "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True ) @pytest.mark.parametrize("block_size", [16, 32, 64]) @pytest.mark.parametrize("num_blocks", [4, 8, 16]) - @pytest.mark.parametrize("use_explicit_amax", [False, True]) - @pytest.mark.parametrize("use_amax_param", [False, True]) - def test_static_vs_dynamic_fp4_kernels( - self, set_torch_dtype, block_size, num_blocks, use_explicit_amax, use_amax_param - ): + def test_static_vs_dynamic_fp4_kernels(self, set_torch_dtype, block_size, num_blocks): """Test that static kernel with computed scales matches dynamic kernel behavior. The dynamic kernel computes scales dynamically from block-wise max values with FP8 quantization. - This test verifies that the static kernel with pre-computed scales (matching dynamic kernel's logic) + This test verifies that the static kernel with pre-computed amax (matching dynamic kernel's logic) produces the same results as the dynamic kernel. - - Args: - use_amax_param: If True, use the amax parameter instead of scale parameter. """ torch.manual_seed(42) x = torch.randn(num_blocks, block_size, dtype=torch.float32).cuda() * 10 block_amax = x.abs().max(dim=1, keepdim=False)[0] global_amax = block_amax.max() - scales = block_amax / 6.0 - - if use_explicit_amax: - scale_fp8_quant_amax = global_amax / 6.0 - else: - scale_fp8_quant_amax = None - - if use_amax_param: - output_static = triton_kernel.static_blockwise_fp4_fake_quant( - x, - amax=block_amax, - scale_fp8_quant_amax=scale_fp8_quant_amax, - skip_scale_quant=False, - ) - else: - output_static = triton_kernel.static_blockwise_fp4_fake_quant( - x, scale=scales, scale_fp8_quant_amax=scale_fp8_quant_amax, skip_scale_quant=False - ) + output_static = triton_kernel.static_blockwise_fp4_fake_quant( + x, + amax=block_amax, + global_amax=global_amax, + quantize_block_scales=True, + ) output_dynamic = triton_kernel.fp4_fake_quant_block( x, global_amax=global_amax, @@ -308,11 +292,9 @@ def test_static_vs_dynamic_fp4_kernels( tile_cols=block_size, ) - amax_mode = "explicit" if use_explicit_amax else "automatic" - param_mode = "amax" if use_amax_param else "scale" assert torch.allclose(output_static, output_dynamic, rtol=1e-3, atol=1e-5), ( f"Static and dynamic kernels produced different outputs " - f"(scale_fp8_quant_amax={amax_mode}, param={param_mode}).\n" + f"(param=amax).\n" f"Max abs diff: {(output_static - output_dynamic).abs().max()}\n" f"Mean abs diff: {(output_static - output_dynamic).abs().mean()}\n" f"Max relative diff: {((output_static - output_dynamic).abs() / (output_dynamic.abs() + 1e-8)).max()}" diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index efccec4c4..5e5546512 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -526,70 +526,3 @@ def quant_func(x, amax): assert a_best.numel() == 2 assert torch.all(torch.isfinite(a_best)) assert torch.all(a_best > 0) - - def test_fp8_scale_sweep_with_fixed_values_and_reset(self): - """Test FP8 scale sweep with fixed hand-written values and reset functionality.""" - x = torch.full((100,), 2.0, dtype=torch.float32) - x[0] = 20.0 - - initial_amax = torch.tensor(20.0) - - quant_cfg = QuantizerAttributeConfig(num_bits=(4, 3), axis=None, unsigned=False) - tq = TensorQuantizer(quant_attribute_cfg=quant_cfg, amax=initial_amax) - - def quant_func(x, amax): - original_amax = tq._amax.clone() if hasattr(tq, "_amax") else None - was_quant_enabled = tq._if_quant - was_calib_enabled = tq._if_calib - - tq._amax = amax - tq._if_quant = True - tq._if_calib = False - - with enable_fake_quant(tq): - xq = tq(x) - - if original_amax is not None: - tq._amax = original_amax - tq._if_quant = was_quant_enabled - tq._if_calib = was_calib_enabled - return xq - - cal = calib.MseCalibrator( - amax=initial_amax, - quant_func=quant_func, - fp8_scale_sweep=True, - ) - - assert cal._num_steps == 126 - - cal.collect(x) - - a_best = cal.compute_amax() - - assert torch.isfinite(a_best), "Optimal amax should be finite" - assert a_best > 0, "Optimal amax should be positive" - assert a_best <= initial_amax, "Optimal amax should not exceed initial amax" - - # FP8 scale sweep uses global_amax * fp8_multiplier where fp8_multiplier - # ranges from ~4.36e-06 to 1.0. For mostly 2.0 values with one 20.0 outlier, - # the optimal amax should be somewhere between these extremes - assert a_best >= initial_amax * 1e-6, "Optimal amax should not be unreasonably small" - - a_best_value = a_best.item() - - cal.reset() - - a_after_reset = cal.compute_amax() - assert a_after_reset is None, "After reset, compute_amax should return None" - - assert cal._num_steps == 126, "After reset, num_steps should still be 126" - - cal.collect(x) - a_best_after_reset = cal.compute_amax() - - assert torch.isfinite(a_best_after_reset), "Should be able to compute amax after reset" - assert a_best_after_reset > 0, "Amax after reset should be positive" - assert abs(a_best_after_reset.item() - a_best_value) < 1e-6, ( - "Amax after reset should match original value with same data" - )