Skip to content

Add NVFP4 per-token quantization recipe#3045

Draft
cael-ling wants to merge 6 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe
Draft

Add NVFP4 per-token quantization recipe#3045
cael-ling wants to merge 6 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe

Conversation

@cael-ling
Copy link
Copy Markdown
Contributor

@cael-ling cael-ling commented May 26, 2026

Description

This PR adds an NVFP4 per-token quantization fast path for bf16 inputs on Blackwell (SM100+) for Model Pre-training. Per-token uses per-row / per-column outer amax instead of the per-tensor scalar amax, which factors cleanly out of the GEMM K-summation and lets the inner GEMM stay on production cuBLASLT NVFP4 plus a thin trailing post-scale.

Status: draft. The cast kernel and GEMM composite, byte-equal-verified against a Python reference, and benched against the per-tensor (RHT + SR) recipe are still in progress. Partial experimental results are shown as follows.

Tests and benches

This PR ships four new pytest / benchmark files under tests/pytorch/nvfp4/. All four require bf16 input and M % 128 == 0 / K % 128 == 0; GEMM tests are gated by SM100 (Blackwell).

test_nvfp4_per_token.py — single-tensor correctness

test_nvfp4_per_token_group.py — group-tensor correctness

pytest -v tests/pytorch/nvfp4/test_nvfp4_per_token.py
pytest -v tests/pytorch/nvfp4/test_nvfp4_per_token_group.py

bench_nvfp4_per_token.py — single-tensor amax + quant bench

Wall-time benchmark of the single-tensor amax + quant composite (tex.nvfp4_per_token_quantize, this PR) against the per-tensor RHT+SR production baseline (NVFP4Quantizer(rht=True, sr=True) via tex.quantize). Both sides use rowwise + columnwise. Single output table:

python tests/pytorch/nvfp4/bench_nvfp4_per_token.py
python tests/pytorch/nvfp4/bench_nvfp4_per_token.py --rht

Each row reports eager wall-time plus CUDA Graphs replay (kernel-only floor). ratio < 1.0 ⇒ per-token is faster than the per-tensor baseline.

M K per-token (ms) per-tensor (ms) ratio per-token (Graph) (ms) per-tensor (Graph) (ms) ratio (Graph)
1024 2048 0.0257 0.0465 0.55x 0.0178 0.0363 0.49x
1024 4096 0.0270 0.0491 0.55x 0.0199 0.0371 0.54x
1024 8192 0.0270 0.0464 0.58x 0.0236 0.0390 0.61x
2048 2048 0.0270 0.0477 0.57x 0.0199 0.0352 0.57x
2048 4096 0.0291 0.0475 0.61x 0.0219 0.0371 0.59x
2048 8192 0.0362 0.0452 0.80x 0.0299 0.0391 0.77x
4096 2048 0.0283 0.0454 0.62x 0.0238 0.0389 0.61x
4096 4096 0.0363 0.0454 0.80x 0.0285 0.0412 0.69x
4096 8192 0.0526 0.0638 0.82x 0.0462 0.0659 0.70x
8192 2048 0.0362 0.0451 0.80x 0.0301 0.0408 0.74x
8192 4096 0.0526 0.0637 0.83x 0.0465 0.0651 0.71x
8192 8192 0.0875 0.1181 0.74x 0.0813 0.1212 0.67x
16384 2048 0.0526 0.0649 0.81x 0.0465 0.0658 0.71x
16384 4096 0.0874 0.1182 0.74x 0.0813 0.1199 0.68x
16384 8192 0.1471 0.1960 0.75x 0.1418 0.1981 0.72x
32768 2048 0.0875 0.1180 0.74x 0.0793 0.1211 0.65x
32768 4096 0.1488 0.1970 0.76x 0.1408 0.1980 0.71x
32768 8192 0.2678 0.3609 0.74x 0.2615 0.3635 0.72x
M K per-token (ms) per-token (+rht) (ms) per-tensor (ms) ratio per-token (Graph) (ms) per-token (+rht) (Graph) (ms) per-tensor (Graph) (ms) ratio (Graph)
1024 2048 0.0258 0.0296 0.0495 0.60x 0.0178 0.0219 0.0362 0.60x
1024 4096 0.0251 0.0270 0.0475 0.57x 0.0198 0.0219 0.0369 0.59x
1024 8192 0.0270 0.0301 0.0485 0.62x 0.0237 0.0260 0.0371 0.70x
2048 2048 0.0294 0.0291 0.0485 0.60x 0.0182 0.0217 0.0359 0.60x
2048 4096 0.0294 0.0321 0.0454 0.71x 0.0233 0.0260 0.0371 0.70x
2048 8192 0.0357 0.0413 0.0454 0.91x 0.0284 0.0342 0.0392 0.87x
4096 2048 0.0285 0.0321 0.0484 0.66x 0.0236 0.0260 0.0372 0.70x
4096 4096 0.0359 0.0413 0.0475 0.87x 0.0301 0.0342 0.0392 0.87x
4096 8192 0.0526 0.0637 0.0635 1.00x 0.0465 0.0567 0.0664 0.85x
8192 2048 0.0354 0.0414 0.0466 0.89x 0.0301 0.0342 0.0393 0.87x
8192 4096 0.0527 0.0647 0.0632 1.02x 0.0465 0.0567 0.0672 0.84x
8192 8192 0.0874 0.1050 0.1164 0.90x 0.0813 0.0977 0.1212 0.81x
16384 2048 0.0526 0.0629 0.0636 0.99x 0.0450 0.0577 0.0649 0.89x
16384 4096 0.0874 0.1039 0.1179 0.88x 0.0813 0.0977 0.1212 0.81x
16384 8192 0.1470 0.1807 0.1960 0.92x 0.1407 0.1735 0.1980 0.88x
32768 2048 0.0876 0.1041 0.1163 0.90x 0.0803 0.0977 0.1212 0.81x
32768 4096 0.1486 0.1808 0.1970 0.92x 0.1419 0.1734 0.1985 0.87x
32768 8192 0.2678 0.3296 0.3599 0.92x 0.2615 0.3168 0.3646 0.87x

bench_nvfp4_per_token_group.py — grouped amax + quant bench

Wall-time benchmark of the grouped amax + quant composite (nvfp4_per_token_group_quantize Python wrapper backed by the new C++ bulk entry, this PR) against the per-tensor RHT+SR grouped production baseline (tex.split_quantize(...) with NVFP4Quantizer(rht=True, sr=True) per split). Layout identical to the single-tensor bench:

python tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py
python tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py --rht

Default sweep is 6 × 3 = 18 cases at fixed N = 8 equal splits (MoE-typical): sum_M ∈ {1024, 2048, 4096, 8192, 16384, 32768} (so per-split M_i ∈ {128 … 4096}) × K ∈ {2048, 4096, 8192}. CUDA Graphs replay reported on every row.

sum_M K per-token (ms) per-tensor (ms) ratio per-token (Graph) (ms) per-tensor (Graph) (ms) ratio (Graph)
1024 2048 0.1271 0.1801 0.71x 0.0177 0.0433 0.41x
1024 4096 0.1358 0.1762 0.77x 0.0198 0.0433 0.46x
1024 8192 0.1318 0.1752 0.75x 0.0239 0.0430 0.56x
2048 2048 0.1317 0.1727 0.76x 0.0178 0.0431 0.41x
2048 4096 0.1313 0.1745 0.75x 0.0238 0.0432 0.55x
2048 8192 0.1299 0.1737 0.75x 0.0301 0.0444 0.68x
4096 2048 0.1297 0.1739 0.75x 0.0219 0.0433 0.51x
4096 4096 0.1289 0.1763 0.73x 0.0301 0.0456 0.66x
4096 8192 0.1243 0.1688 0.74x 0.0485 0.0724 0.67x
8192 2048 0.1296 0.1652 0.78x 0.0301 0.0451 0.67x
8192 4096 0.1268 0.1698 0.75x 0.0503 0.0724 0.69x
8192 8192 0.1309 0.1702 0.77x 0.0854 0.1254 0.68x
16384 2048 0.1330 0.1717 0.77x 0.0504 0.0720 0.70x
16384 4096 0.1291 0.1761 0.73x 0.0874 0.1247 0.70x
16384 8192 0.1593 0.2052 0.78x 0.1509 0.2021 0.75x
32768 2048 0.1319 0.1772 0.74x 0.0866 0.1262 0.69x
32768 4096 0.1591 0.2033 0.78x 0.1510 0.2041 0.74x
32768 8192 0.2900 0.3671 0.79x 0.2820 0.3643 0.77x
sum_M K per-token (ms) per-token (+rht) (ms) per-tensor (ms) ratio per-token (Graph) (ms) per-token (+rht) (Graph) (ms) per-tensor (Graph) (ms) ratio (Graph)
1024 2048 0.1178 0.1205 0.1754 0.69x 0.0178 0.0215 0.0434 0.50x
1024 4096 0.1212 0.1197 0.1607 0.74x 0.0195 0.0219 0.0433 0.51x
1024 8192 0.1184 0.1183 0.1642 0.72x 0.0219 0.0260 0.0431 0.60x
2048 2048 0.1194 0.1170 0.1565 0.75x 0.0183 0.0216 0.0433 0.50x
2048 4096 0.1183 0.1180 0.1568 0.75x 0.0219 0.0260 0.0454 0.57x
2048 8192 0.1196 0.1189 0.1581 0.75x 0.0310 0.0342 0.0454 0.75x
4096 2048 0.1190 0.1169 0.1538 0.76x 0.0219 0.0260 0.0431 0.60x
4096 4096 0.1185 0.1191 0.1536 0.78x 0.0301 0.0342 0.0454 0.75x
4096 8192 0.1185 0.1198 0.1579 0.76x 0.0486 0.0588 0.0730 0.81x
8192 2048 0.1206 0.1589 0.1606 0.99x 0.0302 0.0342 0.0543 0.63x
8192 4096 0.1476 0.1387 0.2148 0.65x 0.0485 0.0582 0.0729 0.80x
8192 8192 0.1387 0.1314 0.1902 0.69x 0.0873 0.1018 0.1264 0.81x
16384 2048 0.1262 0.1242 0.1720 0.72x 0.0485 0.0588 0.0726 0.81x
16384 4096 0.1241 0.1229 0.1717 0.72x 0.0854 0.1018 0.1249 0.81x
16384 8192 0.1589 0.1901 0.2031 0.94x 0.1510 0.1837 0.2035 0.90x
32768 2048 0.1212 0.1208 0.1577 0.77x 0.0875 0.1018 0.1251 0.81x
32768 4096 0.1589 0.1900 0.2029 0.94x 0.1509 0.1837 0.2035 0.90x
32768 8192 0.2903 0.3511 0.3665 0.96x 0.2821 0.3435 0.3640 0.94x

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 26, 2026
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.

  * common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
    grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
  * common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
    kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
    to 2d quant of W).
  * pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
    grouped bulk binding and per-token GEMM entry; thin pybind layer.
  * pytorch/custom_recipes/{gemm_nvfp4_per_token,
    quantization_nvfp4_per_token_group}.py: Python wrappers.
  * tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
    cast tests + bf16-close GEMM tests.
  * tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
    over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
    Graphs columns, ratio against per-tensor RHT+SR baseline.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling force-pushed the feature/nvfp4-per-token-recipe branch from 6f17fe4 to 928ab1c Compare May 27, 2026 13:09
pre-commit-ci Bot and others added 5 commits May 27, 2026 13:10
…uped)

Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax)
and K2 (encode) kernels for both single-tensor and grouped paths.
with_rht=False is byte-equal to the pre-RHT code path; when true,
applies a 16-pt RHT on the columnwise direction in both K1 and K2
(rowwise stays raw) with outer amax + inner SF self-consistent.

Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32
sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into
block_amax / block_scale (bit-exact).

Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and
byte-equality regressions. Benches gain a --rht flag (2-way default,
3-way under --rht).

Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K:
* single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT)
* grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT)

Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D).

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).

The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.

with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).

Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.

Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.

Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.

Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant