-
Notifications
You must be signed in to change notification settings - Fork 258
Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQantizer, NVFP4MSECalibrator #849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
📝 WalkthroughWalkthroughThis PR introduces NVFP4 (NVIDIA FP4) static quantization support with MSE calibration. It refactors MseCalibrator into a generalized candidate-based framework, adds NVFP4StaticQuantizer for two-level scaling quantization, updates Triton FP4 kernel implementations with dequantization and Hopper-optimized paths, and integrates these components into the model calibration pipeline. Changes
Sequence DiagramsequenceDiagram
participant Model
participant Calibrator as NVFP4MSECalibrator
participant Quantizer as NVFP4StaticQuantizer
participant TritonKernel as Triton FP4
Model->>Calibrator: collect(activations)
activate Calibrator
Calibrator->>Calibrator: _generate_candidates()
Note over Calibrator: Generate FP8-based candidates
loop For each candidate
Calibrator->>TritonKernel: static_blockwise_fp4_fake_quant(x, amax, global_amax)
activate TritonKernel
TritonKernel->>TritonKernel: Two-level scaling<br/>(per-block + global)
TritonKernel-->>Calibrator: quantized output
deactivate TritonKernel
Calibrator->>Calibrator: Accumulate loss
end
deactivate Calibrator
Calibrator->>Calibrator: compute_amax()
Note over Calibrator: Select best candidate<br/>based on minimal loss
Calibrator-->>Quantizer: Update amax & global_amax
Model->>Quantizer: forward(x)
activate Quantizer
Quantizer->>TritonKernel: _fake_quantize(inputs)
activate TritonKernel
TritonKernel->>TritonKernel: Block-wise FP4 quantization<br/>scaled by global_amax
TritonKernel-->>Quantizer: Quantized tensor
deactivate TritonKernel
Quantizer-->>Model: Output
deactivate Quantizer
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/triton/fp4_kernel.py (1)
69-85:⚠️ Potential issue | 🔴 CriticalGuard the final tile to prevent out-of-bounds access in
fp4_dequantize_kernel.
packed_maskonly checkspacked_col_idx < (N // 2), which is always true sincepacked_col_idx = packed_offs % (N // 2). When TILE_SIZE doesn't evenly dividepacked_tensor.numel(), the final kernel instance can read/write past the end of the tensor. Add aTOTAL_PACKED_ELEMSparameter and change the mask topacked_offs < TOTAL_PACKED_ELEMS.🛠️ Proposed fix
def fp4_dequantize_kernel( packed_ptr, scale_ptr, global_scale_ptr, output_ptr, N, + TOTAL_PACKED_ELEMS, BLOCK_SIZE: tl.constexpr, TILE_SIZE: tl.constexpr, ): @@ - packed_mask = packed_col_idx < (N // 2) + packed_mask = packed_offs < TOTAL_PACKED_ELEMS @@ - fp4_dequantize_kernel[grid]( + fp4_dequantize_kernel[grid]( packed_tensor, scale_tensor, global_scale, output, N, + TOTAL_PACKED_ELEMS=packed_tensor.numel(), BLOCK_SIZE=block_size, TILE_SIZE=tile_size, )
🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 161-172: Update NVFP4MSECalibrator.__init__ to accept a
block_size:int parameter and assign it to self._block_size so tests passing
block_size=16 succeed; specifically, add block_size to the __init__ signature of
NVFP4MSECalibrator and set self._block_size = block_size (leave the existing
call to super() and other parameters unchanged).
In `@modelopt/torch/quantization/nn/modules/tensor_quantizer.py`:
- Around line 1278-1288: The call in NVFP4StaticQuantizer._fake_quantize passes
a removed parameter (_pass_through_bwd) to static_blockwise_fp4_fake_quant
causing a TypeError; update the invocation in _fake_quantize to call
static_blockwise_fp4_fake_quant with only the supported arguments (inputs,
self.amax, self.global_amax, True, inputs.dtype) and remove the trailing
_pass_through_bwd argument, ensuring argument order or keywords match the
current static_blockwise_fp4_fake_quant signature.
In `@modelopt/torch/quantization/triton/fp4_kernel_hopper.py`:
- Around line 141-144: The docstring for the FP4 kernel (the
block_size/tile_rows/tile_cols parameter description) is out of sync with the
function signature: update the documented defaults to match the signature's
tile_rows=16 and tile_cols=64 (instead of 64/128) so the docstring for the
function/class that documents block_size, tile_rows, and tile_cols reflects the
actual defaults used in the code.
- Around line 151-191: The kernel launch uses global_amax on the wrong device
and doesn't ensure the correct CUDA context; move and validate global_amax to
x.device before using it (e.g., ensure global_amax is a scalar tensor, call
global_amax = global_amax.to(x.device) and then .float()), compute global_scale
from that device-local tensor, and wrap the fp4_fake_quant_kernel[grid](...)
launch in the same CUDA context as x (use torch.cuda.device(x.device) as in
static_blockwise_fp4_fake_quant) so the kernel runs on the correct GPU; also
validate that global_amax is a scalar and raise/convert if not.
In `@modelopt/torch/quantization/triton/fp4_kernel.py`:
- Around line 241-275: The function static_blockwise_fp4_fake_quant must be
back-compatible with older callers that pass scale, skip_scale_quant, or
scale_fp8_quant_amax: update the signature to accept those parameters (e.g.,
scale=None, skip_scale_quant=None, scale_fp8_quant_amax=None or via **kwargs)
and map them to the new behavior—if scale is provided, use it instead of
computing scale = amax/6.0; if skip_scale_quant is True, set
quantize_block_scales=False; if scale_fp8_quant_amax is provided, use it as
global_amax (convert to float and compute scale_fp8_quant_amax =
scale_fp8_quant_amax/6.0) before calling scaled_e4m3_impl/reduce_amax; preserve
existing type conversions (amax.float(), global_amax.float()) and emit a
deprecation warning when any legacy arg is used.
| x_shape = x.shape | ||
| x_dtype = x.dtype | ||
| x = x.reshape(-1, x_shape[-1]).contiguous() | ||
|
|
||
| M, N = x.shape | ||
| y = torch.empty_like(x) | ||
|
|
||
| stride_xm, stride_xn = x.stride() | ||
| stride_ym, stride_yn = y.stride() | ||
|
|
||
| tile_cols = max(tile_cols, block_size) | ||
| tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size | ||
| num_fp4_blocks = tile_cols_aligned // block_size | ||
|
|
||
| global_scale = global_amax.float() / (6.0 * 448.0) | ||
|
|
||
| grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) | ||
|
|
||
| launch_kwargs = { | ||
| "BLOCK_SIZE": block_size, | ||
| "TILE_M": tile_rows, | ||
| "TILE_N": tile_cols_aligned, | ||
| "NUM_FP4_BLOCKS": num_fp4_blocks, | ||
| "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), | ||
| } | ||
| if num_warps is not None: | ||
| launch_kwargs["num_warps"] = num_warps | ||
| if num_stages is not None: | ||
| launch_kwargs["num_stages"] = num_stages | ||
| fp4_fake_quant_kernel[grid]( | ||
| x, | ||
| y, | ||
| M, | ||
| N, | ||
| global_scale, | ||
| stride_xm, | ||
| stride_xn, | ||
| stride_ym, | ||
| stride_yn, | ||
| **launch_kwargs, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
find . -name "fp4_kernel_hopper.py" -type fRepository: NVIDIA/Model-Optimizer
Length of output: 123
🏁 Script executed:
cat -n ./modelopt/torch/quantization/triton/fp4_kernel_hopper.py | sed -n '140,200p'Repository: NVIDIA/Model-Optimizer
Length of output: 2255
🏁 Script executed:
cat -n ./modelopt/torch/quantization/triton/fp4_kernel_hopper.py | sed -n '100,155p'Repository: NVIDIA/Model-Optimizer
Length of output: 2398
🏁 Script executed:
cat -n ./modelopt/torch/quantization/triton/fp4_kernel_hopper.py | sed -n '1,70p'Repository: NVIDIA/Model-Optimizer
Length of output: 2542
🏁 Script executed:
cat -n ./modelopt/torch/quantization/triton/fp4_kernel_hopper.py | sed -n '70,125p'Repository: NVIDIA/Model-Optimizer
Length of output: 2254
🏁 Script executed:
# Check if there are any existing validations or device handling in the full function
cat -n ./modelopt/torch/quantization/triton/fp4_kernel_hopper.py | sed -n '127,200p'Repository: NVIDIA/Model-Optimizer
Length of output: 2741
🏁 Script executed:
# Check if there are similar functions in the codebase that might show the correct pattern
find . -name "*.py" -path "*/quantization/triton/*" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 223
🏁 Script executed:
# Check the similar fp4_kernel.py to see how it handles device management
rg -A 50 "def fp4_fake_quant" ./modelopt/torch/quantization/triton/fp4_kernel.py | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cat -n ./modelopt/torch/quantization/triton/fp4_kernel.py | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 6236
🏁 Script executed:
cat -n ./modelopt/torch/quantization/triton/fp4_kernel.py | sed -n '141,250p'Repository: NVIDIA/Model-Optimizer
Length of output: 4000
🏁 Script executed:
cat -n ./modelopt/torch/quantization/triton/fp4_kernel.py | sed -n '241,330p'Repository: NVIDIA/Model-Optimizer
Length of output: 2306
🏁 Script executed:
# Let me verify if there are any calls to fp4_fake_quant_block to see if global_amax is always guaranteed to be on device
rg -B 5 -A 2 "fp4_fake_quant_block" ./modelopt/torch/quantization/triton/ --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 1335
🏁 Script executed:
# Look for calls/usages of fp4_fake_quant_block in the entire codebase
rg "fp4_fake_quant_block" --type py -B 3 -A 3 | grep -v "def fp4_fake_quant_block" | grep -v "__all__"Repository: NVIDIA/Model-Optimizer
Length of output: 4328
🏁 Script executed:
cat -n ./modelopt/torch/quantization/tensor_quant.py | grep -B 20 "fp4_fake_quant_block" | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 1287
🏁 Script executed:
cat -n ./tests/gpu/torch/quantization/test_tensor_quant_cuda.py | grep -B 5 -A 10 "fp4_fake_quant_block"Repository: NVIDIA/Model-Optimizer
Length of output: 2857
Ensure global_amax is on x.device and launch kernel in correct CUDA context.
Lines 165-191: global_amax.float() does not move the tensor to x.device. If global_amax is on CPU or a different CUDA device, global_scale will be on the wrong device and the kernel will fail or use incorrect data. Additionally, the kernel launch lacks a torch.cuda.device context, unlike the similar static_blockwise_fp4_fake_quant function in fp4_kernel.py. Add a CUDA device check, move global_amax to x.device, validate it's scalar, and wrap the kernel launch in torch.cuda.device.
🔧 Proposed fix
x_shape = x.shape
x_dtype = x.dtype
+ if not x.is_cuda:
+ raise RuntimeError("fp4_fake_quant_block requires a CUDA tensor.")
+ if global_amax.numel() != 1:
+ raise ValueError("global_amax must be a scalar tensor.")
+ global_amax = global_amax.to(device=x.device, dtype=torch.float32)
x = x.reshape(-1, x_shape[-1]).contiguous()
@@
- global_scale = global_amax.float() / (6.0 * 448.0)
+ global_scale = global_amax / (6.0 * 448.0)
@@
- fp4_fake_quant_kernel[grid](
- x,
- y,
- M,
- N,
- global_scale,
- stride_xm,
- stride_xn,
- stride_ym,
- stride_yn,
- **launch_kwargs,
- )
+ with torch.cuda.device(x.device):
+ fp4_fake_quant_kernel[grid](
+ x,
+ y,
+ M,
+ N,
+ global_scale,
+ stride_xm,
+ stride_xn,
+ stride_ym,
+ stride_yn,
+ **launch_kwargs,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| x_shape = x.shape | |
| x_dtype = x.dtype | |
| x = x.reshape(-1, x_shape[-1]).contiguous() | |
| M, N = x.shape | |
| y = torch.empty_like(x) | |
| stride_xm, stride_xn = x.stride() | |
| stride_ym, stride_yn = y.stride() | |
| tile_cols = max(tile_cols, block_size) | |
| tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size | |
| num_fp4_blocks = tile_cols_aligned // block_size | |
| global_scale = global_amax.float() / (6.0 * 448.0) | |
| grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) | |
| launch_kwargs = { | |
| "BLOCK_SIZE": block_size, | |
| "TILE_M": tile_rows, | |
| "TILE_N": tile_cols_aligned, | |
| "NUM_FP4_BLOCKS": num_fp4_blocks, | |
| "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), | |
| } | |
| if num_warps is not None: | |
| launch_kwargs["num_warps"] = num_warps | |
| if num_stages is not None: | |
| launch_kwargs["num_stages"] = num_stages | |
| fp4_fake_quant_kernel[grid]( | |
| x, | |
| y, | |
| M, | |
| N, | |
| global_scale, | |
| stride_xm, | |
| stride_xn, | |
| stride_ym, | |
| stride_yn, | |
| **launch_kwargs, | |
| ) | |
| x_shape = x.shape | |
| x_dtype = x.dtype | |
| if not x.is_cuda: | |
| raise RuntimeError("fp4_fake_quant_block requires a CUDA tensor.") | |
| if global_amax.numel() != 1: | |
| raise ValueError("global_amax must be a scalar tensor.") | |
| global_amax = global_amax.to(device=x.device, dtype=torch.float32) | |
| x = x.reshape(-1, x_shape[-1]).contiguous() | |
| M, N = x.shape | |
| y = torch.empty_like(x) | |
| stride_xm, stride_xn = x.stride() | |
| stride_ym, stride_yn = y.stride() | |
| tile_cols = max(tile_cols, block_size) | |
| tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size | |
| num_fp4_blocks = tile_cols_aligned // block_size | |
| global_scale = global_amax / (6.0 * 448.0) | |
| grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) | |
| launch_kwargs = { | |
| "BLOCK_SIZE": block_size, | |
| "TILE_M": tile_rows, | |
| "TILE_N": tile_cols_aligned, | |
| "NUM_FP4_BLOCKS": num_fp4_blocks, | |
| "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), | |
| } | |
| if num_warps is not None: | |
| launch_kwargs["num_warps"] = num_warps | |
| if num_stages is not None: | |
| launch_kwargs["num_stages"] = num_stages | |
| with torch.cuda.device(x.device): | |
| fp4_fake_quant_kernel[grid]( | |
| x, | |
| y, | |
| M, | |
| N, | |
| global_scale, | |
| stride_xm, | |
| stride_xn, | |
| stride_ym, | |
| stride_yn, | |
| **launch_kwargs, | |
| ) |
🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/triton/fp4_kernel_hopper.py` around lines 151 -
191, The kernel launch uses global_amax on the wrong device and doesn't ensure
the correct CUDA context; move and validate global_amax to x.device before using
it (e.g., ensure global_amax is a scalar tensor, call global_amax =
global_amax.to(x.device) and then .float()), compute global_scale from that
device-local tensor, and wrap the fp4_fake_quant_kernel[grid](...) launch in the
same CUDA context as x (use torch.cuda.device(x.device) as in
static_blockwise_fp4_fake_quant) so the kernel runs on the correct GPU; also
validate that global_amax is a scalar and raise/convert if not.
1a09c12 to
fcf071f
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #849 +/- ##
==========================================
+ Coverage 73.38% 73.45% +0.07%
==========================================
Files 193 197 +4
Lines 19893 20651 +758
==========================================
+ Hits 14598 15169 +571
- Misses 5295 5482 +187 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Fridah-nv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVFP4MSECalibrator and NVFP4StaticQuantizer refactor LGTM.
…antizer, NVFP4MSECalibrator Signed-off-by: realAsma <akuriparambi@nvidia.com> fp4 static kernel fix, test fixes, minor clean ups Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com>
4ca3180 to
d0dfae0
Compare
sugunav14
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the calibrator and the static quantizer logic! Just had a general question about scale search.
mxinO
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left one comment.
Signed-off-by: realAsma <akuriparambi@nvidia.com>
What does this PR do?
Type of change: ?
Overview: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests