From 928ab1cec264e01a0e3e7f3e5eca6c17ed96a433 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Tue, 26 May 2026 07:56:30 -0700 Subject: [PATCH 1/6] Add NVFP4 per-token GEMM, fused grouped amax, cast, tests and benches Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast pair and ships pytest correctness + sweep benches against the per-tensor RHT+SR production baseline. * common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused grouped kernel, reusing the single-tensor 4-stage TMA pipeline. * common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due to 2d quant of W). * pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++ grouped bulk binding and per-token GEMM entry; thin pybind layer. * pytorch/custom_recipes/{gemm_nvfp4_per_token, quantization_nvfp4_per_token_group}.py: Python wrappers. * tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal cast tests + bf16-close GEMM tests. * tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA Graphs columns, ratio against per-tensor RHT+SR baseline. Co-authored-by: Zhongbo Zhu Signed-off-by: Cael Ling --- tests/pytorch/nvfp4/bench_nvfp4_per_token.py | 217 ++++ .../nvfp4/bench_nvfp4_per_token_group.py | 229 ++++ tests/pytorch/nvfp4/test_nvfp4_per_token.py | 375 ++++++ .../nvfp4/test_nvfp4_per_token_group.py | 344 ++++++ transformer_engine/common/CMakeLists.txt | 7 +- .../cast/nvfp4/quantize_nvfp4_per_token.cu | 1025 +++++++++++++++++ .../nvfp4/quantize_nvfp4_per_token_group.cu | 1017 ++++++++++++++++ .../common/gemm/nvfp4_per_token_post_scale.cu | 141 +++ .../transformer_engine/nvfp4_per_token.h | 124 ++ transformer_engine/pytorch/csrc/extensions.h | 58 + .../csrc/extensions/nvfp4_per_token.cpp | 793 +++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 61 + .../pytorch/csrc/extensions/recipe.cpp | 32 + .../pytorch/custom_recipes/__init__.py | 32 + .../custom_recipes/gemm_nvfp4_per_token.py | 206 ++++ .../quantization_nvfp4_per_token.py | 386 +++++++ .../quantization_nvfp4_per_token_group.py | 112 ++ 17 files changed, 5157 insertions(+), 2 deletions(-) create mode 100644 tests/pytorch/nvfp4/bench_nvfp4_per_token.py create mode 100644 tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_per_token.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_per_token_group.py create mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu create mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu create mode 100644 transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu create mode 100644 transformer_engine/common/include/transformer_engine/nvfp4_per_token.h create mode 100644 transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp create mode 100644 transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py create mode 100644 transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py create mode 100644 transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py new file mode 100644 index 0000000000..54aaaeccb7 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py @@ -0,0 +1,217 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Bench NVFP4 per-token K1+K2 quant vs per-tensor RHT+SR baseline. + +Quant-only (no GEMM). Both sides time the K1 (amax) + K2 (cast) composite on +activation A, rowwise+columnwise. Requires bf16 input, M % 128 == 0, K % 128 == 0. +""" + +from __future__ import annotations + +import argparse +import math +import statistics +import sys +from dataclasses import dataclass +from typing import Callable, List, Tuple + +import torch + +# Import transformer_engine first so libtransformer_engine.so is dlopen'd +# before transformer_engine_torch tries to resolve its typeinfo symbols. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + + +def cuda_time_ms(fn: Callable[[], None], *, warmup: int = 5, iters: int = 50) -> float: + """Median wall time of fn over iters invocations, in ms.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + starts[i].record() + fn() + ends[i].record() + torch.cuda.synchronize() + samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)] + return statistics.median(samples) + + +def cuda_graph_time_ms( + fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 +) -> float: + """Median g.replay() wall time of fn captured into a CUDA Graph (kernel-only floor). + + Returns nan if capture fails. + """ + try: + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for _ in range(warmup): + fn() + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + except Exception as e: + print(f" [graph capture skipped: {type(e).__name__}: {e}]", file=sys.stderr) + return float("nan") + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + starts[i].record() + g.replay() + ends[i].record() + torch.cuda.synchronize() + samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)] + return statistics.median(samples) + + +def _make_baseline_quantizer() -> NVFP4Quantizer: + """Per-tensor baseline quantizer: RHT + SR + random sign mask.""" + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + +def _has_sm100() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +@dataclass +class ShapeBench: + M: int + K: int + t_pt: float # per-token full K1+K2 (eager pybind, ms) + t_pten: float # per-tensor full K1+K2 (eager pybind, ms) + t_pt_g: float # per-token under CUDA Graphs replay (ms) + t_pten_g: float # per-tensor under CUDA Graphs replay (ms) + + +def _bench_shape(M: int, K: int, *, device: torch.device) -> ShapeBench: + """Time per-tensor vs per-token K1+K2 quant at one (M, K) shape.""" + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + # Per-tensor quantizer + A output tensor. + quantizer = _make_baseline_quantizer() + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) + + # Per-token A-side buffers: BLOCK_K=16 (1x16 e4m3 inner SF). + BLOCK_K = 16 + ra_a = torch.empty((M,), dtype=torch.float32, device=device) + ca_a = torch.empty((K,), dtype=torch.float32, device=device) + q_row_a = torch.empty((M, K // 2), dtype=torch.uint8, device=device) + s_dec_row_a = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) + q_col_a = torch.empty((K, M // 2), dtype=torch.uint8, device=device) + s_dec_col_a = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device) + + def _baseline_quant_fn(): + tex.quantize(a, quantizer, dst_a, None) + + def _pt_full_quant_fn(): + tex.nvfp4_per_token_quantize( + a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True, + ) + + t_pten = cuda_time_ms(_baseline_quant_fn) + t_pt = cuda_time_ms(_pt_full_quant_fn) + t_pten_g = cuda_graph_time_ms(_baseline_quant_fn) + t_pt_g = cuda_graph_time_ms(_pt_full_quant_fn) + + return ShapeBench(M=M, K=K, t_pt=t_pt, t_pten=t_pten, t_pt_g=t_pt_g, t_pten_g=t_pten_g) + + +# 6x3 sweep matching bench_nvfp4_per_token_group.py: M in {1024..32768}, K in {2048,4096,8192}. +_M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) +_K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) +_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple( + (m, k) for m in _M_VALUES for k in _K_VALUES +) + + +def _parse_shape(s: str) -> Tuple[int, int]: + parts = s.split("x") + if len(parts) != 2: + raise argparse.ArgumentTypeError(f"Shape must be MxK, got '{s}'") + return tuple(int(p) for p in parts) # type: ignore[return-value] + + +def _ratio(num: float, den: float) -> float: + if den <= 0 or math.isnan(num) or math.isnan(den): + return float("nan") + return num / den + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4." + ) + parser.add_argument( + "--shapes", type=_parse_shape, nargs="+", default=None, + help="Shapes to bench, in MxK form (e.g. 4096x4096). " + "Default: an internally-chosen production-shape sweep.", + ) + args = parser.parse_args() + + if not _has_sm100(): + print("SKIP: NVFP4 per-token quant requires SM100 (Blackwell).", file=sys.stderr) + return 1 + + device = torch.device("cuda") + shapes = list(args.shapes) if args.shapes else list(_DEFAULT_SHAPES) + + records: List[ShapeBench] = [_bench_shape(M, K, device=device) for (M, K) in shapes] + + header = ( + f"{'M':>7} {'K':>6}" + f" |" + f"{'per-token':>10} {'per-tensor':>11} {'ratio':>8}" + f" |" + f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>18} {'ratio(Graph)':>13}" + ) + print(header) + print("-" * len(header)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt, rec.t_pten) + ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g) + ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" + ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" + print( + f"{rec.M:>7} {rec.K:>6}" + f" |" + f"{rec.t_pt:>10.4f} {rec.t_pten:>11.4f} {ratio_s:>8}" + f" |" + f"{rec.t_pt_g:>17.4f} {rec.t_pten_g:>18.4f} {ratio_g_s:>13}" + ) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py new file mode 100644 index 0000000000..257cda32c8 --- /dev/null +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py @@ -0,0 +1,229 @@ +"""Bench: NVFP4 per-token grouped (K1+K2 fused) vs per-tensor+RHT baseline. + +18-row sweep at fixed N=8 splits: sum_M in {1024..32768} x K in {2048,4096,8192}. +Both eager and CUDA Graphs columns reported on every row (ratio < 1.0 wins). +Requires bf16, K % 128 == 0, every split % 128 == 0, num_splits <= 64. +""" + +from __future__ import annotations + +import math +import statistics +import sys +from typing import Callable, List, Tuple + +import torch + +# Import transformer_engine first so libtransformer_engine.so is dlopen'd +# before transformer_engine_torch tries to resolve its typeinfo symbols. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex # type: ignore # noqa: F401 + +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token_group import ( + nvfp4_per_token_group_quantize, +) + + +def _make_baseline_quantizer_list(num_splits: int) -> List[NVFP4Quantizer]: + """Per-tensor RHT+SR baseline: one quantizer instance shared across N splits.""" + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + return [q] * num_splits + + +def cuda_graph_time_ms( + fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 +) -> float: + """Median g.replay() time of fn captured into a CUDA Graph, in ms. + + Returns nan if capture fails (e.g. some C-API does an incompatible sync). + """ + try: + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for _ in range(warmup): + fn() + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + except Exception as e: + print(f" [graph capture skipped: {type(e).__name__}: {e}]", file=sys.stderr) + return float("nan") + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters): + starts[i].record() + g.replay() + ends[i].record() + torch.cuda.synchronize() + return statistics.median(starts[i].elapsed_time(ends[i]) for i in range(iters)) + + +def _time_grouped(x_concat, split_sections, rowwise, columnwise, n_iters=20, n_warmup=5): + """Per-token grouped via the BULK Python wrapper. Allocation in-loop.""" + for _ in range(n_warmup): + _ = nvfp4_per_token_group_quantize( + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + ) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + _ = nvfp4_per_token_group_quantize( + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + ) + stop.record() + torch.cuda.synchronize() + return start.elapsed_time(stop) / n_iters # ms + + +def _time_split_quantize(x_concat, split_sections, quantizer_list, n_iters=20, n_warmup=5): + """Per-tensor grouped baseline: tex.split_quantize, allocation in-binding.""" + for _ in range(n_warmup): + _ = tex.split_quantize(x_concat, split_sections, quantizer_list) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + _ = tex.split_quantize(x_concat, split_sections, quantizer_list) + stop.record() + torch.cuda.synchronize() + return start.elapsed_time(stop) / n_iters # ms + + +def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, + n_iters=20, n_warmup=5): + """Per-tensor grouped under CUDA Graphs replay.""" + def fn() -> None: + _ = tex.split_quantize(x_concat, split_sections, quantizer_list) + + return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) + + +def _time_grouped_graph(x_concat, split_sections, rowwise, columnwise, n_iters=20, n_warmup=5): + """Per-token grouped under CUDA Graphs replay.""" + def fn() -> None: + _ = nvfp4_per_token_group_quantize( + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + ) + + return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) + + +# N = 8 equal splits (MoE-typical), sum_M in {1024..32768}, K in {2048..8192}. +_NUM_SPLITS: int = 8 + +_SUM_M_VALUES: List[int] = [1024, 2048, 4096, 8192, 16384, 32768] +_K_VALUES: List[int] = [2048, 4096, 8192] + +_BENCH_CASES: List[Tuple[List[int], int]] = [] +for _sum_M in _SUM_M_VALUES: + _M_i = _sum_M // _NUM_SPLITS + for _K in _K_VALUES: + _BENCH_CASES.append(([_M_i] * _NUM_SPLITS, _K)) + + +def main() -> None: + if not torch.cuda.is_available(): + print("CUDA unavailable, skipping bench.") + return + cap = torch.cuda.get_device_capability() + if cap[0] < 10: + print(f"NVFP4 per-token requires SM100+ (got SM{cap[0]}.{cap[1]}); skipping.") + return + + device = torch.device("cuda") + print(f"# Device: {torch.cuda.get_device_name(0)} (cap {cap[0]}.{cap[1]})") + print(f"# Split structure: N={_NUM_SPLITS} equal splits, M_i = sum_M / {_NUM_SPLITS}") + print() + + # Per-tensor baseline quantizer is fixed to row+col, so both enabled. + rowwise = True + columnwise = True + + header = ( + f"{'sum_M':>6} {'K':>5}" + f" |" + f"{'per-token':>10} {'per-tensor':>10} {'ratio':>8}" + f" |" + f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>17} {'ratio(Graph)':>13}" + ) + print(header) + print("-" * len(header)) + + prev_sum_M = None + for split_sections, K in _BENCH_CASES: + sum_M = sum(split_sections) + num_splits = len(split_sections) + + # Blank line between sum_M groups for readability. + if prev_sum_M is not None and sum_M != prev_sum_M: + print() + prev_sum_M = sum_M + + x_concat = ( + torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0 + ).contiguous() + quantizer_list = _make_baseline_quantizer_list(num_splits) + + t_pt = _time_grouped(x_concat, split_sections, rowwise, columnwise) + t_pten = _time_split_quantize(x_concat, split_sections, quantizer_list) + ratio = t_pt / t_pten if t_pten > 0 else float("nan") + + t_pt_g = _time_grouped_graph( + x_concat, split_sections, rowwise, columnwise + ) + t_pten_g = _time_split_quantize_graph( + x_concat, split_sections, quantizer_list + ) + if math.isnan(t_pt_g) or math.isnan(t_pten_g) or t_pten_g <= 0: + ratio_g = float("nan") + graph_cells = ( + f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {'nan':>13}" + ) + else: + ratio_g = t_pt_g / t_pten_g + graph_cells = ( + f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {ratio_g:>12.2f}x" + ) + + print( + f"{sum_M:>6d} {K:>5d}" + f" |" + f"{t_pt:>10.4f} {t_pten:>10.4f} {ratio:>7.2f}x" + f" |" + f"{graph_cells}" + ) + + del x_concat, quantizer_list + torch.cuda.empty_cache() + + print() + print("Legend:") + print(" per-token = nvfp4_per_token_group_quantize(x, splits, rowwise+colwise)") + print(" = K1 fused amax + K2 fused cast (2 launches), this PR") + print(" per-tensor = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)") + print(" = nvte_group_hadamard_transform_amax") + print(" + nvte_group_hadamard_transform_cast_fusion (2 launches)") + print(" ratio = per-token / per-tensor ** < 1.0 = this PR wins **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided,") + print(" pure kernel-level wall time, ALL rows)") + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token.py b/tests/pytorch/nvfp4/test_nvfp4_per_token.py new file mode 100644 index 0000000000..0cc4f8347f --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token.py @@ -0,0 +1,375 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Correctness tests for NVFP4 per-token cast + cuBLAS LT NVFP4 GEMM. + +Covers byte-equal kernel-vs-reference quantize parity, K1/K2 split-vs-composite +parity, dequant + fp32 reference, and a cuBLAS LT NVFP4 GEMM smoke. Requires +bf16 input, M % 128 == 0, K % 128 == 0; GEMM tests gated by SM100. +""" + +from __future__ import annotations + +import pytest +import torch + +# Must import transformer_engine first to dlopen libtransformer_engine.so so +# transformer_engine_torch.so can resolve typeinfo / vtable symbols at load time. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex # type: ignore # noqa: F401 + +from transformer_engine.pytorch.custom_recipes.gemm_nvfp4_per_token import ( + dequantize_nvfp4_per_token, + nvfp4_per_token_gemm, + nvfp4_per_token_gemm_dequant, +) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + BLOCK_K, + NVFP4QuantizerPerTokenRef, + nvfp4_per_token_amax, + nvfp4_per_token_encode, + nvfp4_per_token_quantize, +) + + +def _has_sm100() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +_GATED_SM100 = pytest.mark.skipif( + not _has_sm100(), + reason="NVFP4 per-token GEMM via cuBLAS LT requires SM100 (Blackwell).", +) + +_GATED_FP4 = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="NVFP4 per-token cast requires CUDA.", +) + + +# (1) Quantize parity: kernel vs Python reference. + +# Shapes obey the kernel contract (M % 128 == 0, K % 128 == 0). +_QUANT_SHAPES = [ + (128, 128), # smallest legal shape + (128, 256), # K > inner SF window of single chunk + (256, 128), # M > inner SF window of single chunk + (256, 512), + (512, 1024), +] + + +def _unpack_fp4_byte_pairs(x: torch.Tensor) -> torch.Tensor: + """Unpack two FP4 values per byte into one uint8 nibble per element.""" + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +@_GATED_FP4 +@pytest.mark.parametrize("M,N", _QUANT_SHAPES) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_per_token_quantize_byte_exact(M: int, N: int, rowwise: bool, columnwise: bool) -> None: + """Composite per-token output is byte-equal to the Python reference.""" + torch.manual_seed(0xBEEF * (M + 17) + (N + 3)) + device = torch.device("cuda") + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) * 4.0 + # Outliers so the per-row outer is exercised. + if M >= 4: + x[0, :] *= 8.0 + x[-1, :] *= 0.125 + + ref = NVFP4QuantizerPerTokenRef(rowwise=rowwise, columnwise=columnwise).quantize(x) + sut = nvfp4_per_token_quantize(x, rowwise=rowwise, columnwise=columnwise) + + if rowwise: + qx_sut = _unpack_fp4_byte_pairs(sut.data.view(torch.uint8)) + qx_ref = _unpack_fp4_byte_pairs(ref.data.view(torch.uint8)) + torch.testing.assert_close(qx_sut, qx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut.scale.view(torch.uint8), ref.scale.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + torch.testing.assert_close(sut.row_amax, ref.row_amax, atol=0.0, rtol=0.0) + + if columnwise: + qxt_sut = _unpack_fp4_byte_pairs(sut.columnwise_data.view(torch.uint8)) + qxt_ref = _unpack_fp4_byte_pairs(ref.columnwise_data.view(torch.uint8)) + torch.testing.assert_close(qxt_sut, qxt_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut.columnwise_scale.view(torch.uint8), + ref.columnwise_scale.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + torch.testing.assert_close(sut.col_amax, ref.col_amax, atol=0.0, rtol=0.0) + + +# (2) Split-kernel parity: K1 then K2 == composite K1+K2. + +@_GATED_FP4 +@pytest.mark.parametrize("M,N", _QUANT_SHAPES) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_per_token_split_byte_equal( + M: int, N: int, rowwise: bool, columnwise: bool, +) -> None: + """K1 (amax) then K2 (encode) byte-equals the composite K1+K2.""" + torch.manual_seed(0xC0FFEE * (M + 7) + (N + 11)) + device = torch.device("cuda") + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) * 4.0 + if M >= 4: + x[0, :] *= 8.0 + x[-1, :] *= 0.125 + + composite = nvfp4_per_token_quantize(x, rowwise=rowwise, columnwise=columnwise) + + row_amax, col_amax = nvfp4_per_token_amax( + x, rowwise=rowwise, columnwise=columnwise, + ) + split = nvfp4_per_token_encode( + x, + row_amax=row_amax, + col_amax=col_amax, + rowwise=rowwise, + columnwise=columnwise, + ) + + if rowwise: + torch.testing.assert_close(split.row_amax, composite.row_amax, atol=0.0, rtol=0.0) + torch.testing.assert_close( + split.data.view(torch.uint8), composite.data.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + torch.testing.assert_close( + split.scale.view(torch.uint8), composite.scale.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + if columnwise: + torch.testing.assert_close(split.col_amax, composite.col_amax, atol=0.0, rtol=0.0) + torch.testing.assert_close( + split.columnwise_data.view(torch.uint8), + composite.columnwise_data.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + torch.testing.assert_close( + split.columnwise_scale.view(torch.uint8), + composite.columnwise_scale.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + + +# (2b) Input-validation rejections. + +@_GATED_FP4 +def test_per_token_validation_rejects_fp32() -> None: + """Per-token must ``ValueError`` on non-bf16 input (no fallback path).""" + device = torch.device("cuda") + x = torch.randn((128, 128), dtype=torch.float32, device=device) + with pytest.raises(ValueError, match="bf16"): + nvfp4_per_token_quantize(x, rowwise=True, columnwise=False) + + +@_GATED_FP4 +def test_per_token_validation_rejects_unaligned() -> None: + """Per-token must ``ValueError`` on M or K not 128-aligned.""" + device = torch.device("cuda") + x = torch.randn((128, 64), dtype=torch.bfloat16, device=device) + with pytest.raises(ValueError, match="K % 128"): + nvfp4_per_token_quantize(x, rowwise=True, columnwise=False) + + x2 = torch.randn((64, 128), dtype=torch.bfloat16, device=device) + with pytest.raises(ValueError, match="M % 128"): + nvfp4_per_token_quantize(x2, rowwise=True, columnwise=False) + + +# (3) Dequant + fp32 reference matmul sanity (pure-Python, no kernel). + +@_GATED_FP4 +@pytest.mark.parametrize("M,N", [(32, 64), (64, 256)]) +def test_per_token_dequant_roundtrip_close(M: int, N: int) -> None: + """``dequantize(quantize(x)) ~ x`` at FP4 quantization precision.""" + torch.manual_seed(0x1234) + device = torch.device("cuda") + x = torch.randn((M, N), dtype=torch.float32, device=device) + + ref = NVFP4QuantizerPerTokenRef(rowwise=True).quantize(x) + y = dequantize_nvfp4_per_token(ref.data, ref.scale, ref.row_amax) + + # Loose bound: catches dequant-formula bugs, not quantization quality. + rel = (y - x).abs() / x.abs().clamp(min=1e-6) + assert rel.mean().item() < 0.5, f"mean rel error {rel.mean().item():.3g} > 0.5" + + +# (4) Production GEMM: cuBLAS LT NVFP4 + post-scale composite. +# Shapes need M, N % 128 == 0 and K % 16 == 0 for cuBLAS LT NVFP4. +_GEMM_SHAPES = [ + (128, 128, 128), # smallest legal shape + (128, 128, 256), # exercise K > inner SF window + (256, 128, 256), # non-square (M != N) + (256, 256, 256), # square mid-size +] + + +def _three_pronged_bf16_close( + d_test: torch.Tensor, + d_ref: torch.Tensor, + *, + label: str, + rel_l2_floor: float = 2e-2, + bad_count_ratio: float = 1e-2, + atol: float = 1e-1, + bad_rtol: float = 5e-2, +) -> None: + """Dequant-vs-SUT closeness for random GEMM outputs. + + Three-pronged: energy-weighted rel_l2 (primary), torch.allclose-style + n_bad_mixed (localised faults), max_abs (NaN-like blow-up sanity). + """ + finite_mask = torch.isfinite(d_test) & torch.isfinite(d_ref) + d_t = d_test.float()[finite_mask] + d_r = d_ref.float()[finite_mask] + diff = (d_t - d_r).abs() + n = d_t.numel() + + diff_l2 = float(diff.norm().item()) + ref_l2 = float(d_r.norm().item()) + rel_l2 = diff_l2 / (ref_l2 + 1e-30) + + n_bad_mixed = int((diff > atol + bad_rtol * d_r.abs()).sum().item()) + + max_abs = float(diff.max().item()) if n else float("nan") + mean_ref_abs = float(d_r.abs().mean().item()) if n else float("nan") + max_abs_bound = atol + bad_rtol * mean_ref_abs + + rel = diff / d_r.abs().clamp(min=1e-30) + mean_rel = float(rel.mean().item()) if n else float("nan") + max_rel = float(rel.max().item()) if n else float("nan") + + diag = ( + f"[{label}] N_finite={n}/{int(finite_mask.numel())} " + f"rel_l2={rel_l2:.3g} max_abs={max_abs:.3g} n_bad_mixed={n_bad_mixed} " + f"mean_|d_ref|={mean_ref_abs:.3g} " + f"(diag: mean_rel={mean_rel:.3g} max_rel={max_rel:.3g} " + f"— mean_rel/max_rel are NOT asserted; see helper docstring)" + ) + print(diag) + + bad_count_abs_floor = max(8, int(bad_count_ratio * n)) + assert rel_l2 <= rel_l2_floor, ( + f"{diag} -> rel_l2 > {rel_l2_floor} (energy-weighted global " + f"relative error too high — possible structural bug)" + ) + assert n_bad_mixed <= bad_count_abs_floor, ( + f"{diag} -> n_bad_mixed > {bad_count_abs_floor} " + f"(|diff| > atol={atol} + rtol={bad_rtol} * |d_r| for too " + f"many elements — possible localised broken row/col)" + ) + assert max_abs <= max_abs_bound, ( + f"{diag} -> max_abs > {max_abs_bound:.3g} = atol + " + f"bad_rtol * mean_|d_ref| (worst element is way outside the " + f"noise envelope — possible NaN-like blow-up)" + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,N,K", _GEMM_SHAPES) +def test_per_token_gemm_close_to_bf16(M: int, N: int, K: int) -> None: + """End-to-end per_token_gemm is structurally close to BF16 GEMM. + + Uses cos_sim + magnitude-ratio (direction + magnitude) instead of + per-element mean_rel, which is pathological on random GEMM outputs. + """ + torch.manual_seed(0xACE * M + K) + device = torch.device("cuda") + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + + a_q = nvfp4_per_token_quantize(a, rowwise=True) + b_q = nvfp4_per_token_quantize(b, rowwise=True) + + d_sut = nvfp4_per_token_gemm( + a_q.data, a_q.scale, a_q.row_amax, + b_q.data, b_q.scale, b_q.row_amax, + ) + + d_ref = (a.float() @ b.float().t()).to(torch.bfloat16) + + d_sut_f = d_sut.float().flatten() + d_ref_f = d_ref.float().flatten() + + sut_norm = d_sut_f.norm() + ref_norm = d_ref_f.norm() + cos_sim = float((d_sut_f @ d_ref_f) / (sut_norm * ref_norm + 1e-30)) + mag_ratio = float(sut_norm / (ref_norm + 1e-30)) + + # cos_sim >= 0.95 catches operand swap; mag in [0.7, 1.3] catches + # missing/duplicated scale or wrong alpha-by-constant. + cos_sim_floor = 0.95 + mag_lo, mag_hi = 0.7, 1.3 + + diag = ( + f"[per_token({M}x{N}x{K})] cos_sim={cos_sim:.4f} " + f"mag_ratio={mag_ratio:.4f} " + f"||d_sut||={float(sut_norm):.4g} ||d_ref||={float(ref_norm):.4g}" + ) + assert cos_sim >= cos_sim_floor, ( + f"{diag} -> cos_sim < {cos_sim_floor} (structural mismatch; " + f"likely wrong operand swap, missing scale, or indexing bug)" + ) + assert mag_lo <= mag_ratio <= mag_hi, ( + f"{diag} -> mag_ratio not in [{mag_lo}, {mag_hi}] " + f"(systematic magnitude error; check alpha/post-scale)" + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,N,K", _GEMM_SHAPES) +def test_per_token_gemm_close_to_dequant_ref(M: int, N: int, K: int) -> None: + """End-to-end per_token_gemm close to dequant + fp32 matmul (TF32 envelope).""" + torch.manual_seed(0xDEAD * (M + 7) + (N + 1) * K) + device = torch.device("cuda") + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) * 0.5 + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) * 0.5 + + a_q = nvfp4_per_token_quantize(a, rowwise=True) + b_q = nvfp4_per_token_quantize(b, rowwise=True) + + d_sut = nvfp4_per_token_gemm( + a_q.data, a_q.scale, a_q.row_amax, + b_q.data, b_q.scale, b_q.row_amax, + ).float() + + d_ref = nvfp4_per_token_gemm_dequant( + a_q.data, a_q.scale, a_q.row_amax, + b_q.data, b_q.scale, b_q.row_amax, + out_dtype=torch.float32, + ) + + _three_pronged_bf16_close( + d_sut, d_ref, + label=f"vs_dequant({M}x{N}x{K})", + # Empirical rel_l2 ~5e-3..1.5e-2 on random N(0, 0.5), K=128-256. + rel_l2_floor=2e-2, atol=1e-1, bad_rtol=5e-2, bad_count_ratio=1e-2, + ) + + +@_GATED_SM100 +def test_per_token_gemm_rejects_beta_nonzero() -> None: + """beta != 0 raises until residual handling is added.""" + device = torch.device("cuda") + M, N, K = 128, 128, 128 + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + a_q = nvfp4_per_token_quantize(a, rowwise=True) + b_q = nvfp4_per_token_quantize(b, rowwise=True) + + with pytest.raises(ValueError, match=r"beta != 0"): + nvfp4_per_token_gemm( + a_q.data, a_q.scale, a_q.row_amax, + b_q.data, b_q.scale, b_q.row_amax, + beta=1.0, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py new file mode 100644 index 0000000000..3a72616f26 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py @@ -0,0 +1,344 @@ +"""Correctness tests for grouped (multi-tensor) NVFP4 per-token cast. + +The grouped kernel must be byte-equal to a for-loop of single-tensor +calls. Covers composite K1+K2, K1-only, single-split, and many-split. +""" + +from __future__ import annotations + +from typing import List, Optional, Tuple + +import pytest +import torch + +# Import transformer_engine first to dlopen libtransformer_engine.so so that +# transformer_engine_torch can resolve typeinfo / vtable symbols at load time. +import transformer_engine.pytorch as te # noqa: F401 +import transformer_engine_torch as tex # type: ignore # noqa: F401 + +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + BLOCK_K, + RefNVFP4TensorPerToken, + nvfp4_per_token_quantize, +) + + +def _has_fp4() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +_GATED_FP4 = pytest.mark.skipif( + not _has_fp4(), + reason="NVFP4 per-token cast requires SM100 (Blackwell) + CUDA 12.8+", +) + + +# Helper: invoke the grouped binding. +def _alloc_per_token_buffers( + M_i: int, + K: int, + rowwise: bool, + columnwise: bool, + device: torch.device, +) -> Tuple[ + Optional[torch.Tensor], # q_row + Optional[torch.Tensor], # s_dec_row + Optional[torch.Tensor], # row_amax + Optional[torch.Tensor], # q_col + Optional[torch.Tensor], # s_dec_col + Optional[torch.Tensor], # col_amax +]: + q_row = None + s_dec_row = None + row_amax = None + q_col = None + s_dec_col = None + col_amax = None + if rowwise: + q_row = torch.empty((M_i, K // 2), dtype=torch.uint8, device=device) + s_dec_row = torch.empty((M_i, K // BLOCK_K), dtype=torch.uint8, device=device) + row_amax = torch.empty((M_i,), dtype=torch.float32, device=device) + if columnwise: + q_col = torch.empty((K, M_i // 2), dtype=torch.uint8, device=device) + s_dec_col = torch.empty((K, M_i // BLOCK_K), dtype=torch.uint8, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + return q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax + + +def _group_quantize_py( + x_concat: torch.Tensor, + split_sections: List[int], + rowwise: bool, + columnwise: bool, +) -> List[RefNVFP4TensorPerToken]: + """Pre-allocate per-split outputs, dispatch tex.nvfp4_per_token_group_quantize.""" + assert x_concat.dim() == 2 + sum_M, K = x_concat.shape + assert sum(split_sections) == sum_M + device = x_concat.device + + n = len(split_sections) + q_row_list: List[torch.Tensor] = [] + s_dec_row_list: List[torch.Tensor] = [] + row_amax_list: List[torch.Tensor] = [] + q_col_list: List[torch.Tensor] = [] + s_dec_col_list: List[torch.Tensor] = [] + col_amax_list: List[torch.Tensor] = [] + + for M_i in split_sections: + qr, sr, ra, qc, sc, ca = _alloc_per_token_buffers( + M_i, K, rowwise, columnwise, device + ) + if rowwise: + q_row_list.append(qr) + s_dec_row_list.append(sr) + row_amax_list.append(ra) + if columnwise: + q_col_list.append(qc) + s_dec_col_list.append(sc) + col_amax_list.append(ca) + + # Binding wants lists matching num_tensors; pass empty for skipped direction. + empty: List[torch.Tensor] = [] + + tex.nvfp4_per_token_group_quantize( + x_concat, + split_sections, + q_row_list if rowwise else empty, + s_dec_row_list if rowwise else empty, + row_amax_list if rowwise else empty, + q_col_list if columnwise else empty, + s_dec_col_list if columnwise else empty, + col_amax_list if columnwise else empty, + rowwise, + columnwise, + ) + + out: List[RefNVFP4TensorPerToken] = [] + for i in range(n): + # Re-view e4m3 SF as torch.float8_e4m3fn (same bytes, expected dtype). + tensor = RefNVFP4TensorPerToken( + data=q_row_list[i] if rowwise else None, + scale=( + s_dec_row_list[i].view(torch.float8_e4m3fn) if rowwise else None + ), + row_amax=row_amax_list[i] if rowwise else None, + columnwise_data=q_col_list[i] if columnwise else None, + columnwise_scale=( + s_dec_col_list[i].view(torch.float8_e4m3fn) if columnwise else None + ), + col_amax=col_amax_list[i] if columnwise else None, + ) + out.append(tensor) + return out + + +# Test fixtures. Per-token kernel requires M_i % 128 == 0 and K % 128 == 0. +_SHAPES: List[Tuple[List[int], int]] = [ + # (split_sections, K) + ([128], 128), # trivial: 1 split, smallest legal shape + ([128, 128], 128), # 2 equal splits + ([128, 256], 128), # 2 unequal splits + ([128, 256, 128], 256), # 3 splits, mixed sizes + ([128, 128, 128, 128], 256), # 4 equal splits + ([256, 128, 384, 128, 128], 512), # 5-way unequal split, typical MoE + ([256, 256], 1024), # larger K, 2 splits +] + + +# (1) Composite K1+K2: grouped == for-loop of single-tensor, byte-equal. +@_GATED_FP4 +@pytest.mark.parametrize("split_sections,K", _SHAPES) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_group_per_token_quantize_byte_equal( + split_sections: List[int], + K: int, + rowwise: bool, + columnwise: bool, + dtype: torch.dtype, +) -> None: + """Grouped == for-loop of single-tensor, byte-equal (FP4 + SF + amax).""" + torch.manual_seed(0xCAFE * (sum(split_sections) + 7) + K) + device = torch.device("cuda") + sum_M = sum(split_sections) + + # Per-split inputs with sprinkled outliers to stress per-row outer. + splits_in: List[torch.Tensor] = [] + for i, M_i in enumerate(split_sections): + s = torch.randn((M_i, K), dtype=dtype, device=device) * (2.0 + 0.5 * i) + if M_i >= 4: + s[0, :] *= 8.0 + s[-1, :] *= 0.125 + splits_in.append(s) + + x_concat = torch.cat(splits_in, dim=0) + assert x_concat.shape == (sum_M, K) + + oracle: List[RefNVFP4TensorPerToken] = [ + nvfp4_per_token_quantize(s, rowwise=rowwise, columnwise=columnwise) + for s in splits_in + ] + + sut: List[RefNVFP4TensorPerToken] = _group_quantize_py( + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + ) + + assert len(sut) == len(oracle) == len(split_sections) + + for i in range(len(split_sections)): + if rowwise: + torch.testing.assert_close( + sut[i].data.view(torch.uint8), + oracle[i].data.view(torch.uint8), + atol=0.0, rtol=0.0, + msg=f"rowwise q[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].scale.view(torch.uint8), + oracle[i].scale.view(torch.uint8), + atol=0.0, rtol=0.0, + msg=f"rowwise s_dec[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].row_amax, oracle[i].row_amax, atol=0.0, rtol=0.0, + msg=f"row_amax[{i}] mismatch", + ) + if columnwise: + torch.testing.assert_close( + sut[i].columnwise_data.view(torch.uint8), + oracle[i].columnwise_data.view(torch.uint8), + atol=0.0, rtol=0.0, + msg=f"columnwise q[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].columnwise_scale.view(torch.uint8), + oracle[i].columnwise_scale.view(torch.uint8), + atol=0.0, rtol=0.0, + msg=f"columnwise s_dec[{i}] mismatch", + ) + torch.testing.assert_close( + sut[i].col_amax, oracle[i].col_amax, atol=0.0, rtol=0.0, + msg=f"col_amax[{i}] mismatch", + ) + + +# (2) K1-only (amax) entry == K1-only of single-tensor, byte-equal. +@_GATED_FP4 +@pytest.mark.parametrize("split_sections,K", _SHAPES[:3]) # subset, K1 is simple +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_group_per_token_amax_byte_equal( + split_sections: List[int], + K: int, + rowwise: bool, + columnwise: bool, +) -> None: + """tex.nvfp4_per_token_group_amax matches K1 of the for-loop variant.""" + torch.manual_seed(0xDEAD * sum(split_sections) + K) + device = torch.device("cuda") + sum_M = sum(split_sections) + n = len(split_sections) + + splits_in: List[torch.Tensor] = [] + for i, M_i in enumerate(split_sections): + splits_in.append(torch.randn((M_i, K), dtype=torch.bfloat16, device=device) * 3.0) + x_concat = torch.cat(splits_in, dim=0) + + # Oracle row_amax / col_amax via single-tensor quantize (shared K1). + oracle_row = [] + oracle_col = [] + for s in splits_in: + o = nvfp4_per_token_quantize(s, rowwise=rowwise, columnwise=columnwise) + oracle_row.append(o.row_amax if rowwise else None) + oracle_col.append(o.col_amax if columnwise else None) + + row_amax_list = [ + torch.empty((M_i,), dtype=torch.float32, device=device) for M_i in split_sections + ] if rowwise else [] + col_amax_list = [ + torch.empty((K,), dtype=torch.float32, device=device) for _ in range(n) + ] if columnwise else [] + + tex.nvfp4_per_token_group_amax( + x_concat, split_sections, row_amax_list, col_amax_list, rowwise, columnwise + ) + + if rowwise: + for i in range(n): + torch.testing.assert_close( + row_amax_list[i], oracle_row[i], atol=0.0, rtol=0.0, + msg=f"row_amax[{i}] mismatch", + ) + if columnwise: + for i in range(n): + torch.testing.assert_close( + col_amax_list[i], oracle_col[i], atol=0.0, rtol=0.0, + msg=f"col_amax[{i}] mismatch", + ) + + +# (3) Single-split call must equal the single-tensor kernel. +@_GATED_FP4 +@pytest.mark.parametrize("M,K", [(128, 128), (128, 256), (256, 1024)]) +@pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) +def test_group_single_split_matches_single_tensor( + M: int, K: int, rowwise: bool, columnwise: bool +) -> None: + """One-split grouped call == single-tensor call (boundary-advance no-op).""" + torch.manual_seed(0xBABE * M + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) * 4.0 + + oracle = nvfp4_per_token_quantize(x, rowwise=rowwise, columnwise=columnwise) + sut_list = _group_quantize_py(x, [M], rowwise=rowwise, columnwise=columnwise) + assert len(sut_list) == 1 + sut = sut_list[0] + + if rowwise: + torch.testing.assert_close(sut.data, oracle.data, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut.scale.view(torch.uint8), oracle.scale.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + torch.testing.assert_close(sut.row_amax, oracle.row_amax, atol=0.0, rtol=0.0) + if columnwise: + torch.testing.assert_close( + sut.columnwise_data, oracle.columnwise_data, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sut.columnwise_scale.view(torch.uint8), + oracle.columnwise_scale.view(torch.uint8), + atol=0.0, rtol=0.0, + ) + torch.testing.assert_close(sut.col_amax, oracle.col_amax, atol=0.0, rtol=0.0) + + +# (4) Many-split scaling test (close to the 64-tensor cap). +@_GATED_FP4 +@pytest.mark.parametrize("n_splits", [8, 16, 32, 64]) +def test_group_many_splits_byte_equal(n_splits: int) -> None: + """Many small splits (MoE expert layout) still byte-equal to oracle.""" + torch.manual_seed(0xFEED * n_splits) + device = torch.device("cuda") + K = 256 + split_sections = [128] * n_splits + + splits_in = [ + torch.randn((128, K), dtype=torch.bfloat16, device=device) * (1.0 + 0.1 * i) + for i in range(n_splits) + ] + x_concat = torch.cat(splits_in, dim=0) + + oracle = [nvfp4_per_token_quantize(s, rowwise=True, columnwise=True) for s in splits_in] + sut = _group_quantize_py(x_concat, split_sections, rowwise=True, columnwise=True) + + for i in range(n_splits): + torch.testing.assert_close(sut[i].data, oracle[i].data, atol=0.0, rtol=0.0) + torch.testing.assert_close(sut[i].row_amax, oracle[i].row_amax, atol=0.0, rtol=0.0) + torch.testing.assert_close( + sut[i].columnwise_data, oracle[i].columnwise_data, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(sut[i].col_amax, oracle[i].col_amax, atol=0.0, rtol=0.0) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 06d85b6d84..5cdd255b08 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -240,8 +240,11 @@ list(APPEND transformer_engine_cuda_arch_specific_sources multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu recipe/nvfp4.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu) + cast/nvfp4/quantize_nvfp4_per_token.cu + cast/nvfp4/quantize_nvfp4_per_token_group.cu + gemm/nvfp4_per_token_post_scale.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu) # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu new file mode 100644 index 0000000000..57b8232c92 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu @@ -0,0 +1,1025 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4_per_token.cu + * \brief NVFP4 per-token cast on the bf16 fast path: + * TMA + mbarrier + 64x64 sub-tile + 2-buffer ping-pong. + * + * Pipeline structure mirrors the per-tensor cast kernel + * (``quantize_transpose_nvfp4_tuned_1D_kernel``) and the RHT + * amax kernel (``HadamardAmaxTmaKernel``): + * + * * Two-kernel design (amax pass + encode pass). Output of amax fed + * into encode via per-row/per-col buffers in ``output->amax`` + * and ``output->columnwise_amax`` (sized [M] and [K] respectively). + * * Each CTA covers a 128x128 chunk decomposed as 4 sequential 64x64 + * sub-tiles, double-buffered. Each sub-tile is one TMA bulk-2D + * tensor transaction. mbarrier expect_tx + parity wait gives + * one-iteration-overlap between HBM and compute. + * * Encode pass reads the input tile ONCE into SMEM, then dispatches + * both the rowwise (FP4 + per-row scale) and the columnwise (FP4 + * transpose + per-col scale) outputs from that same staged copy. + * Outer scaling factors S_enc are loaded from + * ``row_amax_in[M]`` / ``col_amax_in[K]`` once per CTA into a small + * SMEM cache (1 KiB total). + */ + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "common/cast/core/common.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" + +namespace transformer_engine { +namespace nvfp4_per_token { + +#if FP4_TYPE_SUPPORTED + +using dispatch::common::align_smem_ptr_per_TMA_requirements; +using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; +using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; +using dispatch::nvfp4::nvfp4_scale_t; + +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int THREADS_NUM = 128; // threads per CTA +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap +constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; // = 2 ping-pong input buffers + +// Derived (chunk / tile / stage) +constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; // 2 +constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; // 2 +constexpr int STAGES = TILES_Y * TILES_X; // 4 + +constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 inner blocks per row of the chunk +constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 inner blocks per col of the chunk +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 + +// Encode helpers' thread layout (rowwise pass: 4x32 = K-dim x M-dim) +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) +constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 + +// Encode helpers' thread layout (colwise pass: tid.X for col, warp for M-block) +constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 cols per warp +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 (warps) + +// Buffer dimensions (input bf16 SMEM tiles + FP4 output SMEM tiles for TMA store) +constexpr int BUFF_IN_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_IN_DIM_X = TILE_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements +constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) +constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose + +// Manual swizzling parameters to reduce SMEM bank conflicts on rowwise loads +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 + +using IType = bf16; +using IType2 = ptx::FPx2; // = ptx::bf16x2 +using IType3D = IType [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; + +// Compute the per-block (1x16) byte-equal arithmetic and emit FP4 codes into +// SMEM rowwise output buffer + e4m3 scale into SMEM rowwise scale buffer. +__device__ __forceinline__ void rowwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_ptr, + nvfp4_scale_t* __restrict__ sSFrowwise_ptr, + const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out) { + const auto& sIn = *reinterpret_cast(sIn_ptr); + auto& sOut = *reinterpret_cast(sOut_ptr); + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) + + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; + +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = tid_Y_rowwise + it * THREADS_Y_ROWWISE; // 0..63 over 2 iters + const int chunk_local_row = stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; // 0..127 + + // Per-row S_enc (look up from CTA-cached row amax buffer) + const float row_amax = sRowAmax[chunk_local_row]; + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(row_amax, 1e-12f)); + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + // Read 16 elements (in PACK_SIZE=8 waves), swizzled to avoid bank conflicts, + // and reduce to a 1x16 block amax. + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + __uint128_t& elts_8x = *reinterpret_cast<__uint128_t*>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = static_cast( + __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + // Byte-equal compute path (matches the Python reference in + // ``NVFP4QuantizerPerTokenRef``): + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax, S_enc); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + // Store e4m3 scale to SMEM SF buffer (1 thread per 1x16 block stores). + if (SF_storing_thread) { + const int scales_offset_Y = chunk_local_row; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = s_dec; + } + + // Cast 16 elements to FP4 using mul_cvt_4x (4 elements per call, the + // byte-equal path against the Python reference). We've already pre-loaded + // into rIn[WAVES][4]. + // WAVES = 2, PACK_SIZE/2 = 4 elements per wave + // Total per iteration: 2 waves * (4 IType2 elts) = 16 elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + + // 4 fp4 quads from 8 bf16 elements (in PACK_SIZE=8 waves): + // rIn[w][0..3] = 4 IType2 pairs = 8 elements. + // Each mul_cvt_4x packs 4 elements; we need 2 calls per wave. + fp4e2m1x4 qu0{}, qu1{}; + ptx::mul_cvt_4x(qu0, rIn[w][0], rIn[w][1], block_scale); + ptx::mul_cvt_4x(qu1, rIn[w][2], rIn[w][3], block_scale); + + // Pack into a 32-bit word and store to SMEM out (b32 store) + uint32_t out_x8 = (static_cast(*reinterpret_cast(&qu0))) | + (static_cast(*reinterpret_cast(&qu1)) << 16); + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +// Compute the per-block (1x16, along M) byte-equal arithmetic for the columnwise +// pass; emit transposed FP4 + e4m3 scale into SMEM. +__device__ __forceinline__ void colwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_tr_ptr, + nvfp4_scale_t* __restrict__ sSFcolwise_ptr, + const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out_tr) { + const auto& sIn2x = *reinterpret_cast(sIn_ptr); + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) + const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 + const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index + + __align__(8) IType rIn[2][SCALE_DIM]; + // Read 2 columns x 16 rows, accumulate per-column amax. + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + // NOTE: thread_amax_2x.x is the amax of column .x; thread_amax_2x.y is amax of column .y. + const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), + static_cast(__habs(thread_amax_2x.y))}; + +#pragma unroll + for (int w = 0; w < 2; ++w) { + // Per-col S_enc lookup (each of the 2 cols this thread owns has its own amax/S_enc). + const int chunk_local_col = scale_tr_offset_Y + w; + const float col_amax = sColAmax[chunk_local_col]; + const float S_enc_col = compute_global_encode_scaling_factor_FP4(fmaxf(col_amax, 1e-12f)); + + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax[w], S_enc_col); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc_col, s_dec_f); + + // Store e4m3 scale to SMEM colwise SF buffer. + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = s_dec; + + // Cast 16 elements to FP4 via 4x mul_cvt_4x (4 elements per call -> 4 calls). + // The 16 rIn[w][...] values are bf16; pack into IType2 pairs. + fp4e2m1x4 qu[4]; +#pragma unroll + for (int e = 0; e < 4; ++e) { + IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; + IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + } + + // Pack 4 fp4e2m1x4 (= 16 fp4) into a 64-bit value and store to SMEM transpose buffer. + uint64_t out_pack_16x = (static_cast(*reinterpret_cast(&qu[0])) << 0) | + (static_cast(*reinterpret_cast(&qu[1])) << 16) | + (static_cast(*reinterpret_cast(&qu[2])) << 32) | + (static_cast(*reinterpret_cast(&qu[3])) << 48); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +// ============================================================================= +// Kernel 2: per-token encode (rowwise + optional colwise transpose). +// ============================================================================= +template +__global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t* const scales_ptr, + nvfp4_scale_t* const scales_t_ptr, + const float* const row_amax_in, + const float* const col_amax_in, + const float* noop, + const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const bool leading_thread = (threadIdx.x == 0); + + // ------------------------------------------------------------------------- + // Dynamic SMEM layout + // sIn: 2 buffers x (64x64 bf16) = 16 KiB + // sOut: 2 buffers x (64x32 fp4 packed) = 4 KiB (rowwise FP4) + // sOut_tr: 2 buffers x (64x32 fp4 packed) = 4 KiB (colwise FP4) + // sSFrowwise: 128 x 8 e4m3 = 1 KiB + // sSFcolwise: 128 x 8 e4m3 = 1 KiB + // sRowAmax: 128 fp32 = 512 B + // sColAmax: 128 fp32 = 512 B + // IN_buff_readable_mbar: 2 x 8 B = 16 B + // Total: ~27 KiB + alignment padding. + // ------------------------------------------------------------------------- + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int out_mem_rowwise_data = DO_ROW ? buff_size_aligned_out : 0; + constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = + DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) : 0; + constexpr int out_mem_colwise_scales = + DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) : 0; + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data); + + nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data + + out_mem_rowwise_scales); + + // Per-CTA row/col amax SMEM cache (128 floats each). + __shared__ float sRowAmax[CHUNK_DIM_Y]; + __shared__ float sColAmax[CHUNK_DIM_X]; + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + auto& sIn = *reinterpret_cast(sIn_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + // Transpose-output block offsets: row-CTA(X) -> col-tensor's M; col-CTA(Y) -> col-tensor's N. + const int block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; + const int block_offset_X_tr = ctaid_Y * CHUNK_DIM_Y; + + const int scales_block_offset_Y_rowwise = ctaid_Y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X; + const int scales_block_offset_Y_tr = ctaid_X * CHUNK_DIM_X; + const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y; + + // Load per-row / per-col amax into SMEM cache (cooperative, full chunk = 128 entries each). + if (DO_ROW && threadIdx.x < CHUNK_DIM_Y) { + sRowAmax[threadIdx.x] = row_amax_in[block_offset_Y + threadIdx.x]; + } + if (DO_COL && threadIdx.x < CHUNK_DIM_X) { + sColAmax[threadIdx.x] = col_amax_in[block_offset_X + threadIdx.x]; + } + + // Initialize mbarriers. + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0 (one-iteration overlap throughout main loop). +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + uint64_t* dst = reinterpret_cast(&sIn[buff_in]); + const uint64_t* src = reinterpret_cast(&tensor_map_input); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, src, global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); + } + } + + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + // Prefetch next stage's input (skip after the second-to-last stage). + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + + if (leading_thread) { + uint64_t* dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t* src = reinterpret_cast(&tensor_map_input); + ptx::mbarrier_arrive_expect_tx( + &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, src, next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for current stage's input to land. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Wait for any prior TMA store to have finished reading the output SMEM + // buffers (so we can overwrite them). + ptx::cp_async_bulk_wait_group_read(); + + // ----- Compute: rowwise + colwise from the same SMEM tile ----- + if (DO_ROW) { + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, + sRowAmax, stage_Y, stage_X, buff_in, buff_out); + } + if (DO_COL) { + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, + sColAmax, stage_Y, stage_X, buff_in, buff_out_tr); + } + + // Fence + sync so all threads' SMEM writes are visible to TMA store. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + // Issue TMA store(s) for this stage's outputs. + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X; + const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y; + + if (DO_ROW) { + auto& sOut = *reinterpret_cast(sOut_ptr); + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + global_offset_X, global_offset_Y, + reinterpret_cast(&sOut[buff_out])); + } + if (DO_COL) { + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), + global_offset_X_tr, global_offset_Y_tr, + reinterpret_cast(&sOut_tr[buff_out_tr])); + } + ptx::cp_async_bulk_commit_group(); + } + + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } // end of stages + + // Vectorized SF scatter to global (chunk-end batch). Mirrors the + // production tuned 1D scale-store epilogue. + if (DO_ROW) { + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + using ScalesVec = Vec; + const int chunk_cols = static_cast(cols) - block_offset_X; + const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); + + for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = + row_global * scale_stride + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + } + } + } + if (DO_COL) { + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + using ScalesVec = Vec; + const int chunk_rows = static_cast(rows) - block_offset_Y; + const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); + + for (size_t row_tr = threadIdx.x; row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { + const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; + if (row_tr_global < cols) { + ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = + row_tr_global * scale_stride_t + scales_block_offset_X_tr; + scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); + } + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + NVTE_DEVICE_ERROR("Per-token encode kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +// ============================================================================= +// Kernel 1: per-token amax (rowwise + colwise atomicMaxFloat). +// +// Same TMA + mbarrier + 64x64 sub-tile + ping-pong pipeline as the encode +// kernel above, just with compute = abs + reduce instead of FP4 encode. +// +// Compute mapping (one thread per output slot): +// tid t in [0, 128): +// row partial: max over (cols 0..127) for row (row_base + t) +// col partial: max over (rows 0..127) for col (col_base + t) +// For each 64x64 sub-tile in stage (stage_Y, stage_X): +// if t in [stage_Y*64, stage_Y*64+64): scan 64 cols of sub-tile for row t +// if t in [stage_X*64, stage_X*64+64): scan 64 rows of sub-tile for col t +// After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. +// ============================================================================= +template +__global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW + float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL + const float* noop, + const size_t rows, const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const bool leading_thread = (threadIdx.x == 0); + const int tid = threadIdx.x; + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = align_smem_ptr_per_TMA_requirements(dynamic_shmem); + IType* sIn_ptr = reinterpret_cast(dshmem); + auto& sIn = *reinterpret_cast(sIn_ptr); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + // Per-thread row & col partial accumulators (each thread owns 1 of each). + float row_partial = 0.f; + float col_partial = 0.f; + + // Which row / col does THIS thread own within the 128x128 chunk? + // row owned: row_base + tid -> needs sub-tile rows [stage_Y*64, +64) + // i.e., this thread contributes to row partial in stages + // where stage_Y == tid / 64. + // col owned: col_base + tid -> stage_X == tid / 64. + const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 + const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 + const int my_row_in_subtile = tid % TILE_DIM_Y; // 0..63 + const int my_col_in_subtile = tid % TILE_DIM_X; // 0..63 + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0. +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[buff_in]), + reinterpret_cast(&tensor_map_input), + global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); + } + } + + int buff_in = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + + // Prefetch next stage. + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx( + &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[next_prefetch_buff]), + reinterpret_cast(&tensor_map_input), + next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for this stage's tile. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // ----- Row partial update: walk this thread's row across the sub-tile ----- + if (DO_ROW && stage_Y == my_row_stage_Y) { + // 32 warp lanes each own a distinct row but read col-offset e in lockstep; + // SMEM row stride is 64*sizeof(bf16) = 128 B = exactly 32 banks, so every + // lane lands on the same bank set -> 32-way bank conflict per LDS.128. + // Rotate the e-iter visit order by (my_row_in_subtile >> 2) so that lanes + // in distinct row-quads pick distinct e values per iter, splitting the + // warp into 8 disjoint bank groups (4-way conflict, 8x reduction). + // Per-thread data set unchanged; max() is associative & commutative => byte-equal. + float local_max = row_partial; + const int row_bank_group = (my_row_in_subtile >> 2) & 0x7; +#pragma unroll + for (int e_iter = 0; e_iter < 8; ++e_iter) { + const int e = ((e_iter + row_bank_group) & 0x7) << 3; + __uint128_t elts_8x = ptx::ld_shared_b128(&sIn[buff_in][my_row_in_subtile][e]); + const IType2* pairs = reinterpret_cast(&elts_8x); + IType2 amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int p = 0; p < 4; ++p) { + ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); + } + local_max = fmaxf(local_max, + static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + } + row_partial = local_max; + } + + // ----- Col partial update: walk this thread's col down the sub-tile ----- + if (DO_COL && stage_X == my_col_stage_X) { + // Scan 64 rows for our col. Single-column access pattern (1 byte stride + // per row in SMEM); we read 1 bf16 at a time. Bank conflicts mitigated + // by 64-wide tile (column stride = TILE_DIM_X * 2 = 128 bytes, which is + // 1 bank * 32 rows; with 32 threads on different cols, conflicts hit + // groups of 32 -> serialized 32-way, accepted for v1). + float local_max = col_partial; +#pragma unroll + for (int e = 0; e < TILE_DIM_Y; ++e) { + const IType v = sIn[buff_in][e][my_col_in_subtile]; + local_max = fmaxf(local_max, fabsf(static_cast(v))); + } + col_partial = local_max; + } + + __syncthreads(); + buff_in = (buff_in + 1) % BUFFS_NUM; + } + + // ----- Cross-CTA reduction: 1 atomicMaxFloat per row/col slot per CTA ----- + if (DO_ROW) { + atomicMaxFloat(&row_amax_out[block_offset_Y + tid], row_partial); + } + if (DO_COL) { + atomicMaxFloat(&col_amax_out[block_offset_X + tid], col_partial); + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +#endif // FP4_TYPE_SUPPORTED (closes the kernels block opened at line 69) + +// ============================================================================= +// Launchers +// ============================================================================= + +#if FP4_TYPE_SUPPORTED +// Launch Kernel 1 (amax). Writes only to output->amax / output->columnwise_amax; +// other output fields untouched. Pre-zeroes the amax buffers (atomicMax identity). +inline void launch_amax(const Tensor& input, Tensor* output, + const Tensor& noop, cudaStream_t stream) { + const size_t M = input.flat_first_dim(); + const size_t K = input.flat_last_dim(); + + const bool do_row = (output->amax.dptr != nullptr); + const bool do_col = (output->columnwise_amax.dptr != nullptr); + if (!do_row && !do_col) return; + + // Pre-zero amax buffers (atomicMaxFloat identity for non-negative values). + if (do_row) { + NVTE_CHECK(output->amax.numel() == M, + "Per-token amax: output->amax numel must equal M = ", M, + ", got ", output->amax.numel()); + NVTE_CHECK_CUDA(cudaMemsetAsync(output->amax.dptr, 0, M * sizeof(float), stream)); + } + if (do_col) { + NVTE_CHECK(output->columnwise_amax.numel() == K, + "Per-token amax: output->columnwise_amax numel must equal K = ", K, + ", got ", output->columnwise_amax.numel()); + NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_amax.dptr, 0, K * sizeof(float), stream)); + } + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + create_2D_tensor_map(tmap_in, input.data, M, K, + TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad + + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + const float* noop_ptr = (noop.data.dptr != nullptr) + ? reinterpret_cast(noop.data.dptr) + : nullptr; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + auto kernel = per_token_amax_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, + do_row ? reinterpret_cast(output->amax.dptr) : nullptr, + do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, + noop_ptr, M, K); + });); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Launch Kernel 2 (encode). Requires output->amax / columnwise_amax to be pre-filled +// (by a prior launch_amax call or by an external caller); writes +// output->data / scale_inv / columnwise_data / columnwise_scale_inv. +inline void launch_encode(const Tensor& input, Tensor* output, + const Tensor& noop, cudaStream_t stream) { + const size_t M = input.flat_first_dim(); + const size_t K = input.flat_last_dim(); + + const bool do_row = output->has_data(); + const bool do_col = output->has_columnwise_data(); + if (!do_row && !do_col) return; + + if (do_row) { + NVTE_CHECK(output->amax.dptr != nullptr, + "Per-token encode: output->amax (per-row, [M]) must be pre-filled."); + NVTE_CHECK(output->data.dptr != nullptr, + "Per-token encode: output->data (rowwise FP4) must be allocated."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "Per-token encode: output->scale_inv must be allocated."); + } + if (do_col) { + NVTE_CHECK(output->columnwise_amax.dptr != nullptr, + "Per-token encode: output->columnwise_amax (per-col, [K]) must be pre-filled."); + NVTE_CHECK(output->columnwise_data.dptr != nullptr, + "Per-token encode: output->columnwise_data must be allocated."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Per-token encode: output->columnwise_scale_inv must be allocated."); + } + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + alignas(64) CUtensorMap tmap_out{}; + alignas(64) CUtensorMap tmap_out_t{}; + + create_2D_tensor_map(tmap_in, input.data, M, K, + TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + if (do_row) { + create_2D_tensor_map(tmap_out, output->data, M, K, + TILE_DIM_Y, TILE_DIM_X, K, 0, 4); + } + if (do_col) { + create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, + TILE_DIM_X, TILE_DIM_Y, M, 0, 4); + } + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = + DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_t = + DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT); + + // Total dyn SMEM: input + output FP4 (row + col) + SF (row + col) + 128B align. + const int dshmem_size = + buff_size_aligned_in + + (do_row ? buff_size_aligned_out : 0) + + (do_col ? buff_size_aligned_out_t : 0) + + (do_row ? buff_size_scales : 0) + + (do_col ? buff_size_scales_t : 0) + + TMA_SHMEM_ALIGNMENT; + + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + const float* noop_ptr = (noop.data.dptr != nullptr) + ? reinterpret_cast(noop.data.dptr) + : nullptr; + const size_t scale_stride = do_row ? output->scale_inv.shape[1] : 0; + const size_t scale_stride_t = do_col ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t* scales_ptr = + do_row ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + nvfp4_scale_t* scales_t_ptr = + do_col ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const float* row_amax_in = + do_row ? reinterpret_cast(output->amax.dptr) : nullptr; + const float* col_amax_in = + do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + auto kernel = per_token_encode_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, tmap_out, tmap_out_t, + scales_ptr, scales_t_ptr, + row_amax_in, col_amax_in, + noop_ptr, M, K, scale_stride, scale_stride_t); + });); + NVTE_CHECK_CUDA(cudaGetLastError()); +} +#endif // FP4_TYPE_SUPPORTED + +// ============================================================================= +// Impls (validation + dispatch). The K1 amax / K2 encode passes are exposed +// as separately callable entry points alongside the composite K1+K2 entry, +// to enable per-kernel benchmarking and diagnostic use. +// ============================================================================= + +#if FP4_TYPE_SUPPORTED +// Common input + shape validation, shared by all 3 entry points. +// Output constraints differ by entry point (see validate_*_output helpers below). +inline void validate_input_shape(const Tensor& input) { + NVTE_CHECK(input.has_data(), "Per-token cast: input has no data."); + NVTE_CHECK(input.dtype() == DType::kBFloat16, + "Per-token cast is bf16-only. Got dtype enum ", + static_cast(input.dtype())); + const size_t M = input.flat_first_dim(); + const size_t K = input.flat_last_dim(); + NVTE_CHECK(M % CHUNK_DIM_Y == 0, + "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, ", got M=", M); + NVTE_CHECK(K % CHUNK_DIM_X == 0, + "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, ", got K=", K); +} + +// K1 (amax-only) requires at least one amax buffer allocated; FP4 output is not used. +inline void validate_amax_output(const Tensor* output) { + NVTE_CHECK(output->amax.dptr != nullptr || output->columnwise_amax.dptr != nullptr, + "Per-token K1 (amax): at least one of rowwise/columnwise amax buffer " + "must be allocated."); +} + +// K2 (encode) and composite require at least one FP4 output buffer allocated. +inline void validate_encode_output(const Tensor* output) { + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Per-token K2 (encode): at least one of rowwise/columnwise FP4 output " + "must be allocated."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "Per-token cast emits compact (non-swizzled) inner SF."); +} + +void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, + Tensor* output, cudaStream_t stream) { + validate_input_shape(input); + validate_amax_output(output); + if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; + launch_amax(input, output, noop, stream); +} + +void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, + Tensor* output, cudaStream_t stream) { + validate_input_shape(input); + validate_encode_output(output); + if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; + launch_encode(input, output, noop, stream); +} + +void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, + Tensor* output, cudaStream_t stream) { + validate_input_shape(input); + validate_encode_output(output); + if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; + launch_amax(input, output, noop, stream); + launch_encode(input, output, noop, stream); +} + +bool can_use_per_token(size_t M, size_t K, DType dtype) { + return (dtype == DType::kBFloat16) && (M % CHUNK_DIM_Y == 0) && (K % CHUNK_DIM_X == 0); +} +#else // !FP4_TYPE_SUPPORTED +void per_token_amax_blocked_impl(const Tensor&, const Tensor&, Tensor*, cudaStream_t) { + NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); +} +void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, cudaStream_t) { + NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); +} +void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, cudaStream_t) { + NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); +} +bool can_use_per_token(size_t, size_t, DType) { return false; } +#endif // FP4_TYPE_SUPPORTED + +} // namespace nvfp4_per_token +} // namespace transformer_engine + +// ============================================================================= +// C-API entry points +// ============================================================================= + +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, + NVTETensor output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_nvfp4_per_token_amax); + using namespace transformer_engine; + const Tensor* input_tensor = convertNVTETensorCheck(input); + Tensor* output_tensor = convertNVTETensorCheck(output); + Tensor dummy_noop; + const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; + nvfp4_per_token::per_token_amax_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, stream); +#else + (void)input; (void)noop; (void)output; (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, + NVTETensor output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_nvfp4_per_token_encode); + using namespace transformer_engine; + const Tensor* input_tensor = convertNVTETensorCheck(input); + Tensor* output_tensor = convertNVTETensorCheck(output); + Tensor dummy_noop; + const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; + nvfp4_per_token::per_token_encode_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, stream); +#else + (void)input; (void)noop; (void)output; (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, + NVTETensor output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_nvfp4_per_token_quantize); + using namespace transformer_engine; + const Tensor* input_tensor = convertNVTETensorCheck(input); + Tensor* output_tensor = convertNVTETensorCheck(output); + Tensor dummy_noop; + const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; + nvfp4_per_token::per_token_quantize_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, stream); +#else + (void)input; (void)noop; (void)output; (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum) { + using namespace transformer_engine; + const DType dtype = static_cast(input_dtype_enum); + return nvfp4_per_token::can_use_per_token(M, K, dtype) ? 1 : 0; +} diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu new file mode 100644 index 0000000000..9e5ede9ff2 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu @@ -0,0 +1,1017 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4_per_token_group.cu + * \brief Grouped NVFP4 per-token cast: bf16 input (sum_M, K), splits along + * M; K1 fused row+col amax + K2 row + col cast. Requires K % 128 == 0 + * and every split_sections[i] % 128 == 0. + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "common/cast/core/common.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace nvfp4_per_token_group { + +#if FP4_TYPE_SUPPORTED + +using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; +using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; +using dispatch::nvfp4::nvfp4_scale_t; +using ptx::FPx2; + +constexpr int kInnerK = 16; // NVFP4 inner block: 16 elements per e4m3 SF + +// 64-tensor cap so the args struct fits under the 4 KB launch-param limit. +constexpr int kMaxTensorsPerKernel = 64; + +// Per-launch arg table; passed as __grid_constant__ for constant-cache reads. +struct NVFP4PerTokenMultiArgs { + // K1 outputs (per-tensor pointers; one fp32 array per tensor) + void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) + void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) + + // K2 outputs (per-tensor pointers; FP4 codes + e4m3 inner SF) + void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) + void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) + void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) + void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) + + // Shared layout info + int split_sections_range[kMaxTensorsPerKernel + 1]; // prefix sum w/ leading 0 + int num_tensors; +}; + +__device__ __forceinline__ int GetTensorId(const NVFP4PerTokenMultiArgs& args, int global_row) { + const int n = args.num_tensors; + if (global_row >= args.split_sections_range[n]) return n - 1; + int tid = 0; + while (args.split_sections_range[tid + 1] <= global_row) ++tid; + return tid; +} + +// Fused K1: TMA-loaded SMEM tile feeds row+col amax; routes atomicMax to the +// per-tensor buffer via tensor_id lookup at CTA entry. +namespace fused { + +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int THREADS_NUM = 128; +constexpr int PREFETCH_STAGES = 1; +constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; +constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; // 2 +constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; // 2 +constexpr int STAGES = TILES_Y * TILES_X; // 4 + +constexpr int BUFF_IN_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_IN_DIM_X = TILE_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +using FusedIType = bf16; +using FusedIType2 = ptx::FPx2; +using FusedIType3D = FusedIType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; + +// Pre-zero amax buffers (identity for atomicMax). +template +__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, + int K) { + const int tensor_id = blockIdx.x; + if (tensor_id >= args.num_tensors) return; + if (DO_ROW) { + float* row_amax = reinterpret_cast(args.row_amax_list[tensor_id]); + if (row_amax != nullptr) { + const int M_i = args.split_sections_range[tensor_id + 1] - + args.split_sections_range[tensor_id]; + for (int m = threadIdx.x; m < M_i; m += blockDim.x) { + row_amax[m] = 0.0f; + } + } + } + if (DO_COL) { + float* col_amax = reinterpret_cast(args.col_amax_list[tensor_id]); + if (col_amax != nullptr) { + for (int k = threadIdx.x; k < K; k += blockDim.x) { + col_amax[k] = 0.0f; + } + } + } +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + group_per_token_fused_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ NVFP4PerTokenMultiArgs args, + const float* noop, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const bool leading_thread = (threadIdx.x == 0); + const int tid = threadIdx.x; + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = + dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + FusedIType* sIn_ptr = reinterpret_cast(dshmem); + auto& sIn = *reinterpret_cast(sIn_ptr); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + // Tile lies fully inside one tensor (split_sections[i] % 128 == 0). + const int tensor_id = GetTensorId(args, block_offset_Y); + const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; + float* row_amax_out = + DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + float* col_amax_out = + DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + + // Each thread owns chunk-row `tid` (for row amax) and chunk-col `tid` (for col amax). + float row_partial = 0.f; + float col_partial = 0.f; + const int my_row_stage_Y = tid / TILE_DIM_Y; + const int my_col_stage_X = tid / TILE_DIM_X; + const int my_row_in_subtile = tid % TILE_DIM_Y; + const int my_col_in_subtile = tid % TILE_DIM_X; + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0. +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[buff_in]), + reinterpret_cast(&tensor_map_input), global_offset_X, + global_offset_Y, &IN_buff_readable_mbar[buff_in]); + } + } + + int buff_in = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + + // Prefetch next stage. + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], + shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[next_prefetch_buff]), + reinterpret_cast(&tensor_map_input), + next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for this stage's tile. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Row partial: rotate e-iter by bank group to split warp into 8 groups. + if (DO_ROW && stage_Y == my_row_stage_Y) { + float local_max = row_partial; + const int row_bank_group = (my_row_in_subtile >> 2) & 0x7; +#pragma unroll + for (int e_iter = 0; e_iter < 8; ++e_iter) { + const int e = ((e_iter + row_bank_group) & 0x7) << 3; + __uint128_t elts_8x = ptx::ld_shared_b128(&sIn[buff_in][my_row_in_subtile][e]); + const FusedIType2* pairs = reinterpret_cast(&elts_8x); + FusedIType2 amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int p = 0; p < 4; ++p) { + ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); + } + local_max = fmaxf(local_max, static_cast( + __hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + } + row_partial = local_max; + } + + // Col partial: 1 thread per column scans down 64 rows of the sub-tile. + if (DO_COL && stage_X == my_col_stage_X) { + float local_max = col_partial; +#pragma unroll + for (int e = 0; e < TILE_DIM_Y; ++e) { + const FusedIType v = sIn[buff_in][e][my_col_in_subtile]; + local_max = fmaxf(local_max, fabsf(static_cast(v))); + } + col_partial = local_max; + } + + __syncthreads(); + buff_in = (buff_in + 1) % BUFFS_NUM; + } + + // CTAs across (ctaid_X) share row_amax slots; across (ctaid_Y) share col_amax slots. + if (DO_ROW) { + atomicMaxFloat(&row_amax_out[local_row_base + tid], row_partial); + } + if (DO_COL) { + atomicMaxFloat(&col_amax_out[block_offset_X + tid], col_partial); + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + (void)tensor_map_input; + (void)args; + (void)noop; + (void)rows; + (void)cols; + NVTE_DEVICE_ERROR("Fused grouped per-token amax kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +// K2 (encode) constants + helpers; byte-equal port of the single-tensor +// per-token cooperative 4x32 / 32x4 threading + ld_shared_b128 + mul_cvt_4x. +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 +constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 + +// Rowwise pass: 4 (K-dim) x 32 (M-dim) -> 1 NVFP4 block per thread. +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 +constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 + +// Colwise pass: tid.X = col-pair, warp = M-block (32 x 4). +constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 + +// Output / SF SMEM buffer dims (sub-tile sized, double-buffered for ping-pong). +constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 +constexpr int BUFFS_NUM_OUT_TR = 2; + +// Manual SMEM swizzling parameters (matches single-tensor encode kernel). +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 + +using IType = FusedIType; +using IType2 = FusedIType2; +using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; + +// Rowwise encode helper: reads sRowAmax (pre-populated by K1), writes FP4 + +// e4m3 SFs into sOut / sSFrowwise. Byte-equal to the single-tensor version. +__device__ __forceinline__ void rowwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_ptr, + nvfp4_scale_t* __restrict__ sSFrowwise_ptr, + const float* __restrict__ sRowAmax, + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out) { + const auto& sIn = *reinterpret_cast(sIn_ptr); + auto& sOut = *reinterpret_cast(sOut_ptr); + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; + + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; + +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = tid_Y_rowwise + it * THREADS_Y_ROWWISE; + const int chunk_local_row = stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + + const float row_amax = sRowAmax[chunk_local_row]; + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(row_amax, 1e-12f)); + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + __uint128_t& elts_8x = *reinterpret_cast<__uint128_t*>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = static_cast( + __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax, S_enc); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + if (SF_storing_thread) { + const int scales_offset_Y = chunk_local_row; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = s_dec; + } + +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + + fp4e2m1x4 qu0{}, qu1{}; + ptx::mul_cvt_4x(qu0, rIn[w][0], rIn[w][1], block_scale); + ptx::mul_cvt_4x(qu1, rIn[w][2], rIn[w][3], block_scale); + + uint32_t out_x8 = (static_cast(*reinterpret_cast(&qu0))) | + (static_cast(*reinterpret_cast(&qu1)) << 16); + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +// Colwise encode helper. Byte-equal to the single-tensor version. +__device__ __forceinline__ void colwise_scaling_per_token( + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_tr_ptr, + nvfp4_scale_t* __restrict__ sSFcolwise_ptr, + const float* __restrict__ sColAmax, + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out_tr) { + const auto& sIn2x = *reinterpret_cast(sIn_ptr); + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 + const int tid_X_colwise = thread_lane; // 0..31 + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; + const int thread_offset_X_colwise = tid_X_colwise * 2; + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; + + __align__(8) IType rIn[2][SCALE_DIM]; + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), + static_cast(__habs(thread_amax_2x.y))}; + +#pragma unroll + for (int w = 0; w < 2; ++w) { + const int chunk_local_col = scale_tr_offset_Y + w; + const float col_amax = sColAmax[chunk_local_col]; + const float S_enc_col = compute_global_encode_scaling_factor_FP4(fmaxf(col_amax, 1e-12f)); + + const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax[w], S_enc_col); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc_col, s_dec_f); + + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = s_dec; + + fp4e2m1x4 qu[4]; +#pragma unroll + for (int e = 0; e < 4; ++e) { + IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; + IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + } + + uint64_t out_pack_16x = (static_cast(*reinterpret_cast(&qu[0])) << 0) | + (static_cast(*reinterpret_cast(&qu[1])) << 16) | + (static_cast(*reinterpret_cast(&qu[2])) << 32) | + (static_cast(*reinterpret_cast(&qu[3])) << 48); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +// Fused K2: TMA-loads input, runs cooperative row+col encode helpers, scatters +// FP4 + SFs to per-tensor outputs via st.global (multi-dest, no TMA store). +template +__global__ void __launch_bounds__(THREADS_NUM) + group_per_token_fused_cast_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ NVFP4PerTokenMultiArgs args, + const float* noop, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + (void)rows; + + const bool leading_thread = (threadIdx.x == 0); + + // Dynamic SMEM layout (~28 KiB): sIn (16K) + sOut (4K) + sOut_tr (4K) + + // sSF_row (1K) + sSF_col (1K) + sRowAmax/sColAmax (512B each). + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int out_mem_rowwise_data = DO_ROW ? buff_size_aligned_out : 0; + constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = + DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) : 0; + constexpr int out_mem_colwise_scales = + DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) : 0; + (void)out_mem_colwise_scales; + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char* dshmem = + dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data); + nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data + + out_mem_rowwise_scales); + + __shared__ float sRowAmax[CHUNK_DIM_Y]; + __shared__ float sColAmax[CHUNK_DIM_X]; + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + auto& sIn = *reinterpret_cast(sIn_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const int32_t ctaid_X = blockIdx.x; + const int32_t ctaid_Y = blockIdx.y; + const int block_offset_Y = ctaid_Y * CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * CHUNK_DIM_X; + + // Chunk Y stays inside one tensor (split_sections[i] % 128 == 0). + const int tensor_id = GetTensorId(args, block_offset_Y); + const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; + const int M_t = args.split_sections_range[tensor_id + 1] - + args.split_sections_range[tensor_id]; + + // Per-tensor output bases (one constant-cache lookup per CTA). + uint8_t* const q_row_base = DO_ROW + ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; + uint8_t* const q_col_base = DO_COL + ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_row_base = DO_ROW + ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_col_base = DO_COL + ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; + const float* const row_amax_base = DO_ROW + ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + const float* const col_amax_base = DO_COL + ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + + const size_t data_stride_row = static_cast(cols) / 2; + const size_t data_stride_col = static_cast(M_t) / 2; + const size_t scale_stride_row = static_cast(cols) / SCALE_DIM; + const size_t scale_stride_col = static_cast(M_t) / SCALE_DIM; + + // Load per-row / per-col amax into SMEM cache. + if (DO_ROW && threadIdx.x < CHUNK_DIM_Y) { + sRowAmax[threadIdx.x] = row_amax_base[local_row_base + threadIdx.x]; + } + if (DO_COL && threadIdx.x < CHUNK_DIM_X) { + sColAmax[threadIdx.x] = col_amax_base[block_offset_X + threadIdx.x]; + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + // Prefetch stage 0. +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const int buff_in_p = stage; + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + const int global_offset_Y = block_offset_Y + stage_Y * TILE_DIM_Y; + const int global_offset_X = block_offset_X + stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in_p], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[buff_in_p]), + reinterpret_cast(&tensor_map_input), global_offset_X, + global_offset_Y, &IN_buff_readable_mbar[buff_in_p]); + } + } + + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / TILES_X; + const int stage_X = stage % TILES_X; + + if (stage < STAGES - PREFETCH_STAGES) { + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const int next_stage_Y = next_prefetch_stage / TILES_X; + const int next_stage_X = next_prefetch_stage % TILES_X; + const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; + const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], + shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&sIn[next_prefetch_buff]), + reinterpret_cast(&tensor_map_input), + next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for current stage's input tile to land. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // 4x32 cooperative row + col encode helpers. + if (DO_ROW) { + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, + sRowAmax, stage_Y, stage_X, buff_in, buff_out); + } + if (DO_COL) { + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, + sColAmax, stage_Y, stage_X, buff_in, buff_out_tr); + } + + // Make helper SMEM writes visible before the scatter epilogue. + __syncthreads(); + + // Scatter sOut / sOut_tr to per-tensor buffers via cooperative b128 stores; + // 2 threads per row/col x 16 B = 2048 B per sub-tile per direction. + if (DO_ROW) { + auto& sOut = *reinterpret_cast(sOut_ptr); + const int row_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; + const int byte_off_X = (block_offset_X / 2) + + stage_X * (TILE_DIM_X / 2) + + half * 16; + const uint4* src = reinterpret_cast( + &sOut[buff_out][row_in_subtile][half * 16]); + uint4* dst = reinterpret_cast( + q_row_base + static_cast(local_row) * data_stride_row + byte_off_X); + *dst = *src; + } + if (DO_COL) { + auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); + const int col_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; + const int byte_off_M = (local_row_base / 2) + + stage_Y * (TILE_DIM_Y / 2) + + half * 16; + const uint4* src = reinterpret_cast( + &sOut_tr[buff_out_tr][col_in_subtile][half * 16]); + uint4* dst = reinterpret_cast( + q_col_base + static_cast(global_col) * data_stride_col + byte_off_M); + *dst = *src; + } + + // Sync so the scatter completes before next stage overwrites the buffer. + __syncthreads(); + + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } + + // SF epilogue: cooperative store of sSFrowwise / sSFcolwise to global. + if (DO_ROW) { + auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + using ScalesVec = Vec; + const size_t scales_block_offset_X_rowwise = + static_cast(ctaid_X) * SCALES_PER_CHUNK_X; + for (int row = static_cast(threadIdx.x); row < CHUNK_DIM_Y; row += THREADS_NUM) { + ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t local_row = static_cast(local_row_base) + row; + const size_t scale_idx_global = + local_row * scale_stride_row + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&s_dec_row_base[scale_idx_global], 0, SCALES_PER_CHUNK_X); + } + } + if (DO_COL) { + auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + using ScalesVec = Vec; + // M-block offset within s_dec_col[global_col] (shape (K, M_i/16) row-major). + const size_t local_block_offset_M = static_cast(local_row_base) / SCALE_DIM; + for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; + row_tr += THREADS_NUM) { + ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t global_col = static_cast(block_offset_X) + row_tr; + const size_t scale_idx_global = + global_col * scale_stride_col + local_block_offset_M; + scales_vec.store_to_elts(&s_dec_col_base[scale_idx_global], 0, SCALES_PER_CHUNK_Y); + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + (void)tensor_map_input; + (void)args; + (void)noop; + (void)rows; + (void)cols; + NVTE_DEVICE_ERROR("Fused grouped per-token cast kernel requires SM 10.0+ (Blackwell)."); +#endif // __CUDA_ARCH__ >= 1000 +} + +// Host launcher for the fused K2 path. bf16-only. +inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, + const SimpleTensor& input_data, int sum_M, + int K, bool do_row, bool do_col, + const float* noop, cudaStream_t stream) { + if (!do_row && !do_col) return; + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + create_2D_tensor_map(tmap_in, input_data, sum_M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, + sizeof(FusedIType) * 8); + + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + constexpr int sz_in = DIVUP_TO_MULTIPLE( + BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); + constexpr int sz_out_r = DO_ROW + ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; + constexpr int sz_out_c = DO_COL + ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_r = DO_ROW + ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_c = DO_COL + ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int dshmem_size = sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c + + TMA_SHMEM_ALIGNMENT; + auto kernel = group_per_token_fused_cast_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>(tmap_in, args, noop, + static_cast(sum_M), + static_cast(K)); + });); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Host launcher for the fused K1 path. bf16-only. +inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, + const SimpleTensor& input_data, int sum_M, + int K, bool do_row, bool do_col, + const float* noop, cudaStream_t stream) { + if (!do_row && !do_col) return; + + // Pre-zero amax slots (atomicMax identity). + { + dim3 grid_zero(static_cast(args.num_tensors)); + dim3 block_zero(256); + if (do_row && do_col) { + group_per_token_fused_zero_amax_kernel + <<>>(args, K); + } else if (do_row) { + group_per_token_fused_zero_amax_kernel + <<>>(args, K); + } else { + group_per_token_fused_zero_amax_kernel + <<>>(args, K); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + checkCuDriverContext(stream); + + alignas(64) CUtensorMap tmap_in{}; + create_2D_tensor_map(tmap_in, input_data, sum_M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, + sizeof(FusedIType) * 8); + + constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); + constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; + + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 block(THREADS_NUM, 1, 1); + + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + auto kernel = group_per_token_fused_amax_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>(tmap_in, args, noop, + static_cast(sum_M), + static_cast(K)); + });); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace fused + +// Populate per-tensor pointer tables + split_sections prefix-sum. +// which_buffers bitmask: kBufRowAmax | kBufColAmax | kBufRowCast | kBufColCast. +enum BufferFlags : int { + kBufRowAmax = 0x1, + kBufColAmax = 0x2, + kBufRowCast = 0x4, + kBufColCast = 0x8, +}; + +void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, + const size_t* split_sections, size_t num_tensors, int which_buffers, + int expected_sum_M, int K) { + std::memset(args, 0, sizeof(*args)); + args->num_tensors = static_cast(num_tensors); + args->split_sections_range[0] = 0; + for (size_t i = 0; i < num_tensors; ++i) { + Tensor* o = outputs[i]; + NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, + "] = ", split_sections[i], " must be a multiple of 128"); + args->split_sections_range[i + 1] = + args->split_sections_range[i] + static_cast(split_sections[i]); + if (split_sections[i] == 0) continue; + if (which_buffers & kBufRowAmax) { + NVTE_CHECK(o->amax.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, "].amax must be allocated for rowwise"); + args->row_amax_list[i] = o->amax.dptr; + } + if (which_buffers & kBufColAmax) { + NVTE_CHECK(o->columnwise_amax.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].columnwise_amax must be allocated for columnwise"); + args->col_amax_list[i] = o->columnwise_amax.dptr; + } + if (which_buffers & kBufRowCast) { + NVTE_CHECK(o->data.dptr != nullptr && o->scale_inv.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].data + .scale_inv must be allocated for rowwise cast"); + args->q_row_list[i] = o->data.dptr; + args->s_dec_row_list[i] = o->scale_inv.dptr; + } + if (which_buffers & kBufColCast) { + NVTE_CHECK( + o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); + args->q_col_list[i] = o->columnwise_data.dptr; + args->s_dec_col_list[i] = o->columnwise_scale_inv.dptr; + } + } + NVTE_CHECK(args->split_sections_range[num_tensors] == expected_sum_M, + "NVFP4 per-token grouped: sum(split_sections) = ", + args->split_sections_range[num_tensors], " must equal input rows ", expected_sum_M); + (void)K; +} + +// Host entry. do_amax / do_cast select K1 / K2 phases (composite = both). +void quantize_per_token_grouped(const Tensor& input, std::vector& outputs, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, bool do_amax, bool do_cast, cudaStream_t stream) { + NVTE_CHECK(num_tensors > 0, "NVFP4 per-token grouped: num_tensors must be > 0"); + NVTE_CHECK(num_tensors <= static_cast(kMaxTensorsPerKernel), + "NVFP4 per-token grouped: num_tensors (", num_tensors, + ") exceeds kMaxTensorsPerKernel = ", kMaxTensorsPerKernel); + NVTE_CHECK(rowwise || columnwise, + "NVFP4 per-token grouped: at least one of rowwise/columnwise must be true"); + NVTE_CHECK(input.has_data(), "NVFP4 per-token grouped: input has no data"); + NVTE_CHECK(input.dtype() == DType::kBFloat16, + "NVFP4 per-token grouped: input dtype must be bf16 (got ", + static_cast(input.dtype()), ")"); + + const int sum_M = static_cast(input.flat_first_dim()); + const int K = static_cast(input.flat_last_dim()); + if (sum_M == 0 || K == 0) return; + NVTE_CHECK(K % 128 == 0, + "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); + + int which_buffers = 0; + if ((do_amax || do_cast) && rowwise) which_buffers |= kBufRowAmax; + if ((do_amax || do_cast) && columnwise) which_buffers |= kBufColAmax; + if (do_cast && rowwise) which_buffers |= kBufRowCast; + if (do_cast && columnwise) which_buffers |= kBufColCast; + + NVFP4PerTokenMultiArgs args; + populate_args(&args, outputs, split_sections, num_tensors, which_buffers, sum_M, K); + + // K1 + K2 = 2 fused launches; K1 must complete before K2 reads its amax. + if (do_amax) { + fused::launch_grouped_fused_amax_bf16(args, input.data, sum_M, K, + /*do_row=*/rowwise, + /*do_col=*/columnwise, + /*noop=*/nullptr, stream); + } + if (do_cast) { + fused::launch_grouped_fused_cast_bf16(args, input.data, sum_M, K, + /*do_row=*/rowwise, + /*do_col=*/columnwise, + /*noop=*/nullptr, stream); + } +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace nvfp4_per_token_group +} // namespace transformer_engine + +// C-API entries. +namespace { + +std::vector collect_outputs(NVTETensor* outputs, size_t num_tensors) { + std::vector v; + v.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + v.push_back(transformer_engine::convertNVTETensorCheck(outputs[i])); + } + return v; +} + +} // namespace + +void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_group_nvfp4_per_token_amax); + using namespace transformer_engine; + if (num_tensors == 0) return; + const Tensor* in = convertNVTETensorCheck(input); + std::vector outs = collect_outputs(outputs, num_tensors); + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, + rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/false, stream); +#else + (void)input; + (void)outputs; + (void)split_sections; + (void)num_tensors; + (void)rowwise; + (void)columnwise; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_group_nvfp4_per_token_cast); + using namespace transformer_engine; + if (num_tensors == 0) return; + const Tensor* in = convertNVTETensorCheck(input); + std::vector outs = collect_outputs(outputs, num_tensors); + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, + rowwise, columnwise, + /*do_amax=*/false, /*do_cast=*/true, stream); +#else + (void)input; + (void)outputs; + (void)split_sections; + (void)num_tensors; + (void)rowwise; + (void)columnwise; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + NVTE_API_CALL(nvte_group_nvfp4_per_token_quantize); + using namespace transformer_engine; + if (num_tensors == 0) return; + const Tensor* in = convertNVTETensorCheck(input); + std::vector outs = collect_outputs(outputs, num_tensors); + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, + rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/true, stream); +#else + (void)input; + (void)outputs; + (void)split_sections; + (void)num_tensors; + (void)rowwise; + (void)columnwise; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} diff --git a/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu b/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu new file mode 100644 index 0000000000..849b76a6eb --- /dev/null +++ b/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu @@ -0,0 +1,141 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_per_token_post_scale.cu + * \brief NVFP4 per-token GEMM-output post-scale: d[i,j] *= r_A[i] * r_B[j]. + * + * Standalone bf16 epilogue applied after cuBLAS LT NVFP4 GEMM with the + * operand amaxes pinned to 1.0. See nvfp4_per_token.h for the math chain. + */ + +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/ptx.cuh" + +namespace transformer_engine { +namespace nvfp4_per_token { + +namespace { + +// Each block tiles 16 rows x 256 cols of the output: amaxes are loaded +// once into SMEM, then each thread handles 8 cols via a 16-byte int4 LD/ST +// for peak HBM coalescing on SM100. Wrapper enforces M, N % 128 alignment. +constexpr int kTileCols = 256; +constexpr int kTileRows = 16; +constexpr int kElemsPerThread = 8; // bf16x8 = 16-byte vector +constexpr int kThreadsX = kTileCols / kElemsPerThread; +constexpr int kThreadsY = kTileRows; +constexpr int kThreadsPerBlock = kThreadsX * kThreadsY; +static_assert(kTileCols % kElemsPerThread == 0, + "kTileCols must be a multiple of kElemsPerThread"); +static_assert(kElemsPerThread * sizeof(__nv_bfloat16) == sizeof(int4), + "kElemsPerThread bf16 must pack into a single int4 (16 bytes)"); + +__global__ void __launch_bounds__(kThreadsPerBlock) + per_token_post_scale_kernel(__nv_bfloat16* __restrict__ d, const float* __restrict__ row_amax_a, + const float* __restrict__ row_amax_b, const int M, const int N) { + __shared__ float s_row_amax[kTileRows]; + __shared__ float s_col_amax[kTileCols]; + + const int row_tile = blockIdx.y * kTileRows; + const int col_tile = blockIdx.x * kTileCols; + + // Cooperatively load row + col amaxes into SMEM (272 floats / 512 threads). + const int tid = threadIdx.y * kThreadsX + threadIdx.x; + if (tid < kTileRows) { + const int gi = row_tile + tid; + s_row_amax[tid] = (gi < M) ? row_amax_a[gi] : 0.0f; + } + if (tid < kTileCols) { + const int gj = col_tile + tid; + s_col_amax[tid] = (gj < N) ? row_amax_b[gj] : 0.0f; + } + __syncthreads(); + + const int i = row_tile + threadIdx.y; + const int j0 = col_tile + threadIdx.x * kElemsPerThread; + if (i >= M || j0 >= N) return; + + const float a = s_row_amax[threadIdx.y]; + const size_t base = static_cast(i) * N + j0; + + // Fast path = 16-byte aligned LD/ST; slow path = boundary tile fallback. + if (j0 + kElemsPerThread <= N) { + // __align__(16) is required for the int4 reinterpret_cast to be defined. + __nv_bfloat16 __align__(16) chunk[kElemsPerThread]; + *reinterpret_cast(chunk) = *reinterpret_cast(&d[base]); +#pragma unroll + for (int e = 0; e < kElemsPerThread; ++e) { + const float b = s_col_amax[threadIdx.x * kElemsPerThread + e]; + const float current = static_cast(chunk[e]); + chunk[e] = static_cast<__nv_bfloat16>(current * a * b); + } + *reinterpret_cast(&d[base]) = *reinterpret_cast(chunk); + } else { +#pragma unroll + for (int e = 0; e < kElemsPerThread; ++e) { + const int j = j0 + e; + if (j >= N) break; + const float b = s_col_amax[threadIdx.x * kElemsPerThread + e]; + const size_t idx = base + e; + const float current = static_cast(d[idx]); + d[idx] = static_cast<__nv_bfloat16>(current * a * b); + } + } +} + +} // namespace + +void per_token_post_scale(Tensor* d, const Tensor& row_amax_a, const Tensor& row_amax_b, + cudaStream_t stream) { + NVTE_CHECK(d->has_data(), "NVFP4 per-token post-scale: d has no data."); + NVTE_CHECK(d->data.dtype == DType::kBFloat16, + "NVFP4 per-token post-scale: d must be BF16 (got non-BF16 dtype)."); + NVTE_CHECK(row_amax_a.data.dtype == DType::kFloat32, + "NVFP4 per-token post-scale: row_amax_a must be FP32."); + NVTE_CHECK(row_amax_b.data.dtype == DType::kFloat32, + "NVFP4 per-token post-scale: row_amax_b must be FP32."); + + const auto& d_shape = d->data.shape; + NVTE_CHECK(d_shape.size() == 2, "NVFP4 per-token post-scale: d must be 2D, got rank=", + d_shape.size()); + const int M = static_cast(d_shape[0]); + const int N = static_cast(d_shape[1]); + NVTE_CHECK(row_amax_a.data.numel() == static_cast(M), + "NVFP4 per-token post-scale: row_amax_a numel must equal M=", M, ", got ", + row_amax_a.data.numel()); + NVTE_CHECK(row_amax_b.data.numel() == static_cast(N), + "NVFP4 per-token post-scale: row_amax_b numel must equal N=", N, ", got ", + row_amax_b.data.numel()); + + if (M == 0 || N == 0) { + return; + } + + // 32 x 16 threads = 512/block; covers 256 cols x 16 rows = 4096 elems/block. + dim3 block(kThreadsX, kThreadsY, 1); + dim3 grid((N + kTileCols - 1) / kTileCols, (M + kTileRows - 1) / kTileRows, 1); + per_token_post_scale_kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(d->data.dptr), + reinterpret_cast(row_amax_a.data.dptr), + reinterpret_cast(row_amax_b.data.dptr), M, N); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace nvfp4_per_token +} // namespace transformer_engine + +void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, + const NVTETensor row_amax_b, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_per_token_post_scale); + using namespace transformer_engine; + + transformer_engine::nvfp4_per_token::per_token_post_scale( + convertNVTETensorCheck(d), *convertNVTETensorCheck(row_amax_a), + *convertNVTETensorCheck(row_amax_b), stream); +} diff --git a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h new file mode 100644 index 0000000000..8743a40afa --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h @@ -0,0 +1,124 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_ +#define TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_ + +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + + +/*! \brief Composite K1+K2: per-row + per-col amax (K1) then FP4 + 1x16 + * e4m3 SF encode (K2), back-to-back on the same stream. + * + * This is the production entry point for the per-token cast on bf16 + + * 128-aligned shapes. + */ +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, + NVTETensor output, cudaStream_t stream); + +/*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. + * Pre-zeroes the amax buffers and merges per-CTA partials into + * ``output->amax`` (size [M]) / ``output->columnwise_amax`` + * (size [K]). Does NOT touch FP4 data / scale_inv slots. + */ +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, + NVTETensor output, cudaStream_t stream); + +/*! \brief Kernel 2 in isolation: FP4 + 1x16 e4m3 SF encode given a + * pre-filled ``output->amax`` / ``output->columnwise_amax``. Reads + * the outer amax buffer(s) and writes the FP4 data / scale_inv + * tensors only. + */ +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, + NVTETensor output, cudaStream_t stream); + +/*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``. + * + * Currently returns 1 iff ``dtype`` is bf16 AND ``M % 128 == 0`` AND + * ``K % 128 == 0``. Cheap host-side query (no CUDA call). + * + * \param[in] M first-dim (rows). + * \param[in] K last-dim (cols). + * \param[in] input_dtype_enum NVTE_DType cast to int. + */ +int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum); + +/*! \brief Apply per-row * per-col outer-scale to a (M, N) bf16 GEMM output. + * + * Computes: + * + * d[i, j] = d[i, j] * row_amax_a[i] * row_amax_b[j] + */ +void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, + const NVTETensor row_amax_b, + cudaStream_t stream); + +/* ============================================================================ + * Grouped (multi-tensor) per-token quantize. + * + * \param[in] input (sum_M, K) bf16/fp32, row-major contiguous + * \param[in,out] outputs array of `num_tensors` NVTETensors; on + * return, amax/columnwise_amax slots are filled. + * \param[in] split_sections array of `num_tensors` size_t values, + * each a multiple of 64; sum must equal sum_M. + * \param[in] num_tensors <= 64 + * \param[in] rowwise emit per-row amax in `outputs[i].amax` + * \param[in] columnwise emit per-col amax in `outputs[i].columnwise_amax` + * \param[in] stream CUDA stream + */ +void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, cudaStream_t stream); + +/*! \brief Grouped per-token encode (FP4 + 1x16 e4m3 inner SF) using the + * row_amax / col_amax values already populated by + * `nvte_group_nvfp4_per_token_amax`. + * + * \param[in] input same as `nvte_group_nvfp4_per_token_amax` + * \param[in,out] outputs on entry: amax/columnwise_amax populated; + * on return: data/scale_inv + columnwise_data/ + * columnwise_scale_inv populated. + * \param[in] split_sections same as `nvte_group_nvfp4_per_token_amax` + * \param[in] num_tensors <= 64 + * \param[in] rowwise emit per-row FP4 + inner SF + * \param[in] columnwise emit per-col FP4 + inner SF + * \param[in] stream CUDA stream + */ +void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, cudaStream_t stream); + +/*! \brief Composite K1+K2 grouped per-token quantize. Calls the amax + cast + * kernels on the same stream. This is the external API + * `tex.split_quantize(per_token=True)` should call. + * + * \param[in] input (sum_M, K) bf16/fp32, row-major contiguous + * \param[in,out] outputs on entry: amax / columnwise_amax / data / + * scale_inv / columnwise_data / + * columnwise_scale_inv slots allocated; + * on return: all populated. + * \param[in] split_sections array of `num_tensors` size_t values, + * each a multiple of 64; sum must equal sum_M. + * \param[in] num_tensors <= 64 + * \param[in] rowwise emit rowwise output + * \param[in] columnwise emit columnwise output + * \param[in] stream CUDA stream + */ +void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..6411811323 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -411,6 +411,10 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void compute_amax(const at::Tensor &tensor, at::Tensor &amax); +void hadamard_transform_amax(const at::Tensor &tensor, at::Tensor &rowwise_amax, + at::Tensor &columnwise_amax, + int64_t rht_matrix_random_sign_mask); + void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, @@ -448,6 +452,60 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi const at::Tensor &scale_inv_colwise, int rows, int cols, size_t start_offset); +void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, + at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise); + +void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, + at::Tensor col_amax, bool rowwise, bool columnwise); + +void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, + at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise); + +void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, + const at::Tensor &row_amax_b); + +void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta); + +// Bench-only per-tensor twin of nvfp4_per_token_gemm: scalar amaxes folded +// into cuBLAS LT alpha via the amax slot; no trailing post-scale. +void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_amax, const at::Tensor &b_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta); + +void nvfp4_per_token_group_quantize( + const at::Tensor &input, const std::vector &split_sections, + std::vector q_row_list, std::vector s_dec_row_list, + std::vector row_amax_list, std::vector q_col_list, + std::vector s_dec_col_list, std::vector col_amax_list, + bool rowwise, bool columnwise); + +// Amax-only variant of the grouped quantize. Useful for multi-rank training +// where amax is allReduced before the cast pass. +void nvfp4_per_token_group_amax(const at::Tensor &input, + const std::vector &split_sections, + std::vector row_amax_list, + std::vector col_amax_list, bool rowwise, + bool columnwise); + +// Bulk grouped quantize: allocate-view-dispatch all in one pybind hop. +// Returns 6 per-split vectors (q_row, s_dec_row_fp8, row_amax, q_col, +// s_dec_col_fp8, col_amax); disabled directions return empty vectors. +std::tuple, std::vector, std::vector, + std::vector, std::vector, std::vector> +nvfp4_per_token_group_quantize_bulk(const at::Tensor &input, + const std::vector &split_sections, + bool rowwise, bool columnwise); + /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp new file mode 100644 index 0000000000..cb6b56d9ee --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp @@ -0,0 +1,793 @@ +/************************************************************************* + * Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +// NVFP4 per-token cast bindings. Shared TensorWrapper assembler dispatches +// composite (K1+K2), K1-only and K2-only via `mode`. bf16-only, M/K % 128 == 0. +// SFs emit in compact (non-swizzled) layout; swizzle for cuBLAS LT lives elsewhere. +namespace { + +// Validates the input and assembles ``out_te`` for all 3 modes; caller +// dispatches to the right C-API entry on the caller's stream. +void assemble_per_token_tensors(const at::Tensor& input, + at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, + bool rowwise, bool columnwise, int mode, + TensorWrapper& in_te, TensorWrapper& out_te) { + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "Per-token cast is bf16-only. Got dtype ", input.scalar_type()); + const int64_t M = input.size(0); + const int64_t K = input.size(1); + TORCH_CHECK(M % 128 == 0, "Per-token cast requires M % 128 == 0; got M=", M); + TORCH_CHECK(K % 128 == 0, "Per-token cast requires K % 128 == 0; got K=", K); + + const std::vector in_shape = {static_cast(M), static_cast(K)}; + in_te = makeTransformerEngineTensor(input.data_ptr(), in_shape, DType::kBFloat16); + + // K1 (mode==1) populates ONLY amax slots; K2 / composite (mode==0/2) + // populate the FP4 + e4m3 SF slots too. The amax slots are also wired + // for K2 because the kernel READS them. + const bool needs_fp4_outputs = (mode == 0) || (mode == 2); + + if (rowwise) { + TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), + "row_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax.scalar_type() == at::ScalarType::Float, "row_amax must be float32"); + TORCH_CHECK(row_amax.numel() == M, + "row_amax numel mismatch: expected M=", M, ", got ", row_amax.numel()); + out_te.set_amax(row_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(M)}); + + if (needs_fp4_outputs) { + TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), + "q_row must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), + "s_dec_row must be a contiguous CUDA tensor"); + TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, + "q_row must be uint8 (FP4 packed)"); + TORCH_CHECK(s_dec_row.scalar_type() == at::ScalarType::Byte, + "s_dec_row must be uint8 (FP8 e4m3 raw bytes)"); + TORCH_CHECK(q_row.numel() == M * K / 2, + "q_row numel mismatch: expected M*K/2=", M * K / 2, ", got ", q_row.numel()); + TORCH_CHECK(s_dec_row.numel() == M * K / 16, + "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, + ", got ", s_dec_row.numel()); + out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, in_shape); + out_te.set_rowwise_scale_inv( + s_dec_row.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(M), static_cast(K / 16)}); + } + } + if (columnwise) { + TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), + "col_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(col_amax.scalar_type() == at::ScalarType::Float, "col_amax must be float32"); + TORCH_CHECK(col_amax.numel() == K, + "col_amax numel mismatch: expected K=", K, ", got ", col_amax.numel()); + out_te.set_columnwise_amax(col_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(K)}); + + if (needs_fp4_outputs) { + TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), + "q_col must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), + "s_dec_col must be a contiguous CUDA tensor"); + TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, + "q_col must be uint8 (FP4 packed)"); + TORCH_CHECK(s_dec_col.scalar_type() == at::ScalarType::Byte, + "s_dec_col must be uint8 (FP8 e4m3 raw bytes)"); + TORCH_CHECK(q_col.numel() == K * M / 2, + "q_col numel mismatch: expected K*M/2=", K * M / 2, ", got ", q_col.numel()); + TORCH_CHECK(s_dec_col.numel() == K * M / 16, + "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, + ", got ", s_dec_col.numel()); + out_te.set_columnwise_data( + q_col.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(K), static_cast(M)}); + out_te.set_columnwise_scale_inv( + s_dec_col.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(K), static_cast(M / 16)}); + } + } +} + +} // namespace + +// Production composite (K1 + K2 back-to-back). +void nvfp4_per_token_quantize(const at::Tensor& input, + at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, + bool rowwise, bool columnwise) { + TensorWrapper in_te; + TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, + q_col, s_dec_col, col_amax, + rowwise, columnwise, /*mode=*/0, in_te, out_te); + const auto stream = at::cuda::getCurrentCUDAStream(); + nvte_nvfp4_per_token_quantize(in_te.data(), nullptr, out_te.data(), stream); +} + +// K1-only (diagnostic / bench): populates only amax buffers. +void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, + at::Tensor col_amax, bool rowwise, bool columnwise) { + at::Tensor empty_u8; // not consumed by K1 + TensorWrapper in_te; + TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); + assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, + empty_u8, empty_u8, col_amax, + rowwise, columnwise, /*mode=*/1, in_te, out_te); + const auto stream = at::cuda::getCurrentCUDAStream(); + nvte_nvfp4_per_token_amax(in_te.data(), nullptr, out_te.data(), stream); +} + +// K2-only (diagnostic / bench): reads pre-filled amax buffers, emits FP4 + SFs. +void nvfp4_per_token_encode(const at::Tensor& input, + at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, + bool rowwise, bool columnwise) { + TensorWrapper in_te; + TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, + q_col, s_dec_col, col_amax, + rowwise, columnwise, /*mode=*/2, in_te, out_te); + const auto stream = at::cuda::getCurrentCUDAStream(); + nvte_nvfp4_per_token_encode(in_te.data(), nullptr, out_te.data(), stream); +} + +// Apply per-token post-scale to a GEMM output (see nvfp4_per_token.h for math). +void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, + const at::Tensor &row_amax_b) { + TORCH_CHECK(d.is_cuda() && d.is_contiguous(), "d must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax_a.is_cuda() && row_amax_a.is_contiguous(), + "row_amax_a must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax_b.is_cuda() && row_amax_b.is_contiguous(), + "row_amax_b must be a contiguous CUDA tensor"); + TORCH_CHECK(d.dim() == 2, "d must be 2D"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bf16"); + TORCH_CHECK(row_amax_a.scalar_type() == at::ScalarType::Float, "row_amax_a must be fp32"); + TORCH_CHECK(row_amax_b.scalar_type() == at::ScalarType::Float, "row_amax_b must be fp32"); + + const int64_t M = d.size(0); + const int64_t N = d.size(1); + TORCH_CHECK(row_amax_a.numel() == M, + "row_amax_a numel mismatch: expected M=", M, ", got ", row_amax_a.numel()); + TORCH_CHECK(row_amax_b.numel() == N, + "row_amax_b numel mismatch: expected N=", N, ", got ", row_amax_b.numel()); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), + std::vector{static_cast(M), static_cast(N)}, DType::kBFloat16); + TensorWrapper ra_te = makeTransformerEngineTensor( + row_amax_a.data_ptr(), std::vector{static_cast(M)}, DType::kFloat32); + TensorWrapper rb_te = makeTransformerEngineTensor( + row_amax_b.data_ptr(), std::vector{static_cast(N)}, DType::kFloat32); + + nvte_nvfp4_per_token_post_scale(d_te.data(), ra_te.data(), rb_te.data(), stream); +} + +// End-to-end NVFP4 per-token GEMM: swizzle compact SFs -> cuBLAS LT NVFP4 +// GEMM (operand amax pinned to 1.0 to cancel the 2688^2 inner-SF factor) -> +// per-row post-scale. beta must be 0.0. Math in nvfp4_per_token.h. +void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta) { + TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && + a_row_amax.is_cuda() && b_row_amax.is_cuda() && d.is_cuda() && + workspace.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(a_data.is_contiguous() && b_data.is_contiguous() && a_sf.is_contiguous() && + b_sf.is_contiguous() && a_row_amax.is_contiguous() && + b_row_amax.is_contiguous() && d.is_contiguous() && workspace.is_contiguous(), + "All tensors must be contiguous"); + + TORCH_CHECK(a_data.scalar_type() == at::ScalarType::Byte, "a_data must be uint8 (FP4 packed)"); + TORCH_CHECK(b_data.scalar_type() == at::ScalarType::Byte, "b_data must be uint8 (FP4 packed)"); + TORCH_CHECK(a_sf.scalar_type() == at::ScalarType::Byte, "a_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(b_sf.scalar_type() == at::ScalarType::Byte, "b_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(a_row_amax.scalar_type() == at::ScalarType::Float, "a_row_amax must be float32"); + TORCH_CHECK(b_row_amax.scalar_type() == at::ScalarType::Float, "b_row_amax must be float32"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); + TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); + + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, + "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", + a_data.size(0), ", ", a_data.size(1), ")"); + TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", + b_data.size(0), ", ", b_data.size(1), ")"); + TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, + "), got (", d.size(0), ", ", d.size(1), ")"); + + TORCH_CHECK(k % 16 == 0, "k must be a multiple of 16 (NVFP4 inner SFVecSize)"); + TORCH_CHECK(a_sf.numel() == static_cast(m * k / 16), + "a_sf numel mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); + TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), + "b_sf numel mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); + TORCH_CHECK(a_row_amax.numel() == m, + "a_row_amax numel mismatch: expected M=", m, ", got ", a_row_amax.numel()); + TORCH_CHECK(b_row_amax.numel() == n, + "b_row_amax numel mismatch: expected N=", n, ", got ", b_row_amax.numel()); + + TORCH_CHECK(static_cast(beta) == 0.0f, + "nvfp4_per_token_gemm: beta != 0 not yet supported. Got beta=", beta); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + const std::vector a_data_shape = {static_cast(m), static_cast(k)}; + const std::vector b_data_shape = {static_cast(n), static_cast(k)}; + const std::vector a_sf_shape = {static_cast(m), static_cast(k / 16)}; + const std::vector b_sf_shape = {static_cast(n), static_cast(k / 16)}; + + // Swizzled SF buffers (cuBLAS LT requires swizzled layout). + auto byte_opts = a_sf.options().dtype(at::kByte); + at::Tensor a_sf_swizzled = at::empty({a_sf.numel()}, byte_opts); + at::Tensor b_sf_swizzled = at::empty({b_sf.numel()}, byte_opts); + + { + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + in_nvte.set_rowwise_scale_inv(a_sf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + out_nvte.set_rowwise_scale_inv(a_sf_swizzled.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + { + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + in_nvte.set_rowwise_scale_inv(b_sf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + out_nvte.set_rowwise_scale_inv(b_sf_swizzled.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + + // Pin operand amaxes to 1.0 so cuBLAS-internal alpha cancels the 2688^2 + // inner-SF factor. Cache one fp32 "1.0" tensor per device to avoid the + // ~30-50us per-call cost of at::ones({1}) at small shapes. + static std::array s_amax_one_cache; + static std::array s_amax_one_init; + const int dev_idx = a_data.device().index(); + TORCH_CHECK(dev_idx >= 0 && dev_idx < static_cast(s_amax_one_cache.size()), + "nvfp4_per_token_gemm: unexpected device index ", dev_idx); + std::call_once(s_amax_one_init[dev_idx], [&]() { + auto fp32_opts = a_data.options().dtype(at::kFloat); + s_amax_one_cache[dev_idx] = at::ones({1}, fp32_opts); + }); + at::Tensor& amax_one = s_amax_one_cache[dev_idx]; + + // Assemble A's NVTE tensor: NVFP4_1D_SCALING + swizzled SF + amax=1.0. + TensorWrapper a_te(NVTE_NVFP4_1D_SCALING); + a_te.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + a_te.set_rowwise_scale_inv(a_sf_swizzled.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + a_te.set_amax(amax_one.data_ptr(), DType::kFloat32, std::vector{1}); + a_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper b_te(NVTE_NVFP4_1D_SCALING); + b_te.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + b_te.set_rowwise_scale_inv(b_sf_swizzled.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + b_te.set_amax(amax_one.data_ptr(), DType::kFloat32, std::vector{1}); + b_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), + std::vector{static_cast(m), static_cast(n)}, + DType::kBFloat16); + + TensorWrapper workspace_te = makeTransformerEngineTensor( + workspace.data_ptr(), std::vector{static_cast(workspace.numel())}, + DType::kByte); + + // Operands SWAPPED so cuBLAS column-major D = op(B) @ op(A) matches the + // row-major (M, N) PyTorch expects. transa=T forced (NVFP4 is TN-only). + // C and D alias (no separate accumulator). + const float alpha_f = static_cast(alpha); + const float beta_f = static_cast(beta); + nvte_cublas_gemm_v2(/*transa=*/1, /*transb=*/0, &alpha_f, + b_te.data(), // cuBLAS-A := caller's B (N, K) + a_te.data(), // cuBLAS-B := caller's A (M, K) + &beta_f, d_te.data(), d_te.data(), workspace_te.data(), + /*config=*/nullptr, stream); + + // Per-row * per-col post-scale to recover C_true from D_cublas. + TensorWrapper ra_te = makeTransformerEngineTensor( + a_row_amax.data_ptr(), std::vector{static_cast(m)}, DType::kFloat32); + TensorWrapper rb_te = makeTransformerEngineTensor( + b_row_amax.data_ptr(), std::vector{static_cast(n)}, DType::kFloat32); + + nvte_nvfp4_per_token_post_scale(d_te.data(), ra_te.data(), rb_te.data(), stream); +} + +// Per-tensor twin of nvfp4_per_token_gemm: scalar amax goes through cuBLAS's +// own amax slot (no post-scale). Bench-only apples-to-apples baseline. +void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_amax, const at::Tensor &b_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta) { + TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && + a_amax.is_cuda() && b_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(a_data.is_contiguous() && b_data.is_contiguous() && a_sf.is_contiguous() && + b_sf.is_contiguous() && a_amax.is_contiguous() && b_amax.is_contiguous() && + d.is_contiguous() && workspace.is_contiguous(), + "All tensors must be contiguous"); + TORCH_CHECK(a_data.scalar_type() == at::ScalarType::Byte, "a_data must be uint8 (FP4 packed)"); + TORCH_CHECK(b_data.scalar_type() == at::ScalarType::Byte, "b_data must be uint8 (FP4 packed)"); + TORCH_CHECK(a_sf.scalar_type() == at::ScalarType::Byte, "a_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(b_sf.scalar_type() == at::ScalarType::Byte, "b_sf must be uint8 (FP8 e4m3)"); + TORCH_CHECK(a_amax.scalar_type() == at::ScalarType::Float, "a_amax must be float32"); + TORCH_CHECK(b_amax.scalar_type() == at::ScalarType::Float, "b_amax must be float32"); + TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); + TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); + + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, + "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", + a_data.size(0), ", ", a_data.size(1), ")"); + TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", + b_data.size(0), ", ", b_data.size(1), ")"); + TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, + "), got (", d.size(0), ", ", d.size(1), ")"); + + TORCH_CHECK(k % 16 == 0, "k must be a multiple of 16 (NVFP4 inner SFVecSize)"); + TORCH_CHECK(a_sf.numel() == static_cast(m * k / 16), + "a_sf numel mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); + TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), + "b_sf numel mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); + TORCH_CHECK(a_amax.numel() == 1, "a_amax must be a scalar (numel=1), got ", a_amax.numel()); + TORCH_CHECK(b_amax.numel() == 1, "b_amax must be a scalar (numel=1), got ", b_amax.numel()); + + TORCH_CHECK(static_cast(beta) == 0.0f, + "nvfp4_per_tensor_gemm: beta != 0 not yet supported. Got beta=", beta); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + const std::vector a_data_shape = {static_cast(m), static_cast(k)}; + const std::vector b_data_shape = {static_cast(n), static_cast(k)}; + const std::vector a_sf_shape = {static_cast(m), static_cast(k / 16)}; + const std::vector b_sf_shape = {static_cast(n), static_cast(k / 16)}; + + auto byte_opts = a_sf.options().dtype(at::kByte); + at::Tensor a_sf_swizzled = at::empty({a_sf.numel()}, byte_opts); + at::Tensor b_sf_swizzled = at::empty({b_sf.numel()}, byte_opts); + + { + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + in_nvte.set_rowwise_scale_inv(a_sf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + out_nvte.set_rowwise_scale_inv(a_sf_swizzled.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + { + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + in_nvte.set_rowwise_scale_inv(b_sf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + out_nvte.set_rowwise_scale_inv(b_sf_swizzled.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); + } + + // Per-tensor amaxes go in the amax slot; cuBLAS LT folds them into alpha. + TensorWrapper a_te(NVTE_NVFP4_1D_SCALING); + a_te.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); + a_te.set_rowwise_scale_inv(a_sf_swizzled.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + a_te.set_amax(a_amax.data_ptr(), DType::kFloat32, std::vector{1}); + a_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper b_te(NVTE_NVFP4_1D_SCALING); + b_te.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); + b_te.set_rowwise_scale_inv(b_sf_swizzled.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + b_te.set_amax(b_amax.data_ptr(), DType::kFloat32, std::vector{1}); + b_te.set_with_gemm_swizzled_scales(true); + + TensorWrapper d_te = makeTransformerEngineTensor( + d.data_ptr(), + std::vector{static_cast(m), static_cast(n)}, + DType::kBFloat16); + + TensorWrapper workspace_te = makeTransformerEngineTensor( + workspace.data_ptr(), std::vector{static_cast(workspace.numel())}, + DType::kByte); + + // Operand swap: see nvfp4_per_token_gemm. + const float alpha_f = static_cast(alpha); + const float beta_f = static_cast(beta); + nvte_cublas_gemm_v2(/*transa=*/1, /*transb=*/0, &alpha_f, + b_te.data(), // cuBLAS-A := caller's B (N, K) + a_te.data(), // cuBLAS-B := caller's A (M, K) + &beta_f, d_te.data(), d_te.data(), workspace_te.data(), + /*config=*/nullptr, stream); + // No post-scale: per-tensor amaxes already folded into cuBLAS-internal alpha. +} + +// Grouped (multi-tensor) per-token quantize. Each direction takes 3 lists +// of per-split tensors; ``split_sections[i] = M_i`` (% 128, sum = sum_M). +// Disabled direction's lists are ignored. +namespace { + +void build_per_token_output_wrapper( + TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, bool columnwise, + const at::Tensor& q_row, const at::Tensor& s_dec_row, const at::Tensor& row_amax, + const at::Tensor& q_col, const at::Tensor& s_dec_col, const at::Tensor& col_amax) { + if (rowwise) { + TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), + "q_row must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), + "s_dec_row must be a contiguous CUDA tensor"); + TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), + "row_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, "q_row must be uint8"); + TORCH_CHECK(s_dec_row.scalar_type() == at::ScalarType::Byte, "s_dec_row must be uint8"); + TORCH_CHECK(row_amax.scalar_type() == at::ScalarType::Float, "row_amax must be fp32"); + TORCH_CHECK(q_row.numel() == M_i * K / 2, "q_row numel mismatch for split: expected ", + M_i * K / 2, ", got ", q_row.numel()); + TORCH_CHECK(s_dec_row.numel() == M_i * K / 16, "s_dec_row numel mismatch for split"); + TORCH_CHECK(row_amax.numel() == M_i, "row_amax numel mismatch for split"); + out_te.set_rowwise_data( + q_row.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(M_i), static_cast(K)}); + out_te.set_rowwise_scale_inv( + s_dec_row.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(M_i), static_cast(K / 16)}); + out_te.set_amax(row_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(M_i)}); + } + if (columnwise) { + TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), + "q_col must be a contiguous CUDA tensor"); + TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), + "s_dec_col must be a contiguous CUDA tensor"); + TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), + "col_amax must be a contiguous CUDA tensor"); + TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, "q_col must be uint8"); + TORCH_CHECK(s_dec_col.scalar_type() == at::ScalarType::Byte, "s_dec_col must be uint8"); + TORCH_CHECK(col_amax.scalar_type() == at::ScalarType::Float, "col_amax must be fp32"); + TORCH_CHECK(q_col.numel() == K * M_i / 2, "q_col numel mismatch for split"); + TORCH_CHECK(s_dec_col.numel() == K * M_i / 16, "s_dec_col numel mismatch for split"); + TORCH_CHECK(col_amax.numel() == K, "col_amax numel mismatch for split"); + out_te.set_columnwise_data( + q_col.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(K), static_cast(M_i)}); + out_te.set_columnwise_scale_inv( + s_dec_col.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(K), static_cast(M_i / 16)}); + out_te.set_columnwise_amax(col_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(K)}); + } +} + +DType resolve_input_dtype(const at::Tensor& input) { + if (input.scalar_type() == at::ScalarType::BFloat16) return DType::kBFloat16; + if (input.scalar_type() == at::ScalarType::Float) return DType::kFloat32; + if (input.scalar_type() == at::ScalarType::Half) return DType::kFloat16; + TORCH_CHECK(false, "input dtype must be bf16/fp16/fp32, got ", input.scalar_type()); + return DType::kBFloat16; // unreachable +} + +} // namespace + +void nvfp4_per_token_group_quantize( + const at::Tensor& input, const std::vector& split_sections, + std::vector q_row_list, std::vector s_dec_row_list, + std::vector row_amax_list, std::vector q_col_list, + std::vector s_dec_col_list, std::vector col_amax_list, + bool rowwise, bool columnwise) { + TORCH_CHECK(rowwise || columnwise, + "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), + "input must be a contiguous CUDA tensor"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + const int64_t sum_M = input.size(0); + const int64_t K = input.size(1); + const size_t num_tensors = split_sections.size(); + TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); + + // Sum + 64-multiple constraint. + int64_t acc = 0; + for (size_t i = 0; i < num_tensors; ++i) { + TORCH_CHECK(split_sections[i] >= 0, "split_sections[", i, "] must be non-negative"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, + "] = ", split_sections[i], " must be a multiple of 64"); + acc += split_sections[i]; + } + TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); + + if (rowwise) { + TORCH_CHECK(q_row_list.size() == num_tensors, "q_row_list size mismatch"); + TORCH_CHECK(s_dec_row_list.size() == num_tensors, "s_dec_row_list size mismatch"); + TORCH_CHECK(row_amax_list.size() == num_tensors, "row_amax_list size mismatch"); + } + if (columnwise) { + TORCH_CHECK(q_col_list.size() == num_tensors, "q_col_list size mismatch"); + TORCH_CHECK(s_dec_col_list.size() == num_tensors, "s_dec_col_list size mismatch"); + TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); + } + + const DType in_dtype = resolve_input_dtype(input); + const auto stream = at::cuda::getCurrentCUDAStream(); + + TensorWrapper in_te = makeTransformerEngineTensor( + input.data_ptr(), + std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); + + // One TensorWrapper per split; raw NVTETensor handles go into `handles`. + std::vector wrappers; + wrappers.reserve(num_tensors); + std::vector handles; + handles.reserve(num_tensors); + std::vector split_sections_sz(num_tensors); + + at::Tensor empty_dummy; // for slots we don't populate + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + split_sections_sz[i] = static_cast(M_i); + wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); + if (M_i == 0) { + handles.push_back(wrappers.back().data()); + continue; // empty split is allowed (skipped inside the kernel) + } + build_per_token_output_wrapper( + wrappers.back(), M_i, K, rowwise, columnwise, + rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_list[i] : empty_dummy, + rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, + columnwise ? s_dec_col_list[i] : empty_dummy, + columnwise ? col_amax_list[i] : empty_dummy); + handles.push_back(wrappers.back().data()); + } + + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), + split_sections_sz.data(), num_tensors, rowwise, + columnwise, stream); +} + +// Amax-only grouped variant (K1 only); for allReduce-before-cast flows. +void nvfp4_per_token_group_amax(const at::Tensor& input, + const std::vector& split_sections, + std::vector row_amax_list, + std::vector col_amax_list, bool rowwise, + bool columnwise) { + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), + "input must be a contiguous CUDA tensor"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + const int64_t sum_M = input.size(0); + const int64_t K = input.size(1); + const size_t num_tensors = split_sections.size(); + TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); + int64_t acc = 0; + for (size_t i = 0; i < num_tensors; ++i) { + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, + "] must be a multiple of 64"); + acc += split_sections[i]; + } + TORCH_CHECK(acc == sum_M, "sum(split_sections) must equal input.size(0)"); + if (rowwise) TORCH_CHECK(row_amax_list.size() == num_tensors, "row_amax_list size mismatch"); + if (columnwise) + TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); + + const DType in_dtype = resolve_input_dtype(input); + const auto stream = at::cuda::getCurrentCUDAStream(); + + TensorWrapper in_te = makeTransformerEngineTensor( + input.data_ptr(), + std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); + + std::vector wrappers; + wrappers.reserve(num_tensors); + std::vector handles; + handles.reserve(num_tensors); + std::vector split_sections_sz(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + split_sections_sz[i] = static_cast(M_i); + wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); + if (M_i == 0) { + handles.push_back(wrappers.back().data()); + continue; + } + if (rowwise) { + const at::Tensor& ra = row_amax_list[i]; + TORCH_CHECK(ra.is_cuda() && ra.scalar_type() == at::ScalarType::Float, "bad row_amax"); + TORCH_CHECK(ra.numel() == M_i, "row_amax numel mismatch"); + wrappers.back().set_amax(ra.data_ptr(), DType::kFloat32, + std::vector{static_cast(M_i)}); + } + if (columnwise) { + const at::Tensor& ca = col_amax_list[i]; + TORCH_CHECK(ca.is_cuda() && ca.scalar_type() == at::ScalarType::Float, "bad col_amax"); + TORCH_CHECK(ca.numel() == K, "col_amax numel mismatch"); + wrappers.back().set_columnwise_amax(ca.data_ptr(), DType::kFloat32, + std::vector{static_cast(K)}); + } + handles.push_back(wrappers.back().data()); + } + + nvte_group_nvfp4_per_token_amax(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, stream); +} + +// BULK grouped per-token quantize: alloc + view + dispatch in ONE C++ call. +// Returns 6 per-split tensor lists (s_dec_* pre-cast to Float8_e4m3fn). +// Byte-equal to the prior Python wrap (saves ~70-90us at N=8). +std::tuple, std::vector, std::vector, + std::vector, std::vector, std::vector> +nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, + const std::vector& split_sections, bool rowwise, + bool columnwise) { + // Validation mirrors _validate_per_token_group_input in Python. + TORCH_CHECK(rowwise || columnwise, + "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "x_concat must be contiguous (row-major)"); + TORCH_CHECK(input.dim() == 2, + "nvfp4_per_token_group_quantize expects a 2D input, got ", input.dim(), "D"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "Per-token grouped kernel is bf16-only; got dtype ", input.scalar_type()); + + const int64_t sum_M = input.size(0); + const int64_t K = input.size(1); + constexpr int64_t kPerTokenTile = 128; + constexpr int64_t kBlockK = 16; + + TORCH_CHECK(K % kPerTokenTile == 0, + "Per-token grouped kernel requires K % ", kPerTokenTile, " == 0; got K=", K); + + const size_t num_tensors = split_sections.size(); + TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); + TORCH_CHECK(num_tensors <= 64, + "num_tensors must be <= 64 (kernel arg-struct cap); got ", num_tensors); + + int64_t acc = 0; + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + TORCH_CHECK(M_i > 0, "split_sections[", i, "] must be > 0, got ", M_i); + TORCH_CHECK(M_i % kPerTokenTile == 0, "split_sections[", i, "] = ", M_i, + " must be a multiple of ", kPerTokenTile); + acc += M_i; + } + TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, + " must equal input.size(0) = ", sum_M); + + // Bulk allocation: one at::empty per output type, covers all splits. + auto opts_u8 = input.options().dtype(at::kByte); + auto opts_f32 = input.options().dtype(at::kFloat); + + at::Tensor q_row_bulk, s_dec_row_bulk, row_amax_bulk; + at::Tensor q_col_bulk, s_dec_col_bulk, col_amax_bulk; + + if (rowwise) { + q_row_bulk = at::empty({sum_M, K / 2}, opts_u8); + s_dec_row_bulk = at::empty({sum_M, K / kBlockK}, opts_u8); + row_amax_bulk = at::empty({sum_M}, opts_f32); + } + if (columnwise) { + q_col_bulk = at::empty({K * sum_M / 2}, opts_u8); + s_dec_col_bulk = at::empty({K * sum_M / kBlockK}, opts_u8); + col_amax_bulk = at::empty({static_cast(num_tensors), K}, opts_f32); + } + + // Per-split views built in C++; s_dec_* kept in both uint8 (for binding) + // and fp8_e4m3fn (returned to Python directly). + std::vector q_row_list, s_dec_row_u8_list, row_amax_list; + std::vector q_col_list, s_dec_col_u8_list, col_amax_list; + std::vector s_dec_row_fp8_list, s_dec_col_fp8_list; + if (rowwise) { + q_row_list.reserve(num_tensors); + s_dec_row_u8_list.reserve(num_tensors); + row_amax_list.reserve(num_tensors); + s_dec_row_fp8_list.reserve(num_tensors); + } + if (columnwise) { + q_col_list.reserve(num_tensors); + s_dec_col_u8_list.reserve(num_tensors); + col_amax_list.reserve(num_tensors); + s_dec_col_fp8_list.reserve(num_tensors); + } + + int64_t m_off = 0; + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + if (rowwise) { + q_row_list.emplace_back(q_row_bulk.narrow(0, m_off, M_i)); + s_dec_row_u8_list.emplace_back(s_dec_row_bulk.narrow(0, m_off, M_i)); + row_amax_list.emplace_back(row_amax_bulk.narrow(0, m_off, M_i)); + s_dec_row_fp8_list.emplace_back(s_dec_row_u8_list.back().view(at::kFloat8_e4m3fn)); + } + if (columnwise) { + auto q_col_flat = q_col_bulk.narrow(0, K * m_off / 2, K * M_i / 2); + q_col_list.emplace_back(q_col_flat.view({K, M_i / 2})); + auto s_dec_col_flat = + s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); + s_dec_col_u8_list.emplace_back(s_dec_col_flat.view({K, M_i / kBlockK})); + col_amax_list.emplace_back(col_amax_bulk.select(0, static_cast(i))); + s_dec_col_fp8_list.emplace_back(s_dec_col_u8_list.back().view(at::kFloat8_e4m3fn)); + } + m_off += M_i; + } + + // Dispatch K1+K2 grouped kernel via the same C-API the thin entry uses. + const auto stream = at::cuda::getCurrentCUDAStream(); + TensorWrapper in_te = makeTransformerEngineTensor( + input.data_ptr(), + std::vector{static_cast(sum_M), static_cast(K)}, + DType::kBFloat16); + + std::vector wrappers; + wrappers.reserve(num_tensors); + std::vector handles; + handles.reserve(num_tensors); + std::vector split_sections_sz(num_tensors); + + at::Tensor empty_dummy; + for (size_t i = 0; i < num_tensors; ++i) { + const int64_t M_i = split_sections[i]; + split_sections_sz[i] = static_cast(M_i); + wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); + build_per_token_output_wrapper( + wrappers.back(), M_i, K, rowwise, columnwise, + rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_u8_list[i] : empty_dummy, + rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, + columnwise ? s_dec_col_u8_list[i] : empty_dummy, + columnwise ? col_amax_list[i] : empty_dummy); + handles.push_back(wrappers.back().data()); + } + + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), + split_sections_sz.data(), num_tensors, rowwise, + columnwise, stream); + + return std::make_tuple(std::move(q_row_list), std::move(s_dec_row_fp8_list), + std::move(row_amax_list), std::move(q_col_list), + std::move(s_dec_col_fp8_list), std::move(col_amax_list)); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..f38a6532a4 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -342,6 +342,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_amax", &transformer_engine::pytorch::compute_amax, "Compute absolute max value in tensor", py::arg("input"), py::arg("amax"), py::call_guard()); + m.def("hadamard_transform_amax", &transformer_engine::pytorch::hadamard_transform_amax, + "K1 of the NVFP4Quantizer RHT+post_rht_amax path: rowwise (pre-RHT) + " + "columnwise (RHT(input.T)) amax in one launch. Bench-only entry.", + py::arg("input"), py::arg("rowwise_amax"), py::arg("columnwise_amax"), + py::arg("rht_matrix_random_sign_mask"), + py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", @@ -390,6 +396,61 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"), py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"), py::call_guard()); + m.def("nvfp4_per_token_quantize", &transformer_engine::pytorch::nvfp4_per_token_quantize, + "NVFP4 per-token cast (composite K1 amax + K2 encode). Same FP4 + 1x16 " + "e4m3 SF layout as per-tensor, but outer amax is per-row/per-col. " + "Requires bf16 input, M % 128 == 0, K % 128 == 0.", + py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), + py::arg("rowwise"), py::arg("columnwise")); + m.def("nvfp4_per_token_amax", + &transformer_engine::pytorch::nvfp4_per_token_amax, + "K1-only: per-row/per-col outer amax via TMA + atomicMax. Bench/diagnostic.", + py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), + py::arg("rowwise"), py::arg("columnwise")); + m.def("nvfp4_per_token_encode", + &transformer_engine::pytorch::nvfp4_per_token_encode, + "K2-only: FP4 + e4m3 SF encode given pre-filled amax buffers. Bench/diagnostic.", + py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), + py::arg("rowwise"), py::arg("columnwise")); + m.def("nvfp4_per_token_post_scale", &transformer_engine::pytorch::nvfp4_per_token_post_scale, + "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", + py::arg("d"), py::arg("row_amax_a"), py::arg("row_amax_b")); + m.def("nvfp4_per_token_gemm", &transformer_engine::pytorch::nvfp4_per_token_gemm, + "End-to-end NVFP4 per-token GEMM: swizzle compact SFs, cuBLAS LT NVFP4 " + "GEMM, then row*col post-scale to recover C = A @ B^T. beta must be 0.", + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), + py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), + py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), + py::arg("alpha"), py::arg("beta")); + m.def("nvfp4_per_tensor_gemm", &transformer_engine::pytorch::nvfp4_per_tensor_gemm, + "Skinny prod NVFP4 GEMM twin of nvfp4_per_token_gemm: per-tensor amaxes " + "folded into cuBLAS alpha, no trailing post-scale. Bench-only.", + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), + py::arg("a_amax"), py::arg("b_amax"), py::arg("d"), + py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), + py::arg("alpha"), py::arg("beta")); + m.def("nvfp4_per_token_group_quantize", + &transformer_engine::pytorch::nvfp4_per_token_group_quantize, + "Grouped (multi-tensor) NVFP4 per-token cast: K1 + K2 across <= 64 splits " + "of a single (sum_M, K) input. Byte-equal to a for-loop of single-tensor.", + py::arg("input"), py::arg("split_sections"), py::arg("q_row_list"), + py::arg("s_dec_row_list"), py::arg("row_amax_list"), py::arg("q_col_list"), + py::arg("s_dec_col_list"), py::arg("col_amax_list"), py::arg("rowwise"), + py::arg("columnwise")); + m.def("nvfp4_per_token_group_amax", + &transformer_engine::pytorch::nvfp4_per_token_group_amax, + "K1-only variant of nvfp4_per_token_group_quantize: only fills amax slots.", + py::arg("input"), py::arg("split_sections"), py::arg("row_amax_list"), + py::arg("col_amax_list"), py::arg("rowwise"), py::arg("columnwise")); + m.def("nvfp4_per_token_group_quantize_bulk", + &transformer_engine::pytorch::nvfp4_per_token_group_quantize_bulk, + "Bulk grouped quantize: allocates per-split buffers + view-slices inside " + "the binding (one pybind hop instead of 1 + 6N), then dispatches the K1+K2 " + "kernel. Returns 6 per-split tensor lists; empty for disabled directions.", + py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), + py::arg("columnwise")); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index c02d2ec616..80c088a24d 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -29,6 +29,38 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); } +// Thin pybind for nvte_hadamard_transform_amax: K1 of the production +// NVFP4Quantizer(with_rht, with_post_rht_amax) path. Computes rowwise (pre-RHT) +// and columnwise (RHT(input.T)) amax in one launch. Bench-only entry. +void hadamard_transform_amax(const at::Tensor& tensor, + at::Tensor& rowwise_amax, + at::Tensor& columnwise_amax, + int64_t rht_matrix_random_sign_mask) { + auto input_tensor = tensor.contiguous(); + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + + TORCH_CHECK(rowwise_amax.scalar_type() == at::kFloat, + "rowwise_amax must be a float tensor"); + TORCH_CHECK(rowwise_amax.numel() == 1, + "rowwise_amax must have exactly one element"); + TORCH_CHECK(columnwise_amax.scalar_type() == at::kFloat, + "columnwise_amax must be a float tensor"); + TORCH_CHECK(columnwise_amax.numel() == 1, + "columnwise_amax must have exactly one element"); + + // Mirror NVFP4Quantizer: empty NVFP4_1D_SCALING with two amax slots. + TensorWrapper te_output(NVTE_NVFP4_1D_SCALING); + te_output.set_amax(rowwise_amax.data_ptr(), DType::kFloat32, + std::vector{1}); + te_output.set_columnwise_amax(columnwise_amax.data_ptr(), DType::kFloat32, + std::vector{1}); + + nvte_hadamard_transform_amax(te_input.data(), te_output.data(), + /*random_sign_mask=*/0, + static_cast(rht_matrix_random_sign_mask), + at::cuda::getCurrentCUDAStream()); +} + void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, diff --git a/transformer_engine/pytorch/custom_recipes/__init__.py b/transformer_engine/pytorch/custom_recipes/__init__.py index f115ffe743..6d21422bb3 100644 --- a/transformer_engine/pytorch/custom_recipes/__init__.py +++ b/transformer_engine/pytorch/custom_recipes/__init__.py @@ -3,3 +3,35 @@ # See LICENSE for license information. """Experimental features and APIs.""" + +# Per-token NVFP4: per-row outer + 1x16 e4m3 inner SF; cuBLAS LT NVFP4 GEMM +# with operand amaxes pinned to 1.0 and a trailing row-amax post-scale. +# See quantization_nvfp4_per_token.py / gemm_nvfp4_per_token.py for the math. +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + NVFP4QuantizerPerTokenRef, + RefNVFP4TensorPerToken, + nvfp4_per_token_amax, + nvfp4_per_token_encode, + nvfp4_per_token_quantize, +) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token_group import ( + nvfp4_per_token_group_quantize, +) +from transformer_engine.pytorch.custom_recipes.gemm_nvfp4_per_token import ( + dequantize_nvfp4_per_token, + nvfp4_per_token_gemm, + nvfp4_per_token_gemm_dequant, +) + + +__all__ = [ + "NVFP4QuantizerPerTokenRef", + "RefNVFP4TensorPerToken", + "nvfp4_per_token_quantize", + "nvfp4_per_token_group_quantize", + "nvfp4_per_token_amax", + "nvfp4_per_token_encode", + "dequantize_nvfp4_per_token", + "nvfp4_per_token_gemm", + "nvfp4_per_token_gemm_dequant", +] diff --git a/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py new file mode 100644 index 0000000000..a35bca3232 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Reference + production GEMM for the NVFP4 per-token quantization scheme. + +Per-token GEMM reuses cuBLAS LT NVFP4 (no TE fork) + a trailing row-amax +post-scale. Each side is a (data, scale, row_amax) triple matching what +tex.nvfp4_per_token_quantize emits. See include/transformer_engine/nvfp4_per_token.h. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +# get_cublas_workspace is imported lazily inside nvfp4_per_token_gemm to +# avoid a circular import with cpp_extensions.gemm at module load time. +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import cast_from_fp4x2 +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + BLOCK_K, + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + _AMAX_FLOOR, + RefNVFP4TensorPerToken, +) + + +__all__ = [ + "dequantize_nvfp4_per_token", + "nvfp4_per_token_gemm_dequant", + "nvfp4_per_token_gemm", +] + + +# Reference: dequantize + reference matmul. + +def _validate_per_token_triple( + data: torch.Tensor, scale: torch.Tensor, row_amax: torch.Tensor, side: str +) -> int: + """Sanity-check one (data, scale, row_amax) triple; return K.""" + if data.ndim != 2 or scale.ndim != 2 or row_amax.ndim != 1: + raise ValueError( + f"{side}: expected 2D data/scale + 1D row_amax, got dims " + f"data={data.ndim}, scale={scale.ndim}, row_amax={row_amax.ndim}" + ) + rows = data.shape[0] + K = data.shape[1] * 2 # FP4 packs 2 values/byte. + if K % BLOCK_K != 0: + raise ValueError(f"{side}: K={K} must be a multiple of BLOCK_K={BLOCK_K}") + if scale.shape != (rows, K // BLOCK_K): + raise ValueError( + f"{side}: scale shape {tuple(scale.shape)} != ({rows}, {K // BLOCK_K})" + ) + if row_amax.shape != (rows,): + raise ValueError( + f"{side}: row_amax shape {tuple(row_amax.shape)} != ({rows},)" + ) + return K + + +def dequantize_nvfp4_per_token( + data: torch.Tensor, scale: torch.Tensor, row_amax: torch.Tensor +) -> torch.Tensor: + """Dequantize a per-token NVFP4 (data, scale, row_amax) triple to fp32. + + x[i, k] = code[i, k] * s_dec[i, k//16] * row_amax[i] / (FP4_MAX * E4M3_MAX). + """ + K = _validate_per_token_triple(data, scale, row_amax, "dequant") + rows = data.shape[0] + + codes = data.contiguous().view(dtype=torch.uint8) + qf = cast_from_fp4x2(codes, torch.float32) + + if scale.dtype == torch.float8_e4m3fn: + s_dec = scale.to(torch.float32) + else: + s_dec = scale.view(torch.float8_e4m3fn).to(torch.float32) + + inv_outer = row_amax.to(torch.float32) / (FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX) + per_block_decode = s_dec * inv_outer.unsqueeze(-1) + per_elem_decode = per_block_decode.repeat_interleave(BLOCK_K, dim=1) + assert per_elem_decode.shape == (rows, K) + return qf * per_elem_decode + + +def nvfp4_per_token_gemm_dequant( + a_data: torch.Tensor, + a_scale: torch.Tensor, + a_row_amax: torch.Tensor, + b_data: torch.Tensor, + b_scale: torch.Tensor, + b_row_amax: torch.Tensor, + *, + out_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Reference C = A @ B^T via dequant-then-fp32-matmul. + + Agrees with the cuBLAS LT path at TF32 precision (~1e-3 relative). + Exists as executable docs of the math chain and a sanity oracle. + """ + K_a = _validate_per_token_triple(a_data, a_scale, a_row_amax, "A") + K_b = _validate_per_token_triple(b_data, b_scale, b_row_amax, "B") + if K_a != K_b: + raise ValueError(f"K mismatch between A and B: {K_a} vs {K_b}") + + a_fp32 = dequantize_nvfp4_per_token(a_data, a_scale, a_row_amax) + b_fp32 = dequantize_nvfp4_per_token(b_data, b_scale, b_row_amax) + c = a_fp32 @ b_fp32.t() + return c.to(out_dtype) + + +# Production wrapper: cuBLAS LT NVFP4 GEMM + per-token post-scale. + +def nvfp4_per_token_gemm( + a_data: torch.Tensor, + a_scale: torch.Tensor, + a_row_amax: torch.Tensor, + b_data: torch.Tensor, + b_scale: torch.Tensor, + b_row_amax: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + alpha: float = 1.0, + beta: float = 0.0, + out_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Production C = alpha * (A @ B^T) via cuBLAS LT NVFP4 + per-token post-scale. + + Binding swizzles compact SFs in-flight, runs cuBLAS LT NVFP4 with operand + amaxes pinned to 1.0, then applies the row_amax_A * row_amax_B post-scale. + Output is bf16 (cuBLAS LT NVFP4 locks D to bf16/fp32); beta != 0 unsupported. + """ + import transformer_engine_torch as tex # type: ignore + + K_a = _validate_per_token_triple(a_data, a_scale, a_row_amax, "A") + K_b = _validate_per_token_triple(b_data, b_scale, b_row_amax, "B") + if K_a != K_b: + raise ValueError(f"K mismatch between A and B: {K_a} vs {K_b}") + K = K_a + M = a_data.shape[0] + N = b_data.shape[0] + + if K % 16 != 0: + raise ValueError(f"K must be a multiple of 16 (got K={K})") + # cuBLAS LT NVFP4 SF buffer is padded to (roundup(rows, 128), roundup(K/16, 4)). + # Our compact quantize emits (rows, K/16); SF padding is a TODO so reject M/N < 128. + if M < 128 or M % 128 != 0: + raise ValueError( + f"M must be a multiple of 128 (got M={M}); SF padding is a TODO." + ) + if N < 128 or N % 128 != 0: + raise ValueError( + f"N must be a multiple of 128 (got N={N}); SF padding is a TODO." + ) + if a_data.device != b_data.device: + raise ValueError( + f"A and B must be on the same device (got {a_data.device} vs {b_data.device})" + ) + device = a_data.device + + if out is None: + out_bf16 = torch.empty((M, N), dtype=torch.bfloat16, device=device) + else: + if out.shape != (M, N): + raise ValueError( + f"out shape {tuple(out.shape)} != ({M}, {N})" + ) + if out.dtype != torch.bfloat16: + raise ValueError( + f"out dtype must be bf16 for in-place use, got {out.dtype}. " + "(The binding produces bf16; pass `out=None` for non-bf16 dtypes " + "and the result will be cast at the end.)" + ) + out_bf16 = out + + if float(beta) != 0.0: + raise ValueError( + f"nvfp4_per_token_gemm: beta != 0 not yet supported, got beta={beta}. " + "Use beta=0 and accumulate outside the call if needed." + ) + + a_data_u8 = a_data.contiguous().view(dtype=torch.uint8) + b_data_u8 = b_data.contiguous().view(dtype=torch.uint8) + + # Binding expects uint8 SFs (accepts both e4m3 view and raw uint8 storage). + a_scale_u8 = a_scale.contiguous().view(dtype=torch.uint8) + b_scale_u8 = b_scale.contiguous().view(dtype=torch.uint8) + a_scale_u8_flat = a_scale_u8.reshape(-1) + b_scale_u8_flat = b_scale_u8.reshape(-1) + + a_row_amax_f32 = a_row_amax.to(torch.float32).contiguous() + b_row_amax_f32 = b_row_amax.to(torch.float32).contiguous() + + # Lazy import to break the cpp_extensions.gemm circular import. + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace + workspace = get_cublas_workspace(device.index, ub=False, grouped_gemm=False) + + tex.nvfp4_per_token_gemm( + a_data_u8, b_data_u8, a_scale_u8_flat, b_scale_u8_flat, + a_row_amax_f32, b_row_amax_f32, out_bf16, workspace, + M, N, K, float(alpha), float(beta), + ) + + return out_bf16 if out_dtype is torch.bfloat16 else out_bf16.to(out_dtype) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py new file mode 100644 index 0000000000..bd09f8970f --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py @@ -0,0 +1,386 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import dataclasses +from typing import Optional, Tuple + +import torch + +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import cast_to_fp4x2 + +# Inner sub-block size along K is fixed by the NVFP4 spec (one E4M3 +# ``s_dec`` per 16 FP4 samples); only the outer-amax granularity changes +# between per-token / per-tensor / blocked / 2D. +BLOCK_K: int = 16 + +# E2M1 / E4M3 numeric extrema (matches ``TypeExtrema`` in core_nvfp4.cuh). +FLOAT4_E2M1_MAX: float = 6.0 +FLOAT8_E4M3_MAX: float = 448.0 + +# Matches the kernel's ``fmaxf(row_amax, 1e-12f)`` clamp on the divisor of +# ``compute_global_encode_scaling_factor_FP4``. +_AMAX_FLOOR: float = 1e-12 + + +@dataclasses.dataclass +class RefNVFP4TensorPerToken: + """Container for the per-token reference output. + + Attributes + ---------- + data: + Packed rowwise FP4 bytes, ``(M, N // 2)`` ``uint8``. + scale: + Per-1x16-block rowwise decode scale (E4M3), ``(M, N // 16)`` + ``float8_e4m3fn``. + row_amax: + Per-row outer amax, ``(M,)`` ``float32``. This replaces the + per-tensor path's single-scalar ``amax`` and the blocked path's + per-window ``window_amax``. + columnwise_data, columnwise_scale, col_amax: + Their columnwise (transposed) counterparts. Shapes are + ``(N, M // 2)``, ``(N, M // 16)``, and ``(N,)`` respectively. + ``None`` if columnwise was not requested. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + row_amax: Optional[torch.Tensor] = None + columnwise_data: Optional[torch.Tensor] = None + columnwise_scale: Optional[torch.Tensor] = None + col_amax: Optional[torch.Tensor] = None + + +class NVFP4QuantizerPerTokenRef: + """Pure-PyTorch reference for the NVFP4 per-token cast kernel. + + Constructor takes the two output-direction switches (``rowwise`` and + ``columnwise``). RHT, 2D scaling, and stochastic rounding are not + exposed because the per-token CUDA kernel does not implement them + (the per-token path is target-shape simple-and-fast: per-row outer + + 1x16 inner SF, nothing else). + + The arithmetic chain (``S_enc``, ``s_dec``, ``block_scale``, FP4 cast) + matches ``NVFP4Quantizer1x64Ref`` / ``NVFP4QuantizerBlockedRef``; + only the outer-amax granularity differs: + + * 1x64Ref / BlockedRef : one outer amax per ``OUTER_K``-K-window + * **PerTokenRef** : one outer amax per row (full K window) + """ + + def __init__( + self, + rowwise: bool = True, + columnwise: bool = False, + ) -> None: + if not rowwise and not columnwise: + raise ValueError("At least one of rowwise / columnwise must be True.") + self.rowwise = rowwise + self.columnwise = columnwise + + def _quantize_2d(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the per-token reference math on a 2D input along its trailing dim. + + Returns ``(qx, sx, row_amax)`` where ``qx`` is ``(M, N // 2)`` + ``uint8``, ``sx`` is ``(M, N // BLOCK_K)`` ``float8_e4m3fn``, + and ``row_amax`` is ``(M,)`` ``float32``. + + The columnwise pass is implemented by calling this routine on + ``x.transpose(0, 1).contiguous()``. + """ + if x.ndim != 2: + raise ValueError(f"NVFP4QuantizerPerTokenRef expects a 2D tensor, got {x.ndim}D") + M, N = x.shape + if N % BLOCK_K != 0: + raise ValueError(f"N={N} must be a multiple of BLOCK_K={BLOCK_K}") + + device = x.device + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=device, dtype=torch.float32) + fp4_max = torch.tensor(FLOAT4_E2M1_MAX, device=device, dtype=torch.float32) + fp8_max = torch.tensor(FLOAT8_E4M3_MAX, device=device, dtype=torch.float32) + + n_blk = N // BLOCK_K + x_fp32 = x.to(torch.float32).contiguous() + x_blk = x_fp32.view(M, n_blk, BLOCK_K) + + # Outer = whole row. The kernel applies ``fmaxf(row_amax, 1e-12f)`` + # to the divisor; do the same here. + row_amax = torch.amax(torch.abs(x_fp32), dim=-1) # (M,) fp32 -- raw, pre-floor + row_amax_safe = torch.clamp(row_amax, min=_AMAX_FLOOR).unsqueeze(-1) # (M, 1) + + # Same ``compute_global_encode_scaling_factor_FP4`` form as the + # per-tensor / blocked paths (just with ``row_amax`` instead of + # ``window_amax`` / ``global_amax``). + S_enc_row = (fp8_max * fp4_max) / row_amax_safe # (M, 1) + S_enc_row = torch.minimum(S_enc_row, fp32_max) + S_enc_row = torch.where( + (row_amax_safe == 0) | (S_enc_row == 0), + torch.ones_like(S_enc_row), + S_enc_row, + ) + + # Fold ``1 / fp4_max`` into the multiplier the same way the kernel + # does in ``compute_decoding_scaling_factor`` (``S_enc * fp4_max_inv``). + S_enc_row_mul_inv6 = S_enc_row * torch.reciprocal(fp4_max) # (M, 1) + + # 1x16 block amax. Broadcast row's S_enc across n_blk blocks. + vec_max = torch.amax(torch.abs(x_blk), dim=-1, keepdim=True) # (M, n_blk, 1) + S_enc_per_blk = S_enc_row.unsqueeze(-1) # (M, 1, 1) -> broadcasts to (M, n_blk, 1) + S_enc_per_blk_mul = S_enc_row_mul_inv6.unsqueeze(-1) + + # decode_scale = saturating_cast(vec_max * S_enc / 6). + # Kernel does NOT clamp before the cast; we clamp here because + # PyTorch's ``.to(float8_e4m3fn)`` does not match CUDA's saturating + # cast for values above FP8_MAX. After the explicit clamp the two + # paths agree byte-for-byte. + decode_scale_fp32 = vec_max * S_enc_per_blk_mul + decode_scale_fp32 = torch.minimum(decode_scale_fp32, fp32_max) + decode_scale_fp32 = torch.clamp(decode_scale_fp32, min=-fp8_max, max=fp8_max) + decode_scale_e4m3 = decode_scale_fp32.to(torch.float8_e4m3fn) + decode_scale_back_fp32 = decode_scale_e4m3.to(torch.float32) + + # block_scale = S_enc / s_dec, matching ``__fdiv_rn`` in the + # kernel. All-zero blocks: s_dec saturates to 0, naive S_enc/0 + # would NaN; short-circuit to 0 to mirror the kernel. + zero_blk = decode_scale_back_fp32 == 0 + denom = torch.where(zero_blk, torch.ones_like(decode_scale_back_fp32), + decode_scale_back_fp32) + encode_scale = S_enc_per_blk / denom + encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) + encode_scale = torch.minimum(encode_scale, fp32_max) + + # Apply scale, clamp to FP4 range, pack two FP4 values per byte. + scaled_x = x_blk * encode_scale + clipped_x = torch.clamp(scaled_x, -fp4_max, fp4_max).reshape(M, N) + qx = cast_to_fp4x2(clipped_x).contiguous() # (M, N // 2) + + sx = decode_scale_e4m3.squeeze(-1).contiguous() # (M, n_blk) + row_amax_out = row_amax.to(torch.float32).contiguous() # (M,) -- raw, no floor + return qx, sx, row_amax_out + + def quantize(self, tensor: torch.Tensor) -> RefNVFP4TensorPerToken: + """Quantize ``tensor`` and return a ``RefNVFP4TensorPerToken``.""" + out = RefNVFP4TensorPerToken() + if self.rowwise: + qx, sx, ra = self._quantize_2d(tensor) + out.data = qx + out.scale = sx + out.row_amax = ra + if self.columnwise: + # The columnwise output is the rowwise quantization of the + # transpose; both directions share the same math chain. + qx_t, sx_t, ca = self._quantize_2d(tensor.transpose(0, 1).contiguous()) + out.columnwise_data = qx_t + out.columnwise_scale = sx_t + out.col_amax = ca + return out + + +# ============================================================================ +# Production wrapper (calls the CUDA kernel via the C-API binding). +# ============================================================================ + +# ---------------------------------------------------------------------------- +# Shape / dtype gate shared by all three entries. +# ---------------------------------------------------------------------------- +_PER_TOKEN_TILE: int = 128 # CHUNK_DIM_Y / CHUNK_DIM_X in the kernel + + +def _validate_per_token_input(x: torch.Tensor) -> Tuple[int, int]: + """Enforce the per-token kernel's hard constraints. Returns ``(M, K)``.""" + if x.ndim != 2: + raise ValueError(f"nvfp4_per_token expects a 2D tensor, got {x.ndim}D") + if x.dtype != torch.bfloat16: + raise ValueError( + f"Per-token kernel is bf16-only; got dtype {x.dtype}. " + "Non-bf16 inputs are not supported (no fallback path)." + ) + M, K = x.shape + if M % _PER_TOKEN_TILE != 0: + raise ValueError( + f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}" + ) + if K % _PER_TOKEN_TILE != 0: + raise ValueError( + f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" + ) + return M, K + + +def nvfp4_per_token_quantize( + x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = False +) -> RefNVFP4TensorPerToken: + """Production NVFP4 per-token cast through ``tex.nvfp4_per_token_quantize``. + + Backed by the TMA + mbarrier + 64x64 sub-tile pipeline + (``common/cast/nvfp4/quantize_nvfp4_per_token.cu``). The C-API + runs K1 (per-row + per-col amax) and K2 (FP4 + e4m3 SF encode) back- + to-back on the same stream. + + Returns a ``RefNVFP4TensorPerToken`` populated with the kernel + output (compact, non-swizzled scales). The Python-level container is + the same as the reference for symmetry; only the source of the + values differs. + + For cuBLAS LT consumption, the caller must swizzle the inner SF + before forwarding to the GEMM; ``gemm_nvfp4_per_token`` handles + this automatically. + + Raises ``ValueError`` on non-bf16 input or non-128-aligned shapes. + """ + # Import lazily so the module does not require the binary at import time. + # (Mirrors the pattern in ``gemm_nvfp4_blocked.py``.) + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + M, K = _validate_per_token_input(x) + + device = x.device + # Empty placeholders for the direction(s) we don't request -- the + # binding still expects the argument slots (typed-empty is fine). + empty = torch.empty(0, dtype=torch.uint8, device=device) + empty_f32 = torch.empty(0, dtype=torch.float32, device=device) + + if rowwise: + q_row = torch.empty((M, K // 2), dtype=torch.uint8, device=device) + s_dec_row = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + else: + q_row, s_dec_row, row_amax = empty, empty, empty_f32 + + if columnwise: + q_col = torch.empty((K, M // 2), dtype=torch.uint8, device=device) + s_dec_col = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + else: + q_col, s_dec_col, col_amax = empty, empty, empty_f32 + + tex.nvfp4_per_token_quantize( + x, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, columnwise + ) + + out = RefNVFP4TensorPerToken() + if rowwise: + out.data = q_row + out.scale = s_dec_row.view(torch.float8_e4m3fn) + out.row_amax = row_amax + if columnwise: + out.columnwise_data = q_col + out.columnwise_scale = s_dec_col.view(torch.float8_e4m3fn) + out.col_amax = col_amax + return out + + +# ============================================================================ +# Split entries (K1 = amax-only, K2 = encode-only). +# +# Diagnostic / benchmark interface, mirroring the production per-tensor +# kernel split (``HadamardAmaxTmaKernel`` for amax + the row_col_rht_gemm +# cast pass). Production callers should use ``nvfp4_per_token_quantize`` +# above; the composite handles K1 + K2 ordering on the same stream. +# ============================================================================ + +def nvfp4_per_token_amax( + x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = True, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. + Returns ``(row_amax, col_amax)``; either may be ``None`` if the + corresponding direction is not requested. + + Lets the benchmark compare K1 wall-time against the production + ``HadamardAmaxTmaKernel``. Production callers should use the + composite ``nvfp4_per_token_quantize`` instead. + + Raises ``ValueError`` on non-bf16 input or non-128-aligned shapes. + """ + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + M, K = _validate_per_token_input(x) + + device = x.device + row_amax = ( + torch.empty((M,), dtype=torch.float32, device=device) + if rowwise + else torch.empty(0, dtype=torch.float32, device=device) + ) + col_amax = ( + torch.empty((K,), dtype=torch.float32, device=device) + if columnwise + else torch.empty(0, dtype=torch.float32, device=device) + ) + + tex.nvfp4_per_token_amax(x, row_amax, col_amax, rowwise, columnwise) + + return (row_amax if rowwise else None, col_amax if columnwise else None) + + +def nvfp4_per_token_encode( + x: torch.Tensor, + *, + row_amax: Optional[torch.Tensor] = None, + col_amax: Optional[torch.Tensor] = None, + rowwise: bool = True, + columnwise: bool = True, +) -> RefNVFP4TensorPerToken: + """Kernel 2 in isolation: FP4 + e4m3 SF encode given pre-filled + amax buffer(s). + + ``row_amax`` of shape ``(M,)`` is required when ``rowwise=True``; same + for ``col_amax`` of shape ``(K,)`` when ``columnwise=True``. The + buffers are typically produced by a prior + ``nvfp4_per_token_amax`` call. + + Lets the benchmark compare K2 wall-time against the production + per-tensor cast pass. Production callers should use the composite + ``nvfp4_per_token_quantize`` instead. + + Raises ``ValueError`` on non-bf16 input, non-128-aligned shapes, or + missing / mis-shaped amax buffers. + """ + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + M, K = _validate_per_token_input(x) + if rowwise and (row_amax is None or row_amax.shape != (M,)): + raise ValueError(f"row_amax must be (M={M},) fp32 when rowwise=True") + if columnwise and (col_amax is None or col_amax.shape != (K,)): + raise ValueError(f"col_amax must be (K={K},) fp32 when columnwise=True") + + device = x.device + empty = torch.empty(0, dtype=torch.uint8, device=device) + empty_f32 = torch.empty(0, dtype=torch.float32, device=device) + + if rowwise: + q_row = torch.empty((M, K // 2), dtype=torch.uint8, device=device) + s_dec_row = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) + row_amax_t = row_amax # type: ignore[assignment] + else: + q_row, s_dec_row, row_amax_t = empty, empty, empty_f32 + if columnwise: + q_col = torch.empty((K, M // 2), dtype=torch.uint8, device=device) + s_dec_col = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device) + col_amax_t = col_amax # type: ignore[assignment] + else: + q_col, s_dec_col, col_amax_t = empty, empty, empty_f32 + + tex.nvfp4_per_token_encode( + x, q_row, s_dec_row, row_amax_t, q_col, s_dec_col, col_amax_t, rowwise, columnwise, + ) + + out = RefNVFP4TensorPerToken() + if rowwise: + out.data = q_row + out.scale = s_dec_row.view(torch.float8_e4m3fn) + out.row_amax = row_amax_t + if columnwise: + out.columnwise_data = q_col + out.columnwise_scale = s_dec_col.view(torch.float8_e4m3fn) + out.col_amax = col_amax_t + return out diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py new file mode 100644 index 0000000000..dd8b2283e9 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped (multi-tensor) NVFP4 per-token quantize Python wrapper. + +Dispatches through ``tex.nvfp4_per_token_group_quantize_bulk`` -- the bulk +C++ binding owns allocation, view-slicing, and the composite K1+K2 kernel +dispatch. Requires bf16 input with K and every split_sections[i] a multiple +of 128; up to 64 splits. +""" + +from __future__ import annotations + +from typing import List, Sequence + +import torch + +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token import ( + RefNVFP4TensorPerToken, + _PER_TOKEN_TILE, +) + + +def _validate_per_token_group_input( + x_concat: torch.Tensor, split_sections: Sequence[int] +) -> tuple[int, int]: + """Enforce the per-token grouped kernel's hard constraints. Returns + ``(sum_M, K)``. + """ + if x_concat.ndim != 2: + raise ValueError( + f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D" + ) + if not x_concat.is_contiguous(): + raise ValueError("x_concat must be contiguous (row-major)") + if x_concat.dtype != torch.bfloat16: + raise ValueError( + f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}." + ) + sum_M, K = x_concat.shape + if K % _PER_TOKEN_TILE != 0: + raise ValueError( + f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" + ) + if len(split_sections) == 0: + raise ValueError("split_sections must not be empty") + if len(split_sections) > 64: + raise ValueError( + f"num_tensors must be <= 64 (kernel arg-struct cap); got {len(split_sections)}" + ) + acc = 0 + for i, M_i in enumerate(split_sections): + if M_i <= 0: + raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") + if M_i % _PER_TOKEN_TILE != 0: + raise ValueError( + f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}" + ) + acc += M_i + if acc != sum_M: + raise ValueError( + f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}" + ) + return sum_M, K + + +def nvfp4_per_token_group_quantize( + x_concat: torch.Tensor, + split_sections: Sequence[int], + *, + rowwise: bool = True, + columnwise: bool = False, +) -> List[RefNVFP4TensorPerToken]: + """Grouped NVFP4 per-token cast; returns N RefNVFP4TensorPerToken splits. + + Raises ``ValueError`` on shape / dtype / split-size violations. + """ + import transformer_engine_torch as tex # type: ignore + + if not (rowwise or columnwise): + raise ValueError("At least one of rowwise / columnwise must be True.") + + _validate_per_token_group_input(x_concat, split_sections) + split_sections_list = [int(M_i) for M_i in split_sections] + N = len(split_sections_list) + + # Bulk C++ call returns per-split views; s_dec_* already in fp8_e4m3fn dtype. + ( + q_row_list, + s_dec_row_list, + row_amax_list, + q_col_list, + s_dec_col_list, + col_amax_list, + ) = tex.nvfp4_per_token_group_quantize_bulk( + x_concat, split_sections_list, rowwise, columnwise + ) + + outs: List[RefNVFP4TensorPerToken] = [] + for i in range(N): + out = RefNVFP4TensorPerToken() + if rowwise: + out.data = q_row_list[i] + out.scale = s_dec_row_list[i] + out.row_amax = row_amax_list[i] + if columnwise: + out.columnwise_data = q_col_list[i] + out.columnwise_scale = s_dec_col_list[i] + out.col_amax = col_amax_list[i] + outs.append(out) + return outs From c3780567cc1a49c2451932e516941b5728626731 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 13:10:11 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/nvfp4/bench_nvfp4_per_token.py | 45 +- .../nvfp4/bench_nvfp4_per_token_group.py | 43 +- tests/pytorch/nvfp4/test_nvfp4_per_token.py | 107 +++-- .../nvfp4/test_nvfp4_per_token_group.py | 90 ++-- .../cast/nvfp4/quantize_nvfp4_per_token.cu | 387 +++++++++--------- .../nvfp4/quantize_nvfp4_per_token_group.cu | 384 ++++++++--------- .../common/gemm/nvfp4_per_token_post_scale.cu | 7 +- .../transformer_engine/nvfp4_per_token.h | 24 +- transformer_engine/pytorch/csrc/extensions.h | 43 +- .../csrc/extensions/nvfp4_per_token.cpp | 249 +++++------ .../pytorch/csrc/extensions/pybind.cpp | 43 +- .../pytorch/csrc/extensions/recipe.cpp | 18 +- .../custom_recipes/gemm_nvfp4_per_token.py | 39 +- .../quantization_nvfp4_per_token.py | 29 +- .../quantization_nvfp4_per_token_group.py | 24 +- 15 files changed, 734 insertions(+), 798 deletions(-) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py index 54aaaeccb7..1312c841c7 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py @@ -43,9 +43,7 @@ def cuda_time_ms(fn: Callable[[], None], *, warmup: int = 5, iters: int = 50) -> return statistics.median(samples) -def cuda_graph_time_ms( - fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 -) -> float: +def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: """Median g.replay() wall time of fn captured into a CUDA Graph (kernel-only floor). Returns nan if capture fails. @@ -104,10 +102,10 @@ def _has_sm100() -> bool: class ShapeBench: M: int K: int - t_pt: float # per-token full K1+K2 (eager pybind, ms) - t_pten: float # per-tensor full K1+K2 (eager pybind, ms) - t_pt_g: float # per-token under CUDA Graphs replay (ms) - t_pten_g: float # per-tensor under CUDA Graphs replay (ms) + t_pt: float # per-token full K1+K2 (eager pybind, ms) + t_pten: float # per-tensor full K1+K2 (eager pybind, ms) + t_pt_g: float # per-token under CUDA Graphs replay (ms) + t_pten_g: float # per-tensor under CUDA Graphs replay (ms) def _bench_shape(M: int, K: int, *, device: torch.device) -> ShapeBench: @@ -132,7 +130,15 @@ def _baseline_quant_fn(): def _pt_full_quant_fn(): tex.nvfp4_per_token_quantize( - a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True, + a, + q_row_a, + s_dec_row_a, + ra_a, + q_col_a, + s_dec_col_a, + ca_a, + True, + True, ) t_pten = cuda_time_ms(_baseline_quant_fn) @@ -146,9 +152,7 @@ def _pt_full_quant_fn(): # 6x3 sweep matching bench_nvfp4_per_token_group.py: M in {1024..32768}, K in {2048,4096,8192}. _M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) _K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) -_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple( - (m, k) for m in _M_VALUES for k in _K_VALUES -) +_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple((m, k) for m in _M_VALUES for k in _K_VALUES) def _parse_shape(s: str) -> Tuple[int, int]: @@ -169,9 +173,14 @@ def main() -> int: description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4." ) parser.add_argument( - "--shapes", type=_parse_shape, nargs="+", default=None, - help="Shapes to bench, in MxK form (e.g. 4096x4096). " - "Default: an internally-chosen production-shape sweep.", + "--shapes", + type=_parse_shape, + nargs="+", + default=None, + help=( + "Shapes to bench, in MxK form (e.g. 4096x4096). " + "Default: an internally-chosen production-shape sweep." + ), ) args = parser.parse_args() @@ -186,9 +195,9 @@ def main() -> int: header = ( f"{'M':>7} {'K':>6}" - f" |" + " |" f"{'per-token':>10} {'per-tensor':>11} {'ratio':>8}" - f" |" + " |" f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>18} {'ratio(Graph)':>13}" ) print(header) @@ -204,9 +213,9 @@ def main() -> int: ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" print( f"{rec.M:>7} {rec.K:>6}" - f" |" + " |" f"{rec.t_pt:>10.4f} {rec.t_pten:>11.4f} {ratio_s:>8}" - f" |" + " |" f"{rec.t_pt_g:>17.4f} {rec.t_pten_g:>18.4f} {ratio_g_s:>13}" ) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py index 257cda32c8..7c62db0693 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py @@ -39,9 +39,7 @@ def _make_baseline_quantizer_list(num_splits: int) -> List[NVFP4Quantizer]: return [q] * num_splits -def cuda_graph_time_ms( - fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 -) -> float: +def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: """Median g.replay() time of fn captured into a CUDA Graph, in ms. Returns nan if capture fails (e.g. some C-API does an incompatible sync). @@ -106,9 +104,9 @@ def _time_split_quantize(x_concat, split_sections, quantizer_list, n_iters=20, n return start.elapsed_time(stop) / n_iters # ms -def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, - n_iters=20, n_warmup=5): +def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, n_iters=20, n_warmup=5): """Per-tensor grouped under CUDA Graphs replay.""" + def fn() -> None: _ = tex.split_quantize(x_concat, split_sections, quantizer_list) @@ -117,6 +115,7 @@ def fn() -> None: def _time_grouped_graph(x_concat, split_sections, rowwise, columnwise, n_iters=20, n_warmup=5): """Per-token grouped under CUDA Graphs replay.""" + def fn() -> None: _ = nvfp4_per_token_group_quantize( x_concat, split_sections, rowwise=rowwise, columnwise=columnwise @@ -158,9 +157,9 @@ def main() -> None: header = ( f"{'sum_M':>6} {'K':>5}" - f" |" + " |" f"{'per-token':>10} {'per-tensor':>10} {'ratio':>8}" - f" |" + " |" f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>17} {'ratio(Graph)':>13}" ) print(header) @@ -176,39 +175,23 @@ def main() -> None: print() prev_sum_M = sum_M - x_concat = ( - torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0 - ).contiguous() + x_concat = (torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0).contiguous() quantizer_list = _make_baseline_quantizer_list(num_splits) t_pt = _time_grouped(x_concat, split_sections, rowwise, columnwise) t_pten = _time_split_quantize(x_concat, split_sections, quantizer_list) ratio = t_pt / t_pten if t_pten > 0 else float("nan") - t_pt_g = _time_grouped_graph( - x_concat, split_sections, rowwise, columnwise - ) - t_pten_g = _time_split_quantize_graph( - x_concat, split_sections, quantizer_list - ) + t_pt_g = _time_grouped_graph(x_concat, split_sections, rowwise, columnwise) + t_pten_g = _time_split_quantize_graph(x_concat, split_sections, quantizer_list) if math.isnan(t_pt_g) or math.isnan(t_pten_g) or t_pten_g <= 0: ratio_g = float("nan") - graph_cells = ( - f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {'nan':>13}" - ) + graph_cells = f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {'nan':>13}" else: ratio_g = t_pt_g / t_pten_g - graph_cells = ( - f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {ratio_g:>12.2f}x" - ) - - print( - f"{sum_M:>6d} {K:>5d}" - f" |" - f"{t_pt:>10.4f} {t_pten:>10.4f} {ratio:>7.2f}x" - f" |" - f"{graph_cells}" - ) + graph_cells = f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {ratio_g:>12.2f}x" + + print(f"{sum_M:>6d} {K:>5d} |{t_pt:>10.4f} {t_pten:>10.4f} {ratio:>7.2f}x |{graph_cells}") del x_concat, quantizer_list torch.cuda.empty_cache() diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token.py b/tests/pytorch/nvfp4/test_nvfp4_per_token.py index 0cc4f8347f..727b5792d2 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token.py @@ -55,9 +55,9 @@ def _has_sm100() -> bool: # Shapes obey the kernel contract (M % 128 == 0, K % 128 == 0). _QUANT_SHAPES = [ - (128, 128), # smallest legal shape - (128, 256), # K > inner SF window of single chunk - (256, 128), # M > inner SF window of single chunk + (128, 128), # smallest legal shape + (128, 256), # K > inner SF window of single chunk + (256, 128), # M > inner SF window of single chunk (256, 512), (512, 1024), ] @@ -92,8 +92,10 @@ def test_per_token_quantize_byte_exact(M: int, N: int, rowwise: bool, columnwise qx_ref = _unpack_fp4_byte_pairs(ref.data.view(torch.uint8)) torch.testing.assert_close(qx_sut, qx_ref, atol=0.0, rtol=0.0) torch.testing.assert_close( - sut.scale.view(torch.uint8), ref.scale.view(torch.uint8), - atol=0.0, rtol=0.0, + sut.scale.view(torch.uint8), + ref.scale.view(torch.uint8), + atol=0.0, + rtol=0.0, ) torch.testing.assert_close(sut.row_amax, ref.row_amax, atol=0.0, rtol=0.0) @@ -104,18 +106,23 @@ def test_per_token_quantize_byte_exact(M: int, N: int, rowwise: bool, columnwise torch.testing.assert_close( sut.columnwise_scale.view(torch.uint8), ref.columnwise_scale.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, ) torch.testing.assert_close(sut.col_amax, ref.col_amax, atol=0.0, rtol=0.0) # (2) Split-kernel parity: K1 then K2 == composite K1+K2. + @_GATED_FP4 @pytest.mark.parametrize("M,N", _QUANT_SHAPES) @pytest.mark.parametrize("rowwise,columnwise", [(True, False), (False, True), (True, True)]) def test_per_token_split_byte_equal( - M: int, N: int, rowwise: bool, columnwise: bool, + M: int, + N: int, + rowwise: bool, + columnwise: bool, ) -> None: """K1 (amax) then K2 (encode) byte-equals the composite K1+K2.""" torch.manual_seed(0xC0FFEE * (M + 7) + (N + 11)) @@ -128,7 +135,9 @@ def test_per_token_split_byte_equal( composite = nvfp4_per_token_quantize(x, rowwise=rowwise, columnwise=columnwise) row_amax, col_amax = nvfp4_per_token_amax( - x, rowwise=rowwise, columnwise=columnwise, + x, + rowwise=rowwise, + columnwise=columnwise, ) split = nvfp4_per_token_encode( x, @@ -141,29 +150,36 @@ def test_per_token_split_byte_equal( if rowwise: torch.testing.assert_close(split.row_amax, composite.row_amax, atol=0.0, rtol=0.0) torch.testing.assert_close( - split.data.view(torch.uint8), composite.data.view(torch.uint8), - atol=0.0, rtol=0.0, + split.data.view(torch.uint8), + composite.data.view(torch.uint8), + atol=0.0, + rtol=0.0, ) torch.testing.assert_close( - split.scale.view(torch.uint8), composite.scale.view(torch.uint8), - atol=0.0, rtol=0.0, + split.scale.view(torch.uint8), + composite.scale.view(torch.uint8), + atol=0.0, + rtol=0.0, ) if columnwise: torch.testing.assert_close(split.col_amax, composite.col_amax, atol=0.0, rtol=0.0) torch.testing.assert_close( split.columnwise_data.view(torch.uint8), composite.columnwise_data.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, ) torch.testing.assert_close( split.columnwise_scale.view(torch.uint8), composite.columnwise_scale.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, ) # (2b) Input-validation rejections. + @_GATED_FP4 def test_per_token_validation_rejects_fp32() -> None: """Per-token must ``ValueError`` on non-bf16 input (no fallback path).""" @@ -188,6 +204,7 @@ def test_per_token_validation_rejects_unaligned() -> None: # (3) Dequant + fp32 reference matmul sanity (pure-Python, no kernel). + @_GATED_FP4 @pytest.mark.parametrize("M,N", [(32, 64), (64, 256)]) def test_per_token_dequant_roundtrip_close(M: int, N: int) -> None: @@ -207,10 +224,10 @@ def test_per_token_dequant_roundtrip_close(M: int, N: int) -> None: # (4) Production GEMM: cuBLAS LT NVFP4 + post-scale composite. # Shapes need M, N % 128 == 0 and K % 16 == 0 for cuBLAS LT NVFP4. _GEMM_SHAPES = [ - (128, 128, 128), # smallest legal shape - (128, 128, 256), # exercise K > inner SF window - (256, 128, 256), # non-square (M != N) - (256, 256, 256), # square mid-size + (128, 128, 128), # smallest legal shape + (128, 128, 256), # exercise K > inner SF window + (256, 128, 256), # non-square (M != N) + (256, 256, 256), # square mid-size ] @@ -254,24 +271,24 @@ def _three_pronged_bf16_close( f"rel_l2={rel_l2:.3g} max_abs={max_abs:.3g} n_bad_mixed={n_bad_mixed} " f"mean_|d_ref|={mean_ref_abs:.3g} " f"(diag: mean_rel={mean_rel:.3g} max_rel={max_rel:.3g} " - f"— mean_rel/max_rel are NOT asserted; see helper docstring)" + "— mean_rel/max_rel are NOT asserted; see helper docstring)" ) print(diag) bad_count_abs_floor = max(8, int(bad_count_ratio * n)) assert rel_l2 <= rel_l2_floor, ( f"{diag} -> rel_l2 > {rel_l2_floor} (energy-weighted global " - f"relative error too high — possible structural bug)" + "relative error too high — possible structural bug)" ) assert n_bad_mixed <= bad_count_abs_floor, ( f"{diag} -> n_bad_mixed > {bad_count_abs_floor} " f"(|diff| > atol={atol} + rtol={bad_rtol} * |d_r| for too " - f"many elements — possible localised broken row/col)" + "many elements — possible localised broken row/col)" ) assert max_abs <= max_abs_bound, ( f"{diag} -> max_abs > {max_abs_bound:.3g} = atol + " - f"bad_rtol * mean_|d_ref| (worst element is way outside the " - f"noise envelope — possible NaN-like blow-up)" + "bad_rtol * mean_|d_ref| (worst element is way outside the " + "noise envelope — possible NaN-like blow-up)" ) @@ -292,8 +309,12 @@ def test_per_token_gemm_close_to_bf16(M: int, N: int, K: int) -> None: b_q = nvfp4_per_token_quantize(b, rowwise=True) d_sut = nvfp4_per_token_gemm( - a_q.data, a_q.scale, a_q.row_amax, - b_q.data, b_q.scale, b_q.row_amax, + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, ) d_ref = (a.float() @ b.float().t()).to(torch.bfloat16) @@ -318,11 +339,11 @@ def test_per_token_gemm_close_to_bf16(M: int, N: int, K: int) -> None: ) assert cos_sim >= cos_sim_floor, ( f"{diag} -> cos_sim < {cos_sim_floor} (structural mismatch; " - f"likely wrong operand swap, missing scale, or indexing bug)" + "likely wrong operand swap, missing scale, or indexing bug)" ) assert mag_lo <= mag_ratio <= mag_hi, ( f"{diag} -> mag_ratio not in [{mag_lo}, {mag_hi}] " - f"(systematic magnitude error; check alpha/post-scale)" + "(systematic magnitude error; check alpha/post-scale)" ) @@ -339,21 +360,33 @@ def test_per_token_gemm_close_to_dequant_ref(M: int, N: int, K: int) -> None: b_q = nvfp4_per_token_quantize(b, rowwise=True) d_sut = nvfp4_per_token_gemm( - a_q.data, a_q.scale, a_q.row_amax, - b_q.data, b_q.scale, b_q.row_amax, + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, ).float() d_ref = nvfp4_per_token_gemm_dequant( - a_q.data, a_q.scale, a_q.row_amax, - b_q.data, b_q.scale, b_q.row_amax, + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, out_dtype=torch.float32, ) _three_pronged_bf16_close( - d_sut, d_ref, + d_sut, + d_ref, label=f"vs_dequant({M}x{N}x{K})", # Empirical rel_l2 ~5e-3..1.5e-2 on random N(0, 0.5), K=128-256. - rel_l2_floor=2e-2, atol=1e-1, bad_rtol=5e-2, bad_count_ratio=1e-2, + rel_l2_floor=2e-2, + atol=1e-1, + bad_rtol=5e-2, + bad_count_ratio=1e-2, ) @@ -369,7 +402,11 @@ def test_per_token_gemm_rejects_beta_nonzero() -> None: with pytest.raises(ValueError, match=r"beta != 0"): nvfp4_per_token_gemm( - a_q.data, a_q.scale, a_q.row_amax, - b_q.data, b_q.scale, b_q.row_amax, + a_q.data, + a_q.scale, + a_q.row_amax, + b_q.data, + b_q.scale, + b_q.row_amax, beta=1.0, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py index 3a72616f26..7c617200cd 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py @@ -89,9 +89,7 @@ def _group_quantize_py( col_amax_list: List[torch.Tensor] = [] for M_i in split_sections: - qr, sr, ra, qc, sc, ca = _alloc_per_token_buffers( - M_i, K, rowwise, columnwise, device - ) + qr, sr, ra, qc, sc, ca = _alloc_per_token_buffers(M_i, K, rowwise, columnwise, device) if rowwise: q_row_list.append(qr) s_dec_row_list.append(sr) @@ -122,14 +120,10 @@ def _group_quantize_py( # Re-view e4m3 SF as torch.float8_e4m3fn (same bytes, expected dtype). tensor = RefNVFP4TensorPerToken( data=q_row_list[i] if rowwise else None, - scale=( - s_dec_row_list[i].view(torch.float8_e4m3fn) if rowwise else None - ), + scale=(s_dec_row_list[i].view(torch.float8_e4m3fn) if rowwise else None), row_amax=row_amax_list[i] if rowwise else None, columnwise_data=q_col_list[i] if columnwise else None, - columnwise_scale=( - s_dec_col_list[i].view(torch.float8_e4m3fn) if columnwise else None - ), + columnwise_scale=(s_dec_col_list[i].view(torch.float8_e4m3fn) if columnwise else None), col_amax=col_amax_list[i] if columnwise else None, ) out.append(tensor) @@ -139,13 +133,13 @@ def _group_quantize_py( # Test fixtures. Per-token kernel requires M_i % 128 == 0 and K % 128 == 0. _SHAPES: List[Tuple[List[int], int]] = [ # (split_sections, K) - ([128], 128), # trivial: 1 split, smallest legal shape - ([128, 128], 128), # 2 equal splits - ([128, 256], 128), # 2 unequal splits - ([128, 256, 128], 256), # 3 splits, mixed sizes - ([128, 128, 128, 128], 256), # 4 equal splits - ([256, 128, 384, 128, 128], 512), # 5-way unequal split, typical MoE - ([256, 256], 1024), # larger K, 2 splits + ([128], 128), # trivial: 1 split, smallest legal shape + ([128, 128], 128), # 2 equal splits + ([128, 256], 128), # 2 unequal splits + ([128, 256, 128], 256), # 3 splits, mixed sizes + ([128, 128, 128, 128], 256), # 4 equal splits + ([256, 128, 384, 128, 128], 512), # 5-way unequal split, typical MoE + ([256, 256], 1024), # larger K, 2 splits ] @@ -179,8 +173,7 @@ def test_group_per_token_quantize_byte_equal( assert x_concat.shape == (sum_M, K) oracle: List[RefNVFP4TensorPerToken] = [ - nvfp4_per_token_quantize(s, rowwise=rowwise, columnwise=columnwise) - for s in splits_in + nvfp4_per_token_quantize(s, rowwise=rowwise, columnwise=columnwise) for s in splits_in ] sut: List[RefNVFP4TensorPerToken] = _group_quantize_py( @@ -194,34 +187,44 @@ def test_group_per_token_quantize_byte_equal( torch.testing.assert_close( sut[i].data.view(torch.uint8), oracle[i].data.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, msg=f"rowwise q[{i}] mismatch", ) torch.testing.assert_close( sut[i].scale.view(torch.uint8), oracle[i].scale.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, msg=f"rowwise s_dec[{i}] mismatch", ) torch.testing.assert_close( - sut[i].row_amax, oracle[i].row_amax, atol=0.0, rtol=0.0, + sut[i].row_amax, + oracle[i].row_amax, + atol=0.0, + rtol=0.0, msg=f"row_amax[{i}] mismatch", ) if columnwise: torch.testing.assert_close( sut[i].columnwise_data.view(torch.uint8), oracle[i].columnwise_data.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, msg=f"columnwise q[{i}] mismatch", ) torch.testing.assert_close( sut[i].columnwise_scale.view(torch.uint8), oracle[i].columnwise_scale.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, msg=f"columnwise s_dec[{i}] mismatch", ) torch.testing.assert_close( - sut[i].col_amax, oracle[i].col_amax, atol=0.0, rtol=0.0, + sut[i].col_amax, + oracle[i].col_amax, + atol=0.0, + rtol=0.0, msg=f"col_amax[{i}] mismatch", ) @@ -255,12 +258,16 @@ def test_group_per_token_amax_byte_equal( oracle_row.append(o.row_amax if rowwise else None) oracle_col.append(o.col_amax if columnwise else None) - row_amax_list = [ - torch.empty((M_i,), dtype=torch.float32, device=device) for M_i in split_sections - ] if rowwise else [] - col_amax_list = [ - torch.empty((K,), dtype=torch.float32, device=device) for _ in range(n) - ] if columnwise else [] + row_amax_list = ( + [torch.empty((M_i,), dtype=torch.float32, device=device) for M_i in split_sections] + if rowwise + else [] + ) + col_amax_list = ( + [torch.empty((K,), dtype=torch.float32, device=device) for _ in range(n)] + if columnwise + else [] + ) tex.nvfp4_per_token_group_amax( x_concat, split_sections, row_amax_list, col_amax_list, rowwise, columnwise @@ -269,13 +276,19 @@ def test_group_per_token_amax_byte_equal( if rowwise: for i in range(n): torch.testing.assert_close( - row_amax_list[i], oracle_row[i], atol=0.0, rtol=0.0, + row_amax_list[i], + oracle_row[i], + atol=0.0, + rtol=0.0, msg=f"row_amax[{i}] mismatch", ) if columnwise: for i in range(n): torch.testing.assert_close( - col_amax_list[i], oracle_col[i], atol=0.0, rtol=0.0, + col_amax_list[i], + oracle_col[i], + atol=0.0, + rtol=0.0, msg=f"col_amax[{i}] mismatch", ) @@ -300,18 +313,19 @@ def test_group_single_split_matches_single_tensor( if rowwise: torch.testing.assert_close(sut.data, oracle.data, atol=0.0, rtol=0.0) torch.testing.assert_close( - sut.scale.view(torch.uint8), oracle.scale.view(torch.uint8), - atol=0.0, rtol=0.0, + sut.scale.view(torch.uint8), + oracle.scale.view(torch.uint8), + atol=0.0, + rtol=0.0, ) torch.testing.assert_close(sut.row_amax, oracle.row_amax, atol=0.0, rtol=0.0) if columnwise: - torch.testing.assert_close( - sut.columnwise_data, oracle.columnwise_data, atol=0.0, rtol=0.0 - ) + torch.testing.assert_close(sut.columnwise_data, oracle.columnwise_data, atol=0.0, rtol=0.0) torch.testing.assert_close( sut.columnwise_scale.view(torch.uint8), oracle.columnwise_scale.view(torch.uint8), - atol=0.0, rtol=0.0, + atol=0.0, + rtol=0.0, ) torch.testing.assert_close(sut.col_amax, oracle.col_amax, atol=0.0, rtol=0.0) diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu index 57b8232c92..c9e7eec391 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu @@ -29,11 +29,11 @@ #include +#include "common/cast/core/common.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" #include "common/common.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "common/cast/core/common.cuh" -#include "common/cast/nvfp4/core_nvfp4.cuh" namespace transformer_engine { namespace nvfp4_per_token { @@ -41,77 +41,76 @@ namespace nvfp4_per_token { #if FP4_TYPE_SUPPORTED using dispatch::common::align_smem_ptr_per_TMA_requirements; +using dispatch::nvfp4::nvfp4_scale_t; using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; -using dispatch::nvfp4::nvfp4_scale_t; -constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input -constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input -constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height -constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width -constexpr int THREADS_NUM = 128; // threads per CTA -constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM -constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) -constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int THREADS_NUM = 128; // threads per CTA +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; // = 2 ping-pong input buffers // Derived (chunk / tile / stage) constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; // 2 constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; // 2 -constexpr int STAGES = TILES_Y * TILES_X; // 4 +constexpr int STAGES = TILES_Y * TILES_X; // 4 constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 inner blocks per row of the chunk constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 inner blocks per col of the chunk -constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 -constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 // Encode helpers' thread layout (rowwise pass: 4x32 = K-dim x M-dim) -constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 -constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 -constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = + SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 // Encode helpers' thread layout (colwise pass: tid.X for col, warp for M-block) -constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 cols per warp -constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 (warps) +constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 cols per warp +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 (warps) // Buffer dimensions (input bf16 SMEM tiles + FP4 output SMEM tiles for TMA store) constexpr int BUFF_IN_DIM_Y = TILE_DIM_Y; constexpr int BUFF_IN_DIM_X = TILE_DIM_X; -constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; -constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) -constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; -constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 -constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; -constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) -constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) +constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose // Manual swizzling parameters to reduce SMEM bank conflicts on rowwise loads constexpr int PACK_SIZE = 8; -constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 -constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 -using IType = bf16; +using IType = bf16; using IType2 = ptx::FPx2; // = ptx::bf16x2 -using IType3D = IType [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; -using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using IType3D = IType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; -using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; // Compute the per-block (1x16) byte-equal arithmetic and emit FP4 codes into // SMEM rowwise output buffer + e4m3 scale into SMEM rowwise scale buffer. __device__ __forceinline__ void rowwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_ptr, + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, nvfp4_scale_t* __restrict__ sSFrowwise_ptr, - const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out) { + const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row + const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { const auto& sIn = *reinterpret_cast(sIn_ptr); auto& sOut = *reinterpret_cast(sOut_ptr); auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -119,12 +118,14 @@ __device__ __forceinline__ void rowwise_scaling_per_token( const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 - const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) + const int thread_offset_X_rowwise = + tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) - const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here + const int SF_thread_offset_rowwise_X = + tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); const int stage_rowwise_scales_offset_X = @@ -156,8 +157,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); } } - const float block_amax = static_cast( - __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); // Byte-equal compute path (matches the Python reference in // ``NVFP4QuantizerPerTokenRef``): @@ -200,33 +201,33 @@ __device__ __forceinline__ void rowwise_scaling_per_token( // Compute the per-block (1x16, along M) byte-equal arithmetic for the columnwise // pass; emit transposed FP4 + e4m3 scale into SMEM. __device__ __forceinline__ void colwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_tr_ptr, + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, nvfp4_scale_t* __restrict__ sSFcolwise_ptr, - const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out_tr) { + const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col + const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr) { const auto& sIn2x = *reinterpret_cast(sIn_ptr); auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); - const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) - const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) + const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) - const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 - const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 + const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) const int in_thread_offset_Y = thread_offset_Y_colwise; - const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] + const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] - const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y - const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index + const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index - const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) - const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index + const int scale_tr_offset_Y = + (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) + const int scale_tr_offset_X = + (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index __align__(8) IType rIn[2][SCALE_DIM]; // Read 2 columns x 16 rows, accumulate per-column amax. @@ -281,17 +282,14 @@ __device__ __forceinline__ void colwise_scaling_per_token( // Kernel 2: per-token encode (rowwise + optional colwise transpose). // ============================================================================= template -__global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t* const scales_ptr, - nvfp4_scale_t* const scales_t_ptr, - const float* const row_amax_in, - const float* const col_amax_in, - const float* noop, - const size_t rows, const size_t cols, - const size_t scale_stride, const size_t scale_stride_t) { +__global__ void __launch_bounds__(THREADS_NUM) + per_token_encode_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t* const scales_ptr, nvfp4_scale_t* const scales_t_ptr, + const float* const row_amax_in, const float* const col_amax_in, + const float* noop, const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -322,24 +320,26 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; constexpr int out_mem_colwise_scales = DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; extern __shared__ unsigned char dynamic_shmem[]; unsigned char* dshmem = align_smem_ptr_per_TMA_requirements(dynamic_shmem); - IType* sIn_ptr = reinterpret_cast(dshmem); - fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); - fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data); + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data - + out_mem_rowwise_scales); + nvfp4_scale_t* sSFcolwise_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + + out_mem_colwise_data + out_mem_rowwise_scales); // Per-CTA row/col amax SMEM cache (128 floats each). __shared__ float sRowAmax[CHUNK_DIM_Y]; @@ -393,8 +393,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( uint64_t* dst = reinterpret_cast(&sIn[buff_in]); const uint64_t* src = reinterpret_cast(&tensor_map_input); ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - dst, src, global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); } } @@ -422,18 +422,17 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( if (leading_thread) { uint64_t* dst = reinterpret_cast(&sIn[next_prefetch_buff]); const uint64_t* src = reinterpret_cast(&tensor_map_input); - ptx::mbarrier_arrive_expect_tx( - &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - dst, src, next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, next_global_offset_X, + next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for current stage's input to land. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // Wait for any prior TMA store to have finished reading the output SMEM @@ -442,12 +441,12 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( // ----- Compute: rowwise + colwise from the same SMEM tile ----- if (DO_ROW) { - rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, - sRowAmax, stage_Y, stage_X, buff_in, buff_out); + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, + buff_in, buff_out); } if (DO_COL) { - colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, - sColAmax, stage_Y, stage_X, buff_in, buff_out_tr); + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, stage_X, + buff_in, buff_out_tr); } // Fence + sync so all threads' SMEM writes are visible to TMA store. @@ -464,22 +463,20 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( if (DO_ROW) { auto& sOut = *reinterpret_cast(sOut_ptr); ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), - global_offset_X, global_offset_Y, + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, reinterpret_cast(&sOut[buff_out])); } if (DO_COL) { auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_t), - global_offset_X_tr, global_offset_Y_tr, - reinterpret_cast(&sOut_tr[buff_out_tr])); + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); } ptx::cp_async_bulk_commit_group(); } - buff_in = (buff_in + 1) % BUFFS_NUM; - buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; } // end of stages @@ -495,8 +492,7 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( const size_t row_global = scales_block_offset_Y_rowwise + row; if (row_global < rows) { ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); - const size_t scale_idx_global = - row_global * scale_stride + scales_block_offset_X_rowwise; + const size_t scale_idx_global = row_global * scale_stride + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); } } @@ -511,8 +507,7 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); - const size_t scale_idx_global = - row_tr_global * scale_stride_t + scales_block_offset_X_tr; + const size_t scale_idx_global = row_tr_global * scale_stride_t + scales_block_offset_X_tr; scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); } } @@ -545,12 +540,11 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( // After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. // ============================================================================= template -__global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( - const __grid_constant__ CUtensorMap tensor_map_input, - float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW - float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL - const float* noop, - const size_t rows, const size_t cols) { +__global__ void __launch_bounds__(THREADS_NUM) + per_token_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW + float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL + const float* noop, const size_t rows, const size_t cols) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -585,8 +579,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( // i.e., this thread contributes to row partial in stages // where stage_Y == tid / 64. // col owned: col_base + tid -> stage_X == tid / 64. - const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 - const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 + const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 + const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 const int my_row_in_subtile = tid % TILE_DIM_Y; // 0..63 const int my_col_in_subtile = tid % TILE_DIM_X; // 0..63 @@ -611,8 +605,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in]), - reinterpret_cast(&tensor_map_input), - global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); } } @@ -633,20 +627,18 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx( - &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), - next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for this stage's tile. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // ----- Row partial update: walk this thread's row across the sub-tile ----- @@ -670,8 +662,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( for (int p = 0; p < 4; ++p) { ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); } - local_max = fmaxf(local_max, - static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + local_max = + fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); } row_partial = local_max; } @@ -724,8 +716,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( #if FP4_TYPE_SUPPORTED // Launch Kernel 1 (amax). Writes only to output->amax / output->columnwise_amax; // other output fields untouched. Pre-zeroes the amax buffers (atomicMax identity). -inline void launch_amax(const Tensor& input, Tensor* output, - const Tensor& noop, cudaStream_t stream) { +inline void launch_amax(const Tensor& input, Tensor* output, const Tensor& noop, + cudaStream_t stream) { const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); @@ -735,46 +727,41 @@ inline void launch_amax(const Tensor& input, Tensor* output, // Pre-zero amax buffers (atomicMaxFloat identity for non-negative values). if (do_row) { - NVTE_CHECK(output->amax.numel() == M, - "Per-token amax: output->amax numel must equal M = ", M, + NVTE_CHECK(output->amax.numel() == M, "Per-token amax: output->amax numel must equal M = ", M, ", got ", output->amax.numel()); NVTE_CHECK_CUDA(cudaMemsetAsync(output->amax.dptr, 0, M * sizeof(float), stream)); } if (do_col) { NVTE_CHECK(output->columnwise_amax.numel() == K, - "Per-token amax: output->columnwise_amax numel must equal K = ", K, - ", got ", output->columnwise_amax.numel()); + "Per-token amax: output->columnwise_amax numel must equal K = ", K, ", got ", + output->columnwise_amax.numel()); NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_amax.dptr, 0, K * sizeof(float), stream)); } checkCuDriverContext(stream); alignas(64) CUtensorMap tmap_in{}; - create_2D_tensor_map(tmap_in, input.data, M, K, - TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - const float* noop_ptr = (noop.data.dptr != nullptr) - ? reinterpret_cast(noop.data.dptr) - : nullptr; + const float* noop_ptr = + (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { auto kernel = per_token_amax_kernel; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); kernel<<>>( - tmap_in, - do_row ? reinterpret_cast(output->amax.dptr) : nullptr, - do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, - noop_ptr, M, K); + tmap_in, do_row ? reinterpret_cast(output->amax.dptr) : nullptr, + do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, noop_ptr, M, + K); });); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -782,8 +769,8 @@ inline void launch_amax(const Tensor& input, Tensor* output, // Launch Kernel 2 (encode). Requires output->amax / columnwise_amax to be pre-filled // (by a prior launch_amax call or by an external caller); writes // output->data / scale_inv / columnwise_data / columnwise_scale_inv. -inline void launch_encode(const Tensor& input, Tensor* output, - const Tensor& noop, cudaStream_t stream) { +inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noop, + cudaStream_t stream) { const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); @@ -814,15 +801,13 @@ inline void launch_encode(const Tensor& input, Tensor* output, alignas(64) CUtensorMap tmap_out{}; alignas(64) CUtensorMap tmap_out_t{}; - create_2D_tensor_map(tmap_in, input.data, M, K, - TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); if (do_row) { - create_2D_tensor_map(tmap_out, output->data, M, K, - TILE_DIM_Y, TILE_DIM_X, K, 0, 4); + create_2D_tensor_map(tmap_out, output->data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, 4); } if (do_col) { - create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, - TILE_DIM_X, TILE_DIM_Y, M, 0, 4); + create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, TILE_DIM_X, TILE_DIM_Y, M, 0, + 4); } constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; @@ -832,29 +817,21 @@ inline void launch_encode(const Tensor& input, Tensor* output, DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales = - DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales_t = - DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_t = DIVUP_TO_MULTIPLE( + CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); // Total dyn SMEM: input + output FP4 (row + col) + SF (row + col) + 128B align. - const int dshmem_size = - buff_size_aligned_in - + (do_row ? buff_size_aligned_out : 0) - + (do_col ? buff_size_aligned_out_t : 0) - + (do_row ? buff_size_scales : 0) - + (do_col ? buff_size_scales_t : 0) - + TMA_SHMEM_ALIGNMENT; - - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(M / CHUNK_DIM_Y), 1); + const int dshmem_size = buff_size_aligned_in + (do_row ? buff_size_aligned_out : 0) + + (do_col ? buff_size_aligned_out_t : 0) + (do_row ? buff_size_scales : 0) + + (do_col ? buff_size_scales_t : 0) + TMA_SHMEM_ALIGNMENT; + + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - const float* noop_ptr = (noop.data.dptr != nullptr) - ? reinterpret_cast(noop.data.dptr) - : nullptr; + const float* noop_ptr = + (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; const size_t scale_stride = do_row ? output->scale_inv.shape[1] : 0; const size_t scale_stride_t = do_col ? output->columnwise_scale_inv.shape[1] : 0; @@ -862,20 +839,17 @@ inline void launch_encode(const Tensor& input, Tensor* output, do_row ? reinterpret_cast(output->scale_inv.dptr) : nullptr; nvfp4_scale_t* scales_t_ptr = do_col ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const float* row_amax_in = - do_row ? reinterpret_cast(output->amax.dptr) : nullptr; + const float* row_amax_in = do_row ? reinterpret_cast(output->amax.dptr) : nullptr; const float* col_amax_in = do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr; - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { auto kernel = per_token_encode_kernel; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tmap_in, tmap_out, tmap_out_t, - scales_ptr, scales_t_ptr, - row_amax_in, col_amax_in, - noop_ptr, M, K, scale_stride, scale_stride_t); + kernel<<>>(tmap_in, tmap_out, tmap_out_t, scales_ptr, + scales_t_ptr, row_amax_in, col_amax_in, + noop_ptr, M, K, scale_stride, scale_stride_t); });); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -892,15 +866,14 @@ inline void launch_encode(const Tensor& input, Tensor* output, // Output constraints differ by entry point (see validate_*_output helpers below). inline void validate_input_shape(const Tensor& input) { NVTE_CHECK(input.has_data(), "Per-token cast: input has no data."); - NVTE_CHECK(input.dtype() == DType::kBFloat16, - "Per-token cast is bf16-only. Got dtype enum ", + NVTE_CHECK(input.dtype() == DType::kBFloat16, "Per-token cast is bf16-only. Got dtype enum ", static_cast(input.dtype())); const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); - NVTE_CHECK(M % CHUNK_DIM_Y == 0, - "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, ", got M=", M); - NVTE_CHECK(K % CHUNK_DIM_X == 0, - "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, ", got K=", K); + NVTE_CHECK(M % CHUNK_DIM_Y == 0, "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, + ", got M=", M); + NVTE_CHECK(K % CHUNK_DIM_X == 0, "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, + ", got K=", K); } // K1 (amax-only) requires at least one amax buffer allocated; FP4 output is not used. @@ -919,24 +892,24 @@ inline void validate_encode_output(const Tensor* output) { "Per-token cast emits compact (non-swizzled) inner SF."); } -void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, - Tensor* output, cudaStream_t stream) { +void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + cudaStream_t stream) { validate_input_shape(input); validate_amax_output(output); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; launch_amax(input, output, noop, stream); } -void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, - Tensor* output, cudaStream_t stream) { +void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + cudaStream_t stream) { validate_input_shape(input); validate_encode_output(output); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; launch_encode(input, output, noop, stream); } -void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, - Tensor* output, cudaStream_t stream) { +void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + cudaStream_t stream) { validate_input_shape(input); validate_encode_output(output); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; @@ -967,8 +940,8 @@ bool can_use_per_token(size_t, size_t, DType) { return false; } // C-API entry points // ============================================================================= -void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, - NVTETensor output, cudaStream_t stream) { +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_amax); using namespace transformer_engine; @@ -976,16 +949,18 @@ void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, Tensor* output_tensor = convertNVTETensorCheck(output); Tensor dummy_noop; const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; - nvfp4_per_token::per_token_amax_blocked_impl( - *input_tensor, *noop_tensor, output_tensor, stream); + nvfp4_per_token::per_token_amax_blocked_impl(*input_tensor, *noop_tensor, output_tensor, stream); #else - (void)input; (void)noop; (void)output; (void)stream; + (void)input; + (void)noop; + (void)output; + (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } -void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, - NVTETensor output, cudaStream_t stream) { +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_encode); using namespace transformer_engine; @@ -993,16 +968,19 @@ void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, Tensor* output_tensor = convertNVTETensorCheck(output); Tensor dummy_noop; const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; - nvfp4_per_token::per_token_encode_blocked_impl( - *input_tensor, *noop_tensor, output_tensor, stream); + nvfp4_per_token::per_token_encode_blocked_impl(*input_tensor, *noop_tensor, output_tensor, + stream); #else - (void)input; (void)noop; (void)output; (void)stream; + (void)input; + (void)noop; + (void)output; + (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } -void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, - NVTETensor output, cudaStream_t stream) { +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_quantize); using namespace transformer_engine; @@ -1010,10 +988,13 @@ void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop Tensor* output_tensor = convertNVTETensorCheck(output); Tensor dummy_noop; const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; - nvfp4_per_token::per_token_quantize_blocked_impl( - *input_tensor, *noop_tensor, output_tensor, stream); + nvfp4_per_token::per_token_quantize_blocked_impl(*input_tensor, *noop_tensor, output_tensor, + stream); #else - (void)input; (void)noop; (void)output; (void)stream; + (void)input; + (void)noop; + (void)output; + (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu index 9e5ede9ff2..69d00fb139 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu @@ -30,9 +30,9 @@ namespace nvfp4_per_token_group { #if FP4_TYPE_SUPPORTED +using dispatch::nvfp4::nvfp4_scale_t; using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; -using dispatch::nvfp4::nvfp4_scale_t; using ptx::FPx2; constexpr int kInnerK = 16; // NVFP4 inner block: 16 elements per e4m3 SF @@ -43,14 +43,14 @@ constexpr int kMaxTensorsPerKernel = 64; // Per-launch arg table; passed as __grid_constant__ for constant-cache reads. struct NVFP4PerTokenMultiArgs { // K1 outputs (per-tensor pointers; one fp32 array per tensor) - void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) - void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) + void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) + void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) // K2 outputs (per-tensor pointers; FP4 codes + e4m3 inner SF) - void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) - void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) - void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) - void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) + void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) + void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) + void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) + void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) // Shared layout info int split_sections_range[kMaxTensorsPerKernel + 1]; // prefix sum w/ leading 0 @@ -69,10 +69,10 @@ __device__ __forceinline__ int GetTensorId(const NVFP4PerTokenMultiArgs& args, i // per-tensor buffer via tensor_id lookup at CTA entry. namespace fused { -constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows -constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols -constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height -constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width constexpr int THREADS_NUM = 128; constexpr int PREFETCH_STAGES = 1; constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; @@ -90,15 +90,14 @@ using FusedIType3D = FusedIType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; // Pre-zero amax buffers (identity for atomicMax). template -__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, - int K) { +__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, int K) { const int tensor_id = blockIdx.x; if (tensor_id >= args.num_tensors) return; if (DO_ROW) { float* row_amax = reinterpret_cast(args.row_amax_list[tensor_id]); if (row_amax != nullptr) { - const int M_i = args.split_sections_range[tensor_id + 1] - - args.split_sections_range[tensor_id]; + const int M_i = + args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; for (int m = threadIdx.x; m < M_i; m += blockDim.x) { row_amax[m] = 0.0f; } @@ -118,8 +117,7 @@ template __global__ void __launch_bounds__(THREADS_NUM) group_per_token_fused_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ NVFP4PerTokenMultiArgs args, - const float* noop, const size_t rows, - const size_t cols) { + const float* noop, const size_t rows, const size_t cols) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -133,8 +131,7 @@ __global__ void __launch_bounds__(THREADS_NUM) DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); extern __shared__ unsigned char dynamic_shmem[]; - unsigned char* dshmem = - dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); FusedIType* sIn_ptr = reinterpret_cast(dshmem); auto& sIn = *reinterpret_cast(sIn_ptr); @@ -149,10 +146,8 @@ __global__ void __launch_bounds__(THREADS_NUM) // Tile lies fully inside one tensor (split_sections[i] % 128 == 0). const int tensor_id = GetTensorId(args, block_offset_Y); const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; - float* row_amax_out = - DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; - float* col_amax_out = - DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + float* row_amax_out = DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + float* col_amax_out = DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; // Each thread owns chunk-row `tid` (for row amax) and chunk-col `tid` (for col amax). float row_partial = 0.f; @@ -183,8 +178,8 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in]), - reinterpret_cast(&tensor_map_input), global_offset_X, - global_offset_Y, &IN_buff_readable_mbar[buff_in]); + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); } } @@ -205,20 +200,18 @@ __global__ void __launch_bounds__(THREADS_NUM) const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], - shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), - next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for this stage's tile. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // Row partial: rotate e-iter by bank group to split warp into 8 groups. @@ -235,8 +228,8 @@ __global__ void __launch_bounds__(THREADS_NUM) for (int p = 0; p < 4; ++p) { ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); } - local_max = fmaxf(local_max, static_cast( - __hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + local_max = + fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); } row_partial = local_max; } @@ -282,56 +275,53 @@ __global__ void __launch_bounds__(THREADS_NUM) // K2 (encode) constants + helpers; byte-equal port of the single-tensor // per-token cooperative 4x32 / 32x4 threading + ld_shared_b128 + mul_cvt_4x. -constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM -constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 -constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 -constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 // Rowwise pass: 4 (K-dim) x 32 (M-dim) -> 1 NVFP4 block per thread. -constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 -constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 -constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 -constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 +constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 // Colwise pass: tid.X = col-pair, warp = M-block (32 x 4). -constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 -constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 +constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 // Output / SF SMEM buffer dims (sub-tile sized, double-buffered for ping-pong). -constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; -constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) -constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; -constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 -constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; -constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 -constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 +constexpr int BUFFS_NUM_OUT_TR = 2; // Manual SMEM swizzling parameters (matches single-tensor encode kernel). -constexpr int PACK_SIZE = 8; -constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 -constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 -constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 - -using IType = FusedIType; -using IType2 = FusedIType2; -using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; -using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 + +using IType = FusedIType; +using IType2 = FusedIType2; +using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; -using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; // Rowwise encode helper: reads sRowAmax (pre-populated by K1), writes FP4 + // e4m3 SFs into sOut / sSFrowwise. Byte-equal to the single-tensor version. __device__ __forceinline__ void rowwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_ptr, - nvfp4_scale_t* __restrict__ sSFrowwise_ptr, - const float* __restrict__ sRowAmax, - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out) { + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, + nvfp4_scale_t* __restrict__ sSFrowwise_ptr, const float* __restrict__ sRowAmax, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { const auto& sIn = *reinterpret_cast(sIn_ptr); auto& sOut = *reinterpret_cast(sOut_ptr); auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -339,8 +329,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; @@ -373,8 +363,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); } } - const float block_amax = static_cast( - __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax, S_enc); const float s_dec_f = static_cast(s_dec); @@ -404,21 +394,18 @@ __device__ __forceinline__ void rowwise_scaling_per_token( // Colwise encode helper. Byte-equal to the single-tensor version. __device__ __forceinline__ void colwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_tr_ptr, - nvfp4_scale_t* __restrict__ sSFcolwise_ptr, - const float* __restrict__ sColAmax, - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out_tr) { + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, + nvfp4_scale_t* __restrict__ sSFcolwise_ptr, const float* __restrict__ sColAmax, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr) { const auto& sIn2x = *reinterpret_cast(sIn_ptr); auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); - const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 - const int tid_X_colwise = thread_lane; // 0..31 + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 + const int tid_X_colwise = thread_lane; // 0..31 const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; const int thread_offset_X_colwise = tid_X_colwise * 2; @@ -480,8 +467,7 @@ template __global__ void __launch_bounds__(THREADS_NUM) group_per_token_fused_cast_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ NVFP4PerTokenMultiArgs args, - const float* noop, const size_t rows, - const size_t cols) { + const float* noop, const size_t rows, const size_t cols) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -503,25 +489,26 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; constexpr int out_mem_colwise_scales = DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; (void)out_mem_colwise_scales; extern __shared__ unsigned char dynamic_shmem[]; - unsigned char* dshmem = - dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); - IType* sIn_ptr = reinterpret_cast(dshmem); - fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); - fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data); + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data - + out_mem_rowwise_scales); + nvfp4_scale_t* sSFcolwise_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + + out_mem_colwise_data + out_mem_rowwise_scales); __shared__ float sRowAmax[CHUNK_DIM_Y]; __shared__ float sColAmax[CHUNK_DIM_X]; @@ -539,22 +526,21 @@ __global__ void __launch_bounds__(THREADS_NUM) // Chunk Y stays inside one tensor (split_sections[i] % 128 == 0). const int tensor_id = GetTensorId(args, block_offset_Y); const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; - const int M_t = args.split_sections_range[tensor_id + 1] - - args.split_sections_range[tensor_id]; + const int M_t = args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; // Per-tensor output bases (one constant-cache lookup per CTA). - uint8_t* const q_row_base = DO_ROW - ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; - uint8_t* const q_col_base = DO_COL - ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; - nvfp4_scale_t* const s_dec_row_base = DO_ROW - ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; - nvfp4_scale_t* const s_dec_col_base = DO_COL - ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; - const float* const row_amax_base = DO_ROW - ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; - const float* const col_amax_base = DO_COL - ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + uint8_t* const q_row_base = + DO_ROW ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; + uint8_t* const q_col_base = + DO_COL ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_row_base = + DO_ROW ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_col_base = + DO_COL ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; + const float* const row_amax_base = + DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + const float* const col_amax_base = + DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; const size_t data_stride_row = static_cast(cols) / 2; const size_t data_stride_col = static_cast(M_t) / 2; @@ -590,13 +576,13 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in_p], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in_p]), - reinterpret_cast(&tensor_map_input), global_offset_X, - global_offset_Y, &IN_buff_readable_mbar[buff_in_p]); + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in_p]); } } - int buff_in = 0; - int buff_out = 0; + int buff_in = 0; + int buff_out = 0; int buff_out_tr = 0; int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; @@ -613,30 +599,28 @@ __global__ void __launch_bounds__(THREADS_NUM) const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], - shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), - next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for current stage's input tile to land. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // 4x32 cooperative row + col encode helpers. if (DO_ROW) { - rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, - sRowAmax, stage_Y, stage_X, buff_in, buff_out); + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, + buff_in, buff_out); } if (DO_COL) { - colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, - sColAmax, stage_Y, stage_X, buff_in, buff_out_tr); + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, stage_X, + buff_in, buff_out_tr); } // Make helper SMEM writes visible before the scatter epilogue. @@ -647,13 +631,10 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_ROW) { auto& sOut = *reinterpret_cast(sOut_ptr); const int row_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 - const int half = static_cast(threadIdx.x) & 1; // 0..1 - const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; - const int byte_off_X = (block_offset_X / 2) - + stage_X * (TILE_DIM_X / 2) - + half * 16; - const uint4* src = reinterpret_cast( - &sOut[buff_out][row_in_subtile][half * 16]); + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; + const int byte_off_X = (block_offset_X / 2) + stage_X * (TILE_DIM_X / 2) + half * 16; + const uint4* src = reinterpret_cast(&sOut[buff_out][row_in_subtile][half * 16]); uint4* dst = reinterpret_cast( q_row_base + static_cast(local_row) * data_stride_row + byte_off_X); *dst = *src; @@ -661,13 +642,11 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_COL) { auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); const int col_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 - const int half = static_cast(threadIdx.x) & 1; // 0..1 - const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; - const int byte_off_M = (local_row_base / 2) - + stage_Y * (TILE_DIM_Y / 2) - + half * 16; - const uint4* src = reinterpret_cast( - &sOut_tr[buff_out_tr][col_in_subtile][half * 16]); + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; + const int byte_off_M = (local_row_base / 2) + stage_Y * (TILE_DIM_Y / 2) + half * 16; + const uint4* src = + reinterpret_cast(&sOut_tr[buff_out_tr][col_in_subtile][half * 16]); uint4* dst = reinterpret_cast( q_col_base + static_cast(global_col) * data_stride_col + byte_off_M); *dst = *src; @@ -676,8 +655,8 @@ __global__ void __launch_bounds__(THREADS_NUM) // Sync so the scatter completes before next stage overwrites the buffer. __syncthreads(); - buff_in = (buff_in + 1) % BUFFS_NUM; - buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; } @@ -685,13 +664,11 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_ROW) { auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); using ScalesVec = Vec; - const size_t scales_block_offset_X_rowwise = - static_cast(ctaid_X) * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_X_rowwise = static_cast(ctaid_X) * SCALES_PER_CHUNK_X; for (int row = static_cast(threadIdx.x); row < CHUNK_DIM_Y; row += THREADS_NUM) { ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); const size_t local_row = static_cast(local_row_base) + row; - const size_t scale_idx_global = - local_row * scale_stride_row + scales_block_offset_X_rowwise; + const size_t scale_idx_global = local_row * scale_stride_row + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&s_dec_row_base[scale_idx_global], 0, SCALES_PER_CHUNK_X); } } @@ -700,12 +677,10 @@ __global__ void __launch_bounds__(THREADS_NUM) using ScalesVec = Vec; // M-block offset within s_dec_col[global_col] (shape (K, M_i/16) row-major). const size_t local_block_offset_M = static_cast(local_row_base) / SCALE_DIM; - for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; - row_tr += THREADS_NUM) { + for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); const size_t global_col = static_cast(block_offset_X) + row_tr; - const size_t scale_idx_global = - global_col * scale_stride_col + local_block_offset_M; + const size_t scale_idx_global = global_col * scale_stride_col + local_block_offset_M; scales_vec.store_to_elts(&s_dec_col_base[scale_idx_global], 0, SCALES_PER_CHUNK_Y); } } @@ -728,9 +703,9 @@ __global__ void __launch_bounds__(THREADS_NUM) // Host launcher for the fused K2 path. bf16-only. inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, - const SimpleTensor& input_data, int sum_M, - int K, bool do_row, bool do_col, - const float* noop, cudaStream_t stream) { + const SimpleTensor& input_data, int sum_M, int K, + bool do_row, bool do_col, const float* noop, + cudaStream_t stream) { if (!do_row && !do_col) return; checkCuDriverContext(stream); @@ -739,44 +714,41 @@ inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, create_2D_tensor_map(tmap_in, input_data, sum_M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(FusedIType) * 8); - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { - constexpr int sz_in = DIVUP_TO_MULTIPLE( - BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); - constexpr int sz_out_r = DO_ROW - ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; - constexpr int sz_out_c = DO_COL - ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int sz_sf_r = DO_ROW - ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int sz_sf_c = DO_COL - ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int dshmem_size = sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c - + TMA_SHMEM_ALIGNMENT; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + constexpr int sz_in = + DIVUP_TO_MULTIPLE(BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); + constexpr int sz_out_r = + DO_ROW ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; + constexpr int sz_out_c = + DO_COL ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_r = + DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_c = + DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int dshmem_size = + sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c + TMA_SHMEM_ALIGNMENT; auto kernel = group_per_token_fused_cast_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>(tmap_in, args, noop, - static_cast(sum_M), - static_cast(K)); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, args, noop, static_cast(sum_M), static_cast(K)); });); NVTE_CHECK_CUDA(cudaGetLastError()); } // Host launcher for the fused K1 path. bf16-only. inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, - const SimpleTensor& input_data, int sum_M, - int K, bool do_row, bool do_col, - const float* noop, cudaStream_t stream) { + const SimpleTensor& input_data, int sum_M, int K, + bool do_row, bool do_col, const float* noop, + cudaStream_t stream) { if (!do_row && !do_col) return; // Pre-zero amax slots (atomicMax identity). @@ -807,18 +779,15 @@ inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { auto kernel = group_per_token_fused_amax_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>(tmap_in, args, noop, - static_cast(sum_M), - static_cast(K)); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, args, noop, static_cast(sum_M), static_cast(K)); });); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -842,19 +811,18 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, args->split_sections_range[0] = 0; for (size_t i = 0; i < num_tensors; ++i) { Tensor* o = outputs[i]; - NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, - "] = ", split_sections[i], " must be a multiple of 128"); + NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, "] = ", split_sections[i], + " must be a multiple of 128"); args->split_sections_range[i + 1] = args->split_sections_range[i] + static_cast(split_sections[i]); if (split_sections[i] == 0) continue; if (which_buffers & kBufRowAmax) { - NVTE_CHECK(o->amax.dptr != nullptr, - "NVFP4 per-token grouped: outputs[", i, "].amax must be allocated for rowwise"); + NVTE_CHECK(o->amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, + "].amax must be allocated for rowwise"); args->row_amax_list[i] = o->amax.dptr; } if (which_buffers & kBufColAmax) { - NVTE_CHECK(o->columnwise_amax.dptr != nullptr, - "NVFP4 per-token grouped: outputs[", i, + NVTE_CHECK(o->columnwise_amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, "].columnwise_amax must be allocated for columnwise"); args->col_amax_list[i] = o->columnwise_amax.dptr; } @@ -866,10 +834,9 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, args->s_dec_row_list[i] = o->scale_inv.dptr; } if (which_buffers & kBufColCast) { - NVTE_CHECK( - o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, - "NVFP4 per-token grouped: outputs[", i, - "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); + NVTE_CHECK(o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); args->q_col_list[i] = o->columnwise_data.dptr; args->s_dec_col_list[i] = o->columnwise_scale_inv.dptr; } @@ -898,8 +865,7 @@ void quantize_per_token_grouped(const Tensor& input, std::vector& outpu const int sum_M = static_cast(input.flat_first_dim()); const int K = static_cast(input.flat_last_dim()); if (sum_M == 0 || K == 0) return; - NVTE_CHECK(K % 128 == 0, - "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); + NVTE_CHECK(K % 128 == 0, "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); int which_buffers = 0; if ((do_amax || do_cast) && rowwise) which_buffers |= kBufRowAmax; @@ -945,16 +911,16 @@ std::vector collect_outputs(NVTETensor* outputs, si } // namespace void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, cudaStream_t stream) { + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_amax); using namespace transformer_engine; if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, - rowwise, columnwise, + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise, + columnwise, /*do_amax=*/true, /*do_cast=*/false, stream); #else (void)input; @@ -969,16 +935,16 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs } void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, cudaStream_t stream) { + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_cast); using namespace transformer_engine; if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, - rowwise, columnwise, + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise, + columnwise, /*do_amax=*/false, /*do_cast=*/true, stream); #else (void)input; @@ -1001,8 +967,8 @@ void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* out if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, - rowwise, columnwise, + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise, + columnwise, /*do_amax=*/true, /*do_cast=*/true, stream); #else (void)input; diff --git a/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu b/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu index 849b76a6eb..4f4d22d22a 100644 --- a/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu +++ b/transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu @@ -31,8 +31,7 @@ constexpr int kElemsPerThread = 8; // bf16x8 = 16-byte vector constexpr int kThreadsX = kTileCols / kElemsPerThread; constexpr int kThreadsY = kTileRows; constexpr int kThreadsPerBlock = kThreadsX * kThreadsY; -static_assert(kTileCols % kElemsPerThread == 0, - "kTileCols must be a multiple of kElemsPerThread"); +static_assert(kTileCols % kElemsPerThread == 0, "kTileCols must be a multiple of kElemsPerThread"); static_assert(kElemsPerThread * sizeof(__nv_bfloat16) == sizeof(int4), "kElemsPerThread bf16 must pack into a single int4 (16 bytes)"); @@ -102,8 +101,8 @@ void per_token_post_scale(Tensor* d, const Tensor& row_amax_a, const Tensor& row "NVFP4 per-token post-scale: row_amax_b must be FP32."); const auto& d_shape = d->data.shape; - NVTE_CHECK(d_shape.size() == 2, "NVFP4 per-token post-scale: d must be 2D, got rank=", - d_shape.size()); + NVTE_CHECK(d_shape.size() == 2, + "NVFP4 per-token post-scale: d must be 2D, got rank=", d_shape.size()); const int M = static_cast(d_shape[0]); const int N = static_cast(d_shape[1]); NVTE_CHECK(row_amax_a.data.numel() == static_cast(M), diff --git a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h index 8743a40afa..e21be53047 100644 --- a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h +++ b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h @@ -15,31 +15,30 @@ extern "C" { #endif - /*! \brief Composite K1+K2: per-row + per-col amax (K1) then FP4 + 1x16 * e4m3 SF encode (K2), back-to-back on the same stream. * * This is the production entry point for the per-token cast on bf16 + * 128-aligned shapes. */ -void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, - NVTETensor output, cudaStream_t stream); +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, + cudaStream_t stream); /*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. * Pre-zeroes the amax buffers and merges per-CTA partials into * ``output->amax`` (size [M]) / ``output->columnwise_amax`` * (size [K]). Does NOT touch FP4 data / scale_inv slots. */ -void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, - NVTETensor output, cudaStream_t stream); +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, + cudaStream_t stream); /*! \brief Kernel 2 in isolation: FP4 + 1x16 e4m3 SF encode given a * pre-filled ``output->amax`` / ``output->columnwise_amax``. Reads * the outer amax buffer(s) and writes the FP4 data / scale_inv * tensors only. */ -void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, - NVTETensor output, cudaStream_t stream); +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, + cudaStream_t stream); /*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``. * @@ -59,8 +58,7 @@ int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum); * d[i, j] = d[i, j] * row_amax_a[i] * row_amax_b[j] */ void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, - const NVTETensor row_amax_b, - cudaStream_t stream); + const NVTETensor row_amax_b, cudaStream_t stream); /* ============================================================================ * Grouped (multi-tensor) per-token quantize. @@ -76,8 +74,8 @@ void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, * \param[in] stream CUDA stream */ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, cudaStream_t stream); + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, cudaStream_t stream); /*! \brief Grouped per-token encode (FP4 + 1x16 e4m3 inner SF) using the * row_amax / col_amax values already populated by @@ -94,8 +92,8 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs * \param[in] stream CUDA stream */ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, cudaStream_t stream); + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, cudaStream_t stream); /*! \brief Composite K1+K2 grouped per-token quantize. Calls the amax + cast * kernels on the same stream. This is the external API diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6411811323..535fac017e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -412,8 +412,7 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void compute_amax(const at::Tensor &tensor, at::Tensor &amax); void hadamard_transform_amax(const at::Tensor &tensor, at::Tensor &rowwise_amax, - at::Tensor &columnwise_amax, - int64_t rht_matrix_random_sign_mask); + at::Tensor &columnwise_amax, int64_t rht_matrix_random_sign_mask); void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, @@ -452,47 +451,43 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi const at::Tensor &scale_inv_colwise, int rows, int cols, size_t start_offset); -void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, - at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, +void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, bool rowwise, bool columnwise); -void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, - at::Tensor col_amax, bool rowwise, bool columnwise); +void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, at::Tensor col_amax, + bool rowwise, bool columnwise); -void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, - at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise); +void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise); void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, const at::Tensor &row_amax_b); void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta); + const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, at::Tensor d, + const at::Tensor &workspace, int64_t m, int64_t n, int64_t k, + double alpha, double beta); // Bench-only per-tensor twin of nvfp4_per_token_gemm: scalar amaxes folded // into cuBLAS LT alpha via the amax slot; no trailing post-scale. void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, - const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_amax, const at::Tensor &b_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta); + const at::Tensor &a_sf, const at::Tensor &b_sf, const at::Tensor &a_amax, + const at::Tensor &b_amax, at::Tensor d, const at::Tensor &workspace, + int64_t m, int64_t n, int64_t k, double alpha, double beta); void nvfp4_per_token_group_quantize( const at::Tensor &input, const std::vector &split_sections, std::vector q_row_list, std::vector s_dec_row_list, std::vector row_amax_list, std::vector q_col_list, - std::vector s_dec_col_list, std::vector col_amax_list, - bool rowwise, bool columnwise); + std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, + bool columnwise); // Amax-only variant of the grouped quantize. Useful for multi-rank training // where amax is allReduced before the cast pass. -void nvfp4_per_token_group_amax(const at::Tensor &input, - const std::vector &split_sections, +void nvfp4_per_token_group_amax(const at::Tensor &input, const std::vector &split_sections, std::vector row_amax_list, std::vector col_amax_list, bool rowwise, bool columnwise); @@ -503,8 +498,8 @@ void nvfp4_per_token_group_amax(const at::Tensor &input, std::tuple, std::vector, std::vector, std::vector, std::vector, std::vector> nvfp4_per_token_group_quantize_bulk(const at::Tensor &input, - const std::vector &split_sections, - bool rowwise, bool columnwise); + const std::vector &split_sections, bool rowwise, + bool columnwise); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp index cb6b56d9ee..d9831ed5e7 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp @@ -25,11 +25,10 @@ namespace { // Validates the input and assembles ``out_te`` for all 3 modes; caller // dispatches to the right C-API entry on the caller's stream. -void assemble_per_token_tensors(const at::Tensor& input, - at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, - bool rowwise, bool columnwise, int mode, - TensorWrapper& in_te, TensorWrapper& out_te) { +void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, int mode, + TensorWrapper& in_te, TensorWrapper& out_te) { TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); @@ -53,8 +52,8 @@ void assemble_per_token_tensors(const at::Tensor& input, TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), "row_amax must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax.scalar_type() == at::ScalarType::Float, "row_amax must be float32"); - TORCH_CHECK(row_amax.numel() == M, - "row_amax numel mismatch: expected M=", M, ", got ", row_amax.numel()); + TORCH_CHECK(row_amax.numel() == M, "row_amax numel mismatch: expected M=", M, ", got ", + row_amax.numel()); out_te.set_amax(row_amax.data_ptr(), DType::kFloat32, std::vector{static_cast(M)}); @@ -63,15 +62,14 @@ void assemble_per_token_tensors(const at::Tensor& input, "q_row must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), "s_dec_row must be a contiguous CUDA tensor"); - TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, - "q_row must be uint8 (FP4 packed)"); + TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, "q_row must be uint8 (FP4 packed)"); TORCH_CHECK(s_dec_row.scalar_type() == at::ScalarType::Byte, "s_dec_row must be uint8 (FP8 e4m3 raw bytes)"); - TORCH_CHECK(q_row.numel() == M * K / 2, - "q_row numel mismatch: expected M*K/2=", M * K / 2, ", got ", q_row.numel()); + TORCH_CHECK(q_row.numel() == M * K / 2, "q_row numel mismatch: expected M*K/2=", M * K / 2, + ", got ", q_row.numel()); TORCH_CHECK(s_dec_row.numel() == M * K / 16, - "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, - ", got ", s_dec_row.numel()); + "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, ", got ", + s_dec_row.numel()); out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, in_shape); out_te.set_rowwise_scale_inv( s_dec_row.data_ptr(), DType::kFloat8E4M3, @@ -82,8 +80,8 @@ void assemble_per_token_tensors(const at::Tensor& input, TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), "col_amax must be a contiguous CUDA tensor"); TORCH_CHECK(col_amax.scalar_type() == at::ScalarType::Float, "col_amax must be float32"); - TORCH_CHECK(col_amax.numel() == K, - "col_amax numel mismatch: expected K=", K, ", got ", col_amax.numel()); + TORCH_CHECK(col_amax.numel() == K, "col_amax numel mismatch: expected K=", K, ", got ", + col_amax.numel()); out_te.set_columnwise_amax(col_amax.data_ptr(), DType::kFloat32, std::vector{static_cast(K)}); @@ -92,15 +90,14 @@ void assemble_per_token_tensors(const at::Tensor& input, "q_col must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), "s_dec_col must be a contiguous CUDA tensor"); - TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, - "q_col must be uint8 (FP4 packed)"); + TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, "q_col must be uint8 (FP4 packed)"); TORCH_CHECK(s_dec_col.scalar_type() == at::ScalarType::Byte, "s_dec_col must be uint8 (FP8 e4m3 raw bytes)"); - TORCH_CHECK(q_col.numel() == K * M / 2, - "q_col numel mismatch: expected K*M/2=", K * M / 2, ", got ", q_col.numel()); + TORCH_CHECK(q_col.numel() == K * M / 2, "q_col numel mismatch: expected K*M/2=", K * M / 2, + ", got ", q_col.numel()); TORCH_CHECK(s_dec_col.numel() == K * M / 16, - "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, - ", got ", s_dec_col.numel()); + "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, ", got ", + s_dec_col.numel()); out_te.set_columnwise_data( q_col.data_ptr(), DType::kFloat4E2M1, std::vector{static_cast(K), static_cast(M)}); @@ -114,49 +111,44 @@ void assemble_per_token_tensors(const at::Tensor& input, } // namespace // Production composite (K1 + K2 back-to-back). -void nvfp4_per_token_quantize(const at::Tensor& input, - at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, - bool rowwise, bool columnwise) { +void nvfp4_per_token_quantize(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, - q_col, s_dec_col, col_amax, - rowwise, columnwise, /*mode=*/0, in_te, out_te); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, + columnwise, /*mode=*/0, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); nvte_nvfp4_per_token_quantize(in_te.data(), nullptr, out_te.data(), stream); } // K1-only (diagnostic / bench): populates only amax buffers. -void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, - at::Tensor col_amax, bool rowwise, bool columnwise) { +void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, at::Tensor col_amax, + bool rowwise, bool columnwise) { at::Tensor empty_u8; // not consumed by K1 TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, - empty_u8, empty_u8, col_amax, - rowwise, columnwise, /*mode=*/1, in_te, out_te); + assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, empty_u8, empty_u8, col_amax, + rowwise, columnwise, /*mode=*/1, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); nvte_nvfp4_per_token_amax(in_te.data(), nullptr, out_te.data(), stream); } // K2-only (diagnostic / bench): reads pre-filled amax buffers, emits FP4 + SFs. -void nvfp4_per_token_encode(const at::Tensor& input, - at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, - bool rowwise, bool columnwise) { +void nvfp4_per_token_encode(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, - q_col, s_dec_col, col_amax, - rowwise, columnwise, /*mode=*/2, in_te, out_te); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, + columnwise, /*mode=*/2, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); nvte_nvfp4_per_token_encode(in_te.data(), nullptr, out_te.data(), stream); } // Apply per-token post-scale to a GEMM output (see nvfp4_per_token.h for math). -void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, - const at::Tensor &row_amax_b) { +void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor& row_amax_a, + const at::Tensor& row_amax_b) { TORCH_CHECK(d.is_cuda() && d.is_contiguous(), "d must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax_a.is_cuda() && row_amax_a.is_contiguous(), "row_amax_a must be a contiguous CUDA tensor"); @@ -169,16 +161,16 @@ void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, const int64_t M = d.size(0); const int64_t N = d.size(1); - TORCH_CHECK(row_amax_a.numel() == M, - "row_amax_a numel mismatch: expected M=", M, ", got ", row_amax_a.numel()); - TORCH_CHECK(row_amax_b.numel() == N, - "row_amax_b numel mismatch: expected N=", N, ", got ", row_amax_b.numel()); + TORCH_CHECK(row_amax_a.numel() == M, "row_amax_a numel mismatch: expected M=", M, ", got ", + row_amax_a.numel()); + TORCH_CHECK(row_amax_b.numel() == N, "row_amax_b numel mismatch: expected N=", N, ", got ", + row_amax_b.numel()); const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), - std::vector{static_cast(M), static_cast(N)}, DType::kBFloat16); + d.data_ptr(), std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16); TensorWrapper ra_te = makeTransformerEngineTensor( row_amax_a.data_ptr(), std::vector{static_cast(M)}, DType::kFloat32); TensorWrapper rb_te = makeTransformerEngineTensor( @@ -190,11 +182,11 @@ void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, // End-to-end NVFP4 per-token GEMM: swizzle compact SFs -> cuBLAS LT NVFP4 // GEMM (operand amax pinned to 1.0 to cancel the 2688^2 inner-SF factor) -> // per-row post-scale. beta must be 0.0. Math in nvfp4_per_token.h. -void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, - const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta) { +void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, + const at::Tensor& a_sf, const at::Tensor& b_sf, + const at::Tensor& a_row_amax, const at::Tensor& b_row_amax, at::Tensor d, + const at::Tensor& workspace, int64_t m, int64_t n, int64_t k, + double alpha, double beta) { TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && a_row_amax.is_cuda() && b_row_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), @@ -213,14 +205,13 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); - TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, - "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, - "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", - a_data.size(0), ", ", a_data.size(1), ")"); + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), + ", ", a_data.size(1), ")"); TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, - "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", - b_data.size(0), ", ", b_data.size(1), ")"); + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), + ", ", b_data.size(1), ")"); TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, "), got (", d.size(0), ", ", d.size(1), ")"); @@ -229,10 +220,10 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, "a_sf numel mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), "b_sf numel mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); - TORCH_CHECK(a_row_amax.numel() == m, - "a_row_amax numel mismatch: expected M=", m, ", got ", a_row_amax.numel()); - TORCH_CHECK(b_row_amax.numel() == n, - "b_row_amax numel mismatch: expected N=", n, ", got ", b_row_amax.numel()); + TORCH_CHECK(a_row_amax.numel() == m, "a_row_amax numel mismatch: expected M=", m, ", got ", + a_row_amax.numel()); + TORCH_CHECK(b_row_amax.numel() == n, "b_row_amax numel mismatch: expected N=", n, ", got ", + b_row_amax.numel()); TORCH_CHECK(static_cast(beta) == 0.0f, "nvfp4_per_token_gemm: beta != 0 not yet supported. Got beta=", beta); @@ -302,8 +293,7 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, b_te.set_with_gemm_swizzled_scales(true); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), - std::vector{static_cast(m), static_cast(n)}, + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, DType::kBFloat16); TensorWrapper workspace_te = makeTransformerEngineTensor( @@ -332,11 +322,10 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, // Per-tensor twin of nvfp4_per_token_gemm: scalar amax goes through cuBLAS's // own amax slot (no post-scale). Bench-only apples-to-apples baseline. -void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, - const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_amax, const at::Tensor &b_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta) { +void nvfp4_per_tensor_gemm(const at::Tensor& a_data, const at::Tensor& b_data, + const at::Tensor& a_sf, const at::Tensor& b_sf, const at::Tensor& a_amax, + const at::Tensor& b_amax, at::Tensor d, const at::Tensor& workspace, + int64_t m, int64_t n, int64_t k, double alpha, double beta) { TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && a_amax.is_cuda() && b_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), "All tensors must be CUDA tensors"); @@ -353,14 +342,13 @@ void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); - TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, - "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, - "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", - a_data.size(0), ", ", a_data.size(1), ")"); + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), + ", ", a_data.size(1), ")"); TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, - "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", - b_data.size(0), ", ", b_data.size(1), ")"); + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), + ", ", b_data.size(1), ")"); TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, "), got (", d.size(0), ", ", d.size(1), ")"); @@ -425,8 +413,7 @@ void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, b_te.set_with_gemm_swizzled_scales(true); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), - std::vector{static_cast(m), static_cast(n)}, + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, DType::kBFloat16); TensorWrapper workspace_te = makeTransformerEngineTensor( @@ -449,13 +436,13 @@ void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, // Disabled direction's lists are ignored. namespace { -void build_per_token_output_wrapper( - TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, bool columnwise, - const at::Tensor& q_row, const at::Tensor& s_dec_row, const at::Tensor& row_amax, - const at::Tensor& q_col, const at::Tensor& s_dec_col, const at::Tensor& col_amax) { +void build_per_token_output_wrapper(TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, + bool columnwise, const at::Tensor& q_row, + const at::Tensor& s_dec_row, const at::Tensor& row_amax, + const at::Tensor& q_col, const at::Tensor& s_dec_col, + const at::Tensor& col_amax) { if (rowwise) { - TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), - "q_row must be a contiguous CUDA tensor"); + TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), "q_row must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), "s_dec_row must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), @@ -467,9 +454,8 @@ void build_per_token_output_wrapper( M_i * K / 2, ", got ", q_row.numel()); TORCH_CHECK(s_dec_row.numel() == M_i * K / 16, "s_dec_row numel mismatch for split"); TORCH_CHECK(row_amax.numel() == M_i, "row_amax numel mismatch for split"); - out_te.set_rowwise_data( - q_row.data_ptr(), DType::kFloat4E2M1, - std::vector{static_cast(M_i), static_cast(K)}); + out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(M_i), static_cast(K)}); out_te.set_rowwise_scale_inv( s_dec_row.data_ptr(), DType::kFloat8E4M3, std::vector{static_cast(M_i), static_cast(K / 16)}); @@ -477,8 +463,7 @@ void build_per_token_output_wrapper( std::vector{static_cast(M_i)}); } if (columnwise) { - TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), - "q_col must be a contiguous CUDA tensor"); + TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), "q_col must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), "s_dec_col must be a contiguous CUDA tensor"); TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), @@ -514,12 +499,10 @@ void nvfp4_per_token_group_quantize( const at::Tensor& input, const std::vector& split_sections, std::vector q_row_list, std::vector s_dec_row_list, std::vector row_amax_list, std::vector q_col_list, - std::vector s_dec_col_list, std::vector col_amax_list, - bool rowwise, bool columnwise) { - TORCH_CHECK(rowwise || columnwise, - "At least one of rowwise/columnwise must be True."); - TORCH_CHECK(input.is_cuda() && input.is_contiguous(), - "input must be a contiguous CUDA tensor"); + std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, + bool columnwise) { + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); TORCH_CHECK(input.dim() == 2, "input must be 2D"); const int64_t sum_M = input.size(0); const int64_t K = input.size(1); @@ -530,8 +513,8 @@ void nvfp4_per_token_group_quantize( int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { TORCH_CHECK(split_sections[i] >= 0, "split_sections[", i, "] must be non-negative"); - TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, - "] = ", split_sections[i], " must be a multiple of 64"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] = ", split_sections[i], + " must be a multiple of 64"); acc += split_sections[i]; } TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); @@ -551,8 +534,8 @@ void nvfp4_per_token_group_quantize( const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + in_dtype); // One TensorWrapper per split; raw NVTETensor handles go into `handles`. std::vector wrappers; @@ -571,30 +554,24 @@ void nvfp4_per_token_group_quantize( continue; // empty split is allowed (skipped inside the kernel) } build_per_token_output_wrapper( - wrappers.back(), M_i, K, rowwise, columnwise, - rowwise ? q_row_list[i] : empty_dummy, - rowwise ? s_dec_row_list[i] : empty_dummy, - rowwise ? row_amax_list[i] : empty_dummy, - columnwise ? q_col_list[i] : empty_dummy, - columnwise ? s_dec_col_list[i] : empty_dummy, + wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_list[i] : empty_dummy, columnwise ? col_amax_list[i] : empty_dummy); handles.push_back(wrappers.back().data()); } - nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), - split_sections_sz.data(), num_tensors, rowwise, - columnwise, stream); + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, stream); } // Amax-only grouped variant (K1 only); for allReduce-before-cast flows. -void nvfp4_per_token_group_amax(const at::Tensor& input, - const std::vector& split_sections, +void nvfp4_per_token_group_amax(const at::Tensor& input, const std::vector& split_sections, std::vector row_amax_list, std::vector col_amax_list, bool rowwise, bool columnwise) { TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); - TORCH_CHECK(input.is_cuda() && input.is_contiguous(), - "input must be a contiguous CUDA tensor"); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); TORCH_CHECK(input.dim() == 2, "input must be 2D"); const int64_t sum_M = input.size(0); const int64_t K = input.size(1); @@ -602,21 +579,19 @@ void nvfp4_per_token_group_amax(const at::Tensor& input, TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { - TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, - "] must be a multiple of 64"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] must be a multiple of 64"); acc += split_sections[i]; } TORCH_CHECK(acc == sum_M, "sum(split_sections) must equal input.size(0)"); if (rowwise) TORCH_CHECK(row_amax_list.size() == num_tensors, "row_amax_list size mismatch"); - if (columnwise) - TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); + if (columnwise) TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); const DType in_dtype = resolve_input_dtype(input); const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + in_dtype); std::vector wrappers; wrappers.reserve(num_tensors); @@ -662,12 +637,11 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, const std::vector& split_sections, bool rowwise, bool columnwise) { // Validation mirrors _validate_per_token_group_input in Python. - TORCH_CHECK(rowwise || columnwise, - "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_contiguous(), "x_concat must be contiguous (row-major)"); - TORCH_CHECK(input.dim() == 2, - "nvfp4_per_token_group_quantize expects a 2D input, got ", input.dim(), "D"); + TORCH_CHECK(input.dim() == 2, "nvfp4_per_token_group_quantize expects a 2D input, got ", + input.dim(), "D"); TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, "Per-token grouped kernel is bf16-only; got dtype ", input.scalar_type()); @@ -676,13 +650,13 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, constexpr int64_t kPerTokenTile = 128; constexpr int64_t kBlockK = 16; - TORCH_CHECK(K % kPerTokenTile == 0, - "Per-token grouped kernel requires K % ", kPerTokenTile, " == 0; got K=", K); + TORCH_CHECK(K % kPerTokenTile == 0, "Per-token grouped kernel requires K % ", kPerTokenTile, + " == 0; got K=", K); const size_t num_tensors = split_sections.size(); TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); - TORCH_CHECK(num_tensors <= 64, - "num_tensors must be <= 64 (kernel arg-struct cap); got ", num_tensors); + TORCH_CHECK(num_tensors <= 64, "num_tensors must be <= 64 (kernel arg-struct cap); got ", + num_tensors); int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { @@ -692,8 +666,7 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, " must be a multiple of ", kPerTokenTile); acc += M_i; } - TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, - " must equal input.size(0) = ", sum_M); + TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); // Bulk allocation: one at::empty per output type, covers all splits. auto opts_u8 = input.options().dtype(at::kByte); @@ -743,8 +716,7 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, if (columnwise) { auto q_col_flat = q_col_bulk.narrow(0, K * m_off / 2, K * M_i / 2); q_col_list.emplace_back(q_col_flat.view({K, M_i / 2})); - auto s_dec_col_flat = - s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); + auto s_dec_col_flat = s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); s_dec_col_u8_list.emplace_back(s_dec_col_flat.view({K, M_i / kBlockK})); col_amax_list.emplace_back(col_amax_bulk.select(0, static_cast(i))); s_dec_col_fp8_list.emplace_back(s_dec_col_u8_list.back().view(at::kFloat8_e4m3fn)); @@ -755,8 +727,7 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, // Dispatch K1+K2 grouped kernel via the same C-API the thin entry uses. const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(sum_M), static_cast(K)}, + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, DType::kBFloat16); std::vector wrappers; @@ -771,19 +742,15 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, split_sections_sz[i] = static_cast(M_i); wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); build_per_token_output_wrapper( - wrappers.back(), M_i, K, rowwise, columnwise, - rowwise ? q_row_list[i] : empty_dummy, - rowwise ? s_dec_row_u8_list[i] : empty_dummy, - rowwise ? row_amax_list[i] : empty_dummy, - columnwise ? q_col_list[i] : empty_dummy, - columnwise ? s_dec_col_u8_list[i] : empty_dummy, + wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_u8_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_u8_list[i] : empty_dummy, columnwise ? col_amax_list[i] : empty_dummy); handles.push_back(wrappers.back().data()); } - nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), - split_sections_sz.data(), num_tensors, rowwise, - columnwise, stream); + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, stream); return std::make_tuple(std::move(q_row_list), std::move(s_dec_row_fp8_list), std::move(row_amax_list), std::move(q_col_list), diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f38a6532a4..46d00ba9c2 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -346,8 +346,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "K1 of the NVFP4Quantizer RHT+post_rht_amax path: rowwise (pre-RHT) + " "columnwise (RHT(input.T)) amax in one launch. Bench-only entry.", py::arg("input"), py::arg("rowwise_amax"), py::arg("columnwise_amax"), - py::arg("rht_matrix_random_sign_mask"), - py::call_guard()); + py::arg("rht_matrix_random_sign_mask"), py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", @@ -401,36 +400,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "e4m3 SF layout as per-tensor, but outer amax is per-row/per-col. " "Requires bf16 input, M % 128 == 0, K % 128 == 0.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), - py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), - py::arg("rowwise"), py::arg("columnwise")); - m.def("nvfp4_per_token_amax", - &transformer_engine::pytorch::nvfp4_per_token_amax, + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise")); + m.def("nvfp4_per_token_amax", &transformer_engine::pytorch::nvfp4_per_token_amax, "K1-only: per-row/per-col outer amax via TMA + atomicMax. Bench/diagnostic.", - py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), - py::arg("rowwise"), py::arg("columnwise")); - m.def("nvfp4_per_token_encode", - &transformer_engine::pytorch::nvfp4_per_token_encode, + py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise")); + m.def("nvfp4_per_token_encode", &transformer_engine::pytorch::nvfp4_per_token_encode, "K2-only: FP4 + e4m3 SF encode given pre-filled amax buffers. Bench/diagnostic.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), - py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), - py::arg("rowwise"), py::arg("columnwise")); + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise")); m.def("nvfp4_per_token_post_scale", &transformer_engine::pytorch::nvfp4_per_token_post_scale, - "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", - py::arg("d"), py::arg("row_amax_a"), py::arg("row_amax_b")); + "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", py::arg("d"), + py::arg("row_amax_a"), py::arg("row_amax_b")); m.def("nvfp4_per_token_gemm", &transformer_engine::pytorch::nvfp4_per_token_gemm, "End-to-end NVFP4 per-token GEMM: swizzle compact SFs, cuBLAS LT NVFP4 " "GEMM, then row*col post-scale to recover C = A @ B^T. beta must be 0.", py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), - py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), - py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), - py::arg("alpha"), py::arg("beta")); + py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), py::arg("workspace"), + py::arg("m"), py::arg("n"), py::arg("k"), py::arg("alpha"), py::arg("beta")); m.def("nvfp4_per_tensor_gemm", &transformer_engine::pytorch::nvfp4_per_tensor_gemm, "Skinny prod NVFP4 GEMM twin of nvfp4_per_token_gemm: per-tensor amaxes " "folded into cuBLAS alpha, no trailing post-scale. Bench-only.", - py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), - py::arg("a_amax"), py::arg("b_amax"), py::arg("d"), - py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), - py::arg("alpha"), py::arg("beta")); + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), py::arg("a_amax"), + py::arg("b_amax"), py::arg("d"), py::arg("workspace"), py::arg("m"), py::arg("n"), + py::arg("k"), py::arg("alpha"), py::arg("beta")); m.def("nvfp4_per_token_group_quantize", &transformer_engine::pytorch::nvfp4_per_token_group_quantize, "Grouped (multi-tensor) NVFP4 per-token cast: K1 + K2 across <= 64 splits " @@ -439,8 +434,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("s_dec_row_list"), py::arg("row_amax_list"), py::arg("q_col_list"), py::arg("s_dec_col_list"), py::arg("col_amax_list"), py::arg("rowwise"), py::arg("columnwise")); - m.def("nvfp4_per_token_group_amax", - &transformer_engine::pytorch::nvfp4_per_token_group_amax, + m.def("nvfp4_per_token_group_amax", &transformer_engine::pytorch::nvfp4_per_token_group_amax, "K1-only variant of nvfp4_per_token_group_quantize: only fills amax slots.", py::arg("input"), py::arg("split_sections"), py::arg("row_amax_list"), py::arg("col_amax_list"), py::arg("rowwise"), py::arg("columnwise")); @@ -449,8 +443,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Bulk grouped quantize: allocates per-split buffers + view-slices inside " "the binding (one pybind hop instead of 1 + 6N), then dispatches the K1+K2 " "kernel. Returns 6 per-split tensor lists; empty for disabled directions.", - py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), - py::arg("columnwise")); + py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), py::arg("columnwise")); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 80c088a24d..d9d21a78bf 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -32,26 +32,20 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { // Thin pybind for nvte_hadamard_transform_amax: K1 of the production // NVFP4Quantizer(with_rht, with_post_rht_amax) path. Computes rowwise (pre-RHT) // and columnwise (RHT(input.T)) amax in one launch. Bench-only entry. -void hadamard_transform_amax(const at::Tensor& tensor, - at::Tensor& rowwise_amax, - at::Tensor& columnwise_amax, - int64_t rht_matrix_random_sign_mask) { +void hadamard_transform_amax(const at::Tensor& tensor, at::Tensor& rowwise_amax, + at::Tensor& columnwise_amax, int64_t rht_matrix_random_sign_mask) { auto input_tensor = tensor.contiguous(); const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - TORCH_CHECK(rowwise_amax.scalar_type() == at::kFloat, - "rowwise_amax must be a float tensor"); - TORCH_CHECK(rowwise_amax.numel() == 1, - "rowwise_amax must have exactly one element"); + TORCH_CHECK(rowwise_amax.scalar_type() == at::kFloat, "rowwise_amax must be a float tensor"); + TORCH_CHECK(rowwise_amax.numel() == 1, "rowwise_amax must have exactly one element"); TORCH_CHECK(columnwise_amax.scalar_type() == at::kFloat, "columnwise_amax must be a float tensor"); - TORCH_CHECK(columnwise_amax.numel() == 1, - "columnwise_amax must have exactly one element"); + TORCH_CHECK(columnwise_amax.numel() == 1, "columnwise_amax must have exactly one element"); // Mirror NVFP4Quantizer: empty NVFP4_1D_SCALING with two amax slots. TensorWrapper te_output(NVTE_NVFP4_1D_SCALING); - te_output.set_amax(rowwise_amax.data_ptr(), DType::kFloat32, - std::vector{1}); + te_output.set_amax(rowwise_amax.data_ptr(), DType::kFloat32, std::vector{1}); te_output.set_columnwise_amax(columnwise_amax.data_ptr(), DType::kFloat32, std::vector{1}); diff --git a/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py index a35bca3232..74fc9dc228 100644 --- a/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py +++ b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py @@ -36,6 +36,7 @@ # Reference: dequantize + reference matmul. + def _validate_per_token_triple( data: torch.Tensor, scale: torch.Tensor, row_amax: torch.Tensor, side: str ) -> int: @@ -50,13 +51,9 @@ def _validate_per_token_triple( if K % BLOCK_K != 0: raise ValueError(f"{side}: K={K} must be a multiple of BLOCK_K={BLOCK_K}") if scale.shape != (rows, K // BLOCK_K): - raise ValueError( - f"{side}: scale shape {tuple(scale.shape)} != ({rows}, {K // BLOCK_K})" - ) + raise ValueError(f"{side}: scale shape {tuple(scale.shape)} != ({rows}, {K // BLOCK_K})") if row_amax.shape != (rows,): - raise ValueError( - f"{side}: row_amax shape {tuple(row_amax.shape)} != ({rows},)" - ) + raise ValueError(f"{side}: row_amax shape {tuple(row_amax.shape)} != ({rows},)") return K @@ -113,6 +110,7 @@ def nvfp4_per_token_gemm_dequant( # Production wrapper: cuBLAS LT NVFP4 GEMM + per-token post-scale. + def nvfp4_per_token_gemm( a_data: torch.Tensor, a_scale: torch.Tensor, @@ -147,13 +145,9 @@ def nvfp4_per_token_gemm( # cuBLAS LT NVFP4 SF buffer is padded to (roundup(rows, 128), roundup(K/16, 4)). # Our compact quantize emits (rows, K/16); SF padding is a TODO so reject M/N < 128. if M < 128 or M % 128 != 0: - raise ValueError( - f"M must be a multiple of 128 (got M={M}); SF padding is a TODO." - ) + raise ValueError(f"M must be a multiple of 128 (got M={M}); SF padding is a TODO.") if N < 128 or N % 128 != 0: - raise ValueError( - f"N must be a multiple of 128 (got N={N}); SF padding is a TODO." - ) + raise ValueError(f"N must be a multiple of 128 (got N={N}); SF padding is a TODO.") if a_data.device != b_data.device: raise ValueError( f"A and B must be on the same device (got {a_data.device} vs {b_data.device})" @@ -164,9 +158,7 @@ def nvfp4_per_token_gemm( out_bf16 = torch.empty((M, N), dtype=torch.bfloat16, device=device) else: if out.shape != (M, N): - raise ValueError( - f"out shape {tuple(out.shape)} != ({M}, {N})" - ) + raise ValueError(f"out shape {tuple(out.shape)} != ({M}, {N})") if out.dtype != torch.bfloat16: raise ValueError( f"out dtype must be bf16 for in-place use, got {out.dtype}. " @@ -195,12 +187,23 @@ def nvfp4_per_token_gemm( # Lazy import to break the cpp_extensions.gemm circular import. from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace + workspace = get_cublas_workspace(device.index, ub=False, grouped_gemm=False) tex.nvfp4_per_token_gemm( - a_data_u8, b_data_u8, a_scale_u8_flat, b_scale_u8_flat, - a_row_amax_f32, b_row_amax_f32, out_bf16, workspace, - M, N, K, float(alpha), float(beta), + a_data_u8, + b_data_u8, + a_scale_u8_flat, + b_scale_u8_flat, + a_row_amax_f32, + b_row_amax_f32, + out_bf16, + workspace, + M, + N, + K, + float(alpha), + float(beta), ) return out_bf16 if out_dtype is torch.bfloat16 else out_bf16.to(out_dtype) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py index bd09f8970f..30e5b8e51e 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py @@ -146,8 +146,9 @@ def _quantize_2d(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, tor # kernel. All-zero blocks: s_dec saturates to 0, naive S_enc/0 # would NaN; short-circuit to 0 to mirror the kernel. zero_blk = decode_scale_back_fp32 == 0 - denom = torch.where(zero_blk, torch.ones_like(decode_scale_back_fp32), - decode_scale_back_fp32) + denom = torch.where( + zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32 + ) encode_scale = S_enc_per_blk / denom encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) encode_scale = torch.minimum(encode_scale, fp32_max) @@ -200,13 +201,9 @@ def _validate_per_token_input(x: torch.Tensor) -> Tuple[int, int]: ) M, K = x.shape if M % _PER_TOKEN_TILE != 0: - raise ValueError( - f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}" - ) + raise ValueError(f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}") if K % _PER_TOKEN_TILE != 0: - raise ValueError( - f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" - ) + raise ValueError(f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") return M, K @@ -284,8 +281,12 @@ def nvfp4_per_token_quantize( # above; the composite handles K1 + K2 ordering on the same stream. # ============================================================================ + def nvfp4_per_token_amax( - x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = True, + x: torch.Tensor, + *, + rowwise: bool = True, + columnwise: bool = True, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. Returns ``(row_amax, col_amax)``; either may be ``None`` if the @@ -371,7 +372,15 @@ def nvfp4_per_token_encode( q_col, s_dec_col, col_amax_t = empty, empty, empty_f32 tex.nvfp4_per_token_encode( - x, q_row, s_dec_row, row_amax_t, q_col, s_dec_col, col_amax_t, rowwise, columnwise, + x, + q_row, + s_dec_row, + row_amax_t, + q_col, + s_dec_col, + col_amax_t, + rowwise, + columnwise, ) out = RefNVFP4TensorPerToken() diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py index dd8b2283e9..25c6a324ad 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py @@ -29,20 +29,14 @@ def _validate_per_token_group_input( ``(sum_M, K)``. """ if x_concat.ndim != 2: - raise ValueError( - f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D" - ) + raise ValueError(f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D") if not x_concat.is_contiguous(): raise ValueError("x_concat must be contiguous (row-major)") if x_concat.dtype != torch.bfloat16: - raise ValueError( - f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}." - ) + raise ValueError(f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}.") sum_M, K = x_concat.shape if K % _PER_TOKEN_TILE != 0: - raise ValueError( - f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" - ) + raise ValueError(f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") if len(split_sections) == 0: raise ValueError("split_sections must not be empty") if len(split_sections) > 64: @@ -54,14 +48,10 @@ def _validate_per_token_group_input( if M_i <= 0: raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") if M_i % _PER_TOKEN_TILE != 0: - raise ValueError( - f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}" - ) + raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}") acc += M_i if acc != sum_M: - raise ValueError( - f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}" - ) + raise ValueError(f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}") return sum_M, K @@ -93,9 +83,7 @@ def nvfp4_per_token_group_quantize( q_col_list, s_dec_col_list, col_amax_list, - ) = tex.nvfp4_per_token_group_quantize_bulk( - x_concat, split_sections_list, rowwise, columnwise - ) + ) = tex.nvfp4_per_token_group_quantize_bulk(x_concat, split_sections_list, rowwise, columnwise) outs: List[RefNVFP4TensorPerToken] = [] for i in range(N): From f10e50560ebd4f71ab846f0be6829f85fd7c3f51 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Wed, 27 May 2026 23:30:34 -0700 Subject: [PATCH 3/6] Add optional col-wise RHT to NVFP4 per-token amax+quant (single + grouped) Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax) and K2 (encode) kernels for both single-tensor and grouped paths. with_rht=False is byte-equal to the pre-RHT code path; when true, applies a 16-pt RHT on the columnwise direction in both K1 and K2 (rowwise stays raw) with outer amax + inner SF self-consistent. Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32 sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into block_amax / block_scale (bit-exact). Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and byte-equality regressions. Benches gain a --rht flag (2-way default, 3-way under --rht). Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K: * single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT) * grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT) Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning #177-D). Co-authored-by: Zhongbo Zhu Signed-off-by: Cael Ling --- tests/pytorch/nvfp4/bench_nvfp4_per_token.py | 427 ++++++++++-- .../nvfp4/bench_nvfp4_per_token_group.py | 306 +++++++-- tests/pytorch/nvfp4/test_nvfp4_per_token.py | 343 +++++++++- .../nvfp4/test_nvfp4_per_token_group.py | 192 +++++- .../cast/nvfp4/quantize_nvfp4_per_token.cu | 632 ++++++++++++------ .../nvfp4/quantize_nvfp4_per_token_group.cu | 585 ++++++++++------ .../transformer_engine/nvfp4_per_token.h | 76 ++- transformer_engine/pytorch/csrc/extensions.h | 58 +- .../csrc/extensions/nvfp4_per_token.cpp | 298 +++++---- .../pytorch/csrc/extensions/pybind.cpp | 76 ++- .../quantization_nvfp4_per_token.py | 55 +- .../quantization_nvfp4_per_token_group.py | 41 +- 12 files changed, 2346 insertions(+), 743 deletions(-) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py index 1312c841c7..1663dce1f8 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py @@ -4,8 +4,15 @@ """Bench NVFP4 per-token K1+K2 quant vs per-tensor RHT+SR baseline. -Quant-only (no GEMM). Both sides time the K1 (amax) + K2 (cast) composite on -activation A, rowwise+columnwise. Requires bf16 input, M % 128 == 0, K % 128 == 0. +Quant-only (no GEMM). bf16, M % 128 == 0, K % 128 == 0. + +Modes: + * default: 2-way composite (per-token vs per-tensor). Ratio = pt / pten. + * ``--rht``: 3-way composite (adds per-token + col-wise 16-pt RHT). Ratio = + per-token (+rht) / per-tensor. + * ``--k1-only``: K1 in isolation. Without ``--rht``: pt_K1 vs prod_K1. + With ``--rht``: (A) pt_K1 vs pt_K1+RHT (apples-to-apples) and + (B) pt_K1+RHT vs prod_K1 (NOT apples-to-apples; output shapes differ). """ from __future__ import annotations @@ -43,7 +50,9 @@ def cuda_time_ms(fn: Callable[[], None], *, warmup: int = 5, iters: int = 50) -> return statistics.median(samples) -def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: +def cuda_graph_time_ms( + fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 +) -> float: """Median g.replay() wall time of fn captured into a CUDA Graph (kernel-only floor). Returns nan if capture fails. @@ -102,21 +111,44 @@ def _has_sm100() -> bool: class ShapeBench: M: int K: int - t_pt: float # per-token full K1+K2 (eager pybind, ms) - t_pten: float # per-tensor full K1+K2 (eager pybind, ms) - t_pt_g: float # per-token under CUDA Graphs replay (ms) - t_pten_g: float # per-tensor under CUDA Graphs replay (ms) + t_pt: float # per-token full K1+K2, no RHT (Eager pybind, ms) + t_pt_rht: float # per-token full K1+K2, +RHT col-wise (Eager pybind, ms) + t_pten: float # per-tensor full K1+K2 with RHT+SR (Eager pybind, ms) + t_pt_g: float # per-token under CUDA Graphs replay (ms) + t_pt_rht_g: float # per-token+RHT under CUDA Graphs replay (ms) + t_pten_g: float # per-tensor under CUDA Graphs replay (ms) -def _bench_shape(M: int, K: int, *, device: torch.device) -> ShapeBench: - """Time per-tensor vs per-token K1+K2 quant at one (M, K) shape.""" +@dataclass +class K1ShapeBench: + M: int + K: int + # K1-only timings: 3 paths x 2 modes (Eager + CUDA Graphs). + t_pt: float # per-token K1, no RHT (rowwise+columnwise amax vectors) + t_pt_rht: float # per-token K1, +RHT on col direction + t_prod: float # prod K1 hadamard_transform_amax (per-tensor scalar amax) + t_pt_g: float + t_pt_rht_g: float + t_prod_g: float + + +# Default mask seed; matches prod's `te-nvfp4-build-overrides.mdc` convention. +_RHT_MASK_DEFAULT: int = 0xACE1 + + +def _bench_shape(M: int, K: int, *, device: torch.device, + with_rht: bool = False, + mask_t: int = _RHT_MASK_DEFAULT) -> ShapeBench: + """Composite K1+K2 timing at one (M, K) shape. + pt = per-token (no RHT), pt_rht = per-token + col-wise 16-pt RHT + (NaN unless with_rht=True), pten = per-tensor + RHT + SR (prod baseline). + """ a = torch.randn((M, K), dtype=torch.bfloat16, device=device) - # Per-tensor quantizer + A output tensor. quantizer = _make_baseline_quantizer() dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) - # Per-token A-side buffers: BLOCK_K=16 (1x16 e4m3 inner SF). + # Per-token A-side buffers reused across no-RHT and +RHT paths. BLOCK_K = 16 ra_a = torch.empty((M,), dtype=torch.float32, device=device) ca_a = torch.empty((K,), dtype=torch.float32, device=device) @@ -130,15 +162,8 @@ def _baseline_quant_fn(): def _pt_full_quant_fn(): tex.nvfp4_per_token_quantize( - a, - q_row_a, - s_dec_row_a, - ra_a, - q_col_a, - s_dec_col_a, - ca_a, - True, - True, + a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True, + with_rht=False, random_sign_mask_t=0, ) t_pten = cuda_time_ms(_baseline_quant_fn) @@ -146,13 +171,87 @@ def _pt_full_quant_fn(): t_pten_g = cuda_graph_time_ms(_baseline_quant_fn) t_pt_g = cuda_graph_time_ms(_pt_full_quant_fn) - return ShapeBench(M=M, K=K, t_pt=t_pt, t_pten=t_pten, t_pt_g=t_pt_g, t_pten_g=t_pten_g) + if with_rht: + def _pt_full_quant_rht_fn(): + tex.nvfp4_per_token_quantize( + a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True, + with_rht=True, random_sign_mask_t=mask_t, + ) + + t_pt_rht = cuda_time_ms(_pt_full_quant_rht_fn) + t_pt_rht_g = cuda_graph_time_ms(_pt_full_quant_rht_fn) + else: + t_pt_rht = float("nan") + t_pt_rht_g = float("nan") + + return ShapeBench( + M=M, K=K, + t_pt=t_pt, t_pt_rht=t_pt_rht, t_pten=t_pten, + t_pt_g=t_pt_g, t_pt_rht_g=t_pt_rht_g, t_pten_g=t_pten_g, + ) + + +def _bench_shape_k1_only(M: int, K: int, *, device: torch.device, + with_rht: bool = False, + mask_t: int = _RHT_MASK_DEFAULT) -> K1ShapeBench: + """K1-only timing. pt = per-token (no RHT), pt_rht = per-token + col RHT + (NaN unless with_rht=True), prod = hadamard_transform_amax (scalar amax; + NOT apples-to-apples but the closest prod K1 reference). + """ + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + # Per-token K1 amax buffers (vectors). + ra_pt = torch.empty((M,), dtype=torch.float32, device=device) + ca_pt = torch.empty((K,), dtype=torch.float32, device=device) + + # prod K1 amax buffers (scalars). + ra_prod = torch.empty((1,), dtype=torch.float32, device=device) + ca_prod = torch.empty((1,), dtype=torch.float32, device=device) + + def _pt_k1_fn(): + tex.nvfp4_per_token_amax( + a, ra_pt, ca_pt, True, True, + with_rht=False, random_sign_mask_t=0, + ) + + def _prod_k1_fn(): + # row pre-RHT + col post-RHT scalar amax; both numel=1 buffers. + tex.hadamard_transform_amax(a, ra_prod, ca_prod, mask_t) + + t_pt = cuda_time_ms(_pt_k1_fn) + t_prod = cuda_time_ms(_prod_k1_fn) + t_pt_g = cuda_graph_time_ms(_pt_k1_fn) + t_prod_g = cuda_graph_time_ms(_prod_k1_fn) + + if with_rht: + ra_pt_rht = torch.empty((M,), dtype=torch.float32, device=device) + ca_pt_rht = torch.empty((K,), dtype=torch.float32, device=device) + + def _pt_k1_rht_fn(): + tex.nvfp4_per_token_amax( + a, ra_pt_rht, ca_pt_rht, True, True, + with_rht=True, random_sign_mask_t=mask_t, + ) + + t_pt_rht = cuda_time_ms(_pt_k1_rht_fn) + t_pt_rht_g = cuda_graph_time_ms(_pt_k1_rht_fn) + else: + t_pt_rht = float("nan") + t_pt_rht_g = float("nan") + + return K1ShapeBench( + M=M, K=K, + t_pt=t_pt, t_pt_rht=t_pt_rht, t_prod=t_prod, + t_pt_g=t_pt_g, t_pt_rht_g=t_pt_rht_g, t_prod_g=t_prod_g, + ) # 6x3 sweep matching bench_nvfp4_per_token_group.py: M in {1024..32768}, K in {2048,4096,8192}. _M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) _K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) -_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple((m, k) for m in _M_VALUES for k in _K_VALUES) +_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple( + (m, k) for m in _M_VALUES for k in _K_VALUES +) def _parse_shape(s: str) -> Tuple[int, int]: @@ -168,37 +267,109 @@ def _ratio(num: float, den: float) -> float: return num / den -def main() -> int: - parser = argparse.ArgumentParser( - description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4." +def _print_composite_table_2way(records: List[ShapeBench]) -> None: + """2-way composite (no RHT). ratio = per-token / per-tensor (< 1.0 wins).""" + w_pt, w_pten, w_ratio = 14, 15, 8 + block_w = w_pt + 1 + w_pten + 1 + w_ratio + header1 = ( + f"{'':>7} {'':>6}" + f" |{'Eager, unit (ms)':^{block_w}}" + f" |{'Graph, unit (ms)':^{block_w}}" ) - parser.add_argument( - "--shapes", - type=_parse_shape, - nargs="+", - default=None, - help=( - "Shapes to bench, in MxK form (e.g. 4096x4096). " - "Default: an internally-chosen production-shape sweep." - ), + header2 = ( + f"{'M':>7} {'K':>6}" + f" |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + f" |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" ) - args = parser.parse_args() + print(header1) + print(header2) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt, rec.t_pten) + ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g) - if not _has_sm100(): - print("SKIP: NVFP4 per-token quant requires SM100 (Blackwell).", file=sys.stderr) - return 1 + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" - device = torch.device("cuda") - shapes = list(args.shapes) if args.shapes else list(_DEFAULT_SHAPES) + print( + f"{rec.M:>7} {rec.K:>6}" + f" |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + f" |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_composite_table(records: List[ShapeBench]) -> None: + """3-way composite (--rht). ratio = per-token (+rht) / per-tensor.""" + w_pt, w_pt_rht, w_pten, w_ratio = 12, 12, 13, 8 + block_w = w_pt + 1 + w_pt_rht + 1 + w_pten + 1 + w_ratio + header1 = ( + f"{'':>7} {'':>6}" + f" |{'Eager, unit (ms)':^{block_w}}" + f" |{'Graph, unit (ms)':^{block_w}}" + ) + header2 = ( + f"{'M':>7} {'K':>6}" + f" |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + f" |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>7} {'':>6}" + f" |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + f" |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_rht, rec.t_pten) + ratio_g = _ratio(rec.t_pt_rht_g, rec.t_pten_g) - records: List[ShapeBench] = [_bench_shape(M, K, device=device) for (M, K) in shapes] + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + print( + f"{rec.M:>7} {rec.K:>6}" + f" |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pt_rht:>{w_pt_rht}.4f}" + f" {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + f" |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pt_rht_g:>{w_pt_rht}.4f}" + f" {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_k1_2way_table(records: List[K1ShapeBench]) -> None: + """2-way K1 (default --k1-only). pt_K1 vs prod_K1; NOT apples-to-apples + (per-token K1 outputs M+K floats, prod outputs 2 scalars). + """ + print("K1-only: pt vs prod (NOT apples-to-apples; output shapes differ).") header = ( f"{'M':>7} {'K':>6}" - " |" - f"{'per-token':>10} {'per-tensor':>11} {'ratio':>8}" - " |" - f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>18} {'ratio(Graph)':>13}" + f" |" + f"{'pt_K1':>9} {'prod_K1':>9} {'ratio':>8}" + f" |" + f"{'pt_K1(Graph)':>14} {'prod_K1(Graph)':>16} {'ratio(Graph)':>13}" ) print(header) print("-" * len(header)) @@ -207,18 +378,172 @@ def main() -> int: if prev_M is not None and rec.M != prev_M: print() prev_M = rec.M - ratio = _ratio(rec.t_pt, rec.t_pten) - ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g) + ratio = _ratio(rec.t_pt, rec.t_prod) + ratio_g = _ratio(rec.t_pt_g, rec.t_prod_g) ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" print( f"{rec.M:>7} {rec.K:>6}" - " |" - f"{rec.t_pt:>10.4f} {rec.t_pten:>11.4f} {ratio_s:>8}" - " |" - f"{rec.t_pt_g:>17.4f} {rec.t_pten_g:>18.4f} {ratio_g_s:>13}" + f" |" + f"{rec.t_pt:>9.4f} {rec.t_prod:>9.4f} {ratio_s:>8}" + f" |" + f"{rec.t_pt_g:>14.4f} {rec.t_prod_g:>16.4f} {ratio_g_s:>13}" ) + +def _print_k1_rht_cost_table(records: List[K1ShapeBench]) -> None: + """Table A: pt_K1 vs pt_K1+RHT (apples-to-apples; same output shapes).""" + print("Table A -- K1-only RHT cost (pt = per-token, +RHT = col-wise FHT).") + header = ( + f"{'M':>7} {'K':>6}" + f" |" + f"{'pt_K1':>9} {'pt_K1+RHT':>11} {'ratio':>8}" + f" |" + f"{'pt_K1(Graph)':>14} {'pt_K1+RHT(Graph)':>18} {'ratio(Graph)':>13}" + ) + print(header) + print("-" * len(header)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_rht, rec.t_pt) + ratio_g = _ratio(rec.t_pt_rht_g, rec.t_pt_g) + ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" + ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" + print( + f"{rec.M:>7} {rec.K:>6}" + f" |" + f"{rec.t_pt:>9.4f} {rec.t_pt_rht:>11.4f} {ratio_s:>8}" + f" |" + f"{rec.t_pt_g:>14.4f} {rec.t_pt_rht_g:>18.4f} {ratio_g_s:>13}" + ) + + +def _print_k1_vs_prod_table(records: List[K1ShapeBench]) -> None: + """Table B: pt_K1+RHT vs prod_K1 (NOT apples-to-apples; output shapes + differ -- 2 scalars vs M+K floats). Fast-floor reference only. + """ + print("Table B -- K1-only vs prod (NOT apples-to-apples; output shapes differ).") + header = ( + f"{'M':>7} {'K':>6}" + f" |" + f"{'pt_K1+RHT':>11} {'prod_K1':>9} {'ratio':>8}" + f" |" + f"{'pt_K1+RHT(Graph)':>18} {'prod_K1(Graph)':>16} {'ratio(Graph)':>13}" + ) + print(header) + print("-" * len(header)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_rht, rec.t_prod) + ratio_g = _ratio(rec.t_pt_rht_g, rec.t_prod_g) + ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" + ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" + print( + f"{rec.M:>7} {rec.K:>6}" + f" |" + f"{rec.t_pt_rht:>11.4f} {rec.t_prod:>9.4f} {ratio_s:>8}" + f" |" + f"{rec.t_pt_rht_g:>18.4f} {rec.t_prod_g:>16.4f} {ratio_g_s:>13}" + ) + + +def _print_composite_legend(*, with_rht: bool, rht_mask: int) -> None: + """Prose legend mapping table labels to their C++ entry points.""" + print() + print("Legend:") + if with_rht: + print(" per-token (ms) = tex.nvfp4_per_token_quantize(a, ..., rowwise+colwise,") + print(" with_rht=False)") + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print(f" per-token (+rht) (ms) = same, but with_rht=True + random_sign_mask_t=0x{rht_mask:04X}.") + print(" Applies a 16-point RHT along the columnwise direction in") + print(" BOTH K1 amax and K2 cast; rowwise stays raw. Length-16") + print(" matches the 1x16 inner-SF block of NVFP4, so each scale") + print(" window is decorrelated.") + print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr), ...)") + print(" = nvte_quantize_with_hadamard_transform") + print(" (1 fused launch: rowwise quant + col-wise RHT + col quant,") + print(" prod baseline).") + print(" ratio = per-token (+rht) / per-tensor") + print(" ** < 1.0 = this PR wins vs prod baseline **") + else: + print(" per-token (ms) = tex.nvfp4_per_token_quantize(a, ..., rowwise+colwise, with_rht=False)") + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr), ...)") + print(" = nvte_quantize_with_hadamard_transform") + print(" (1 fused launch: rowwise quant + col-wise RHT + col quant,") + print(" prod baseline).") + print(" ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod baseline **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4." + ) + parser.add_argument( + "--shapes", type=_parse_shape, nargs="+", default=None, + help="Shapes to bench, in MxK form (e.g. 4096x4096). " + "Default: an internally-chosen production-shape sweep.", + ) + parser.add_argument( + "--rht", action="store_true", + help="Also time the per-token + RHT path (col-wise 16-pt RHT in K1 + K2). " + "Default OFF: prints a 2-way table (per-token vs per-tensor). With " + "--rht: prints a 3-way table with one ratio " + "(per-token (+rht) / per-tensor).", + ) + parser.add_argument( + "--k1-only", action="store_true", + help="K1-only mode (no K2 cast). Without --rht: 2-way table (pt_K1 " + "vs prod_K1). With --rht: two tables back-to-back -- (A) RHT cost " + "pt_K1 vs pt_K1+RHT (apples-to-apples) and (B) pt_K1+RHT vs prod_K1 " + "(context only; output shapes differ).", + ) + parser.add_argument( + "--rht-mask", type=lambda s: int(s, 0), default=_RHT_MASK_DEFAULT, + help="16-bit random sign mask for the RHT path (only matters with --rht). " + f"Default 0x{_RHT_MASK_DEFAULT:04X}; accepts hex (0x...) or decimal.", + ) + args = parser.parse_args() + + if not _has_sm100(): + print("SKIP: NVFP4 per-token quant requires SM100 (Blackwell).", file=sys.stderr) + return 1 + + device = torch.device("cuda") + shapes = list(args.shapes) if args.shapes else list(_DEFAULT_SHAPES) + mask = args.rht_mask & 0xFFFF + + if args.k1_only: + records_k1: List[K1ShapeBench] = [ + _bench_shape_k1_only(M, K, device=device, + with_rht=args.rht, mask_t=mask) + for (M, K) in shapes + ] + if args.rht: + _print_k1_rht_cost_table(records_k1) + print() + _print_k1_vs_prod_table(records_k1) + else: + _print_k1_2way_table(records_k1) + else: + records: List[ShapeBench] = [ + _bench_shape(M, K, device=device, with_rht=args.rht, mask_t=mask) + for (M, K) in shapes + ] + if args.rht: + _print_composite_table(records) + else: + _print_composite_table_2way(records) + _print_composite_legend(with_rht=args.rht, rht_mask=mask) + return 0 diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py index 7c62db0693..1111382a19 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py @@ -1,12 +1,23 @@ -"""Bench: NVFP4 per-token grouped (K1+K2 fused) vs per-tensor+RHT baseline. +"""Bench NVFP4 per-token grouped K1+K2 quant vs per-tensor RHT+SR baseline. -18-row sweep at fixed N=8 splits: sum_M in {1024..32768} x K in {2048,4096,8192}. -Both eager and CUDA Graphs columns reported on every row (ratio < 1.0 wins). +Modes: + * default: 2-way (per-token vs per-tensor). Ratio = pt / pten. + * ``--rht``: 3-way (adds per-token + col-wise 16-pt RHT). Ratio = + per-token (+rht) / per-tensor. + +Default sweep: N=8 equal splits, sum_M in {1024..32768} x K in {2048,4096,8192}. Requires bf16, K % 128 == 0, every split % 128 == 0, num_splits <= 64. + +CLI: + --shapes SUMMxK ... custom shapes (default: 18-row sweep) + --num-splits N equal splits per shape (default 8) + --rht enable 3-way RHT comparison + --rht-mask 0x... 16-bit RHT sign pattern (default 0xACE1) """ from __future__ import annotations +import argparse import math import statistics import sys @@ -39,11 +50,10 @@ def _make_baseline_quantizer_list(num_splits: int) -> List[NVFP4Quantizer]: return [q] * num_splits -def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: - """Median g.replay() time of fn captured into a CUDA Graph, in ms. - - Returns nan if capture fails (e.g. some C-API does an incompatible sync). - """ +def cuda_graph_time_ms( + fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 +) -> float: + """Median g.replay() time of fn under CUDA Graphs, in ms (nan on capture failure).""" try: side = torch.cuda.Stream() side.wait_stream(torch.cuda.current_stream()) @@ -70,11 +80,18 @@ def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int return statistics.median(starts[i].elapsed_time(ends[i]) for i in range(iters)) -def _time_grouped(x_concat, split_sections, rowwise, columnwise, n_iters=20, n_warmup=5): +# Default RHT mask seed; matches te-nvfp4-build-overrides.mdc convention. +_RHT_MASK_DEFAULT: int = 0xACE1 + + +def _time_grouped(x_concat, split_sections, rowwise, columnwise, + *, with_rht: bool = False, mask: int = _RHT_MASK_DEFAULT, + n_iters: int = 20, n_warmup: int = 5) -> float: """Per-token grouped via the BULK Python wrapper. Allocation in-loop.""" for _ in range(n_warmup): _ = nvfp4_per_token_group_quantize( - x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise, + with_rht=with_rht, random_sign_mask_t=mask, ) torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) @@ -82,7 +99,8 @@ def _time_grouped(x_concat, split_sections, rowwise, columnwise, n_iters=20, n_w start.record() for _ in range(n_iters): _ = nvfp4_per_token_group_quantize( - x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise, + with_rht=with_rht, random_sign_mask_t=mask, ) stop.record() torch.cuda.synchronize() @@ -104,69 +122,193 @@ def _time_split_quantize(x_concat, split_sections, quantizer_list, n_iters=20, n return start.elapsed_time(stop) / n_iters # ms -def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, n_iters=20, n_warmup=5): +def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, + n_iters=20, n_warmup=5): """Per-tensor grouped under CUDA Graphs replay.""" - def fn() -> None: _ = tex.split_quantize(x_concat, split_sections, quantizer_list) return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) -def _time_grouped_graph(x_concat, split_sections, rowwise, columnwise, n_iters=20, n_warmup=5): +def _time_grouped_graph(x_concat, split_sections, rowwise, columnwise, + *, with_rht: bool = False, mask: int = _RHT_MASK_DEFAULT, + n_iters: int = 20, n_warmup: int = 5) -> float: """Per-token grouped under CUDA Graphs replay.""" - def fn() -> None: _ = nvfp4_per_token_group_quantize( - x_concat, split_sections, rowwise=rowwise, columnwise=columnwise + x_concat, split_sections, rowwise=rowwise, columnwise=columnwise, + with_rht=with_rht, random_sign_mask_t=mask, ) return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) -# N = 8 equal splits (MoE-typical), sum_M in {1024..32768}, K in {2048..8192}. -_NUM_SPLITS: int = 8 +# Default sweep: N = 8 equal splits (MoE-typical), sum_M in {1024..32768}, +# K in {2048..8192}. Override either via the CLI flags below. +_DEFAULT_NUM_SPLITS: int = 8 +_DEFAULT_SUM_M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) +_DEFAULT_K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) -_SUM_M_VALUES: List[int] = [1024, 2048, 4096, 8192, 16384, 32768] -_K_VALUES: List[int] = [2048, 4096, 8192] -_BENCH_CASES: List[Tuple[List[int], int]] = [] -for _sum_M in _SUM_M_VALUES: - _M_i = _sum_M // _NUM_SPLITS - for _K in _K_VALUES: - _BENCH_CASES.append(([_M_i] * _NUM_SPLITS, _K)) +def _parse_shape(s: str) -> Tuple[int, int]: + """Parse a `sum_MxK` CLI argument.""" + parts = s.split("x") + if len(parts) != 2: + raise argparse.ArgumentTypeError(f"Shape must be sum_MxK, got '{s}'") + return tuple(int(p) for p in parts) # type: ignore[return-value] -def main() -> None: +def _build_bench_cases( + shapes: List[Tuple[int, int]], num_splits: int +) -> List[Tuple[List[int], int]]: + """Turn (sum_M, K) pairs into (split_sections, K) cases; each split + must be a multiple of 128. + """ + cases: List[Tuple[List[int], int]] = [] + for sum_M, K in shapes: + if sum_M % num_splits != 0: + raise argparse.ArgumentTypeError( + f"sum_M={sum_M} not divisible by num_splits={num_splits}" + ) + M_i = sum_M // num_splits + if M_i % 128 != 0: + raise argparse.ArgumentTypeError( + f"sum_M={sum_M} / num_splits={num_splits} = M_i={M_i} must be a " + f"multiple of 128 (NVFP4 per-token kernel constraint)" + ) + if K % 128 != 0: + raise argparse.ArgumentTypeError( + f"K={K} must be a multiple of 128" + ) + cases.append(([M_i] * num_splits, K)) + return cases + + +def main() -> int: + parser = argparse.ArgumentParser( + description=( + "Bench NVFP4 per-token grouped K1+K2 quant. Three-way: " + "per-token (no RHT) / per-token+RHT / per-tensor (RHT+SR)." + ) + ) + parser.add_argument( + "--shapes", type=_parse_shape, nargs="+", default=None, + help="Shapes to bench, in sum_MxK form (e.g. 8192x4096). " + "Default: a 6x3 = 18-row internally-chosen sweep.", + ) + parser.add_argument( + "--num-splits", type=int, default=_DEFAULT_NUM_SPLITS, + help=f"Number of equal splits per shape (default {_DEFAULT_NUM_SPLITS}; " + f"<= 64). M_i = sum_M / num_splits must be a multiple of 128.", + ) + parser.add_argument( + "--rht", action="store_true", + help="Enable 3-way table with per-token + col-wise 16-pt RHT path. " + "Default OFF prints 2-way (per-token vs per-tensor).", + ) + parser.add_argument( + "--rht-mask", type=lambda s: int(s, 0), default=_RHT_MASK_DEFAULT, + help=f"16-bit RHT sign mask (default 0x{_RHT_MASK_DEFAULT:04X}; accepts " + "hex/dec). Only affects per-token+RHT; per-tensor uses its own mask.", + ) + args = parser.parse_args() + if not torch.cuda.is_available(): print("CUDA unavailable, skipping bench.") - return + return 1 cap = torch.cuda.get_device_capability() if cap[0] < 10: print(f"NVFP4 per-token requires SM100+ (got SM{cap[0]}.{cap[1]}); skipping.") - return + return 1 + if args.num_splits <= 0 or args.num_splits > 64: + print(f"--num-splits must be in [1, 64], got {args.num_splits}") + return 2 + + if args.shapes is not None: + shapes_in = [tuple(s) for s in args.shapes] + else: + shapes_in = [ + (sm, k) for sm in _DEFAULT_SUM_M_VALUES for k in _DEFAULT_K_VALUES + ] + bench_cases = _build_bench_cases(shapes_in, args.num_splits) + rht_mask: int = args.rht_mask & 0xFFFF + with_rht: bool = args.rht device = torch.device("cuda") print(f"# Device: {torch.cuda.get_device_name(0)} (cap {cap[0]}.{cap[1]})") - print(f"# Split structure: N={_NUM_SPLITS} equal splits, M_i = sum_M / {_NUM_SPLITS}") + print(f"# Split structure: N={args.num_splits} equal splits, " + f"M_i = sum_M / {args.num_splits}") + if with_rht: + print(f"# RHT mask: 0x{rht_mask:04X} (per-token+RHT col-wise; per-tensor uses its own internal mask)") + else: + print("# RHT: disabled (pass --rht to enable 3-way per-token / per-token (+rht) / per-tensor table)") print() # Per-tensor baseline quantizer is fixed to row+col, so both enabled. rowwise = True columnwise = True - header = ( - f"{'sum_M':>6} {'K':>5}" - " |" - f"{'per-token':>10} {'per-tensor':>10} {'ratio':>8}" - " |" - f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>17} {'ratio(Graph)':>13}" - ) - print(header) - print("-" * len(header)) + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + def _ratio(num: float, den: float) -> float: + if den <= 0 or math.isnan(num) or math.isnan(den): + return float("nan") + return num / den + + # Multi-line header: section label + column names (+ `(+rht)` sub-label + # row in 3-way mode), then separator + data rows. + if with_rht: + w_pt, w_pt_rht, w_pten, w_ratio = 12, 12, 13, 8 + block_w = w_pt + 1 + w_pt_rht + 1 + w_pten + 1 + w_ratio + header1 = ( + f"{'':>6} {'':>5}" + f" |{'Eager, unit (ms)':^{block_w}}" + f" |{'Graph, unit (ms)':^{block_w}}" + ) + header2 = ( + f"{'sum_M':>6} {'K':>5}" + f" |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + f" |" + f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>6} {'':>5}" + f" |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + f" |" + f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + else: + w_pt, w_pten, w_ratio = 14, 15, 8 + block_w = w_pt + 1 + w_pten + 1 + w_ratio + header1 = ( + f"{'':>6} {'':>5}" + f" |{'Eager, unit (ms)':^{block_w}}" + f" |{'Graph, unit (ms)':^{block_w}}" + ) + header2 = ( + f"{'sum_M':>6} {'K':>5}" + f" |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + f" |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + print(header1) + print(header2) + print("-" * len(header2)) prev_sum_M = None - for split_sections, K in _BENCH_CASES: + for split_sections, K in bench_cases: sum_M = sum(split_sections) num_splits = len(split_sections) @@ -175,38 +317,82 @@ def main() -> None: print() prev_sum_M = sum_M - x_concat = (torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0).contiguous() + x_concat = ( + torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0 + ).contiguous() quantizer_list = _make_baseline_quantizer_list(num_splits) - t_pt = _time_grouped(x_concat, split_sections, rowwise, columnwise) + t_pt = _time_grouped(x_concat, split_sections, rowwise, columnwise, + with_rht=False) t_pten = _time_split_quantize(x_concat, split_sections, quantizer_list) - ratio = t_pt / t_pten if t_pten > 0 else float("nan") + t_pt_g = _time_grouped_graph( + x_concat, split_sections, rowwise, columnwise, with_rht=False, + ) + t_pten_g = _time_split_quantize_graph( + x_concat, split_sections, quantizer_list, + ) - t_pt_g = _time_grouped_graph(x_concat, split_sections, rowwise, columnwise) - t_pten_g = _time_split_quantize_graph(x_concat, split_sections, quantizer_list) - if math.isnan(t_pt_g) or math.isnan(t_pten_g) or t_pten_g <= 0: - ratio_g = float("nan") - graph_cells = f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {'nan':>13}" + if with_rht: + t_pt_rht = _time_grouped(x_concat, split_sections, rowwise, columnwise, + with_rht=True, mask=rht_mask) + t_pt_rht_g = _time_grouped_graph( + x_concat, split_sections, rowwise, columnwise, + with_rht=True, mask=rht_mask, + ) + + ratio_eager = _ratio(t_pt_rht, t_pten) + ratio_graph = _ratio(t_pt_rht_g, t_pten_g) + + print( + f"{sum_M:>6d} {K:>5d}" + f" |" + f"{t_pt:>{w_pt}.4f} {t_pt_rht:>{w_pt_rht}.4f}" + f" {t_pten:>{w_pten}.4f} {_fmt(ratio_eager):>{w_ratio}}" + f" |" + f"{t_pt_g:>{w_pt}.4f} {t_pt_rht_g:>{w_pt_rht}.4f}" + f" {t_pten_g:>{w_pten}.4f} {_fmt(ratio_graph):>{w_ratio}}" + ) else: - ratio_g = t_pt_g / t_pten_g - graph_cells = f"{t_pt_g:>17.4f} {t_pten_g:>17.4f} {ratio_g:>12.2f}x" - - print(f"{sum_M:>6d} {K:>5d} |{t_pt:>10.4f} {t_pten:>10.4f} {ratio:>7.2f}x |{graph_cells}") + ratio_eager = _ratio(t_pt, t_pten) + ratio_graph = _ratio(t_pt_g, t_pten_g) + print( + f"{sum_M:>6d} {K:>5d}" + f" |" + f"{t_pt:>{w_pt}.4f} {t_pten:>{w_pten}.4f} {_fmt(ratio_eager):>{w_ratio}}" + f" |" + f"{t_pt_g:>{w_pt}.4f} {t_pten_g:>{w_pten}.4f} {_fmt(ratio_graph):>{w_ratio}}" + ) del x_concat, quantizer_list torch.cuda.empty_cache() print() print("Legend:") - print(" per-token = nvfp4_per_token_group_quantize(x, splits, rowwise+colwise)") - print(" = K1 fused amax + K2 fused cast (2 launches), this PR") - print(" per-tensor = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)") - print(" = nvte_group_hadamard_transform_amax") - print(" + nvte_group_hadamard_transform_cast_fusion (2 launches)") - print(" ratio = per-token / per-tensor ** < 1.0 = this PR wins **") - print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided,") - print(" pure kernel-level wall time, ALL rows)") + if with_rht: + print(" per-token (ms) = nvfp4_per_token_group_quantize(x, splits,") + print(" rowwise+colwise, with_rht=False)") + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print(f" per-token (+rht) (ms) = same, but with_rht=True + random_sign_mask_t=0x{rht_mask:04X}.") + print(" Applies a 16-point RHT along the columnwise direction in") + print(" BOTH K1 amax and K2 cast; rowwise stays raw. Length-16") + print(" matches the 1x16 inner-SF block of NVFP4, so each scale") + print(" window is decorrelated.") + print(" per-tensor (ms) = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)") + print(" = nvte_group_hadamard_transform_amax") + print(" + nvte_group_hadamard_transform_cast_fusion") + print(" (2 launches, prod baseline).") + print(" ratio = per-token (+rht) / per-tensor") + print(" ** < 1.0 = this PR wins vs prod baseline **") + else: + print(" per-token (ms) = nvfp4_per_token_group_quantize(x, splits, rowwise+colwise, with_rht=False)") + print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") + print(" per-tensor (ms) = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)") + print(" = nvte_group_hadamard_transform_amax") + print(" + nvte_group_hadamard_transform_cast_fusion (2 launches, prod baseline).") + print(" ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod baseline **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + return 0 if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token.py b/tests/pytorch/nvfp4/test_nvfp4_per_token.py index 727b5792d2..90c3f87592 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token.py @@ -5,12 +5,15 @@ """Correctness tests for NVFP4 per-token cast + cuBLAS LT NVFP4 GEMM. Covers byte-equal kernel-vs-reference quantize parity, K1/K2 split-vs-composite -parity, dequant + fp32 reference, and a cuBLAS LT NVFP4 GEMM smoke. Requires -bf16 input, M % 128 == 0, K % 128 == 0; GEMM tests gated by SM100. +parity, dequant + fp32 reference, optional RHT (K1 amax + K2 cast), and a +cuBLAS LT NVFP4 GEMM smoke. Requires bf16 input, M % 128 == 0, K % 128 == 0; +GEMM and RHT tests gated by SM100. """ from __future__ import annotations +from typing import Tuple + import pytest import torch @@ -410,3 +413,339 @@ def test_per_token_gemm_rejects_beta_nonzero() -> None: b_q.row_amax, beta=1.0, ) + + +# ============================================================================= +# (5) RHT correctness: K1 amax + K2 cast with optional col-wise RHT. +# Opt-in via with_rht=True + random_sign_mask_t=; row direction never +# sees RHT. with_rht=False is byte-equal to the pre-RHT path. +# ============================================================================= + +_RHT_SHAPES = [ + (128, 128), + (256, 256), + (128, 1024), # K > single 64x64 sub-tile along col + (1024, 128), # M > single 64x64 sub-tile along row + (512, 512), +] + + +def _walsh_hadamard_16(device: torch.device) -> torch.Tensor: + """16x16 Sylvester / Walsh-Hadamard matrix, +/-1 entries (unnormalized).""" + H = torch.tensor([[1.0]], dtype=torch.float32, device=device) + for _ in range(4): + top = torch.cat([H, H], dim=1) + bot = torch.cat([H, -H], dim=1) + H = torch.cat([top, bot], dim=0) + return H + + +def _sign_diag_16(mask: int, device: torch.device) -> torch.Tensor: + """16-elt +/-1 vector; s_i = -1 iff bit i of `mask` is set.""" + bits = torch.tensor( + [1 - 2 * ((mask >> i) & 1) for i in range(16)], + dtype=torch.float32, device=device, + ) + return bits + + +def _reference_col_amax_rht(x_bf16: torch.Tensor, mask: int) -> torch.Tensor: + """PyTorch reference for the per-token col-wise RHT amax: max over + 16-row blocks of |H * D * x_block| / 4. FHT may permute element order + but |y|.max() is permutation-invariant. + """ + M, K = x_bf16.shape + assert M % 16 == 0, "Test setup error: M must be a multiple of 16." + H = _walsh_hadamard_16(x_bf16.device) + sign = _sign_diag_16(mask, x_bf16.device) + x = x_bf16.to(torch.float32) + blocks = x.reshape(M // 16, 16, K) + masked = blocks * sign.view(1, 16, 1) + rotated = torch.einsum("ij,bjk->bik", H, masked) + return (rotated.abs() / 4.0).reshape(-1, K).amax(dim=0) + + +def _reference_amax_raw(x_bf16: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Raw per-row + per-col absolute max (no RHT, bf16 -> fp32 first).""" + x = x_bf16.to(torch.float32) + return x.abs().amax(dim=1), x.abs().amax(dim=0) + + +def _allocate_per_token_buffers(M: int, K: int, device: torch.device): + """Match the layout that ``tex.nvfp4_per_token_quantize`` writes.""" + return { + "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), + "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), + "ra": torch.empty((M,), dtype=torch.float32, device=device), + "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), + "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), + "ca": torch.empty((K,), dtype=torch.float32, device=device), + } + + +def _dequant_fp4_with_outer_amax( + q_packed: torch.Tensor, # (R, C // 2) uint8 packed FP4 + s_dec: torch.Tensor, # (R, C // 16) e4m3 held as uint8 + outer_amax: torch.Tensor, # (R,) fp32 +) -> torch.Tensor: + """Decode a rowwise FP4 tensor back to fp32 using the kernel's own + arithmetic: x_hat = qcode * s_dec_e4m3 * (6 / S_enc_row), + S_enc_row = (448 * 6) / max(outer_amax, 1e-12). + """ + R, half_C = q_packed.shape + C = half_C * 2 + s_dec_f = s_dec.view(torch.float8_e4m3fn).to(torch.float32) + + lo = (q_packed & 0x0F).to(torch.int8) + hi = ((q_packed >> 4) & 0x0F).to(torch.int8) + interleaved = torch.stack([lo, hi], dim=-1).reshape(R, C) + # NVFP4 E2M1 LUT (sign-magnitude): 0000..0111 map to {0, 0.5, 1, 1.5, + # 2, 3, 4, 6}; 1000..1111 are the negatives. + fp4_lut = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + dtype=torch.float32, device=q_packed.device, + ) + fp4_val = fp4_lut[interleaved.to(torch.int64)] + + fp8_max = 448.0 + fp4_max = 6.0 + safe_amax = torch.clamp(outer_amax, min=1e-12) + S_enc_row = (fp8_max * fp4_max) / safe_amax + inv_S = (1.0 / S_enc_row).unsqueeze(1) + + block_scale_inv = s_dec_f * inv_S + block_scale_inv = block_scale_inv.repeat_interleave(BLOCK_K, dim=1) + + return fp4_val * block_scale_inv + + +# ----- (5a) K1 RHT: standalone amax kernel ---------------------------------- + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +def test_per_token_k1_with_rht_false_equals_raw_amax(M: int, K: int) -> None: + """Regression: with_rht=False reproduces raw bf16->fp32 amax along each axis.""" + torch.manual_seed(0xABCD * (M + 1) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + + tex.nvfp4_per_token_amax( + x, row_amax, col_amax, True, True, + with_rht=False, random_sign_mask_t=0, + ) + + ref_row, ref_col = _reference_amax_raw(x) + torch.testing.assert_close(row_amax, ref_row, rtol=0.0, atol=0.0, + msg=f"row_amax mismatch at ({M}, {K})") + torch.testing.assert_close(col_amax, ref_col, rtol=0.0, atol=0.0, + msg=f"col_amax mismatch at ({M}, {K})") + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +@pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF, 0x5A5A]) +def test_per_token_k1_with_rht_matches_reference( + M: int, K: int, mask: int, +) -> None: + """with_rht=True col_amax matches max|H*D*x_block|/4; rowwise stays raw.""" + torch.manual_seed(0xDEAD * (M + 7) + (K + 3) + mask) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + + tex.nvfp4_per_token_amax( + x, row_amax, col_amax, True, True, + with_rht=True, random_sign_mask_t=mask, + ) + + ref_row, _ = _reference_amax_raw(x) + torch.testing.assert_close(row_amax, ref_row, rtol=0.0, atol=0.0, + msg=f"row_amax mismatch at ({M}, {K}, mask=0x{mask:04X})") + + # Col tolerance accounts for bf16->fp32 promotion noise + butterfly + # summation order vs. einsum reduction order. + ref_col = _reference_col_amax_rht(x, mask) + torch.testing.assert_close( + col_amax, ref_col, rtol=2e-3, atol=1e-4, + msg=f"col_amax (RHT) mismatch at ({M}, {K}, mask=0x{mask:04X})", + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(128, 128), (256, 512)]) +def test_per_token_k1_with_rht_zero_mask_is_hadamard_only(M: int, K: int) -> None: + """mask=0 -> D=I; col_amax equals bare Hadamard amax max|H*x_block|/4.""" + torch.manual_seed(0xC0DE * (M + 11) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + row_amax = torch.empty((M,), dtype=torch.float32, device=device) + col_amax = torch.empty((K,), dtype=torch.float32, device=device) + + tex.nvfp4_per_token_amax( + x, row_amax, col_amax, True, True, + with_rht=True, random_sign_mask_t=0, + ) + + H = _walsh_hadamard_16(device) + x_fp32 = x.to(torch.float32) + blocks = x_fp32.reshape(M // 16, 16, K) + rotated = torch.einsum("ij,bjk->bik", H, blocks) + ref_col = (rotated.abs() / 4.0).reshape(-1, K).amax(dim=0) + + torch.testing.assert_close( + col_amax, ref_col, rtol=2e-3, atol=1e-4, + msg=f"col_amax (RHT, mask=0) mismatch at ({M}, {K})", + ) + + +# ----- (5b) K2 + composite RHT: encode kernel and composite quantize -------- + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +def test_per_token_composite_with_rht_false_byte_equal(M: int, K: int) -> None: + """Regression: with_rht=False composite byte-equals the default (no-kwargs) path.""" + torch.manual_seed(0xCAFE * (M + 1) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + bufs_default = _allocate_per_token_buffers(M, K, device) + bufs_explicit = _allocate_per_token_buffers(M, K, device) + + tex.nvfp4_per_token_quantize( + x, bufs_default["q_row"], bufs_default["s_row"], bufs_default["ra"], + bufs_default["q_col"], bufs_default["s_col"], bufs_default["ca"], + True, True, + ) + tex.nvfp4_per_token_quantize( + x, bufs_explicit["q_row"], bufs_explicit["s_row"], bufs_explicit["ra"], + bufs_explicit["q_col"], bufs_explicit["s_col"], bufs_explicit["ca"], + True, True, + with_rht=False, random_sign_mask_t=0xACE1, + ) + + for k in ("q_row", "s_row", "ra", "q_col", "s_col", "ca"): + assert torch.equal(bufs_default[k], bufs_explicit[k]), ( + f"with_rht=False not byte-equal to default path on `{k}` at ({M}, {K})" + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _RHT_SHAPES) +def test_per_token_composite_rowwise_unchanged_under_rht(M: int, K: int) -> None: + """Rowwise FP4 + inner SF + row amax byte-equal across with_rht=False / True.""" + torch.manual_seed(0xBEEF * (M + 3) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + bufs_no_rht = _allocate_per_token_buffers(M, K, device) + bufs_with_rht = _allocate_per_token_buffers(M, K, device) + + tex.nvfp4_per_token_quantize( + x, bufs_no_rht["q_row"], bufs_no_rht["s_row"], bufs_no_rht["ra"], + bufs_no_rht["q_col"], bufs_no_rht["s_col"], bufs_no_rht["ca"], + True, True, + with_rht=False, random_sign_mask_t=0, + ) + tex.nvfp4_per_token_quantize( + x, bufs_with_rht["q_row"], bufs_with_rht["s_row"], bufs_with_rht["ra"], + bufs_with_rht["q_col"], bufs_with_rht["s_col"], bufs_with_rht["ca"], + True, True, + with_rht=True, random_sign_mask_t=0xACE1, + ) + + for k in ("q_row", "s_row", "ra"): + assert torch.equal(bufs_no_rht[k], bufs_with_rht[k]), ( + f"rowwise output differs between with_rht=False/True on `{k}` " + f"at ({M}, {K}) -- rowwise should never see RHT." + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(128, 128), (256, 512), (512, 512)]) +@pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF]) +def test_per_token_composite_with_rht_col_dequant_matches_reference( + M: int, K: int, mask: int, +) -> None: + """Dequant'd col FP4 (with_rht=True) ~ H*D*x_block/sqrt(16); checks + column-aggregate median + p99 relative error (FP4's 16-code grain and + butterfly permutation make element-wise comparison too loose). + """ + torch.manual_seed(0xFEED * (M + 5) + K + mask) + device = torch.device("cuda") + # Scale down so most blocks land in non-saturating FP4 (else we measure + # clamping noise, not RHT). + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) * 0.5 + + bufs = _allocate_per_token_buffers(M, K, device) + tex.nvfp4_per_token_quantize( + x, bufs["q_row"], bufs["s_row"], bufs["ra"], + bufs["q_col"], bufs["s_col"], bufs["ca"], + True, True, + with_rht=True, random_sign_mask_t=mask, + ) + + H = _walsh_hadamard_16(device) + sign = _sign_diag_16(mask, device) + x_fp32 = x.to(torch.float32) + blocks = x_fp32.reshape(M // 16, 16, K) + masked = blocks * sign.view(1, 16, 1) + rotated = torch.einsum("ij,bjk->bik", H, masked) # (M/16, 16, K) + y_ref = rotated.reshape(M, K) / 4.0 # (M, K) + y_ref_col_view = y_ref.transpose(0, 1).contiguous() # (K, M) + + y_kernel = _dequant_fp4_with_outer_amax( + bufs["q_col"], bufs["s_col"], bufs["ca"], + ) # (K, M) + + diff = (y_kernel - y_ref_col_view).abs() + col_outer = bufs["ca"].unsqueeze(1).clamp(min=1e-6) + rel = diff / col_outer + p99 = torch.quantile(rel.flatten(), 0.99).item() + median = rel.median().item() + assert median < 0.1, ( + f"median per-element relative error too large: {median:.4f} > 0.1 " + f"at ({M}, {K}, mask=0x{mask:04X})" + ) + assert p99 < 0.5, ( + f"p99 per-element relative error too large: {p99:.4f} > 0.5 " + f"at ({M}, {K}, mask=0x{mask:04X})" + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(128, 128), (256, 256)]) +def test_per_token_composite_with_rht_col_amax_matches_k1( + M: int, K: int, +) -> None: + """Composite col_amax byte-equals standalone K1 amax with the same mask.""" + torch.manual_seed(0xDADA * (M + 13) + K) + device = torch.device("cuda") + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + mask = 0xACE1 + + bufs = _allocate_per_token_buffers(M, K, device) + tex.nvfp4_per_token_quantize( + x, bufs["q_row"], bufs["s_row"], bufs["ra"], + bufs["q_col"], bufs["s_col"], bufs["ca"], + True, True, + with_rht=True, random_sign_mask_t=mask, + ) + + ra_k1 = torch.empty((M,), dtype=torch.float32, device=device) + ca_k1 = torch.empty((K,), dtype=torch.float32, device=device) + tex.nvfp4_per_token_amax( + x, ra_k1, ca_k1, True, True, + with_rht=True, random_sign_mask_t=mask, + ) + + torch.testing.assert_close(bufs["ca"], ca_k1, rtol=0.0, atol=0.0, + msg=f"composite ca != K1-only ca at ({M}, {K})") + torch.testing.assert_close(bufs["ra"], ra_k1, rtol=0.0, atol=0.0, + msg=f"composite ra != K1-only ra at ({M}, {K})") diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py index 7c617200cd..5dd2588021 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py @@ -1,12 +1,13 @@ """Correctness tests for grouped (multi-tensor) NVFP4 per-token cast. The grouped kernel must be byte-equal to a for-loop of single-tensor -calls. Covers composite K1+K2, K1-only, single-split, and many-split. +calls. Covers composite K1+K2, K1-only, single-split, many-split, and +optional RHT (random Hadamard transform) on the column direction. """ from __future__ import annotations -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple import pytest import torch @@ -21,6 +22,9 @@ RefNVFP4TensorPerToken, nvfp4_per_token_quantize, ) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_per_token_group import ( + nvfp4_per_token_group_quantize, +) def _has_fp4() -> bool: @@ -356,3 +360,187 @@ def test_group_many_splits_byte_equal(n_splits: int) -> None: sut[i].columnwise_data, oracle[i].columnwise_data, atol=0.0, rtol=0.0 ) torch.testing.assert_close(sut[i].col_amax, oracle[i].col_amax, atol=0.0, rtol=0.0) + + +# ============================================================================= +# (5) RHT correctness: grouped K1+K2 with optional col-wise RHT. +# Contract: each split's 6 outputs MUST byte-equal single-tensor with the +# same mask. Row direction never sees RHT. +# ============================================================================= + +_RHT_GROUP_SHAPES: List[Tuple[List[int], int]] = [ + ([128, 128], 128), # 2 splits, smallest legal shape + ([128, 256, 128], 256), # 3 splits, mixed sizes + ([256, 256, 256, 256], 512), # 4 equal splits, larger K + ([128, 384], 128), # 2 splits, very asymmetric +] + + +def _rht_pt_buffers(M: int, K: int, device: torch.device): + """Match the layout that ``tex.nvfp4_per_token_quantize`` writes.""" + return { + "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), + "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), + "ra": torch.empty((M,), dtype=torch.float32, device=device), + "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), + "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), + "ca": torch.empty((K,), dtype=torch.float32, device=device), + } + + +def _split_views(x_concat: torch.Tensor, splits: Sequence[int]) -> List[torch.Tensor]: + out, off = [], 0 + for s in splits: + out.append(x_concat[off : off + s].contiguous()) + off += int(s) + return out + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) +def test_group_with_rht_false_byte_equal_to_default( + splits: List[int], K: int, +) -> None: + """Regression: with_rht=False grouped byte-equals the default (no-kwargs) path.""" + torch.manual_seed(0xCAFE * (sum(splits) + 1) + K + len(splits)) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + + outs_default = nvfp4_per_token_group_quantize( + x, splits, rowwise=True, columnwise=True, + ) + outs_explicit_false = nvfp4_per_token_group_quantize( + x, splits, rowwise=True, columnwise=True, + with_rht=False, random_sign_mask_t=0xACE1, + ) + + assert len(outs_default) == len(outs_explicit_false) == len(splits) + for i, (a, b) in enumerate(zip(outs_default, outs_explicit_false)): + for attr in ("data", "scale", "row_amax", + "columnwise_data", "columnwise_scale", "col_amax"): + ta, tb = getattr(a, attr), getattr(b, attr) + assert torch.equal(ta, tb), ( + f"split[{i}].{attr} differs between default and explicit " + f"with_rht=False at K={K}, splits={splits}" + ) + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) +def test_group_rowwise_unchanged_under_rht( + splits: List[int], K: int, +) -> None: + """Rowwise outputs byte-equal across with_rht=False / True.""" + torch.manual_seed(0xBEEF * (sum(splits) + 3) + K) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + + outs_no_rht = nvfp4_per_token_group_quantize( + x, splits, rowwise=True, columnwise=True, + with_rht=False, random_sign_mask_t=0, + ) + outs_with_rht = nvfp4_per_token_group_quantize( + x, splits, rowwise=True, columnwise=True, + with_rht=True, random_sign_mask_t=0xACE1, + ) + + for i, (a, b) in enumerate(zip(outs_no_rht, outs_with_rht)): + for attr in ("data", "scale", "row_amax"): + ta, tb = getattr(a, attr), getattr(b, attr) + assert torch.equal(ta, tb), ( + f"split[{i}].{attr} differs between with_rht=False and =True " + f"on the ROW direction at K={K}, splits={splits} -- " + f"rowwise should never see RHT." + ) + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) +@pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF]) +def test_group_with_rht_equals_single_tensor_per_split( + splits: List[int], K: int, mask: int, +) -> None: + """Each split's 6 outputs byte-equal single-tensor with the same mask.""" + torch.manual_seed(0xDADA * (sum(splits) + 11) + K + mask) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + + outs_grouped = nvfp4_per_token_group_quantize( + x, splits, rowwise=True, columnwise=True, + with_rht=True, random_sign_mask_t=mask, + ) + + x_splits = _split_views(x, splits) + for i, (x_i, out_g) in enumerate(zip(x_splits, outs_grouped)): + M_i = x_i.size(0) + bufs = _rht_pt_buffers(M_i, K, device) + tex.nvfp4_per_token_quantize( + x_i, bufs["q_row"], bufs["s_row"], bufs["ra"], + bufs["q_col"], bufs["s_col"], bufs["ca"], + True, True, + with_rht=True, random_sign_mask_t=mask, + ) + + mapping = { + "data": ("q_row", out_g.data), + "scale": ("s_row", out_g.scale.view(torch.uint8)), + "row_amax": ("ra", out_g.row_amax), + "columnwise_data": ("q_col", out_g.columnwise_data), + "columnwise_scale": ("s_col", out_g.columnwise_scale.view(torch.uint8)), + "col_amax": ("ca", out_g.col_amax), + } + for attr, (single_key, grouped_t) in mapping.items(): + single_t = bufs[single_key] + assert single_t.shape == grouped_t.shape, ( + f"split[{i}].{attr} shape mismatch: grouped={grouped_t.shape}, " + f"single-tensor={single_t.shape} at K={K}, splits={splits}, mask=0x{mask:04X}" + ) + assert torch.equal(grouped_t, single_t), ( + f"split[{i}].{attr} grouped result differs from single-tensor " + f"reference at K={K}, splits={splits}, mask=0x{mask:04X}" + ) + + +@_GATED_FP4 +@pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES[:2]) +def test_group_k1_amax_matches_single_tensor_per_split_under_rht( + splits: List[int], K: int, +) -> None: + """Grouped K1 amax byte-equals single-tensor K1 per split. Isolates K1 + via the lighter nvfp4_per_token_group_amax binding to catch K1-vs-K2 + divergences earlier than the full composite check. + """ + torch.manual_seed(0x1234 * (sum(splits) + 7) + K) + device = torch.device("cuda") + sum_M = sum(splits) + x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() + mask = 0xACE1 + + row_amax_list = [ + torch.empty((int(s),), dtype=torch.float32, device=device) for s in splits + ] + col_amax_list = [ + torch.empty((K,), dtype=torch.float32, device=device) for _ in splits + ] + tex.nvfp4_per_token_group_amax( + x, [int(s) for s in splits], row_amax_list, col_amax_list, + True, True, + with_rht=True, random_sign_mask_t=mask, + ) + + x_splits = _split_views(x, splits) + for i, (x_i, ra_g, ca_g) in enumerate(zip(x_splits, row_amax_list, col_amax_list)): + M_i = x_i.size(0) + ra_s = torch.empty((M_i,), dtype=torch.float32, device=device) + ca_s = torch.empty((K,), dtype=torch.float32, device=device) + tex.nvfp4_per_token_amax( + x_i, ra_s, ca_s, True, True, + with_rht=True, random_sign_mask_t=mask, + ) + torch.testing.assert_close(ra_g, ra_s, rtol=0.0, atol=0.0, + msg=f"split[{i}] row_amax mismatch (K1 only)") + torch.testing.assert_close(ca_g, ca_s, rtol=0.0, atol=0.0, + msg=f"split[{i}] col_amax mismatch (K1 only)") diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu index c9e7eec391..03ef547834 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu @@ -29,11 +29,11 @@ #include -#include "common/cast/core/common.cuh" -#include "common/cast/nvfp4/core_nvfp4.cuh" #include "common/common.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" +#include "common/cast/core/common.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" namespace transformer_engine { namespace nvfp4_per_token { @@ -41,76 +41,73 @@ namespace nvfp4_per_token { #if FP4_TYPE_SUPPORTED using dispatch::common::align_smem_ptr_per_TMA_requirements; -using dispatch::nvfp4::nvfp4_scale_t; using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; +using dispatch::nvfp4::nvfp4_scale_t; -constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input -constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input -constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height -constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width -constexpr int THREADS_NUM = 128; // threads per CTA -constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM -constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) -constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int THREADS_NUM = 128; // threads per CTA +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; // = 2 ping-pong input buffers // Derived (chunk / tile / stage) constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; // 2 constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; // 2 -constexpr int STAGES = TILES_Y * TILES_X; // 4 +constexpr int STAGES = TILES_Y * TILES_X; // 4 constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 inner blocks per row of the chunk constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 inner blocks per col of the chunk -constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 -constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 // Encode helpers' thread layout (rowwise pass: 4x32 = K-dim x M-dim) -constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 -constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 -constexpr int THREADS_PER_SCALE_ROWWISE = - SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 -// Encode helpers' thread layout (colwise pass: tid.X for col, warp for M-block) -constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 cols per warp -constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 (warps) - // Buffer dimensions (input bf16 SMEM tiles + FP4 output SMEM tiles for TMA store) constexpr int BUFF_IN_DIM_Y = TILE_DIM_Y; constexpr int BUFF_IN_DIM_X = TILE_DIM_X; -constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; -constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) -constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; -constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 -constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; -constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) -constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) +constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose // Manual swizzling parameters to reduce SMEM bank conflicts on rowwise loads constexpr int PACK_SIZE = 8; -constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 -constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 -using IType = bf16; +using IType = bf16; using IType2 = ptx::FPx2; // = ptx::bf16x2 -using IType3D = IType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; -using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using IType3D = IType [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; -using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; // Compute the per-block (1x16) byte-equal arithmetic and emit FP4 codes into // SMEM rowwise output buffer + e4m3 scale into SMEM rowwise scale buffer. __device__ __forceinline__ void rowwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_ptr, nvfp4_scale_t* __restrict__ sSFrowwise_ptr, - const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row - const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { + const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out) { const auto& sIn = *reinterpret_cast(sIn_ptr); auto& sOut = *reinterpret_cast(sOut_ptr); auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -118,14 +115,12 @@ __device__ __forceinline__ void rowwise_scaling_per_token( const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 - const int thread_offset_X_rowwise = - tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) - const int SF_thread_offset_rowwise_X = - tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); const int stage_rowwise_scales_offset_X = @@ -157,8 +152,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); } } - const float block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float block_amax = static_cast( + __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); // Byte-equal compute path (matches the Python reference in // ``NVFP4QuantizerPerTokenRef``): @@ -198,39 +193,92 @@ __device__ __forceinline__ void rowwise_scaling_per_token( } } -// Compute the per-block (1x16, along M) byte-equal arithmetic for the columnwise -// pass; emit transposed FP4 + e4m3 scale into SMEM. +// Randomized Hadamard Transform helpers (per-thread, 16-wide). Used by the +// optional col-wise RHT path (kWithRht=true) in K1 amax and K2 colwise cast; +// K1 and K2 must consume identical helper output for the encoded FP4 and +// outer SF to be self-consistent (mismatch -> saturated codes / wrong SF). + +// Apply +/-1 sign diagonal D then a 16-pt Walsh-Hadamard butterfly in place. +// Output is NOT normalized; caller multiplies by k16HadamardNorm (0.25). +// Sign-flip is a branchless XOR on the fp32 sign bit (bit-exact == r = -r on +// finite fp32, which is all this helper sees from bf16 SMEM reads). +__device__ __forceinline__ void apply_signed_fht16_inplace( + float r[16], uint32_t random_sign_mask) { +#pragma unroll + for (int i = 0; i < 16; ++i) { + const uint32_t bits = __float_as_uint(r[i]); + const uint32_t flip = ((random_sign_mask >> i) & 1u) << 31; + r[i] = __uint_as_float(bits ^ flip); + } +#pragma unroll + for (int stride = 1; stride < 16; stride <<= 1) { +#pragma unroll + for (int g = 0; g < 16; g += stride << 1) { +#pragma unroll + for (int j = 0; j < stride; ++j) { + const float a = r[g + j]; + const float b = r[g + j + stride]; + r[g + j] = a + b; + r[g + j + stride] = a - b; + } + } + } +} + +__device__ __forceinline__ float amax_16_abs(const float r[16]) { + float m = 0.f; +#pragma unroll + for (int i = 0; i < 16; ++i) m = fmaxf(m, fabsf(r[i])); + return m; +} + +// 1/sqrt(16) normalization for the 16-pt Hadamard so H*H^T = I after sign +// scaling. Applied once per block on K1 amax / K2 block_scale. +constexpr float k16HadamardNorm = 0.25f; + +// Per-block (1x16 along M) columnwise FP4 cast; writes transposed FP4 + +// e4m3 SF to SMEM. When kWithRht=true, each thread's 16-row strip is rotated +// through the FHT with random_sign_mask_t; K1 amax must use the same mask so +// sColAmax already reflects the rotated columns. +template __device__ __forceinline__ void colwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_tr_ptr, nvfp4_scale_t* __restrict__ sSFcolwise_ptr, - const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col - const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr) { + const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out_tr, + const uint32_t random_sign_mask_t = 0u) { const auto& sIn2x = *reinterpret_cast(sIn_ptr); auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); - const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) - const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) + const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) - const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 - const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 + const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) const int in_thread_offset_Y = thread_offset_Y_colwise; - const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] + const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] - const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y - const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index + const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index - const int scale_tr_offset_Y = - (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) - const int scale_tr_offset_X = - (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index __align__(8) IType rIn[2][SCALE_DIM]; - // Read 2 columns x 16 rows, accumulate per-column amax. + // RHT staging in fp32 from FHT through mul_cvt_4x: avoids the lossy + // fp32->bf16->fp32 round-trip and lets us fold the 0.25 normalization into + // block_scale. Untouched by the non-RHT instantiation (nvcc DCE). + float rRht[2][SCALE_DIM]; + + // Non-RHT path accumulates the 1x16 block amax during the load; RHT path + // recomputes it after the butterfly so we skip abs_max_2x here. IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int i = 0; i < SCALE_DIM; ++i) { @@ -238,11 +286,35 @@ __device__ __forceinline__ void colwise_scaling_per_token( ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); rIn[0][i] = elt_pair.x; rIn[1][i] = elt_pair.y; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + if constexpr (!kWithRht) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + } + + // 1x16 block amax used to calibrate the inner FP4 scale. + float block_amax[2]; + if constexpr (kWithRht) { +#pragma unroll + for (int w = 0; w < 2; ++w) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + rRht[w][i] = static_cast(rIn[w][i]); + } + apply_signed_fht16_inplace(rRht[w], random_sign_mask_t); + float local_max = 0.f; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + local_max = fmaxf(local_max, fabsf(rRht[w][i])); + } + // amax(|r * 0.25|) == amax(|r|) * 0.25 (exact: 0.25 = 2^-2). One + // post-amax mul instead of 16 per-element muls; matching 0.25 folded + // into block_scale_rht below. + block_amax[w] = local_max * k16HadamardNorm; + } + } else { + block_amax[0] = static_cast(__habs(thread_amax_2x.x)); + block_amax[1] = static_cast(__habs(thread_amax_2x.y)); } - // NOTE: thread_amax_2x.x is the amax of column .x; thread_amax_2x.y is amax of column .y. - const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), - static_cast(__habs(thread_amax_2x.y))}; #pragma unroll for (int w = 0; w < 2; ++w) { @@ -258,14 +330,24 @@ __device__ __forceinline__ void colwise_scaling_per_token( // Store e4m3 scale to SMEM colwise SF buffer. sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = s_dec; - // Cast 16 elements to FP4 via 4x mul_cvt_4x (4 elements per call -> 4 calls). - // The 16 rIn[w][...] values are bf16; pack into IType2 pairs. + // 4x mul_cvt_4x emits 16 FP4 codes. RHT path feeds fp32 staging so we + // skip the bf16 round-trip; block_scale_rht folds in 0.25. fp4e2m1x4 qu[4]; + if constexpr (kWithRht) { + const float block_scale_rht = block_scale * k16HadamardNorm; +#pragma unroll + for (int e = 0; e < 4; ++e) { + const ptx::floatx2 in01{rRht[w][4 * e + 0], rRht[w][4 * e + 1]}; + const ptx::floatx2 in23{rRht[w][4 * e + 2], rRht[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale_rht); + } + } else { #pragma unroll - for (int e = 0; e < 4; ++e) { - IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; - IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; - ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + for (int e = 0; e < 4; ++e) { + IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; + IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + } } // Pack 4 fp4e2m1x4 (= 16 fp4) into a 64-bit value and store to SMEM transpose buffer. @@ -280,16 +362,22 @@ __device__ __forceinline__ void colwise_scaling_per_token( // ============================================================================= // Kernel 2: per-token encode (rowwise + optional colwise transpose). +// kWithRht=true: col-wise FP4 cast over RHT-rotated strips, matching K1's +// RHT-rotated columnwise_amax. Row direction never sees RHT. // ============================================================================= -template -__global__ void __launch_bounds__(THREADS_NUM) - per_token_encode_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t* const scales_ptr, nvfp4_scale_t* const scales_t_ptr, - const float* const row_amax_in, const float* const col_amax_in, - const float* noop, const size_t rows, const size_t cols, - const size_t scale_stride, const size_t scale_stride_t) { +template +__global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t* const scales_ptr, + nvfp4_scale_t* const scales_t_ptr, + const float* const row_amax_in, + const float* const col_amax_in, + const float* noop, + const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t, + const uint32_t random_sign_mask_t) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -320,26 +408,24 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; + TMA_SHMEM_ALIGNMENT) : 0; constexpr int out_mem_colwise_scales = DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; + TMA_SHMEM_ALIGNMENT) : 0; extern __shared__ unsigned char dynamic_shmem[]; unsigned char* dshmem = align_smem_ptr_per_TMA_requirements(dynamic_shmem); - IType* sIn_ptr = reinterpret_cast(dshmem); - fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); - fp4e2m1x2* sOut_tr_ptr = - reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data); nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t* sSFcolwise_ptr = - reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + - out_mem_colwise_data + out_mem_rowwise_scales); + nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data + + out_mem_rowwise_scales); // Per-CTA row/col amax SMEM cache (128 floats each). __shared__ float sRowAmax[CHUNK_DIM_Y]; @@ -393,8 +479,8 @@ __global__ void __launch_bounds__(THREADS_NUM) uint64_t* dst = reinterpret_cast(&sIn[buff_in]); const uint64_t* src = reinterpret_cast(&tensor_map_input); ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, - &IN_buff_readable_mbar[buff_in]); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, src, global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); } } @@ -422,17 +508,18 @@ __global__ void __launch_bounds__(THREADS_NUM) if (leading_thread) { uint64_t* dst = reinterpret_cast(&sIn[next_prefetch_buff]); const uint64_t* src = reinterpret_cast(&tensor_map_input); - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, next_global_offset_X, - next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + ptx::mbarrier_arrive_expect_tx( + &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, src, next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for current stage's input to land. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // Wait for any prior TMA store to have finished reading the output SMEM @@ -441,12 +528,14 @@ __global__ void __launch_bounds__(THREADS_NUM) // ----- Compute: rowwise + colwise from the same SMEM tile ----- if (DO_ROW) { - rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, - buff_in, buff_out); + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, + sRowAmax, stage_Y, stage_X, buff_in, buff_out); } if (DO_COL) { - colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, stage_X, - buff_in, buff_out_tr); + colwise_scaling_per_token( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, + sColAmax, stage_Y, stage_X, buff_in, buff_out_tr, + random_sign_mask_t); } // Fence + sync so all threads' SMEM writes are visible to TMA store. @@ -463,20 +552,22 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_ROW) { auto& sOut = *reinterpret_cast(sOut_ptr); ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&tensor_map_output), + global_offset_X, global_offset_Y, reinterpret_cast(&sOut[buff_out])); } if (DO_COL) { auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, - global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); + reinterpret_cast(&tensor_map_output_t), + global_offset_X_tr, global_offset_Y_tr, + reinterpret_cast(&sOut_tr[buff_out_tr])); } ptx::cp_async_bulk_commit_group(); } - buff_in = (buff_in + 1) % BUFFS_NUM; - buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; } // end of stages @@ -492,7 +583,8 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t row_global = scales_block_offset_Y_rowwise + row; if (row_global < rows) { ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); - const size_t scale_idx_global = row_global * scale_stride + scales_block_offset_X_rowwise; + const size_t scale_idx_global = + row_global * scale_stride + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); } } @@ -507,7 +599,8 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); - const size_t scale_idx_global = row_tr_global * scale_stride_t + scales_block_offset_X_tr; + const size_t scale_idx_global = + row_tr_global * scale_stride_t + scales_block_offset_X_tr; scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); } } @@ -538,13 +631,18 @@ __global__ void __launch_bounds__(THREADS_NUM) // if t in [stage_Y*64, stage_Y*64+64): scan 64 cols of sub-tile for row t // if t in [stage_X*64, stage_X*64+64): scan 64 rows of sub-tile for col t // After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. +// +// kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread +// FHT with random_sign_mask_t). Row direction never sees RHT. // ============================================================================= -template -__global__ void __launch_bounds__(THREADS_NUM) - per_token_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW - float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL - const float* noop, const size_t rows, const size_t cols) { +template +__global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW + float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL + const float* noop, + const size_t rows, const size_t cols, + const uint32_t random_sign_mask_t) { // col-only; low 16 bits = signs #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -579,8 +677,8 @@ __global__ void __launch_bounds__(THREADS_NUM) // i.e., this thread contributes to row partial in stages // where stage_Y == tid / 64. // col owned: col_base + tid -> stage_X == tid / 64. - const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 - const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 + const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 + const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 const int my_row_in_subtile = tid % TILE_DIM_Y; // 0..63 const int my_col_in_subtile = tid % TILE_DIM_X; // 0..63 @@ -605,8 +703,8 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in]), - reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, - &IN_buff_readable_mbar[buff_in]); + reinterpret_cast(&tensor_map_input), + global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); } } @@ -627,18 +725,20 @@ __global__ void __launch_bounds__(THREADS_NUM) const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::mbarrier_arrive_expect_tx( + &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), next_global_offset_X, - next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), + next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for this stage's tile. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // ----- Row partial update: walk this thread's row across the sub-tile ----- @@ -662,26 +762,41 @@ __global__ void __launch_bounds__(THREADS_NUM) for (int p = 0; p < 4; ++p) { ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); } - local_max = - fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + local_max = fmaxf(local_max, + static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); } row_partial = local_max; } // ----- Col partial update: walk this thread's col down the sub-tile ----- if (DO_COL && stage_X == my_col_stage_X) { - // Scan 64 rows for our col. Single-column access pattern (1 byte stride - // per row in SMEM); we read 1 bf16 at a time. Bank conflicts mitigated - // by 64-wide tile (column stride = TILE_DIM_X * 2 = 128 bytes, which is - // 1 bank * 32 rows; with 32 threads on different cols, conflicts hit - // groups of 32 -> serialized 32-way, accepted for v1). - float local_max = col_partial; + if constexpr (kWithRht) { + // 4 contiguous 16-row blocks per sub-tile, one FHT per block; amax + // is taken over the rotated values. +#pragma unroll + for (int blk = 0; blk < TILE_DIM_Y / 16; ++blk) { + float r[16]; #pragma unroll - for (int e = 0; e < TILE_DIM_Y; ++e) { - const IType v = sIn[buff_in][e][my_col_in_subtile]; - local_max = fmaxf(local_max, fabsf(static_cast(v))); + for (int i = 0; i < 16; ++i) { + r[i] = static_cast(sIn[buff_in][blk * 16 + i][my_col_in_subtile]); + } + apply_signed_fht16_inplace(r, random_sign_mask_t); + col_partial = fmaxf(col_partial, amax_16_abs(r) * k16HadamardNorm); + } + } else { + // Scan 64 rows for our col. Single-column access pattern (1 byte stride + // per row in SMEM); we read 1 bf16 at a time. Bank conflicts mitigated + // by 64-wide tile (column stride = TILE_DIM_X * 2 = 128 bytes, which is + // 1 bank * 32 rows; with 32 threads on different cols, conflicts hit + // groups of 32 -> serialized 32-way, accepted for v1). + float local_max = col_partial; +#pragma unroll + for (int e = 0; e < TILE_DIM_Y; ++e) { + const IType v = sIn[buff_in][e][my_col_in_subtile]; + local_max = fmaxf(local_max, fabsf(static_cast(v))); + } + col_partial = local_max; } - col_partial = local_max; } __syncthreads(); @@ -714,9 +829,13 @@ __global__ void __launch_bounds__(THREADS_NUM) // ============================================================================= #if FP4_TYPE_SUPPORTED -// Launch Kernel 1 (amax). Writes only to output->amax / output->columnwise_amax; -// other output fields untouched. Pre-zeroes the amax buffers (atomicMax identity). -inline void launch_amax(const Tensor& input, Tensor* output, const Tensor& noop, +// Launch Kernel 1 (amax). Pre-zeroes the amax buffers (atomicMax identity). +// with_rht=true applies a 16-pt RHT on the col direction before amax; +// random_sign_mask_t carries the 16-bit sign pattern (ignored when false). +inline void launch_amax(const Tensor& input, Tensor* output, + const Tensor& noop, + const bool with_rht, + const uint32_t random_sign_mask_t, cudaStream_t stream) { const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); @@ -727,49 +846,63 @@ inline void launch_amax(const Tensor& input, Tensor* output, const Tensor& noop, // Pre-zero amax buffers (atomicMaxFloat identity for non-negative values). if (do_row) { - NVTE_CHECK(output->amax.numel() == M, "Per-token amax: output->amax numel must equal M = ", M, + NVTE_CHECK(output->amax.numel() == M, + "Per-token amax: output->amax numel must equal M = ", M, ", got ", output->amax.numel()); NVTE_CHECK_CUDA(cudaMemsetAsync(output->amax.dptr, 0, M * sizeof(float), stream)); } if (do_col) { NVTE_CHECK(output->columnwise_amax.numel() == K, - "Per-token amax: output->columnwise_amax numel must equal K = ", K, ", got ", - output->columnwise_amax.numel()); + "Per-token amax: output->columnwise_amax numel must equal K = ", K, + ", got ", output->columnwise_amax.numel()); NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_amax.dptr, 0, K * sizeof(float), stream)); } checkCuDriverContext(stream); alignas(64) CUtensorMap tmap_in{}; - create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + create_2D_tensor_map(tmap_in, input.data, M, K, + TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad - dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - const float* noop_ptr = - (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { - auto kernel = per_token_amax_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tmap_in, do_row ? reinterpret_cast(output->amax.dptr) : nullptr, - do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, noop_ptr, M, - K); - });); + const float* noop_ptr = (noop.data.dptr != nullptr) + ? reinterpret_cast(noop.data.dptr) + : nullptr; + + // RHT only matters when colwise amax is computed; collapse to the + // kWithRht=false instantiation otherwise. + const bool with_rht_effective = with_rht && do_col; + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, + TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + auto kernel = per_token_amax_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, + do_row ? reinterpret_cast(output->amax.dptr) : nullptr, + do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, + noop_ptr, M, K, random_sign_mask_t); + }))); NVTE_CHECK_CUDA(cudaGetLastError()); } // Launch Kernel 2 (encode). Requires output->amax / columnwise_amax to be pre-filled // (by a prior launch_amax call or by an external caller); writes // output->data / scale_inv / columnwise_data / columnwise_scale_inv. -inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noop, +// with_rht=true requires K1 amax to have been launched with the SAME mask; +// the composite per_token_quantize path threads this automatically. +inline void launch_encode(const Tensor& input, Tensor* output, + const Tensor& noop, + const bool with_rht, + const uint32_t random_sign_mask_t, cudaStream_t stream) { const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); @@ -801,13 +934,15 @@ inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noo alignas(64) CUtensorMap tmap_out{}; alignas(64) CUtensorMap tmap_out_t{}; - create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + create_2D_tensor_map(tmap_in, input.data, M, K, + TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); if (do_row) { - create_2D_tensor_map(tmap_out, output->data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, 4); + create_2D_tensor_map(tmap_out, output->data, M, K, + TILE_DIM_Y, TILE_DIM_X, K, 0, 4); } if (do_col) { - create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, TILE_DIM_X, TILE_DIM_Y, M, 0, - 4); + create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, + TILE_DIM_X, TILE_DIM_Y, M, 0, 4); } constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; @@ -817,21 +952,29 @@ inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noo DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( - CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales_t = DIVUP_TO_MULTIPLE( - CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = + DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_t = + DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT); // Total dyn SMEM: input + output FP4 (row + col) + SF (row + col) + 128B align. - const int dshmem_size = buff_size_aligned_in + (do_row ? buff_size_aligned_out : 0) + - (do_col ? buff_size_aligned_out_t : 0) + (do_row ? buff_size_scales : 0) + - (do_col ? buff_size_scales_t : 0) + TMA_SHMEM_ALIGNMENT; - - dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); + const int dshmem_size = + buff_size_aligned_in + + (do_row ? buff_size_aligned_out : 0) + + (do_col ? buff_size_aligned_out_t : 0) + + (do_row ? buff_size_scales : 0) + + (do_col ? buff_size_scales_t : 0) + + TMA_SHMEM_ALIGNMENT; + + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - const float* noop_ptr = - (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; + const float* noop_ptr = (noop.data.dptr != nullptr) + ? reinterpret_cast(noop.data.dptr) + : nullptr; const size_t scale_stride = do_row ? output->scale_inv.shape[1] : 0; const size_t scale_stride_t = do_col ? output->columnwise_scale_inv.shape[1] : 0; @@ -839,18 +982,26 @@ inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noo do_row ? reinterpret_cast(output->scale_inv.dptr) : nullptr; nvfp4_scale_t* scales_t_ptr = do_col ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const float* row_amax_in = do_row ? reinterpret_cast(output->amax.dptr) : nullptr; + const float* row_amax_in = + do_row ? reinterpret_cast(output->amax.dptr) : nullptr; const float* col_amax_in = do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr; - TRANSFORMER_ENGINE_SWITCH_CONDITION( - do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { - auto kernel = per_token_encode_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>(tmap_in, tmap_out, tmap_out_t, scales_ptr, - scales_t_ptr, row_amax_in, col_amax_in, - noop_ptr, M, K, scale_stride, scale_stride_t); - });); + // RHT only matters when colwise FP4 is produced; collapse to the + // kWithRht=false instantiation for rowwise-only callers. + const bool with_rht_effective = with_rht && do_col; + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, + TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + auto kernel = per_token_encode_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, tmap_out, tmap_out_t, + scales_ptr, scales_t_ptr, + row_amax_in, col_amax_in, + noop_ptr, M, K, scale_stride, scale_stride_t, + random_sign_mask_t); + }))); NVTE_CHECK_CUDA(cudaGetLastError()); } #endif // FP4_TYPE_SUPPORTED @@ -866,14 +1017,15 @@ inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noo // Output constraints differ by entry point (see validate_*_output helpers below). inline void validate_input_shape(const Tensor& input) { NVTE_CHECK(input.has_data(), "Per-token cast: input has no data."); - NVTE_CHECK(input.dtype() == DType::kBFloat16, "Per-token cast is bf16-only. Got dtype enum ", + NVTE_CHECK(input.dtype() == DType::kBFloat16, + "Per-token cast is bf16-only. Got dtype enum ", static_cast(input.dtype())); const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); - NVTE_CHECK(M % CHUNK_DIM_Y == 0, "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, - ", got M=", M); - NVTE_CHECK(K % CHUNK_DIM_X == 0, "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, - ", got K=", K); + NVTE_CHECK(M % CHUNK_DIM_Y == 0, + "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, ", got M=", M); + NVTE_CHECK(K % CHUNK_DIM_X == 0, + "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, ", got K=", K); } // K1 (amax-only) requires at least one amax buffer allocated; FP4 output is not used. @@ -892,42 +1044,61 @@ inline void validate_encode_output(const Tensor* output) { "Per-token cast emits compact (non-swizzled) inner SF."); } -void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, +// K1 amax with optional col-wise RHT. with_rht=false is byte-equal to the +// pre-RHT per-token K1 path regardless of random_sign_mask_t. +void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, + Tensor* output, + const bool with_rht, + const uint32_t random_sign_mask_t, cudaStream_t stream) { validate_input_shape(input); validate_amax_output(output); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; - launch_amax(input, output, noop, stream); + launch_amax(input, output, noop, with_rht, random_sign_mask_t, stream); } -void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, +// K2 encode with optional col-wise RHT. Caller must have filled +// output->columnwise_amax via K1 amax with the SAME with_rht/mask, else the +// inner SF + FP4 codes are calibrated against mismatched data and saturate. +void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, + Tensor* output, + const bool with_rht, + const uint32_t random_sign_mask_t, cudaStream_t stream) { validate_input_shape(input); validate_encode_output(output); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; - launch_encode(input, output, noop, stream); + launch_encode(input, output, noop, with_rht, random_sign_mask_t, stream); } -void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, +// Composite K1+K2. Both launches receive the same with_rht / mask so the +// colwise amax and FP4 cast see byte-identical data. +void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, + Tensor* output, + const bool with_rht, + const uint32_t random_sign_mask_t, cudaStream_t stream) { validate_input_shape(input); validate_encode_output(output); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; - launch_amax(input, output, noop, stream); - launch_encode(input, output, noop, stream); + launch_amax(input, output, noop, with_rht, random_sign_mask_t, stream); + launch_encode(input, output, noop, with_rht, random_sign_mask_t, stream); } bool can_use_per_token(size_t M, size_t K, DType dtype) { return (dtype == DType::kBFloat16) && (M % CHUNK_DIM_Y == 0) && (K % CHUNK_DIM_X == 0); } #else // !FP4_TYPE_SUPPORTED -void per_token_amax_blocked_impl(const Tensor&, const Tensor&, Tensor*, cudaStream_t) { +void per_token_amax_blocked_impl(const Tensor&, const Tensor&, Tensor*, + bool, uint32_t, cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } -void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, cudaStream_t) { +void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, + bool, uint32_t, cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } -void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, cudaStream_t) { +void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, + bool, uint32_t, cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } bool can_use_per_token(size_t, size_t, DType) { return false; } @@ -940,7 +1111,10 @@ bool can_use_per_token(size_t, size_t, DType) { return false; } // C-API entry points // ============================================================================= -void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, + NVTETensor output, + const int with_rht, + const int random_sign_mask_t, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_amax); @@ -949,18 +1123,25 @@ void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NV Tensor* output_tensor = convertNVTETensorCheck(output); Tensor dummy_noop; const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; - nvfp4_per_token::per_token_amax_blocked_impl(*input_tensor, *noop_tensor, output_tensor, stream); + // C-API takes `int` to match prod's nvte_hadamard_transform_amax convention; + // internally we treat the low 16 bits as a uint32_t bitmask. + nvfp4_per_token::per_token_amax_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, + with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, + stream); #else - (void)input; - (void)noop; - (void)output; - (void)stream; + (void)input; (void)noop; (void)output; (void)with_rht; + (void)random_sign_mask_t; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } -void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, - cudaStream_t stream) { +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, + NVTETensor output, + const int with_rht, + const int random_sign_mask_t, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_encode); using namespace transformer_engine; @@ -968,19 +1149,25 @@ void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, Tensor* output_tensor = convertNVTETensorCheck(output); Tensor dummy_noop; const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; - nvfp4_per_token::per_token_encode_blocked_impl(*input_tensor, *noop_tensor, output_tensor, - stream); + // C-API mirrors nvte_nvfp4_per_token_amax: `int` for cross-language ABI + // safety, internal kernel arg is uint32_t with only the low 16 bits used. + nvfp4_per_token::per_token_encode_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, + with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, + stream); #else - (void)input; - (void)noop; - (void)output; - (void)stream; + (void)input; (void)noop; (void)output; (void)with_rht; + (void)random_sign_mask_t; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } -void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, - cudaStream_t stream) { +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, + NVTETensor output, + const int with_rht, + const int random_sign_mask_t, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_quantize); using namespace transformer_engine; @@ -988,13 +1175,14 @@ void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop Tensor* output_tensor = convertNVTETensorCheck(output); Tensor dummy_noop; const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; - nvfp4_per_token::per_token_quantize_blocked_impl(*input_tensor, *noop_tensor, output_tensor, - stream); + nvfp4_per_token::per_token_quantize_blocked_impl( + *input_tensor, *noop_tensor, output_tensor, + with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, + stream); #else - (void)input; - (void)noop; - (void)output; - (void)stream; + (void)input; (void)noop; (void)output; (void)with_rht; + (void)random_sign_mask_t; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu index 69d00fb139..9eb049443b 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu @@ -30,9 +30,9 @@ namespace nvfp4_per_token_group { #if FP4_TYPE_SUPPORTED -using dispatch::nvfp4::nvfp4_scale_t; using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; +using dispatch::nvfp4::nvfp4_scale_t; using ptx::FPx2; constexpr int kInnerK = 16; // NVFP4 inner block: 16 elements per e4m3 SF @@ -43,14 +43,14 @@ constexpr int kMaxTensorsPerKernel = 64; // Per-launch arg table; passed as __grid_constant__ for constant-cache reads. struct NVFP4PerTokenMultiArgs { // K1 outputs (per-tensor pointers; one fp32 array per tensor) - void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) - void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) + void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) + void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) // K2 outputs (per-tensor pointers; FP4 codes + e4m3 inner SF) - void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) - void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) - void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) - void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) + void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) + void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) + void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) + void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) // Shared layout info int split_sections_range[kMaxTensorsPerKernel + 1]; // prefix sum w/ leading 0 @@ -69,10 +69,10 @@ __device__ __forceinline__ int GetTensorId(const NVFP4PerTokenMultiArgs& args, i // per-tensor buffer via tensor_id lookup at CTA entry. namespace fused { -constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows -constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols -constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height -constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width constexpr int THREADS_NUM = 128; constexpr int PREFETCH_STAGES = 1; constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; @@ -88,16 +88,54 @@ using FusedIType = bf16; using FusedIType2 = ptx::FPx2; using FusedIType3D = FusedIType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +// Randomized Hadamard Transform helpers (per-thread, 16-wide). Direct copy +// of the single-tensor helpers in quantize_nvfp4_per_token.cu; K1 and K2 +// must consume identical output for FP4 + outer SF to be self-consistent. +// TODO: hoist into a shared core header. +__device__ __forceinline__ void apply_signed_fht16_inplace( + float r[16], uint32_t random_sign_mask) { +#pragma unroll + for (int i = 0; i < 16; ++i) { + const uint32_t bits = __float_as_uint(r[i]); + const uint32_t flip = ((random_sign_mask >> i) & 1u) << 31; + r[i] = __uint_as_float(bits ^ flip); + } +#pragma unroll + for (int stride = 1; stride < 16; stride <<= 1) { +#pragma unroll + for (int g = 0; g < 16; g += stride << 1) { +#pragma unroll + for (int j = 0; j < stride; ++j) { + const float a = r[g + j]; + const float b = r[g + j + stride]; + r[g + j] = a + b; + r[g + j + stride] = a - b; + } + } + } +} + +__device__ __forceinline__ float amax_16_abs(const float r[16]) { + float m = 0.f; +#pragma unroll + for (int i = 0; i < 16; ++i) m = fmaxf(m, fabsf(r[i])); + return m; +} + +// 1/sqrt(16) Hadamard normalization, folded once per 1x16 block. +constexpr float k16HadamardNorm = 0.25f; + // Pre-zero amax buffers (identity for atomicMax). template -__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, int K) { +__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, + int K) { const int tensor_id = blockIdx.x; if (tensor_id >= args.num_tensors) return; if (DO_ROW) { float* row_amax = reinterpret_cast(args.row_amax_list[tensor_id]); if (row_amax != nullptr) { - const int M_i = - args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; + const int M_i = args.split_sections_range[tensor_id + 1] - + args.split_sections_range[tensor_id]; for (int m = threadIdx.x; m < M_i; m += blockDim.x) { row_amax[m] = 0.0f; } @@ -113,11 +151,15 @@ __global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs ar } } -template +// kWithRht=true: col-wise amax over RHT-rotated 16-row strips. Row direction +// never sees RHT. +template __global__ void __launch_bounds__(THREADS_NUM) group_per_token_fused_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ NVFP4PerTokenMultiArgs args, - const float* noop, const size_t rows, const size_t cols) { + const float* noop, const size_t rows, + const size_t cols, + const uint32_t random_sign_mask_t) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -131,7 +173,8 @@ __global__ void __launch_bounds__(THREADS_NUM) DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); extern __shared__ unsigned char dynamic_shmem[]; - unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + unsigned char* dshmem = + dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); FusedIType* sIn_ptr = reinterpret_cast(dshmem); auto& sIn = *reinterpret_cast(sIn_ptr); @@ -146,8 +189,10 @@ __global__ void __launch_bounds__(THREADS_NUM) // Tile lies fully inside one tensor (split_sections[i] % 128 == 0). const int tensor_id = GetTensorId(args, block_offset_Y); const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; - float* row_amax_out = DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; - float* col_amax_out = DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + float* row_amax_out = + DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + float* col_amax_out = + DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; // Each thread owns chunk-row `tid` (for row amax) and chunk-col `tid` (for col amax). float row_partial = 0.f; @@ -178,8 +223,8 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in]), - reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, - &IN_buff_readable_mbar[buff_in]); + reinterpret_cast(&tensor_map_input), global_offset_X, + global_offset_Y, &IN_buff_readable_mbar[buff_in]); } } @@ -200,18 +245,20 @@ __global__ void __launch_bounds__(THREADS_NUM) const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], + shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), next_global_offset_X, - next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), + next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for this stage's tile. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // Row partial: rotate e-iter by bank group to split warp into 8 groups. @@ -228,21 +275,36 @@ __global__ void __launch_bounds__(THREADS_NUM) for (int p = 0; p < 4; ++p) { ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); } - local_max = - fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + local_max = fmaxf(local_max, static_cast( + __hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); } row_partial = local_max; } // Col partial: 1 thread per column scans down 64 rows of the sub-tile. if (DO_COL && stage_X == my_col_stage_X) { - float local_max = col_partial; + if constexpr (kWithRht) { + // 4 contiguous 16-row blocks per sub-tile, one FHT per block; 0.25 + // is folded post-amax (exact, since 0.25 = 2^-2). +#pragma unroll + for (int blk = 0; blk < TILE_DIM_Y / 16; ++blk) { + float r[16]; #pragma unroll - for (int e = 0; e < TILE_DIM_Y; ++e) { - const FusedIType v = sIn[buff_in][e][my_col_in_subtile]; - local_max = fmaxf(local_max, fabsf(static_cast(v))); + for (int i = 0; i < 16; ++i) { + r[i] = static_cast(sIn[buff_in][blk * 16 + i][my_col_in_subtile]); + } + apply_signed_fht16_inplace(r, random_sign_mask_t); + col_partial = fmaxf(col_partial, amax_16_abs(r) * k16HadamardNorm); + } + } else { + float local_max = col_partial; +#pragma unroll + for (int e = 0; e < TILE_DIM_Y; ++e) { + const FusedIType v = sIn[buff_in][e][my_col_in_subtile]; + local_max = fmaxf(local_max, fabsf(static_cast(v))); + } + col_partial = local_max; } - col_partial = local_max; } __syncthreads(); @@ -269,59 +331,59 @@ __global__ void __launch_bounds__(THREADS_NUM) (void)noop; (void)rows; (void)cols; + (void)random_sign_mask_t; NVTE_DEVICE_ERROR("Fused grouped per-token amax kernel requires SM 10.0+ (Blackwell)."); #endif // __CUDA_ARCH__ >= 1000 } // K2 (encode) constants + helpers; byte-equal port of the single-tensor // per-token cooperative 4x32 / 32x4 threading + ld_shared_b128 + mul_cvt_4x. -constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM -constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 -constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 -constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 // Rowwise pass: 4 (K-dim) x 32 (M-dim) -> 1 NVFP4 block per thread. -constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 -constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 -constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 -constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 - -// Colwise pass: tid.X = col-pair, warp = M-block (32 x 4). -constexpr int THREADS_X_TR = TILE_DIM_X / 2; // 32 -constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; // 4 +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 +constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 // Output / SF SMEM buffer dims (sub-tile sized, double-buffered for ping-pong). -constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; -constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) -constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; -constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 -constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; -constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 -constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 +constexpr int BUFFS_NUM_OUT_TR = 2; // Manual SMEM swizzling parameters (matches single-tensor encode kernel). -constexpr int PACK_SIZE = 8; -constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 -constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 -constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 - -using IType = FusedIType; -using IType2 = FusedIType2; -using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; -using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 + +using IType = FusedIType; +using IType2 = FusedIType2; +using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; -using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; // Rowwise encode helper: reads sRowAmax (pre-populated by K1), writes FP4 + // e4m3 SFs into sOut / sSFrowwise. Byte-equal to the single-tensor version. __device__ __forceinline__ void rowwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, - nvfp4_scale_t* __restrict__ sSFrowwise_ptr, const float* __restrict__ sRowAmax, - const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_ptr, + nvfp4_scale_t* __restrict__ sSFrowwise_ptr, + const float* __restrict__ sRowAmax, + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out) { const auto& sIn = *reinterpret_cast(sIn_ptr); auto& sOut = *reinterpret_cast(sOut_ptr); auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -329,8 +391,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; @@ -363,8 +425,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); } } - const float block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float block_amax = static_cast( + __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax, S_enc); const float s_dec_f = static_cast(s_dec); @@ -392,20 +454,27 @@ __device__ __forceinline__ void rowwise_scaling_per_token( } } -// Colwise encode helper. Byte-equal to the single-tensor version. +// Colwise encode helper. kWithRht=true rotates each thread's 16-row strip +// via the FHT before block_amax + cast; K1 amax must have used the same +// mask so the per-col outer amax matches. +template __device__ __forceinline__ void colwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, - nvfp4_scale_t* __restrict__ sSFcolwise_ptr, const float* __restrict__ sColAmax, - const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr) { + const IType* __restrict__ sIn_ptr, + fp4e2m1x2* __restrict__ sOut_tr_ptr, + nvfp4_scale_t* __restrict__ sSFcolwise_ptr, + const float* __restrict__ sColAmax, + const int stage_Y, const int stage_X, + const int buff_in, const int buff_out_tr, + const uint32_t random_sign_mask_t = 0u) { const auto& sIn2x = *reinterpret_cast(sIn_ptr); auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); - const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 - const int tid_X_colwise = thread_lane; // 0..31 + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 + const int tid_X_colwise = thread_lane; // 0..31 const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; const int thread_offset_X_colwise = tid_X_colwise * 2; @@ -420,6 +489,9 @@ __device__ __forceinline__ void colwise_scaling_per_token( const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; __align__(8) IType rIn[2][SCALE_DIM]; + // RHT staging in fp32 (DCE'd in the non-RHT instantiation). + float rRht[2][SCALE_DIM]; + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int i = 0; i < SCALE_DIM; ++i) { @@ -427,10 +499,33 @@ __device__ __forceinline__ void colwise_scaling_per_token( ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); rIn[0][i] = elt_pair.x; rIn[1][i] = elt_pair.y; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + if constexpr (!kWithRht) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + } + + float block_amax[2]; + if constexpr (kWithRht) { +#pragma unroll + for (int w = 0; w < 2; ++w) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + rRht[w][i] = static_cast(rIn[w][i]); + } + apply_signed_fht16_inplace(rRht[w], random_sign_mask_t); + float local_max = 0.f; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + local_max = fmaxf(local_max, fabsf(rRht[w][i])); + } + // amax(|r * 0.25|) == amax(|r|) * 0.25; 0.25 also folded into + // block_scale_rht below (bit-exact: 0.25 = 2^-2). + block_amax[w] = local_max * k16HadamardNorm; + } + } else { + block_amax[0] = static_cast(__habs(thread_amax_2x.x)); + block_amax[1] = static_cast(__habs(thread_amax_2x.y)); } - const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), - static_cast(__habs(thread_amax_2x.y))}; #pragma unroll for (int w = 0; w < 2; ++w) { @@ -445,11 +540,22 @@ __device__ __forceinline__ void colwise_scaling_per_token( sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = s_dec; fp4e2m1x4 qu[4]; + if constexpr (kWithRht) { + // ptx::floatx2 keeps mul_cvt_4x's input fp32 (no bf16 round-trip). + const float block_scale_rht = block_scale * k16HadamardNorm; #pragma unroll - for (int e = 0; e < 4; ++e) { - IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; - IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; - ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + for (int e = 0; e < 4; ++e) { + const ptx::floatx2 in01{rRht[w][4 * e + 0], rRht[w][4 * e + 1]}; + const ptx::floatx2 in23{rRht[w][4 * e + 2], rRht[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale_rht); + } + } else { +#pragma unroll + for (int e = 0; e < 4; ++e) { + IType2 in01{rIn[w][4 * e + 0], rIn[w][4 * e + 1]}; + IType2 in23{rIn[w][4 * e + 2], rIn[w][4 * e + 3]}; + ptx::mul_cvt_4x(qu[e], in01, in23, block_scale); + } } uint64_t out_pack_16x = (static_cast(*reinterpret_cast(&qu[0])) << 0) | @@ -463,11 +569,15 @@ __device__ __forceinline__ void colwise_scaling_per_token( // Fused K2: TMA-loads input, runs cooperative row+col encode helpers, scatters // FP4 + SFs to per-tensor outputs via st.global (multi-dest, no TMA store). -template +// kWithRht=true (and DO_COL=true): col-wise FHT with random_sign_mask_t, +// matching the K1 amax launch. Row direction never sees RHT. +template __global__ void __launch_bounds__(THREADS_NUM) group_per_token_fused_cast_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ NVFP4PerTokenMultiArgs args, - const float* noop, const size_t rows, const size_t cols) { + const float* noop, const size_t rows, + const size_t cols, + const uint32_t random_sign_mask_t) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -489,26 +599,25 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; + TMA_SHMEM_ALIGNMENT) : 0; constexpr int out_mem_colwise_scales = DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; + TMA_SHMEM_ALIGNMENT) : 0; (void)out_mem_colwise_scales; extern __shared__ unsigned char dynamic_shmem[]; - unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + unsigned char* dshmem = + dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); - IType* sIn_ptr = reinterpret_cast(dshmem); - fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); - fp4e2m1x2* sOut_tr_ptr = - reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data); nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t* sSFcolwise_ptr = - reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + - out_mem_colwise_data + out_mem_rowwise_scales); + nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( + dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data + + out_mem_rowwise_scales); __shared__ float sRowAmax[CHUNK_DIM_Y]; __shared__ float sColAmax[CHUNK_DIM_X]; @@ -526,21 +635,22 @@ __global__ void __launch_bounds__(THREADS_NUM) // Chunk Y stays inside one tensor (split_sections[i] % 128 == 0). const int tensor_id = GetTensorId(args, block_offset_Y); const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; - const int M_t = args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; + const int M_t = args.split_sections_range[tensor_id + 1] - + args.split_sections_range[tensor_id]; // Per-tensor output bases (one constant-cache lookup per CTA). - uint8_t* const q_row_base = - DO_ROW ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; - uint8_t* const q_col_base = - DO_COL ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; - nvfp4_scale_t* const s_dec_row_base = - DO_ROW ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; - nvfp4_scale_t* const s_dec_col_base = - DO_COL ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; - const float* const row_amax_base = - DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; - const float* const col_amax_base = - DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + uint8_t* const q_row_base = DO_ROW + ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; + uint8_t* const q_col_base = DO_COL + ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_row_base = DO_ROW + ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_col_base = DO_COL + ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; + const float* const row_amax_base = DO_ROW + ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + const float* const col_amax_base = DO_COL + ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; const size_t data_stride_row = static_cast(cols) / 2; const size_t data_stride_col = static_cast(M_t) / 2; @@ -576,13 +686,13 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in_p], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in_p]), - reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, - &IN_buff_readable_mbar[buff_in_p]); + reinterpret_cast(&tensor_map_input), global_offset_X, + global_offset_Y, &IN_buff_readable_mbar[buff_in_p]); } } - int buff_in = 0; - int buff_out = 0; + int buff_in = 0; + int buff_out = 0; int buff_out_tr = 0; int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; @@ -599,28 +709,31 @@ __global__ void __launch_bounds__(THREADS_NUM) const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], + shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), next_global_offset_X, - next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), + next_global_offset_X, next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for current stage's input tile to land. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta( + &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // 4x32 cooperative row + col encode helpers. if (DO_ROW) { - rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, - buff_in, buff_out); + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, + sRowAmax, stage_Y, stage_X, buff_in, buff_out); } if (DO_COL) { - colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, stage_X, - buff_in, buff_out_tr); + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, + sColAmax, stage_Y, stage_X, buff_in, buff_out_tr, + random_sign_mask_t); } // Make helper SMEM writes visible before the scatter epilogue. @@ -631,10 +744,13 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_ROW) { auto& sOut = *reinterpret_cast(sOut_ptr); const int row_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 - const int half = static_cast(threadIdx.x) & 1; // 0..1 - const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; - const int byte_off_X = (block_offset_X / 2) + stage_X * (TILE_DIM_X / 2) + half * 16; - const uint4* src = reinterpret_cast(&sOut[buff_out][row_in_subtile][half * 16]); + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; + const int byte_off_X = (block_offset_X / 2) + + stage_X * (TILE_DIM_X / 2) + + half * 16; + const uint4* src = reinterpret_cast( + &sOut[buff_out][row_in_subtile][half * 16]); uint4* dst = reinterpret_cast( q_row_base + static_cast(local_row) * data_stride_row + byte_off_X); *dst = *src; @@ -642,11 +758,13 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_COL) { auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); const int col_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 - const int half = static_cast(threadIdx.x) & 1; // 0..1 - const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; - const int byte_off_M = (local_row_base / 2) + stage_Y * (TILE_DIM_Y / 2) + half * 16; - const uint4* src = - reinterpret_cast(&sOut_tr[buff_out_tr][col_in_subtile][half * 16]); + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; + const int byte_off_M = (local_row_base / 2) + + stage_Y * (TILE_DIM_Y / 2) + + half * 16; + const uint4* src = reinterpret_cast( + &sOut_tr[buff_out_tr][col_in_subtile][half * 16]); uint4* dst = reinterpret_cast( q_col_base + static_cast(global_col) * data_stride_col + byte_off_M); *dst = *src; @@ -655,8 +773,8 @@ __global__ void __launch_bounds__(THREADS_NUM) // Sync so the scatter completes before next stage overwrites the buffer. __syncthreads(); - buff_in = (buff_in + 1) % BUFFS_NUM; - buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; } @@ -664,11 +782,13 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_ROW) { auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); using ScalesVec = Vec; - const size_t scales_block_offset_X_rowwise = static_cast(ctaid_X) * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_X_rowwise = + static_cast(ctaid_X) * SCALES_PER_CHUNK_X; for (int row = static_cast(threadIdx.x); row < CHUNK_DIM_Y; row += THREADS_NUM) { ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); const size_t local_row = static_cast(local_row_base) + row; - const size_t scale_idx_global = local_row * scale_stride_row + scales_block_offset_X_rowwise; + const size_t scale_idx_global = + local_row * scale_stride_row + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&s_dec_row_base[scale_idx_global], 0, SCALES_PER_CHUNK_X); } } @@ -677,10 +797,12 @@ __global__ void __launch_bounds__(THREADS_NUM) using ScalesVec = Vec; // M-block offset within s_dec_col[global_col] (shape (K, M_i/16) row-major). const size_t local_block_offset_M = static_cast(local_row_base) / SCALE_DIM; - for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { + for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; + row_tr += THREADS_NUM) { ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); const size_t global_col = static_cast(block_offset_X) + row_tr; - const size_t scale_idx_global = global_col * scale_stride_col + local_block_offset_M; + const size_t scale_idx_global = + global_col * scale_stride_col + local_block_offset_M; scales_vec.store_to_elts(&s_dec_col_base[scale_idx_global], 0, SCALES_PER_CHUNK_Y); } } @@ -697,15 +819,21 @@ __global__ void __launch_bounds__(THREADS_NUM) (void)noop; (void)rows; (void)cols; + (void)random_sign_mask_t; NVTE_DEVICE_ERROR("Fused grouped per-token cast kernel requires SM 10.0+ (Blackwell)."); #endif // __CUDA_ARCH__ >= 1000 } // Host launcher for the fused K2 path. bf16-only. +// with_rht=true applies a 16-pt RHT on the col direction; K1 amax must have +// used the same flag + mask, else inner SF + FP4 saturate against mismatched +// data. inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, - const SimpleTensor& input_data, int sum_M, int K, - bool do_row, bool do_col, const float* noop, - cudaStream_t stream) { + const SimpleTensor& input_data, int sum_M, + int K, bool do_row, bool do_col, + bool with_rht, + uint32_t random_sign_mask_t, + const float* noop, cudaStream_t stream) { if (!do_row && !do_col) return; checkCuDriverContext(stream); @@ -714,41 +842,52 @@ inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, create_2D_tensor_map(tmap_in, input_data, sum_M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(FusedIType) * 8); - dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(sum_M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - TRANSFORMER_ENGINE_SWITCH_CONDITION( - do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { - constexpr int sz_in = - DIVUP_TO_MULTIPLE(BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); - constexpr int sz_out_r = - DO_ROW ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; - constexpr int sz_out_c = - DO_COL ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int sz_sf_r = - DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int sz_sf_c = - DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int dshmem_size = - sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c + TMA_SHMEM_ALIGNMENT; - auto kernel = group_per_token_fused_cast_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tmap_in, args, noop, static_cast(sum_M), static_cast(K)); - });); + // Collapse to kWithRht=false when no colwise output is requested. + const bool with_rht_effective = with_rht && do_col; + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, + TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + constexpr int sz_in = DIVUP_TO_MULTIPLE( + BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); + constexpr int sz_out_r = DO_ROW + ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; + constexpr int sz_out_c = DO_COL + ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_r = DO_ROW + ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_c = DO_COL + ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int dshmem_size = sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c + + TMA_SHMEM_ALIGNMENT; + auto kernel = group_per_token_fused_cast_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>(tmap_in, args, noop, + static_cast(sum_M), + static_cast(K), + random_sign_mask_t); + }));); NVTE_CHECK_CUDA(cudaGetLastError()); } // Host launcher for the fused K1 path. bf16-only. +// with_rht=true applies a 16-pt RHT on the col amax (rowwise raw). The +// downstream K2 cast MUST use the same flag + mask. inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, - const SimpleTensor& input_data, int sum_M, int K, - bool do_row, bool do_col, const float* noop, - cudaStream_t stream) { + const SimpleTensor& input_data, int sum_M, + int K, bool do_row, bool do_col, + bool with_rht, + uint32_t random_sign_mask_t, + const float* noop, cudaStream_t stream) { if (!do_row && !do_col) return; // Pre-zero amax slots (atomicMax identity). @@ -779,16 +918,23 @@ inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; - dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), + static_cast(sum_M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - TRANSFORMER_ENGINE_SWITCH_CONDITION( - do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, { - auto kernel = group_per_token_fused_amax_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tmap_in, args, noop, static_cast(sum_M), static_cast(K)); - });); + // Collapse to kWithRht=false when no colwise amax is requested. + const bool with_rht_effective = with_rht && do_col; + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, + TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + auto kernel = group_per_token_fused_amax_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>(tmap_in, args, noop, + static_cast(sum_M), + static_cast(K), + random_sign_mask_t); + }));); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -811,18 +957,19 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, args->split_sections_range[0] = 0; for (size_t i = 0; i < num_tensors; ++i) { Tensor* o = outputs[i]; - NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, "] = ", split_sections[i], - " must be a multiple of 128"); + NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, + "] = ", split_sections[i], " must be a multiple of 128"); args->split_sections_range[i + 1] = args->split_sections_range[i] + static_cast(split_sections[i]); if (split_sections[i] == 0) continue; if (which_buffers & kBufRowAmax) { - NVTE_CHECK(o->amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, - "].amax must be allocated for rowwise"); + NVTE_CHECK(o->amax.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, "].amax must be allocated for rowwise"); args->row_amax_list[i] = o->amax.dptr; } if (which_buffers & kBufColAmax) { - NVTE_CHECK(o->columnwise_amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, + NVTE_CHECK(o->columnwise_amax.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, "].columnwise_amax must be allocated for columnwise"); args->col_amax_list[i] = o->columnwise_amax.dptr; } @@ -834,9 +981,10 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, args->s_dec_row_list[i] = o->scale_inv.dptr; } if (which_buffers & kBufColCast) { - NVTE_CHECK(o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, - "NVFP4 per-token grouped: outputs[", i, - "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); + NVTE_CHECK( + o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); args->q_col_list[i] = o->columnwise_data.dptr; args->s_dec_col_list[i] = o->columnwise_scale_inv.dptr; } @@ -848,9 +996,13 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, } // Host entry. do_amax / do_cast select K1 / K2 phases (composite = both). +// with_rht / mask are threaded into BOTH K1 and K2; the caller must use the +// same flag/mask if they invoke amax + cast separately. void quantize_per_token_grouped(const Tensor& input, std::vector& outputs, const size_t* split_sections, size_t num_tensors, bool rowwise, - bool columnwise, bool do_amax, bool do_cast, cudaStream_t stream) { + bool columnwise, bool do_amax, bool do_cast, + bool with_rht, uint32_t random_sign_mask_t, + cudaStream_t stream) { NVTE_CHECK(num_tensors > 0, "NVFP4 per-token grouped: num_tensors must be > 0"); NVTE_CHECK(num_tensors <= static_cast(kMaxTensorsPerKernel), "NVFP4 per-token grouped: num_tensors (", num_tensors, @@ -865,7 +1017,8 @@ void quantize_per_token_grouped(const Tensor& input, std::vector& outpu const int sum_M = static_cast(input.flat_first_dim()); const int K = static_cast(input.flat_last_dim()); if (sum_M == 0 || K == 0) return; - NVTE_CHECK(K % 128 == 0, "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); + NVTE_CHECK(K % 128 == 0, + "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); int which_buffers = 0; if ((do_amax || do_cast) && rowwise) which_buffers |= kBufRowAmax; @@ -881,12 +1034,16 @@ void quantize_per_token_grouped(const Tensor& input, std::vector& outpu fused::launch_grouped_fused_amax_bf16(args, input.data, sum_M, K, /*do_row=*/rowwise, /*do_col=*/columnwise, + /*with_rht=*/with_rht, + /*random_sign_mask_t=*/random_sign_mask_t, /*noop=*/nullptr, stream); } if (do_cast) { fused::launch_grouped_fused_cast_bf16(args, input.data, sum_M, K, /*do_row=*/rowwise, /*do_col=*/columnwise, + /*with_rht=*/with_rht, + /*random_sign_mask_t=*/random_sign_mask_t, /*noop=*/nullptr, stream); } } @@ -911,17 +1068,25 @@ std::vector collect_outputs(NVTETensor* outputs, si } // namespace void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, bool rowwise, - bool columnwise, cudaStream_t stream) { + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, + int with_rht, int random_sign_mask_t, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_amax); using namespace transformer_engine; if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise, - columnwise, - /*do_amax=*/true, /*do_cast=*/false, stream); + // C-API mirrors nvte_nvfp4_per_token_amax: `int` for cross-language ABI + // safety; internal kernel arg is uint32_t with only the low 16 bits used. + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, + rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/false, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, + stream); #else (void)input; (void)outputs; @@ -929,23 +1094,31 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs (void)num_tensors; (void)rowwise; (void)columnwise; + (void)with_rht; + (void)random_sign_mask_t; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, bool rowwise, - bool columnwise, cudaStream_t stream) { + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, + int with_rht, int random_sign_mask_t, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_cast); using namespace transformer_engine; if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise, - columnwise, - /*do_amax=*/false, /*do_cast=*/true, stream); + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, + rowwise, columnwise, + /*do_amax=*/false, /*do_cast=*/true, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, + stream); #else (void)input; (void)outputs; @@ -953,6 +1126,8 @@ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs (void)num_tensors; (void)rowwise; (void)columnwise; + (void)with_rht; + (void)random_sign_mask_t; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif @@ -960,16 +1135,22 @@ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, cudaStream_t stream) { + bool rowwise, bool columnwise, + int with_rht, int random_sign_mask_t, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_quantize); using namespace transformer_engine; if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise, - columnwise, - /*do_amax=*/true, /*do_cast=*/true, stream); + nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, + rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/true, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, + stream); #else (void)input; (void)outputs; @@ -977,6 +1158,8 @@ void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* out (void)num_tensors; (void)rowwise; (void)columnwise; + (void)with_rht; + (void)random_sign_mask_t; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif diff --git a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h index e21be53047..9533dddc4d 100644 --- a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h +++ b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h @@ -15,30 +15,58 @@ extern "C" { #endif + /*! \brief Composite K1+K2: per-row + per-col amax (K1) then FP4 + 1x16 * e4m3 SF encode (K2), back-to-back on the same stream. * - * This is the production entry point for the per-token cast on bf16 + - * 128-aligned shapes. + * Production entry point for the per-token cast on bf16 + 128-aligned shapes. + * + * \param[in] with_rht non-zero -> apply 16-pt RHT on the col direction in + * both K1 and K2. Rowwise stays raw; zero is byte-equal + * to the pre-RHT path. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared by + * K1 and K2. Ignored when with_rht == 0. */ -void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, - cudaStream_t stream); +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, + NVTETensor output, + int with_rht, + int random_sign_mask_t, + cudaStream_t stream); /*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. * Pre-zeroes the amax buffers and merges per-CTA partials into * ``output->amax`` (size [M]) / ``output->columnwise_amax`` * (size [K]). Does NOT touch FP4 data / scale_inv slots. + * + * \param[in] with_rht non-zero -> apply 16-pt RHT on the col direction + * before columnwise_amax (rowwise stays raw); zero is + * byte-equal to the pre-RHT K1. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored + * when with_rht == 0. Type matches prod's + * nvte_hadamard_transform_amax convention. */ -void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, + NVTETensor output, + int with_rht, + int random_sign_mask_t, cudaStream_t stream); /*! \brief Kernel 2 in isolation: FP4 + 1x16 e4m3 SF encode given a * pre-filled ``output->amax`` / ``output->columnwise_amax``. Reads * the outer amax buffer(s) and writes the FP4 data / scale_inv * tensors only. + * + * \param[in] with_rht non-zero -> col-wise cast applies the same 16-pt RHT + * that K1 amax must have used (caller's responsibility + * to thread the same flag + mask through K1 and K2). + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored + * when with_rht == 0. */ -void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, - cudaStream_t stream); +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, + NVTETensor output, + int with_rht, + int random_sign_mask_t, + cudaStream_t stream); /*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``. * @@ -58,7 +86,8 @@ int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum); * d[i, j] = d[i, j] * row_amax_a[i] * row_amax_b[j] */ void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, - const NVTETensor row_amax_b, cudaStream_t stream); + const NVTETensor row_amax_b, + cudaStream_t stream); /* ============================================================================ * Grouped (multi-tensor) per-token quantize. @@ -71,11 +100,18 @@ void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, * \param[in] num_tensors <= 64 * \param[in] rowwise emit per-row amax in `outputs[i].amax` * \param[in] columnwise emit per-col amax in `outputs[i].columnwise_amax` + * \param[in] with_rht non-zero -> 16-pt RHT on the col direction + * (rowwise stays raw). + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; must + * match the value passed to the matching cast + * if amax + cast are launched separately. * \param[in] stream CUDA stream */ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, bool rowwise, - bool columnwise, cudaStream_t stream); + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, + int with_rht, int random_sign_mask_t, + cudaStream_t stream); /*! \brief Grouped per-token encode (FP4 + 1x16 e4m3 inner SF) using the * row_amax / col_amax values already populated by @@ -89,11 +125,18 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs * \param[in] num_tensors <= 64 * \param[in] rowwise emit per-row FP4 + inner SF * \param[in] columnwise emit per-col FP4 + inner SF + * \param[in] with_rht must match the preceding amax call's + * with_rht; applies the same 16-pt RHT on the + * colwise cast. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; must + * match K1. * \param[in] stream CUDA stream */ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, bool rowwise, - bool columnwise, cudaStream_t stream); + const size_t* split_sections, size_t num_tensors, + bool rowwise, bool columnwise, + int with_rht, int random_sign_mask_t, + cudaStream_t stream); /*! \brief Composite K1+K2 grouped per-token quantize. Calls the amax + cast * kernels on the same stream. This is the external API @@ -109,11 +152,18 @@ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs * \param[in] num_tensors <= 64 * \param[in] rowwise emit rowwise output * \param[in] columnwise emit columnwise output + * \param[in] with_rht non-zero -> 16-pt RHT on the col direction + * in BOTH K1 and K2; zero is byte-equal to the + * pre-RHT path. + * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared + * between K1 and K2; ignored when with_rht==0. * \param[in] stream CUDA stream */ void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, cudaStream_t stream); + bool rowwise, bool columnwise, + int with_rht, int random_sign_mask_t, + cudaStream_t stream); #ifdef __cplusplus } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 535fac017e..a1cee4461e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -412,7 +412,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void compute_amax(const at::Tensor &tensor, at::Tensor &amax); void hadamard_transform_amax(const at::Tensor &tensor, at::Tensor &rowwise_amax, - at::Tensor &columnwise_amax, int64_t rht_matrix_random_sign_mask); + at::Tensor &columnwise_amax, + int64_t rht_matrix_random_sign_mask); void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, @@ -451,46 +452,58 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi const at::Tensor &scale_inv_colwise, int rows, int cols, size_t start_offset); -void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, - at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise); +void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, + at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t); -void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, at::Tensor col_amax, - bool rowwise, bool columnwise); +void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, + at::Tensor col_amax, bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t); -void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, - at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise); +void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, + at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t); void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, const at::Tensor &row_amax_b); void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, at::Tensor d, - const at::Tensor &workspace, int64_t m, int64_t n, int64_t k, - double alpha, double beta); + const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta); // Bench-only per-tensor twin of nvfp4_per_token_gemm: scalar amaxes folded // into cuBLAS LT alpha via the amax slot; no trailing post-scale. void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, - const at::Tensor &a_sf, const at::Tensor &b_sf, const at::Tensor &a_amax, - const at::Tensor &b_amax, at::Tensor d, const at::Tensor &workspace, - int64_t m, int64_t n, int64_t k, double alpha, double beta); + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_amax, const at::Tensor &b_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta); +// with_rht=true applies a 16-pt RHT on the col direction in BOTH K1 and K2; +// random_sign_mask_t low 16 bits = sign pattern (ignored when with_rht=false). void nvfp4_per_token_group_quantize( const at::Tensor &input, const std::vector &split_sections, std::vector q_row_list, std::vector s_dec_row_list, std::vector row_amax_list, std::vector q_col_list, - std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, - bool columnwise); + std::vector s_dec_col_list, std::vector col_amax_list, + bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t); // Amax-only variant of the grouped quantize. Useful for multi-rank training -// where amax is allReduced before the cast pass. -void nvfp4_per_token_group_amax(const at::Tensor &input, const std::vector &split_sections, +// where amax is allReduced before the cast pass. Caller must thread the +// matching with_rht / mask into the subsequent cast launch. +void nvfp4_per_token_group_amax(const at::Tensor &input, + const std::vector &split_sections, std::vector row_amax_list, std::vector col_amax_list, bool rowwise, - bool columnwise); + bool columnwise, + bool with_rht, int64_t random_sign_mask_t); // Bulk grouped quantize: allocate-view-dispatch all in one pybind hop. // Returns 6 per-split vectors (q_row, s_dec_row_fp8, row_amax, q_col, @@ -498,8 +511,9 @@ void nvfp4_per_token_group_amax(const at::Tensor &input, const std::vector, std::vector, std::vector, std::vector, std::vector, std::vector> nvfp4_per_token_group_quantize_bulk(const at::Tensor &input, - const std::vector &split_sections, bool rowwise, - bool columnwise); + const std::vector &split_sections, + bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp index d9831ed5e7..8498f2f249 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp @@ -25,10 +25,11 @@ namespace { // Validates the input and assembles ``out_te`` for all 3 modes; caller // dispatches to the right C-API entry on the caller's stream. -void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, - at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise, int mode, - TensorWrapper& in_te, TensorWrapper& out_te) { +void assemble_per_token_tensors(const at::Tensor& input, + at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, + bool rowwise, bool columnwise, int mode, + TensorWrapper& in_te, TensorWrapper& out_te) { TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); @@ -52,8 +53,8 @@ void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::T TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), "row_amax must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax.scalar_type() == at::ScalarType::Float, "row_amax must be float32"); - TORCH_CHECK(row_amax.numel() == M, "row_amax numel mismatch: expected M=", M, ", got ", - row_amax.numel()); + TORCH_CHECK(row_amax.numel() == M, + "row_amax numel mismatch: expected M=", M, ", got ", row_amax.numel()); out_te.set_amax(row_amax.data_ptr(), DType::kFloat32, std::vector{static_cast(M)}); @@ -62,14 +63,15 @@ void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::T "q_row must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), "s_dec_row must be a contiguous CUDA tensor"); - TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, "q_row must be uint8 (FP4 packed)"); + TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, + "q_row must be uint8 (FP4 packed)"); TORCH_CHECK(s_dec_row.scalar_type() == at::ScalarType::Byte, "s_dec_row must be uint8 (FP8 e4m3 raw bytes)"); - TORCH_CHECK(q_row.numel() == M * K / 2, "q_row numel mismatch: expected M*K/2=", M * K / 2, - ", got ", q_row.numel()); + TORCH_CHECK(q_row.numel() == M * K / 2, + "q_row numel mismatch: expected M*K/2=", M * K / 2, ", got ", q_row.numel()); TORCH_CHECK(s_dec_row.numel() == M * K / 16, - "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, ", got ", - s_dec_row.numel()); + "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, + ", got ", s_dec_row.numel()); out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, in_shape); out_te.set_rowwise_scale_inv( s_dec_row.data_ptr(), DType::kFloat8E4M3, @@ -80,8 +82,8 @@ void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::T TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), "col_amax must be a contiguous CUDA tensor"); TORCH_CHECK(col_amax.scalar_type() == at::ScalarType::Float, "col_amax must be float32"); - TORCH_CHECK(col_amax.numel() == K, "col_amax numel mismatch: expected K=", K, ", got ", - col_amax.numel()); + TORCH_CHECK(col_amax.numel() == K, + "col_amax numel mismatch: expected K=", K, ", got ", col_amax.numel()); out_te.set_columnwise_amax(col_amax.data_ptr(), DType::kFloat32, std::vector{static_cast(K)}); @@ -90,14 +92,15 @@ void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::T "q_col must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), "s_dec_col must be a contiguous CUDA tensor"); - TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, "q_col must be uint8 (FP4 packed)"); + TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, + "q_col must be uint8 (FP4 packed)"); TORCH_CHECK(s_dec_col.scalar_type() == at::ScalarType::Byte, "s_dec_col must be uint8 (FP8 e4m3 raw bytes)"); - TORCH_CHECK(q_col.numel() == K * M / 2, "q_col numel mismatch: expected K*M/2=", K * M / 2, - ", got ", q_col.numel()); + TORCH_CHECK(q_col.numel() == K * M / 2, + "q_col numel mismatch: expected K*M/2=", K * M / 2, ", got ", q_col.numel()); TORCH_CHECK(s_dec_col.numel() == K * M / 16, - "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, ", got ", - s_dec_col.numel()); + "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, + ", got ", s_dec_col.numel()); out_te.set_columnwise_data( q_col.data_ptr(), DType::kFloat4E2M1, std::vector{static_cast(K), static_cast(M)}); @@ -110,45 +113,71 @@ void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::T } // namespace -// Production composite (K1 + K2 back-to-back). -void nvfp4_per_token_quantize(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, - at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise) { +// Production composite (K1 + K2 back-to-back). with_rht=true enables the +// 16-pt col-wise RHT in BOTH K1 and K2 so outer + inner SFs stay consistent. +void nvfp4_per_token_quantize(const at::Tensor& input, + at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, + bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, - columnwise, /*mode=*/0, in_te, out_te); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, + q_col, s_dec_col, col_amax, + rowwise, columnwise, /*mode=*/0, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); - nvte_nvfp4_per_token_quantize(in_te.data(), nullptr, out_te.data(), stream); + nvte_nvfp4_per_token_quantize( + in_te.data(), nullptr, out_te.data(), + with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), + stream); } -// K1-only (diagnostic / bench): populates only amax buffers. -void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, at::Tensor col_amax, - bool rowwise, bool columnwise) { +// K1-only (diagnostic / bench): populates only amax buffers. with_rht=true +// applies the 16-pt col-wise RHT before amax (rowwise unaffected); +// random_sign_mask_t low 16 bits = sign-flip pattern. +void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, + at::Tensor col_amax, bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t) { at::Tensor empty_u8; // not consumed by K1 TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, empty_u8, empty_u8, col_amax, - rowwise, columnwise, /*mode=*/1, in_te, out_te); + assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, + empty_u8, empty_u8, col_amax, + rowwise, columnwise, /*mode=*/1, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); - nvte_nvfp4_per_token_amax(in_te.data(), nullptr, out_te.data(), stream); + // C-API matches prod's `int` convention; only low 16 bits are consumed. + nvte_nvfp4_per_token_amax( + in_te.data(), nullptr, out_te.data(), + with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), + stream); } // K2-only (diagnostic / bench): reads pre-filled amax buffers, emits FP4 + SFs. -void nvfp4_per_token_encode(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, - at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise) { +// with_rht=true requires col_amax to have been produced by an earlier K1 +// amax call with the SAME mask, else inner SFs are miscalibrated. +void nvfp4_per_token_encode(const at::Tensor& input, + at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, + at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, + bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, - columnwise, /*mode=*/2, in_te, out_te); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, + q_col, s_dec_col, col_amax, + rowwise, columnwise, /*mode=*/2, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); - nvte_nvfp4_per_token_encode(in_te.data(), nullptr, out_te.data(), stream); + nvte_nvfp4_per_token_encode( + in_te.data(), nullptr, out_te.data(), + with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), + stream); } // Apply per-token post-scale to a GEMM output (see nvfp4_per_token.h for math). -void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor& row_amax_a, - const at::Tensor& row_amax_b) { +void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, + const at::Tensor &row_amax_b) { TORCH_CHECK(d.is_cuda() && d.is_contiguous(), "d must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax_a.is_cuda() && row_amax_a.is_contiguous(), "row_amax_a must be a contiguous CUDA tensor"); @@ -161,16 +190,16 @@ void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor& row_amax_a, const int64_t M = d.size(0); const int64_t N = d.size(1); - TORCH_CHECK(row_amax_a.numel() == M, "row_amax_a numel mismatch: expected M=", M, ", got ", - row_amax_a.numel()); - TORCH_CHECK(row_amax_b.numel() == N, "row_amax_b numel mismatch: expected N=", N, ", got ", - row_amax_b.numel()); + TORCH_CHECK(row_amax_a.numel() == M, + "row_amax_a numel mismatch: expected M=", M, ", got ", row_amax_a.numel()); + TORCH_CHECK(row_amax_b.numel() == N, + "row_amax_b numel mismatch: expected N=", N, ", got ", row_amax_b.numel()); const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16); + d.data_ptr(), + std::vector{static_cast(M), static_cast(N)}, DType::kBFloat16); TensorWrapper ra_te = makeTransformerEngineTensor( row_amax_a.data_ptr(), std::vector{static_cast(M)}, DType::kFloat32); TensorWrapper rb_te = makeTransformerEngineTensor( @@ -182,11 +211,11 @@ void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor& row_amax_a, // End-to-end NVFP4 per-token GEMM: swizzle compact SFs -> cuBLAS LT NVFP4 // GEMM (operand amax pinned to 1.0 to cancel the 2688^2 inner-SF factor) -> // per-row post-scale. beta must be 0.0. Math in nvfp4_per_token.h. -void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, - const at::Tensor& a_sf, const at::Tensor& b_sf, - const at::Tensor& a_row_amax, const at::Tensor& b_row_amax, at::Tensor d, - const at::Tensor& workspace, int64_t m, int64_t n, int64_t k, - double alpha, double beta) { +void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta) { TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && a_row_amax.is_cuda() && b_row_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), @@ -205,13 +234,14 @@ void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); - TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, + "a_data/b_data/d must be 2D"); TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, - "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), - ", ", a_data.size(1), ")"); + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", + a_data.size(0), ", ", a_data.size(1), ")"); TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, - "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), - ", ", b_data.size(1), ")"); + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", + b_data.size(0), ", ", b_data.size(1), ")"); TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, "), got (", d.size(0), ", ", d.size(1), ")"); @@ -220,10 +250,10 @@ void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, "a_sf numel mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), "b_sf numel mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); - TORCH_CHECK(a_row_amax.numel() == m, "a_row_amax numel mismatch: expected M=", m, ", got ", - a_row_amax.numel()); - TORCH_CHECK(b_row_amax.numel() == n, "b_row_amax numel mismatch: expected N=", n, ", got ", - b_row_amax.numel()); + TORCH_CHECK(a_row_amax.numel() == m, + "a_row_amax numel mismatch: expected M=", m, ", got ", a_row_amax.numel()); + TORCH_CHECK(b_row_amax.numel() == n, + "b_row_amax numel mismatch: expected N=", n, ", got ", b_row_amax.numel()); TORCH_CHECK(static_cast(beta) == 0.0f, "nvfp4_per_token_gemm: beta != 0 not yet supported. Got beta=", beta); @@ -293,7 +323,8 @@ void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, b_te.set_with_gemm_swizzled_scales(true); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, + d.data_ptr(), + std::vector{static_cast(m), static_cast(n)}, DType::kBFloat16); TensorWrapper workspace_te = makeTransformerEngineTensor( @@ -322,10 +353,11 @@ void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, // Per-tensor twin of nvfp4_per_token_gemm: scalar amax goes through cuBLAS's // own amax slot (no post-scale). Bench-only apples-to-apples baseline. -void nvfp4_per_tensor_gemm(const at::Tensor& a_data, const at::Tensor& b_data, - const at::Tensor& a_sf, const at::Tensor& b_sf, const at::Tensor& a_amax, - const at::Tensor& b_amax, at::Tensor d, const at::Tensor& workspace, - int64_t m, int64_t n, int64_t k, double alpha, double beta) { +void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, + const at::Tensor &a_sf, const at::Tensor &b_sf, + const at::Tensor &a_amax, const at::Tensor &b_amax, + at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, + int64_t k, double alpha, double beta) { TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && a_amax.is_cuda() && b_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), "All tensors must be CUDA tensors"); @@ -342,13 +374,14 @@ void nvfp4_per_tensor_gemm(const at::Tensor& a_data, const at::Tensor& b_data, TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); - TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, + "a_data/b_data/d must be 2D"); TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, - "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), - ", ", a_data.size(1), ")"); + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", + a_data.size(0), ", ", a_data.size(1), ")"); TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, - "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), - ", ", b_data.size(1), ")"); + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", + b_data.size(0), ", ", b_data.size(1), ")"); TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, "), got (", d.size(0), ", ", d.size(1), ")"); @@ -413,7 +446,8 @@ void nvfp4_per_tensor_gemm(const at::Tensor& a_data, const at::Tensor& b_data, b_te.set_with_gemm_swizzled_scales(true); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, + d.data_ptr(), + std::vector{static_cast(m), static_cast(n)}, DType::kBFloat16); TensorWrapper workspace_te = makeTransformerEngineTensor( @@ -436,13 +470,13 @@ void nvfp4_per_tensor_gemm(const at::Tensor& a_data, const at::Tensor& b_data, // Disabled direction's lists are ignored. namespace { -void build_per_token_output_wrapper(TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, - bool columnwise, const at::Tensor& q_row, - const at::Tensor& s_dec_row, const at::Tensor& row_amax, - const at::Tensor& q_col, const at::Tensor& s_dec_col, - const at::Tensor& col_amax) { +void build_per_token_output_wrapper( + TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, bool columnwise, + const at::Tensor& q_row, const at::Tensor& s_dec_row, const at::Tensor& row_amax, + const at::Tensor& q_col, const at::Tensor& s_dec_col, const at::Tensor& col_amax) { if (rowwise) { - TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), "q_row must be a contiguous CUDA tensor"); + TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), + "q_row must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), "s_dec_row must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), @@ -454,8 +488,9 @@ void build_per_token_output_wrapper(TensorWrapper& out_te, int64_t M_i, int64_t M_i * K / 2, ", got ", q_row.numel()); TORCH_CHECK(s_dec_row.numel() == M_i * K / 16, "s_dec_row numel mismatch for split"); TORCH_CHECK(row_amax.numel() == M_i, "row_amax numel mismatch for split"); - out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, - std::vector{static_cast(M_i), static_cast(K)}); + out_te.set_rowwise_data( + q_row.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(M_i), static_cast(K)}); out_te.set_rowwise_scale_inv( s_dec_row.data_ptr(), DType::kFloat8E4M3, std::vector{static_cast(M_i), static_cast(K / 16)}); @@ -463,7 +498,8 @@ void build_per_token_output_wrapper(TensorWrapper& out_te, int64_t M_i, int64_t std::vector{static_cast(M_i)}); } if (columnwise) { - TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), "q_col must be a contiguous CUDA tensor"); + TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), + "q_col must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), "s_dec_col must be a contiguous CUDA tensor"); TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), @@ -499,10 +535,13 @@ void nvfp4_per_token_group_quantize( const at::Tensor& input, const std::vector& split_sections, std::vector q_row_list, std::vector s_dec_row_list, std::vector row_amax_list, std::vector q_col_list, - std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, - bool columnwise) { - TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); - TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); + std::vector s_dec_col_list, std::vector col_amax_list, + bool rowwise, bool columnwise, + bool with_rht, int64_t random_sign_mask_t) { + TORCH_CHECK(rowwise || columnwise, + "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), + "input must be a contiguous CUDA tensor"); TORCH_CHECK(input.dim() == 2, "input must be 2D"); const int64_t sum_M = input.size(0); const int64_t K = input.size(1); @@ -513,8 +552,8 @@ void nvfp4_per_token_group_quantize( int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { TORCH_CHECK(split_sections[i] >= 0, "split_sections[", i, "] must be non-negative"); - TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] = ", split_sections[i], - " must be a multiple of 64"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, + "] = ", split_sections[i], " must be a multiple of 64"); acc += split_sections[i]; } TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); @@ -534,8 +573,8 @@ void nvfp4_per_token_group_quantize( const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, - in_dtype); + input.data_ptr(), + std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); // One TensorWrapper per split; raw NVTETensor handles go into `handles`. std::vector wrappers; @@ -554,24 +593,34 @@ void nvfp4_per_token_group_quantize( continue; // empty split is allowed (skipped inside the kernel) } build_per_token_output_wrapper( - wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, - rowwise ? s_dec_row_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, - columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_list[i] : empty_dummy, + wrappers.back(), M_i, K, rowwise, columnwise, + rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_list[i] : empty_dummy, + rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, + columnwise ? s_dec_col_list[i] : empty_dummy, columnwise ? col_amax_list[i] : empty_dummy); handles.push_back(wrappers.back().data()); } - nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), - num_tensors, rowwise, columnwise, stream); + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), + split_sections_sz.data(), num_tensors, rowwise, + columnwise, + static_cast(with_rht), + static_cast(random_sign_mask_t), + stream); } // Amax-only grouped variant (K1 only); for allReduce-before-cast flows. -void nvfp4_per_token_group_amax(const at::Tensor& input, const std::vector& split_sections, +void nvfp4_per_token_group_amax(const at::Tensor& input, + const std::vector& split_sections, std::vector row_amax_list, std::vector col_amax_list, bool rowwise, - bool columnwise) { + bool columnwise, + bool with_rht, int64_t random_sign_mask_t) { TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); - TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), + "input must be a contiguous CUDA tensor"); TORCH_CHECK(input.dim() == 2, "input must be 2D"); const int64_t sum_M = input.size(0); const int64_t K = input.size(1); @@ -579,19 +628,21 @@ void nvfp4_per_token_group_amax(const at::Tensor& input, const std::vector 0, "split_sections must not be empty"); int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { - TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] must be a multiple of 64"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, + "] must be a multiple of 64"); acc += split_sections[i]; } TORCH_CHECK(acc == sum_M, "sum(split_sections) must equal input.size(0)"); if (rowwise) TORCH_CHECK(row_amax_list.size() == num_tensors, "row_amax_list size mismatch"); - if (columnwise) TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); + if (columnwise) + TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); const DType in_dtype = resolve_input_dtype(input); const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, - in_dtype); + input.data_ptr(), + std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); std::vector wrappers; wrappers.reserve(num_tensors); @@ -625,7 +676,10 @@ void nvfp4_per_token_group_amax(const at::Tensor& input, const std::vector(with_rht), + static_cast(random_sign_mask_t), + stream); } // BULK grouped per-token quantize: alloc + view + dispatch in ONE C++ call. @@ -635,13 +689,15 @@ std::tuple, std::vector, std::vector, std::vector, std::vector> nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, const std::vector& split_sections, bool rowwise, - bool columnwise) { + bool columnwise, + bool with_rht, int64_t random_sign_mask_t) { // Validation mirrors _validate_per_token_group_input in Python. - TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(rowwise || columnwise, + "At least one of rowwise/columnwise must be True."); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_contiguous(), "x_concat must be contiguous (row-major)"); - TORCH_CHECK(input.dim() == 2, "nvfp4_per_token_group_quantize expects a 2D input, got ", - input.dim(), "D"); + TORCH_CHECK(input.dim() == 2, + "nvfp4_per_token_group_quantize expects a 2D input, got ", input.dim(), "D"); TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, "Per-token grouped kernel is bf16-only; got dtype ", input.scalar_type()); @@ -650,13 +706,13 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, constexpr int64_t kPerTokenTile = 128; constexpr int64_t kBlockK = 16; - TORCH_CHECK(K % kPerTokenTile == 0, "Per-token grouped kernel requires K % ", kPerTokenTile, - " == 0; got K=", K); + TORCH_CHECK(K % kPerTokenTile == 0, + "Per-token grouped kernel requires K % ", kPerTokenTile, " == 0; got K=", K); const size_t num_tensors = split_sections.size(); TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); - TORCH_CHECK(num_tensors <= 64, "num_tensors must be <= 64 (kernel arg-struct cap); got ", - num_tensors); + TORCH_CHECK(num_tensors <= 64, + "num_tensors must be <= 64 (kernel arg-struct cap); got ", num_tensors); int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { @@ -666,7 +722,8 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, " must be a multiple of ", kPerTokenTile); acc += M_i; } - TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); + TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, + " must equal input.size(0) = ", sum_M); // Bulk allocation: one at::empty per output type, covers all splits. auto opts_u8 = input.options().dtype(at::kByte); @@ -716,7 +773,8 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, if (columnwise) { auto q_col_flat = q_col_bulk.narrow(0, K * m_off / 2, K * M_i / 2); q_col_list.emplace_back(q_col_flat.view({K, M_i / 2})); - auto s_dec_col_flat = s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); + auto s_dec_col_flat = + s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); s_dec_col_u8_list.emplace_back(s_dec_col_flat.view({K, M_i / kBlockK})); col_amax_list.emplace_back(col_amax_bulk.select(0, static_cast(i))); s_dec_col_fp8_list.emplace_back(s_dec_col_u8_list.back().view(at::kFloat8_e4m3fn)); @@ -727,7 +785,8 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, // Dispatch K1+K2 grouped kernel via the same C-API the thin entry uses. const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + input.data_ptr(), + std::vector{static_cast(sum_M), static_cast(K)}, DType::kBFloat16); std::vector wrappers; @@ -742,15 +801,22 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, split_sections_sz[i] = static_cast(M_i); wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); build_per_token_output_wrapper( - wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, - rowwise ? s_dec_row_u8_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, - columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_u8_list[i] : empty_dummy, + wrappers.back(), M_i, K, rowwise, columnwise, + rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_u8_list[i] : empty_dummy, + rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, + columnwise ? s_dec_col_u8_list[i] : empty_dummy, columnwise ? col_amax_list[i] : empty_dummy); handles.push_back(wrappers.back().data()); } - nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), - num_tensors, rowwise, columnwise, stream); + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), + split_sections_sz.data(), num_tensors, rowwise, + columnwise, + static_cast(with_rht), + static_cast(random_sign_mask_t), + stream); return std::make_tuple(std::move(q_row_list), std::move(s_dec_row_fp8_list), std::move(row_amax_list), std::move(q_col_list), diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 46d00ba9c2..e97e5837a8 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -346,7 +346,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "K1 of the NVFP4Quantizer RHT+post_rht_amax path: rowwise (pre-RHT) + " "columnwise (RHT(input.T)) amax in one launch. Bench-only entry.", py::arg("input"), py::arg("rowwise_amax"), py::arg("columnwise_amax"), - py::arg("rht_matrix_random_sign_mask"), py::call_guard()); + py::arg("rht_matrix_random_sign_mask"), + py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", @@ -398,52 +399,75 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvfp4_per_token_quantize", &transformer_engine::pytorch::nvfp4_per_token_quantize, "NVFP4 per-token cast (composite K1 amax + K2 encode). Same FP4 + 1x16 " "e4m3 SF layout as per-tensor, but outer amax is per-row/per-col. " - "Requires bf16 input, M % 128 == 0, K % 128 == 0.", + "Requires bf16 input, M % 128 == 0, K % 128 == 0. " + "with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), - py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), - py::arg("columnwise")); - m.def("nvfp4_per_token_amax", &transformer_engine::pytorch::nvfp4_per_token_amax, - "K1-only: per-row/per-col outer amax via TMA + atomicMax. Bench/diagnostic.", - py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), py::arg("rowwise"), - py::arg("columnwise")); - m.def("nvfp4_per_token_encode", &transformer_engine::pytorch::nvfp4_per_token_encode, - "K2-only: FP4 + e4m3 SF encode given pre-filled amax buffers. Bench/diagnostic.", + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), + py::arg("rowwise"), py::arg("columnwise"), + py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); + m.def("nvfp4_per_token_amax", + &transformer_engine::pytorch::nvfp4_per_token_amax, + "K1-only: per-row/per-col outer amax via TMA + atomicMax. Bench/diagnostic. " + "with_rht=True applies a 16-pt col-wise RHT before amax.", + py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), + py::arg("rowwise"), py::arg("columnwise"), + py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); + m.def("nvfp4_per_token_encode", + &transformer_engine::pytorch::nvfp4_per_token_encode, + "K2-only: FP4 + e4m3 SF encode given pre-filled amax buffers. Bench/diagnostic. " + "with_rht=True requires col_amax produced by a K1 launch with the same mask.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), - py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), - py::arg("columnwise")); + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), + py::arg("rowwise"), py::arg("columnwise"), + py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("nvfp4_per_token_post_scale", &transformer_engine::pytorch::nvfp4_per_token_post_scale, - "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", py::arg("d"), - py::arg("row_amax_a"), py::arg("row_amax_b")); + "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", + py::arg("d"), py::arg("row_amax_a"), py::arg("row_amax_b")); m.def("nvfp4_per_token_gemm", &transformer_engine::pytorch::nvfp4_per_token_gemm, "End-to-end NVFP4 per-token GEMM: swizzle compact SFs, cuBLAS LT NVFP4 " "GEMM, then row*col post-scale to recover C = A @ B^T. beta must be 0.", py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), - py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), py::arg("workspace"), - py::arg("m"), py::arg("n"), py::arg("k"), py::arg("alpha"), py::arg("beta")); + py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), + py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), + py::arg("alpha"), py::arg("beta")); m.def("nvfp4_per_tensor_gemm", &transformer_engine::pytorch::nvfp4_per_tensor_gemm, "Skinny prod NVFP4 GEMM twin of nvfp4_per_token_gemm: per-tensor amaxes " "folded into cuBLAS alpha, no trailing post-scale. Bench-only.", - py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), py::arg("a_amax"), - py::arg("b_amax"), py::arg("d"), py::arg("workspace"), py::arg("m"), py::arg("n"), - py::arg("k"), py::arg("alpha"), py::arg("beta")); + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), + py::arg("a_amax"), py::arg("b_amax"), py::arg("d"), + py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), + py::arg("alpha"), py::arg("beta")); m.def("nvfp4_per_token_group_quantize", &transformer_engine::pytorch::nvfp4_per_token_group_quantize, "Grouped (multi-tensor) NVFP4 per-token cast: K1 + K2 across <= 64 splits " - "of a single (sum_M, K) input. Byte-equal to a for-loop of single-tensor.", + "of a single (sum_M, K) input. Byte-equal to a for-loop of single-tensor. " + "with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", py::arg("input"), py::arg("split_sections"), py::arg("q_row_list"), py::arg("s_dec_row_list"), py::arg("row_amax_list"), py::arg("q_col_list"), py::arg("s_dec_col_list"), py::arg("col_amax_list"), py::arg("rowwise"), - py::arg("columnwise")); - m.def("nvfp4_per_token_group_amax", &transformer_engine::pytorch::nvfp4_per_token_group_amax, - "K1-only variant of nvfp4_per_token_group_quantize: only fills amax slots.", + py::arg("columnwise"), + py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); + m.def("nvfp4_per_token_group_amax", + &transformer_engine::pytorch::nvfp4_per_token_group_amax, + "K1-only variant of nvfp4_per_token_group_quantize: only fills amax slots. " + "with_rht / random_sign_mask_t must match the trailing cast launch.", py::arg("input"), py::arg("split_sections"), py::arg("row_amax_list"), - py::arg("col_amax_list"), py::arg("rowwise"), py::arg("columnwise")); + py::arg("col_amax_list"), py::arg("rowwise"), py::arg("columnwise"), + py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("nvfp4_per_token_group_quantize_bulk", &transformer_engine::pytorch::nvfp4_per_token_group_quantize_bulk, "Bulk grouped quantize: allocates per-split buffers + view-slices inside " "the binding (one pybind hop instead of 1 + 6N), then dispatches the K1+K2 " - "kernel. Returns 6 per-split tensor lists; empty for disabled directions.", - py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), py::arg("columnwise")); + "kernel. with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", + py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), + py::arg("columnwise"), + py::arg("with_rht") = false, + py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py index 30e5b8e51e..8c86d03b98 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py @@ -146,9 +146,8 @@ def _quantize_2d(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, tor # kernel. All-zero blocks: s_dec saturates to 0, naive S_enc/0 # would NaN; short-circuit to 0 to mirror the kernel. zero_blk = decode_scale_back_fp32 == 0 - denom = torch.where( - zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32 - ) + denom = torch.where(zero_blk, torch.ones_like(decode_scale_back_fp32), + decode_scale_back_fp32) encode_scale = S_enc_per_blk / denom encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) encode_scale = torch.minimum(encode_scale, fp32_max) @@ -201,14 +200,19 @@ def _validate_per_token_input(x: torch.Tensor) -> Tuple[int, int]: ) M, K = x.shape if M % _PER_TOKEN_TILE != 0: - raise ValueError(f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}") + raise ValueError( + f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}" + ) if K % _PER_TOKEN_TILE != 0: - raise ValueError(f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") + raise ValueError( + f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" + ) return M, K def nvfp4_per_token_quantize( - x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = False + x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = False, + with_rht: bool = False, random_sign_mask_t: int = 0xACE1, ) -> RefNVFP4TensorPerToken: """Production NVFP4 per-token cast through ``tex.nvfp4_per_token_quantize``. @@ -226,6 +230,10 @@ def nvfp4_per_token_quantize( before forwarding to the GEMM; ``gemm_nvfp4_per_token`` handles this automatically. + ``with_rht=True`` applies a 16-pt col-wise RHT in BOTH K1 and K2 so + outer + inner SF stay self-consistent (rowwise never sees RHT). + ``random_sign_mask_t`` low 16 bits = sign pattern (default ``0xACE1``). + Raises ``ValueError`` on non-bf16 input or non-128-aligned shapes. """ # Import lazily so the module does not require the binary at import time. @@ -257,7 +265,8 @@ def nvfp4_per_token_quantize( q_col, s_dec_col, col_amax = empty, empty, empty_f32 tex.nvfp4_per_token_quantize( - x, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, columnwise + x, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, columnwise, + with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, ) out = RefNVFP4TensorPerToken() @@ -281,12 +290,9 @@ def nvfp4_per_token_quantize( # above; the composite handles K1 + K2 ordering on the same stream. # ============================================================================ - def nvfp4_per_token_amax( - x: torch.Tensor, - *, - rowwise: bool = True, - columnwise: bool = True, + x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = True, + with_rht: bool = False, random_sign_mask_t: int = 0xACE1, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. Returns ``(row_amax, col_amax)``; either may be ``None`` if the @@ -296,6 +302,10 @@ def nvfp4_per_token_amax( ``HadamardAmaxTmaKernel``. Production callers should use the composite ``nvfp4_per_token_quantize`` instead. + ``with_rht=True`` applies a 16-pt col-wise RHT before amax; rowwise + never sees RHT. ``random_sign_mask_t`` low 16 bits = sign pattern + (default ``0xACE1``). + Raises ``ValueError`` on non-bf16 input or non-128-aligned shapes. """ import transformer_engine_torch as tex # type: ignore @@ -316,7 +326,10 @@ def nvfp4_per_token_amax( else torch.empty(0, dtype=torch.float32, device=device) ) - tex.nvfp4_per_token_amax(x, row_amax, col_amax, rowwise, columnwise) + tex.nvfp4_per_token_amax( + x, row_amax, col_amax, rowwise, columnwise, + with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + ) return (row_amax if rowwise else None, col_amax if columnwise else None) @@ -328,6 +341,8 @@ def nvfp4_per_token_encode( col_amax: Optional[torch.Tensor] = None, rowwise: bool = True, columnwise: bool = True, + with_rht: bool = False, + random_sign_mask_t: int = 0xACE1, ) -> RefNVFP4TensorPerToken: """Kernel 2 in isolation: FP4 + e4m3 SF encode given pre-filled amax buffer(s). @@ -341,6 +356,9 @@ def nvfp4_per_token_encode( per-tensor cast pass. Production callers should use the composite ``nvfp4_per_token_quantize`` instead. + ``with_rht=True`` requires ``col_amax`` produced by a prior K1 call + with the SAME mask, else inner SF / FP4 saturate. + Raises ``ValueError`` on non-bf16 input, non-128-aligned shapes, or missing / mis-shaped amax buffers. """ @@ -372,15 +390,8 @@ def nvfp4_per_token_encode( q_col, s_dec_col, col_amax_t = empty, empty, empty_f32 tex.nvfp4_per_token_encode( - x, - q_row, - s_dec_row, - row_amax_t, - q_col, - s_dec_col, - col_amax_t, - rowwise, - columnwise, + x, q_row, s_dec_row, row_amax_t, q_col, s_dec_col, col_amax_t, rowwise, columnwise, + with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, ) out = RefNVFP4TensorPerToken() diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py index 25c6a324ad..ba3aabeaa6 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py @@ -29,14 +29,20 @@ def _validate_per_token_group_input( ``(sum_M, K)``. """ if x_concat.ndim != 2: - raise ValueError(f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D") + raise ValueError( + f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D" + ) if not x_concat.is_contiguous(): raise ValueError("x_concat must be contiguous (row-major)") if x_concat.dtype != torch.bfloat16: - raise ValueError(f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}.") + raise ValueError( + f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}." + ) sum_M, K = x_concat.shape if K % _PER_TOKEN_TILE != 0: - raise ValueError(f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") + raise ValueError( + f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" + ) if len(split_sections) == 0: raise ValueError("split_sections must not be empty") if len(split_sections) > 64: @@ -48,22 +54,41 @@ def _validate_per_token_group_input( if M_i <= 0: raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") if M_i % _PER_TOKEN_TILE != 0: - raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}") + raise ValueError( + f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}" + ) acc += M_i if acc != sum_M: - raise ValueError(f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}") + raise ValueError( + f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}" + ) return sum_M, K +# Default RHT sign-flip mask seed; matches the single-tensor wrapper. +_RHT_MASK_DEFAULT: int = 0xACE1 + + def nvfp4_per_token_group_quantize( x_concat: torch.Tensor, split_sections: Sequence[int], *, rowwise: bool = True, columnwise: bool = False, + with_rht: bool = False, + random_sign_mask_t: int = _RHT_MASK_DEFAULT, ) -> List[RefNVFP4TensorPerToken]: """Grouped NVFP4 per-token cast; returns N RefNVFP4TensorPerToken splits. + Args: + x_concat: (sum_M, K) bf16, row-major contiguous. + split_sections: per-split row counts (each a multiple of 128). + rowwise / columnwise: which directions to emit. + with_rht: True -> apply a 16-pt col-wise RHT in BOTH K1 and K2; + downstream GEMM must consume RHT-rotated weights to stay + unbiased. Rowwise never sees RHT. + random_sign_mask_t: low 16 bits = sign pattern shared by K1+K2. + Raises ``ValueError`` on shape / dtype / split-size violations. """ import transformer_engine_torch as tex # type: ignore @@ -83,7 +108,11 @@ def nvfp4_per_token_group_quantize( q_col_list, s_dec_col_list, col_amax_list, - ) = tex.nvfp4_per_token_group_quantize_bulk(x_concat, split_sections_list, rowwise, columnwise) + ) = tex.nvfp4_per_token_group_quantize_bulk( + x_concat, split_sections_list, rowwise, columnwise, + with_rht=bool(with_rht), + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + ) outs: List[RefNVFP4TensorPerToken] = [] for i in range(N): From 1f436839dd0ac6e6c4a019a39be8c6fc46be9c32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 May 2026 08:25:16 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/nvfp4/bench_nvfp4_per_token.py | 233 +++++---- .../nvfp4/bench_nvfp4_per_token_group.py | 206 +++++--- tests/pytorch/nvfp4/test_nvfp4_per_token.py | 223 ++++++--- .../nvfp4/test_nvfp4_per_token_group.py | 147 ++++-- .../cast/nvfp4/quantize_nvfp4_per_token.cu | 449 ++++++++--------- .../nvfp4/quantize_nvfp4_per_token_group.cu | 451 ++++++++---------- .../transformer_engine/nvfp4_per_token.h | 40 +- transformer_engine/pytorch/csrc/extensions.h | 55 +-- .../csrc/extensions/nvfp4_per_token.cpp | 297 +++++------- .../pytorch/csrc/extensions/pybind.cpp | 55 +-- .../quantization_nvfp4_per_token.py | 65 ++- .../quantization_nvfp4_per_token_group.py | 25 +- 12 files changed, 1171 insertions(+), 1075 deletions(-) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py index 1663dce1f8..7da8ec9519 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py @@ -50,9 +50,7 @@ def cuda_time_ms(fn: Callable[[], None], *, warmup: int = 5, iters: int = 50) -> return statistics.median(samples) -def cuda_graph_time_ms( - fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 -) -> float: +def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: """Median g.replay() wall time of fn captured into a CUDA Graph (kernel-only floor). Returns nan if capture fails. @@ -111,12 +109,12 @@ def _has_sm100() -> bool: class ShapeBench: M: int K: int - t_pt: float # per-token full K1+K2, no RHT (Eager pybind, ms) - t_pt_rht: float # per-token full K1+K2, +RHT col-wise (Eager pybind, ms) - t_pten: float # per-tensor full K1+K2 with RHT+SR (Eager pybind, ms) - t_pt_g: float # per-token under CUDA Graphs replay (ms) - t_pt_rht_g: float # per-token+RHT under CUDA Graphs replay (ms) - t_pten_g: float # per-tensor under CUDA Graphs replay (ms) + t_pt: float # per-token full K1+K2, no RHT (Eager pybind, ms) + t_pt_rht: float # per-token full K1+K2, +RHT col-wise (Eager pybind, ms) + t_pten: float # per-tensor full K1+K2 with RHT+SR (Eager pybind, ms) + t_pt_g: float # per-token under CUDA Graphs replay (ms) + t_pt_rht_g: float # per-token+RHT under CUDA Graphs replay (ms) + t_pten_g: float # per-tensor under CUDA Graphs replay (ms) @dataclass @@ -124,9 +122,9 @@ class K1ShapeBench: M: int K: int # K1-only timings: 3 paths x 2 modes (Eager + CUDA Graphs). - t_pt: float # per-token K1, no RHT (rowwise+columnwise amax vectors) - t_pt_rht: float # per-token K1, +RHT on col direction - t_prod: float # prod K1 hadamard_transform_amax (per-tensor scalar amax) + t_pt: float # per-token K1, no RHT (rowwise+columnwise amax vectors) + t_pt_rht: float # per-token K1, +RHT on col direction + t_prod: float # prod K1 hadamard_transform_amax (per-tensor scalar amax) t_pt_g: float t_pt_rht_g: float t_prod_g: float @@ -136,9 +134,9 @@ class K1ShapeBench: _RHT_MASK_DEFAULT: int = 0xACE1 -def _bench_shape(M: int, K: int, *, device: torch.device, - with_rht: bool = False, - mask_t: int = _RHT_MASK_DEFAULT) -> ShapeBench: +def _bench_shape( + M: int, K: int, *, device: torch.device, with_rht: bool = False, mask_t: int = _RHT_MASK_DEFAULT +) -> ShapeBench: """Composite K1+K2 timing at one (M, K) shape. pt = per-token (no RHT), pt_rht = per-token + col-wise 16-pt RHT (NaN unless with_rht=True), pten = per-tensor + RHT + SR (prod baseline). @@ -162,8 +160,17 @@ def _baseline_quant_fn(): def _pt_full_quant_fn(): tex.nvfp4_per_token_quantize( - a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True, - with_rht=False, random_sign_mask_t=0, + a, + q_row_a, + s_dec_row_a, + ra_a, + q_col_a, + s_dec_col_a, + ca_a, + True, + True, + with_rht=False, + random_sign_mask_t=0, ) t_pten = cuda_time_ms(_baseline_quant_fn) @@ -172,10 +179,20 @@ def _pt_full_quant_fn(): t_pt_g = cuda_graph_time_ms(_pt_full_quant_fn) if with_rht: + def _pt_full_quant_rht_fn(): tex.nvfp4_per_token_quantize( - a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True, - with_rht=True, random_sign_mask_t=mask_t, + a, + q_row_a, + s_dec_row_a, + ra_a, + q_col_a, + s_dec_col_a, + ca_a, + True, + True, + with_rht=True, + random_sign_mask_t=mask_t, ) t_pt_rht = cuda_time_ms(_pt_full_quant_rht_fn) @@ -185,15 +202,20 @@ def _pt_full_quant_rht_fn(): t_pt_rht_g = float("nan") return ShapeBench( - M=M, K=K, - t_pt=t_pt, t_pt_rht=t_pt_rht, t_pten=t_pten, - t_pt_g=t_pt_g, t_pt_rht_g=t_pt_rht_g, t_pten_g=t_pten_g, + M=M, + K=K, + t_pt=t_pt, + t_pt_rht=t_pt_rht, + t_pten=t_pten, + t_pt_g=t_pt_g, + t_pt_rht_g=t_pt_rht_g, + t_pten_g=t_pten_g, ) -def _bench_shape_k1_only(M: int, K: int, *, device: torch.device, - with_rht: bool = False, - mask_t: int = _RHT_MASK_DEFAULT) -> K1ShapeBench: +def _bench_shape_k1_only( + M: int, K: int, *, device: torch.device, with_rht: bool = False, mask_t: int = _RHT_MASK_DEFAULT +) -> K1ShapeBench: """K1-only timing. pt = per-token (no RHT), pt_rht = per-token + col RHT (NaN unless with_rht=True), prod = hadamard_transform_amax (scalar amax; NOT apples-to-apples but the closest prod K1 reference). @@ -210,8 +232,13 @@ def _bench_shape_k1_only(M: int, K: int, *, device: torch.device, def _pt_k1_fn(): tex.nvfp4_per_token_amax( - a, ra_pt, ca_pt, True, True, - with_rht=False, random_sign_mask_t=0, + a, + ra_pt, + ca_pt, + True, + True, + with_rht=False, + random_sign_mask_t=0, ) def _prod_k1_fn(): @@ -229,8 +256,13 @@ def _prod_k1_fn(): def _pt_k1_rht_fn(): tex.nvfp4_per_token_amax( - a, ra_pt_rht, ca_pt_rht, True, True, - with_rht=True, random_sign_mask_t=mask_t, + a, + ra_pt_rht, + ca_pt_rht, + True, + True, + with_rht=True, + random_sign_mask_t=mask_t, ) t_pt_rht = cuda_time_ms(_pt_k1_rht_fn) @@ -240,18 +272,21 @@ def _pt_k1_rht_fn(): t_pt_rht_g = float("nan") return K1ShapeBench( - M=M, K=K, - t_pt=t_pt, t_pt_rht=t_pt_rht, t_prod=t_prod, - t_pt_g=t_pt_g, t_pt_rht_g=t_pt_rht_g, t_prod_g=t_prod_g, + M=M, + K=K, + t_pt=t_pt, + t_pt_rht=t_pt_rht, + t_prod=t_prod, + t_pt_g=t_pt_g, + t_pt_rht_g=t_pt_rht_g, + t_prod_g=t_prod_g, ) # 6x3 sweep matching bench_nvfp4_per_token_group.py: M in {1024..32768}, K in {2048,4096,8192}. _M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) _K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) -_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple( - (m, k) for m in _M_VALUES for k in _K_VALUES -) +_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple((m, k) for m in _M_VALUES for k in _K_VALUES) def _parse_shape(s: str) -> Tuple[int, int]: @@ -271,16 +306,12 @@ def _print_composite_table_2way(records: List[ShapeBench]) -> None: """2-way composite (no RHT). ratio = per-token / per-tensor (< 1.0 wins).""" w_pt, w_pten, w_ratio = 14, 15, 8 block_w = w_pt + 1 + w_pten + 1 + w_ratio - header1 = ( - f"{'':>7} {'':>6}" - f" |{'Eager, unit (ms)':^{block_w}}" - f" |{'Graph, unit (ms)':^{block_w}}" - ) + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" header2 = ( f"{'M':>7} {'K':>6}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" ) print(header1) @@ -299,9 +330,9 @@ def _fmt(r: float) -> str: print( f"{rec.M:>7} {rec.K:>6}" - f" |" + " |" f"{rec.t_pt:>{w_pt}.4f} {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" - f" |" + " |" f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" ) @@ -310,26 +341,22 @@ def _print_composite_table(records: List[ShapeBench]) -> None: """3-way composite (--rht). ratio = per-token (+rht) / per-tensor.""" w_pt, w_pt_rht, w_pten, w_ratio = 12, 12, 13, 8 block_w = w_pt + 1 + w_pt_rht + 1 + w_pten + 1 + w_ratio - header1 = ( - f"{'':>7} {'':>6}" - f" |{'Eager, unit (ms)':^{block_w}}" - f" |{'Graph, unit (ms)':^{block_w}}" - ) + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" header2 = ( f"{'M':>7} {'K':>6}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" ) header3 = ( f"{'':>7} {'':>6}" - f" |" + " |" f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" f" {'':>{w_pten}} {'':>{w_ratio}}" - f" |" + " |" f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" f" {'':>{w_pten}} {'':>{w_ratio}}" ) @@ -350,10 +377,10 @@ def _fmt(r: float) -> str: print( f"{rec.M:>7} {rec.K:>6}" - f" |" + " |" f"{rec.t_pt:>{w_pt}.4f} {rec.t_pt_rht:>{w_pt_rht}.4f}" f" {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" - f" |" + " |" f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pt_rht_g:>{w_pt_rht}.4f}" f" {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" ) @@ -366,9 +393,9 @@ def _print_k1_2way_table(records: List[K1ShapeBench]) -> None: print("K1-only: pt vs prod (NOT apples-to-apples; output shapes differ).") header = ( f"{'M':>7} {'K':>6}" - f" |" + " |" f"{'pt_K1':>9} {'prod_K1':>9} {'ratio':>8}" - f" |" + " |" f"{'pt_K1(Graph)':>14} {'prod_K1(Graph)':>16} {'ratio(Graph)':>13}" ) print(header) @@ -384,9 +411,9 @@ def _print_k1_2way_table(records: List[K1ShapeBench]) -> None: ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" print( f"{rec.M:>7} {rec.K:>6}" - f" |" + " |" f"{rec.t_pt:>9.4f} {rec.t_prod:>9.4f} {ratio_s:>8}" - f" |" + " |" f"{rec.t_pt_g:>14.4f} {rec.t_prod_g:>16.4f} {ratio_g_s:>13}" ) @@ -396,9 +423,9 @@ def _print_k1_rht_cost_table(records: List[K1ShapeBench]) -> None: print("Table A -- K1-only RHT cost (pt = per-token, +RHT = col-wise FHT).") header = ( f"{'M':>7} {'K':>6}" - f" |" + " |" f"{'pt_K1':>9} {'pt_K1+RHT':>11} {'ratio':>8}" - f" |" + " |" f"{'pt_K1(Graph)':>14} {'pt_K1+RHT(Graph)':>18} {'ratio(Graph)':>13}" ) print(header) @@ -414,9 +441,9 @@ def _print_k1_rht_cost_table(records: List[K1ShapeBench]) -> None: ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" print( f"{rec.M:>7} {rec.K:>6}" - f" |" + " |" f"{rec.t_pt:>9.4f} {rec.t_pt_rht:>11.4f} {ratio_s:>8}" - f" |" + " |" f"{rec.t_pt_g:>14.4f} {rec.t_pt_rht_g:>18.4f} {ratio_g_s:>13}" ) @@ -428,9 +455,9 @@ def _print_k1_vs_prod_table(records: List[K1ShapeBench]) -> None: print("Table B -- K1-only vs prod (NOT apples-to-apples; output shapes differ).") header = ( f"{'M':>7} {'K':>6}" - f" |" + " |" f"{'pt_K1+RHT':>11} {'prod_K1':>9} {'ratio':>8}" - f" |" + " |" f"{'pt_K1+RHT(Graph)':>18} {'prod_K1(Graph)':>16} {'ratio(Graph)':>13}" ) print(header) @@ -446,9 +473,9 @@ def _print_k1_vs_prod_table(records: List[K1ShapeBench]) -> None: ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" print( f"{rec.M:>7} {rec.K:>6}" - f" |" + " |" f"{rec.t_pt_rht:>11.4f} {rec.t_prod:>9.4f} {ratio_s:>8}" - f" |" + " |" f"{rec.t_pt_rht_g:>18.4f} {rec.t_prod_g:>16.4f} {ratio_g_s:>13}" ) @@ -461,25 +488,36 @@ def _print_composite_legend(*, with_rht: bool, rht_mask: int) -> None: print(" per-token (ms) = tex.nvfp4_per_token_quantize(a, ..., rowwise+colwise,") print(" with_rht=False)") print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") - print(f" per-token (+rht) (ms) = same, but with_rht=True + random_sign_mask_t=0x{rht_mask:04X}.") + print( + " per-token (+rht) (ms) = same, but with_rht=True +" + f" random_sign_mask_t=0x{rht_mask:04X}." + ) print(" Applies a 16-point RHT along the columnwise direction in") print(" BOTH K1 amax and K2 cast; rowwise stays raw. Length-16") print(" matches the 1x16 inner-SF block of NVFP4, so each scale") print(" window is decorrelated.") print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr), ...)") print(" = nvte_quantize_with_hadamard_transform") - print(" (1 fused launch: rowwise quant + col-wise RHT + col quant,") + print( + " (1 fused launch: rowwise quant + col-wise RHT + col quant," + ) print(" prod baseline).") print(" ratio = per-token (+rht) / per-tensor") print(" ** < 1.0 = this PR wins vs prod baseline **") else: - print(" per-token (ms) = tex.nvfp4_per_token_quantize(a, ..., rowwise+colwise, with_rht=False)") + print( + " per-token (ms) = tex.nvfp4_per_token_quantize(a, ..., rowwise+colwise," + " with_rht=False)" + ) print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr), ...)") print(" = nvte_quantize_with_hadamard_transform") print(" (1 fused launch: rowwise quant + col-wise RHT + col quant,") print(" prod baseline).") - print(" ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod baseline **") + print( + " ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod" + " baseline **" + ) print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") @@ -488,28 +526,43 @@ def main() -> int: description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4." ) parser.add_argument( - "--shapes", type=_parse_shape, nargs="+", default=None, - help="Shapes to bench, in MxK form (e.g. 4096x4096). " - "Default: an internally-chosen production-shape sweep.", + "--shapes", + type=_parse_shape, + nargs="+", + default=None, + help=( + "Shapes to bench, in MxK form (e.g. 4096x4096). " + "Default: an internally-chosen production-shape sweep." + ), ) parser.add_argument( - "--rht", action="store_true", - help="Also time the per-token + RHT path (col-wise 16-pt RHT in K1 + K2). " - "Default OFF: prints a 2-way table (per-token vs per-tensor). With " - "--rht: prints a 3-way table with one ratio " - "(per-token (+rht) / per-tensor).", + "--rht", + action="store_true", + help=( + "Also time the per-token + RHT path (col-wise 16-pt RHT in K1 + K2). " + "Default OFF: prints a 2-way table (per-token vs per-tensor). With " + "--rht: prints a 3-way table with one ratio " + "(per-token (+rht) / per-tensor)." + ), ) parser.add_argument( - "--k1-only", action="store_true", - help="K1-only mode (no K2 cast). Without --rht: 2-way table (pt_K1 " - "vs prod_K1). With --rht: two tables back-to-back -- (A) RHT cost " - "pt_K1 vs pt_K1+RHT (apples-to-apples) and (B) pt_K1+RHT vs prod_K1 " - "(context only; output shapes differ).", + "--k1-only", + action="store_true", + help=( + "K1-only mode (no K2 cast). Without --rht: 2-way table (pt_K1 " + "vs prod_K1). With --rht: two tables back-to-back -- (A) RHT cost " + "pt_K1 vs pt_K1+RHT (apples-to-apples) and (B) pt_K1+RHT vs prod_K1 " + "(context only; output shapes differ)." + ), ) parser.add_argument( - "--rht-mask", type=lambda s: int(s, 0), default=_RHT_MASK_DEFAULT, - help="16-bit random sign mask for the RHT path (only matters with --rht). " - f"Default 0x{_RHT_MASK_DEFAULT:04X}; accepts hex (0x...) or decimal.", + "--rht-mask", + type=lambda s: int(s, 0), + default=_RHT_MASK_DEFAULT, + help=( + "16-bit random sign mask for the RHT path (only matters with --rht). " + f"Default 0x{_RHT_MASK_DEFAULT:04X}; accepts hex (0x...) or decimal." + ), ) args = parser.parse_args() @@ -523,8 +576,7 @@ def main() -> int: if args.k1_only: records_k1: List[K1ShapeBench] = [ - _bench_shape_k1_only(M, K, device=device, - with_rht=args.rht, mask_t=mask) + _bench_shape_k1_only(M, K, device=device, with_rht=args.rht, mask_t=mask) for (M, K) in shapes ] if args.rht: @@ -535,8 +587,7 @@ def main() -> int: _print_k1_2way_table(records_k1) else: records: List[ShapeBench] = [ - _bench_shape(M, K, device=device, with_rht=args.rht, mask_t=mask) - for (M, K) in shapes + _bench_shape(M, K, device=device, with_rht=args.rht, mask_t=mask) for (M, K) in shapes ] if args.rht: _print_composite_table(records) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py index 1111382a19..d6f3a50da5 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py @@ -50,9 +50,7 @@ def _make_baseline_quantizer_list(num_splits: int) -> List[NVFP4Quantizer]: return [q] * num_splits -def cuda_graph_time_ms( - fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 -) -> float: +def cuda_graph_time_ms(fn: Callable[[], object], *, warmup: int = 5, iters: int = 50) -> float: """Median g.replay() time of fn under CUDA Graphs, in ms (nan on capture failure).""" try: side = torch.cuda.Stream() @@ -84,14 +82,26 @@ def cuda_graph_time_ms( _RHT_MASK_DEFAULT: int = 0xACE1 -def _time_grouped(x_concat, split_sections, rowwise, columnwise, - *, with_rht: bool = False, mask: int = _RHT_MASK_DEFAULT, - n_iters: int = 20, n_warmup: int = 5) -> float: +def _time_grouped( + x_concat, + split_sections, + rowwise, + columnwise, + *, + with_rht: bool = False, + mask: int = _RHT_MASK_DEFAULT, + n_iters: int = 20, + n_warmup: int = 5, +) -> float: """Per-token grouped via the BULK Python wrapper. Allocation in-loop.""" for _ in range(n_warmup): _ = nvfp4_per_token_group_quantize( - x_concat, split_sections, rowwise=rowwise, columnwise=columnwise, - with_rht=with_rht, random_sign_mask_t=mask, + x_concat, + split_sections, + rowwise=rowwise, + columnwise=columnwise, + with_rht=with_rht, + random_sign_mask_t=mask, ) torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) @@ -99,8 +109,12 @@ def _time_grouped(x_concat, split_sections, rowwise, columnwise, start.record() for _ in range(n_iters): _ = nvfp4_per_token_group_quantize( - x_concat, split_sections, rowwise=rowwise, columnwise=columnwise, - with_rht=with_rht, random_sign_mask_t=mask, + x_concat, + split_sections, + rowwise=rowwise, + columnwise=columnwise, + with_rht=with_rht, + random_sign_mask_t=mask, ) stop.record() torch.cuda.synchronize() @@ -122,23 +136,36 @@ def _time_split_quantize(x_concat, split_sections, quantizer_list, n_iters=20, n return start.elapsed_time(stop) / n_iters # ms -def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, - n_iters=20, n_warmup=5): +def _time_split_quantize_graph(x_concat, split_sections, quantizer_list, n_iters=20, n_warmup=5): """Per-tensor grouped under CUDA Graphs replay.""" + def fn() -> None: _ = tex.split_quantize(x_concat, split_sections, quantizer_list) return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) -def _time_grouped_graph(x_concat, split_sections, rowwise, columnwise, - *, with_rht: bool = False, mask: int = _RHT_MASK_DEFAULT, - n_iters: int = 20, n_warmup: int = 5) -> float: +def _time_grouped_graph( + x_concat, + split_sections, + rowwise, + columnwise, + *, + with_rht: bool = False, + mask: int = _RHT_MASK_DEFAULT, + n_iters: int = 20, + n_warmup: int = 5, +) -> float: """Per-token grouped under CUDA Graphs replay.""" + def fn() -> None: _ = nvfp4_per_token_group_quantize( - x_concat, split_sections, rowwise=rowwise, columnwise=columnwise, - with_rht=with_rht, random_sign_mask_t=mask, + x_concat, + split_sections, + rowwise=rowwise, + columnwise=columnwise, + with_rht=with_rht, + random_sign_mask_t=mask, ) return cuda_graph_time_ms(fn, warmup=n_warmup, iters=n_iters) @@ -175,12 +202,10 @@ def _build_bench_cases( if M_i % 128 != 0: raise argparse.ArgumentTypeError( f"sum_M={sum_M} / num_splits={num_splits} = M_i={M_i} must be a " - f"multiple of 128 (NVFP4 per-token kernel constraint)" + "multiple of 128 (NVFP4 per-token kernel constraint)" ) if K % 128 != 0: - raise argparse.ArgumentTypeError( - f"K={K} must be a multiple of 128" - ) + raise argparse.ArgumentTypeError(f"K={K} must be a multiple of 128") cases.append(([M_i] * num_splits, K)) return cases @@ -193,24 +218,40 @@ def main() -> int: ) ) parser.add_argument( - "--shapes", type=_parse_shape, nargs="+", default=None, - help="Shapes to bench, in sum_MxK form (e.g. 8192x4096). " - "Default: a 6x3 = 18-row internally-chosen sweep.", + "--shapes", + type=_parse_shape, + nargs="+", + default=None, + help=( + "Shapes to bench, in sum_MxK form (e.g. 8192x4096). " + "Default: a 6x3 = 18-row internally-chosen sweep." + ), ) parser.add_argument( - "--num-splits", type=int, default=_DEFAULT_NUM_SPLITS, - help=f"Number of equal splits per shape (default {_DEFAULT_NUM_SPLITS}; " - f"<= 64). M_i = sum_M / num_splits must be a multiple of 128.", + "--num-splits", + type=int, + default=_DEFAULT_NUM_SPLITS, + help=( + f"Number of equal splits per shape (default {_DEFAULT_NUM_SPLITS}; " + "<= 64). M_i = sum_M / num_splits must be a multiple of 128." + ), ) parser.add_argument( - "--rht", action="store_true", - help="Enable 3-way table with per-token + col-wise 16-pt RHT path. " - "Default OFF prints 2-way (per-token vs per-tensor).", + "--rht", + action="store_true", + help=( + "Enable 3-way table with per-token + col-wise 16-pt RHT path. " + "Default OFF prints 2-way (per-token vs per-tensor)." + ), ) parser.add_argument( - "--rht-mask", type=lambda s: int(s, 0), default=_RHT_MASK_DEFAULT, - help=f"16-bit RHT sign mask (default 0x{_RHT_MASK_DEFAULT:04X}; accepts " - "hex/dec). Only affects per-token+RHT; per-tensor uses its own mask.", + "--rht-mask", + type=lambda s: int(s, 0), + default=_RHT_MASK_DEFAULT, + help=( + f"16-bit RHT sign mask (default 0x{_RHT_MASK_DEFAULT:04X}; accepts " + "hex/dec). Only affects per-token+RHT; per-tensor uses its own mask." + ), ) args = parser.parse_args() @@ -228,21 +269,24 @@ def main() -> int: if args.shapes is not None: shapes_in = [tuple(s) for s in args.shapes] else: - shapes_in = [ - (sm, k) for sm in _DEFAULT_SUM_M_VALUES for k in _DEFAULT_K_VALUES - ] + shapes_in = [(sm, k) for sm in _DEFAULT_SUM_M_VALUES for k in _DEFAULT_K_VALUES] bench_cases = _build_bench_cases(shapes_in, args.num_splits) rht_mask: int = args.rht_mask & 0xFFFF with_rht: bool = args.rht device = torch.device("cuda") print(f"# Device: {torch.cuda.get_device_name(0)} (cap {cap[0]}.{cap[1]})") - print(f"# Split structure: N={args.num_splits} equal splits, " - f"M_i = sum_M / {args.num_splits}") + print(f"# Split structure: N={args.num_splits} equal splits, M_i = sum_M / {args.num_splits}") if with_rht: - print(f"# RHT mask: 0x{rht_mask:04X} (per-token+RHT col-wise; per-tensor uses its own internal mask)") + print( + f"# RHT mask: 0x{rht_mask:04X} (per-token+RHT col-wise; per-tensor uses its own" + " internal mask)" + ) else: - print("# RHT: disabled (pass --rht to enable 3-way per-token / per-token (+rht) / per-tensor table)") + print( + "# RHT: disabled (pass --rht to enable 3-way per-token / per-token (+rht) / per-tensor" + " table)" + ) print() # Per-tensor baseline quantizer is fixed to row+col, so both enabled. @@ -263,25 +307,23 @@ def _ratio(num: float, den: float) -> float: w_pt, w_pt_rht, w_pten, w_ratio = 12, 12, 13, 8 block_w = w_pt + 1 + w_pt_rht + 1 + w_pten + 1 + w_ratio header1 = ( - f"{'':>6} {'':>5}" - f" |{'Eager, unit (ms)':^{block_w}}" - f" |{'Graph, unit (ms)':^{block_w}}" + f"{'':>6} {'':>5} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" ) header2 = ( f"{'sum_M':>6} {'K':>5}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-token':>{w_pt_rht}}" f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" ) header3 = ( f"{'':>6} {'':>5}" - f" |" + " |" f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" f" {'':>{w_pten}} {'':>{w_ratio}}" - f" |" + " |" f"{'':>{w_pt}} {'(+rht)':>{w_pt_rht}}" f" {'':>{w_pten}} {'':>{w_ratio}}" ) @@ -292,15 +334,13 @@ def _ratio(num: float, den: float) -> float: w_pt, w_pten, w_ratio = 14, 15, 8 block_w = w_pt + 1 + w_pten + 1 + w_ratio header1 = ( - f"{'':>6} {'':>5}" - f" |{'Eager, unit (ms)':^{block_w}}" - f" |{'Graph, unit (ms)':^{block_w}}" + f"{'':>6} {'':>5} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" ) header2 = ( f"{'sum_M':>6} {'K':>5}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" - f" |" + " |" f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" ) print(header1) @@ -317,27 +357,35 @@ def _ratio(num: float, den: float) -> float: print() prev_sum_M = sum_M - x_concat = ( - torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0 - ).contiguous() + x_concat = (torch.randn((sum_M, K), dtype=torch.bfloat16, device=device) * 3.0).contiguous() quantizer_list = _make_baseline_quantizer_list(num_splits) - t_pt = _time_grouped(x_concat, split_sections, rowwise, columnwise, - with_rht=False) + t_pt = _time_grouped(x_concat, split_sections, rowwise, columnwise, with_rht=False) t_pten = _time_split_quantize(x_concat, split_sections, quantizer_list) t_pt_g = _time_grouped_graph( - x_concat, split_sections, rowwise, columnwise, with_rht=False, + x_concat, + split_sections, + rowwise, + columnwise, + with_rht=False, ) t_pten_g = _time_split_quantize_graph( - x_concat, split_sections, quantizer_list, + x_concat, + split_sections, + quantizer_list, ) if with_rht: - t_pt_rht = _time_grouped(x_concat, split_sections, rowwise, columnwise, - with_rht=True, mask=rht_mask) + t_pt_rht = _time_grouped( + x_concat, split_sections, rowwise, columnwise, with_rht=True, mask=rht_mask + ) t_pt_rht_g = _time_grouped_graph( - x_concat, split_sections, rowwise, columnwise, - with_rht=True, mask=rht_mask, + x_concat, + split_sections, + rowwise, + columnwise, + with_rht=True, + mask=rht_mask, ) ratio_eager = _ratio(t_pt_rht, t_pten) @@ -345,10 +393,10 @@ def _ratio(num: float, den: float) -> float: print( f"{sum_M:>6d} {K:>5d}" - f" |" + " |" f"{t_pt:>{w_pt}.4f} {t_pt_rht:>{w_pt_rht}.4f}" f" {t_pten:>{w_pten}.4f} {_fmt(ratio_eager):>{w_ratio}}" - f" |" + " |" f"{t_pt_g:>{w_pt}.4f} {t_pt_rht_g:>{w_pt_rht}.4f}" f" {t_pten_g:>{w_pten}.4f} {_fmt(ratio_graph):>{w_ratio}}" ) @@ -357,9 +405,9 @@ def _ratio(num: float, den: float) -> float: ratio_graph = _ratio(t_pt_g, t_pten_g) print( f"{sum_M:>6d} {K:>5d}" - f" |" + " |" f"{t_pt:>{w_pt}.4f} {t_pten:>{w_pten}.4f} {_fmt(ratio_eager):>{w_ratio}}" - f" |" + " |" f"{t_pt_g:>{w_pt}.4f} {t_pten_g:>{w_pten}.4f} {_fmt(ratio_graph):>{w_ratio}}" ) @@ -372,24 +420,38 @@ def _ratio(num: float, den: float) -> float: print(" per-token (ms) = nvfp4_per_token_group_quantize(x, splits,") print(" rowwise+colwise, with_rht=False)") print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") - print(f" per-token (+rht) (ms) = same, but with_rht=True + random_sign_mask_t=0x{rht_mask:04X}.") + print( + " per-token (+rht) (ms) = same, but with_rht=True +" + f" random_sign_mask_t=0x{rht_mask:04X}." + ) print(" Applies a 16-point RHT along the columnwise direction in") print(" BOTH K1 amax and K2 cast; rowwise stays raw. Length-16") print(" matches the 1x16 inner-SF block of NVFP4, so each scale") print(" window is decorrelated.") - print(" per-tensor (ms) = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)") + print( + " per-tensor (ms) = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)" + ) print(" = nvte_group_hadamard_transform_amax") print(" + nvte_group_hadamard_transform_cast_fusion") print(" (2 launches, prod baseline).") print(" ratio = per-token (+rht) / per-tensor") print(" ** < 1.0 = this PR wins vs prod baseline **") else: - print(" per-token (ms) = nvfp4_per_token_group_quantize(x, splits, rowwise+colwise, with_rht=False)") + print( + " per-token (ms) = nvfp4_per_token_group_quantize(x, splits, rowwise+colwise," + " with_rht=False)" + ) print(" = K1 fused amax + K2 fused cast (2 launches), no RHT.") print(" per-tensor (ms) = tex.split_quantize(x, splits, [NVFP4Quantizer(rht+sr)]*N)") print(" = nvte_group_hadamard_transform_amax") - print(" + nvte_group_hadamard_transform_cast_fusion (2 launches, prod baseline).") - print(" ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod baseline **") + print( + " + nvte_group_hadamard_transform_cast_fusion (2 launches, prod" + " baseline)." + ) + print( + " ratio = per-token / per-tensor ** < 1.0 = per-token wins vs prod" + " baseline **" + ) print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") return 0 diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token.py b/tests/pytorch/nvfp4/test_nvfp4_per_token.py index 90c3f87592..cfe2a9f4aa 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token.py @@ -444,7 +444,8 @@ def _sign_diag_16(mask: int, device: torch.device) -> torch.Tensor: """16-elt +/-1 vector; s_i = -1 iff bit i of `mask` is set.""" bits = torch.tensor( [1 - 2 * ((mask >> i) & 1) for i in range(16)], - dtype=torch.float32, device=device, + dtype=torch.float32, + device=device, ) return bits @@ -474,18 +475,18 @@ def _reference_amax_raw(x_bf16: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso def _allocate_per_token_buffers(M: int, K: int, device: torch.device): """Match the layout that ``tex.nvfp4_per_token_quantize`` writes.""" return { - "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), - "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), - "ra": torch.empty((M,), dtype=torch.float32, device=device), - "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), - "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), - "ca": torch.empty((K,), dtype=torch.float32, device=device), + "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), + "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), + "ra": torch.empty((M,), dtype=torch.float32, device=device), + "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), + "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), + "ca": torch.empty((K,), dtype=torch.float32, device=device), } def _dequant_fp4_with_outer_amax( - q_packed: torch.Tensor, # (R, C // 2) uint8 packed FP4 - s_dec: torch.Tensor, # (R, C // 16) e4m3 held as uint8 + q_packed: torch.Tensor, # (R, C // 2) uint8 packed FP4 + s_dec: torch.Tensor, # (R, C // 16) e4m3 held as uint8 outer_amax: torch.Tensor, # (R,) fp32 ) -> torch.Tensor: """Decode a rowwise FP4 tensor back to fp32 using the kernel's own @@ -502,9 +503,9 @@ def _dequant_fp4_with_outer_amax( # NVFP4 E2M1 LUT (sign-magnitude): 0000..0111 map to {0, 0.5, 1, 1.5, # 2, 3, 4, 6}; 1000..1111 are the negatives. fp4_lut = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], - dtype=torch.float32, device=q_packed.device, + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + dtype=torch.float32, + device=q_packed.device, ) fp4_val = fp4_lut[interleaved.to(torch.int64)] @@ -522,6 +523,7 @@ def _dequant_fp4_with_outer_amax( # ----- (5a) K1 RHT: standalone amax kernel ---------------------------------- + @_GATED_SM100 @pytest.mark.parametrize("M,K", _RHT_SHAPES) def test_per_token_k1_with_rht_false_equals_raw_amax(M: int, K: int) -> None: @@ -534,22 +536,31 @@ def test_per_token_k1_with_rht_false_equals_raw_amax(M: int, K: int) -> None: col_amax = torch.empty((K,), dtype=torch.float32, device=device) tex.nvfp4_per_token_amax( - x, row_amax, col_amax, True, True, - with_rht=False, random_sign_mask_t=0, + x, + row_amax, + col_amax, + True, + True, + with_rht=False, + random_sign_mask_t=0, ) ref_row, ref_col = _reference_amax_raw(x) - torch.testing.assert_close(row_amax, ref_row, rtol=0.0, atol=0.0, - msg=f"row_amax mismatch at ({M}, {K})") - torch.testing.assert_close(col_amax, ref_col, rtol=0.0, atol=0.0, - msg=f"col_amax mismatch at ({M}, {K})") + torch.testing.assert_close( + row_amax, ref_row, rtol=0.0, atol=0.0, msg=f"row_amax mismatch at ({M}, {K})" + ) + torch.testing.assert_close( + col_amax, ref_col, rtol=0.0, atol=0.0, msg=f"col_amax mismatch at ({M}, {K})" + ) @_GATED_SM100 @pytest.mark.parametrize("M,K", _RHT_SHAPES) @pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF, 0x5A5A]) def test_per_token_k1_with_rht_matches_reference( - M: int, K: int, mask: int, + M: int, + K: int, + mask: int, ) -> None: """with_rht=True col_amax matches max|H*D*x_block|/4; rowwise stays raw.""" torch.manual_seed(0xDEAD * (M + 7) + (K + 3) + mask) @@ -560,19 +571,32 @@ def test_per_token_k1_with_rht_matches_reference( col_amax = torch.empty((K,), dtype=torch.float32, device=device) tex.nvfp4_per_token_amax( - x, row_amax, col_amax, True, True, - with_rht=True, random_sign_mask_t=mask, + x, + row_amax, + col_amax, + True, + True, + with_rht=True, + random_sign_mask_t=mask, ) ref_row, _ = _reference_amax_raw(x) - torch.testing.assert_close(row_amax, ref_row, rtol=0.0, atol=0.0, - msg=f"row_amax mismatch at ({M}, {K}, mask=0x{mask:04X})") + torch.testing.assert_close( + row_amax, + ref_row, + rtol=0.0, + atol=0.0, + msg=f"row_amax mismatch at ({M}, {K}, mask=0x{mask:04X})", + ) # Col tolerance accounts for bf16->fp32 promotion noise + butterfly # summation order vs. einsum reduction order. ref_col = _reference_col_amax_rht(x, mask) torch.testing.assert_close( - col_amax, ref_col, rtol=2e-3, atol=1e-4, + col_amax, + ref_col, + rtol=2e-3, + atol=1e-4, msg=f"col_amax (RHT) mismatch at ({M}, {K}, mask=0x{mask:04X})", ) @@ -589,8 +613,13 @@ def test_per_token_k1_with_rht_zero_mask_is_hadamard_only(M: int, K: int) -> Non col_amax = torch.empty((K,), dtype=torch.float32, device=device) tex.nvfp4_per_token_amax( - x, row_amax, col_amax, True, True, - with_rht=True, random_sign_mask_t=0, + x, + row_amax, + col_amax, + True, + True, + with_rht=True, + random_sign_mask_t=0, ) H = _walsh_hadamard_16(device) @@ -600,13 +629,17 @@ def test_per_token_k1_with_rht_zero_mask_is_hadamard_only(M: int, K: int) -> Non ref_col = (rotated.abs() / 4.0).reshape(-1, K).amax(dim=0) torch.testing.assert_close( - col_amax, ref_col, rtol=2e-3, atol=1e-4, + col_amax, + ref_col, + rtol=2e-3, + atol=1e-4, msg=f"col_amax (RHT, mask=0) mismatch at ({M}, {K})", ) # ----- (5b) K2 + composite RHT: encode kernel and composite quantize -------- + @_GATED_SM100 @pytest.mark.parametrize("M,K", _RHT_SHAPES) def test_per_token_composite_with_rht_false_byte_equal(M: int, K: int) -> None: @@ -619,21 +652,34 @@ def test_per_token_composite_with_rht_false_byte_equal(M: int, K: int) -> None: bufs_explicit = _allocate_per_token_buffers(M, K, device) tex.nvfp4_per_token_quantize( - x, bufs_default["q_row"], bufs_default["s_row"], bufs_default["ra"], - bufs_default["q_col"], bufs_default["s_col"], bufs_default["ca"], - True, True, + x, + bufs_default["q_row"], + bufs_default["s_row"], + bufs_default["ra"], + bufs_default["q_col"], + bufs_default["s_col"], + bufs_default["ca"], + True, + True, ) tex.nvfp4_per_token_quantize( - x, bufs_explicit["q_row"], bufs_explicit["s_row"], bufs_explicit["ra"], - bufs_explicit["q_col"], bufs_explicit["s_col"], bufs_explicit["ca"], - True, True, - with_rht=False, random_sign_mask_t=0xACE1, + x, + bufs_explicit["q_row"], + bufs_explicit["s_row"], + bufs_explicit["ra"], + bufs_explicit["q_col"], + bufs_explicit["s_col"], + bufs_explicit["ca"], + True, + True, + with_rht=False, + random_sign_mask_t=0xACE1, ) for k in ("q_row", "s_row", "ra", "q_col", "s_col", "ca"): - assert torch.equal(bufs_default[k], bufs_explicit[k]), ( - f"with_rht=False not byte-equal to default path on `{k}` at ({M}, {K})" - ) + assert torch.equal( + bufs_default[k], bufs_explicit[k] + ), f"with_rht=False not byte-equal to default path on `{k}` at ({M}, {K})" @_GATED_SM100 @@ -648,16 +694,30 @@ def test_per_token_composite_rowwise_unchanged_under_rht(M: int, K: int) -> None bufs_with_rht = _allocate_per_token_buffers(M, K, device) tex.nvfp4_per_token_quantize( - x, bufs_no_rht["q_row"], bufs_no_rht["s_row"], bufs_no_rht["ra"], - bufs_no_rht["q_col"], bufs_no_rht["s_col"], bufs_no_rht["ca"], - True, True, - with_rht=False, random_sign_mask_t=0, + x, + bufs_no_rht["q_row"], + bufs_no_rht["s_row"], + bufs_no_rht["ra"], + bufs_no_rht["q_col"], + bufs_no_rht["s_col"], + bufs_no_rht["ca"], + True, + True, + with_rht=False, + random_sign_mask_t=0, ) tex.nvfp4_per_token_quantize( - x, bufs_with_rht["q_row"], bufs_with_rht["s_row"], bufs_with_rht["ra"], - bufs_with_rht["q_col"], bufs_with_rht["s_col"], bufs_with_rht["ca"], - True, True, - with_rht=True, random_sign_mask_t=0xACE1, + x, + bufs_with_rht["q_row"], + bufs_with_rht["s_row"], + bufs_with_rht["ra"], + bufs_with_rht["q_col"], + bufs_with_rht["s_col"], + bufs_with_rht["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=0xACE1, ) for k in ("q_row", "s_row", "ra"): @@ -671,7 +731,9 @@ def test_per_token_composite_rowwise_unchanged_under_rht(M: int, K: int) -> None @pytest.mark.parametrize("M,K", [(128, 128), (256, 512), (512, 512)]) @pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF]) def test_per_token_composite_with_rht_col_dequant_matches_reference( - M: int, K: int, mask: int, + M: int, + K: int, + mask: int, ) -> None: """Dequant'd col FP4 (with_rht=True) ~ H*D*x_block/sqrt(16); checks column-aggregate median + p99 relative error (FP4's 16-code grain and @@ -685,10 +747,17 @@ def test_per_token_composite_with_rht_col_dequant_matches_reference( bufs = _allocate_per_token_buffers(M, K, device) tex.nvfp4_per_token_quantize( - x, bufs["q_row"], bufs["s_row"], bufs["ra"], - bufs["q_col"], bufs["s_col"], bufs["ca"], - True, True, - with_rht=True, random_sign_mask_t=mask, + x, + bufs["q_row"], + bufs["s_row"], + bufs["ra"], + bufs["q_col"], + bufs["s_col"], + bufs["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=mask, ) H = _walsh_hadamard_16(device) @@ -696,12 +765,14 @@ def test_per_token_composite_with_rht_col_dequant_matches_reference( x_fp32 = x.to(torch.float32) blocks = x_fp32.reshape(M // 16, 16, K) masked = blocks * sign.view(1, 16, 1) - rotated = torch.einsum("ij,bjk->bik", H, masked) # (M/16, 16, K) - y_ref = rotated.reshape(M, K) / 4.0 # (M, K) - y_ref_col_view = y_ref.transpose(0, 1).contiguous() # (K, M) + rotated = torch.einsum("ij,bjk->bik", H, masked) # (M/16, 16, K) + y_ref = rotated.reshape(M, K) / 4.0 # (M, K) + y_ref_col_view = y_ref.transpose(0, 1).contiguous() # (K, M) y_kernel = _dequant_fp4_with_outer_amax( - bufs["q_col"], bufs["s_col"], bufs["ca"], + bufs["q_col"], + bufs["s_col"], + bufs["ca"], ) # (K, M) diff = (y_kernel - y_ref_col_view).abs() @@ -713,16 +784,16 @@ def test_per_token_composite_with_rht_col_dequant_matches_reference( f"median per-element relative error too large: {median:.4f} > 0.1 " f"at ({M}, {K}, mask=0x{mask:04X})" ) - assert p99 < 0.5, ( - f"p99 per-element relative error too large: {p99:.4f} > 0.5 " - f"at ({M}, {K}, mask=0x{mask:04X})" - ) + assert ( + p99 < 0.5 + ), f"p99 per-element relative error too large: {p99:.4f} > 0.5 at ({M}, {K}, mask=0x{mask:04X})" @_GATED_SM100 @pytest.mark.parametrize("M,K", [(128, 128), (256, 256)]) def test_per_token_composite_with_rht_col_amax_matches_k1( - M: int, K: int, + M: int, + K: int, ) -> None: """Composite col_amax byte-equals standalone K1 amax with the same mask.""" torch.manual_seed(0xDADA * (M + 13) + K) @@ -732,20 +803,34 @@ def test_per_token_composite_with_rht_col_amax_matches_k1( bufs = _allocate_per_token_buffers(M, K, device) tex.nvfp4_per_token_quantize( - x, bufs["q_row"], bufs["s_row"], bufs["ra"], - bufs["q_col"], bufs["s_col"], bufs["ca"], - True, True, - with_rht=True, random_sign_mask_t=mask, + x, + bufs["q_row"], + bufs["s_row"], + bufs["ra"], + bufs["q_col"], + bufs["s_col"], + bufs["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=mask, ) ra_k1 = torch.empty((M,), dtype=torch.float32, device=device) ca_k1 = torch.empty((K,), dtype=torch.float32, device=device) tex.nvfp4_per_token_amax( - x, ra_k1, ca_k1, True, True, - with_rht=True, random_sign_mask_t=mask, + x, + ra_k1, + ca_k1, + True, + True, + with_rht=True, + random_sign_mask_t=mask, ) - torch.testing.assert_close(bufs["ca"], ca_k1, rtol=0.0, atol=0.0, - msg=f"composite ca != K1-only ca at ({M}, {K})") - torch.testing.assert_close(bufs["ra"], ra_k1, rtol=0.0, atol=0.0, - msg=f"composite ra != K1-only ra at ({M}, {K})") + torch.testing.assert_close( + bufs["ca"], ca_k1, rtol=0.0, atol=0.0, msg=f"composite ca != K1-only ca at ({M}, {K})" + ) + torch.testing.assert_close( + bufs["ra"], ra_k1, rtol=0.0, atol=0.0, msg=f"composite ra != K1-only ra at ({M}, {K})" + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py index 5dd2588021..21fedabaab 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token_group.py @@ -369,22 +369,22 @@ def test_group_many_splits_byte_equal(n_splits: int) -> None: # ============================================================================= _RHT_GROUP_SHAPES: List[Tuple[List[int], int]] = [ - ([128, 128], 128), # 2 splits, smallest legal shape - ([128, 256, 128], 256), # 3 splits, mixed sizes - ([256, 256, 256, 256], 512), # 4 equal splits, larger K - ([128, 384], 128), # 2 splits, very asymmetric + ([128, 128], 128), # 2 splits, smallest legal shape + ([128, 256, 128], 256), # 3 splits, mixed sizes + ([256, 256, 256, 256], 512), # 4 equal splits, larger K + ([128, 384], 128), # 2 splits, very asymmetric ] def _rht_pt_buffers(M: int, K: int, device: torch.device): """Match the layout that ``tex.nvfp4_per_token_quantize`` writes.""" return { - "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), - "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), - "ra": torch.empty((M,), dtype=torch.float32, device=device), - "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), - "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), - "ca": torch.empty((K,), dtype=torch.float32, device=device), + "q_row": torch.empty((M, K // 2), dtype=torch.uint8, device=device), + "s_row": torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device), + "ra": torch.empty((M,), dtype=torch.float32, device=device), + "q_col": torch.empty((K, M // 2), dtype=torch.uint8, device=device), + "s_col": torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device), + "ca": torch.empty((K,), dtype=torch.float32, device=device), } @@ -399,7 +399,8 @@ def _split_views(x_concat: torch.Tensor, splits: Sequence[int]) -> List[torch.Te @_GATED_FP4 @pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) def test_group_with_rht_false_byte_equal_to_default( - splits: List[int], K: int, + splits: List[int], + K: int, ) -> None: """Regression: with_rht=False grouped byte-equals the default (no-kwargs) path.""" torch.manual_seed(0xCAFE * (sum(splits) + 1) + K + len(splits)) @@ -408,17 +409,30 @@ def test_group_with_rht_false_byte_equal_to_default( x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() outs_default = nvfp4_per_token_group_quantize( - x, splits, rowwise=True, columnwise=True, + x, + splits, + rowwise=True, + columnwise=True, ) outs_explicit_false = nvfp4_per_token_group_quantize( - x, splits, rowwise=True, columnwise=True, - with_rht=False, random_sign_mask_t=0xACE1, + x, + splits, + rowwise=True, + columnwise=True, + with_rht=False, + random_sign_mask_t=0xACE1, ) assert len(outs_default) == len(outs_explicit_false) == len(splits) for i, (a, b) in enumerate(zip(outs_default, outs_explicit_false)): - for attr in ("data", "scale", "row_amax", - "columnwise_data", "columnwise_scale", "col_amax"): + for attr in ( + "data", + "scale", + "row_amax", + "columnwise_data", + "columnwise_scale", + "col_amax", + ): ta, tb = getattr(a, attr), getattr(b, attr) assert torch.equal(ta, tb), ( f"split[{i}].{attr} differs between default and explicit " @@ -429,7 +443,8 @@ def test_group_with_rht_false_byte_equal_to_default( @_GATED_FP4 @pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) def test_group_rowwise_unchanged_under_rht( - splits: List[int], K: int, + splits: List[int], + K: int, ) -> None: """Rowwise outputs byte-equal across with_rht=False / True.""" torch.manual_seed(0xBEEF * (sum(splits) + 3) + K) @@ -438,12 +453,20 @@ def test_group_rowwise_unchanged_under_rht( x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() outs_no_rht = nvfp4_per_token_group_quantize( - x, splits, rowwise=True, columnwise=True, - with_rht=False, random_sign_mask_t=0, + x, + splits, + rowwise=True, + columnwise=True, + with_rht=False, + random_sign_mask_t=0, ) outs_with_rht = nvfp4_per_token_group_quantize( - x, splits, rowwise=True, columnwise=True, - with_rht=True, random_sign_mask_t=0xACE1, + x, + splits, + rowwise=True, + columnwise=True, + with_rht=True, + random_sign_mask_t=0xACE1, ) for i, (a, b) in enumerate(zip(outs_no_rht, outs_with_rht)): @@ -452,7 +475,7 @@ def test_group_rowwise_unchanged_under_rht( assert torch.equal(ta, tb), ( f"split[{i}].{attr} differs between with_rht=False and =True " f"on the ROW direction at K={K}, splits={splits} -- " - f"rowwise should never see RHT." + "rowwise should never see RHT." ) @@ -460,7 +483,9 @@ def test_group_rowwise_unchanged_under_rht( @pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES) @pytest.mark.parametrize("mask", [0x0000, 0xACE1, 0xFFFF]) def test_group_with_rht_equals_single_tensor_per_split( - splits: List[int], K: int, mask: int, + splits: List[int], + K: int, + mask: int, ) -> None: """Each split's 6 outputs byte-equal single-tensor with the same mask.""" torch.manual_seed(0xDADA * (sum(splits) + 11) + K + mask) @@ -469,8 +494,12 @@ def test_group_with_rht_equals_single_tensor_per_split( x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() outs_grouped = nvfp4_per_token_group_quantize( - x, splits, rowwise=True, columnwise=True, - with_rht=True, random_sign_mask_t=mask, + x, + splits, + rowwise=True, + columnwise=True, + with_rht=True, + random_sign_mask_t=mask, ) x_splits = _split_views(x, splits) @@ -478,19 +507,26 @@ def test_group_with_rht_equals_single_tensor_per_split( M_i = x_i.size(0) bufs = _rht_pt_buffers(M_i, K, device) tex.nvfp4_per_token_quantize( - x_i, bufs["q_row"], bufs["s_row"], bufs["ra"], - bufs["q_col"], bufs["s_col"], bufs["ca"], - True, True, - with_rht=True, random_sign_mask_t=mask, + x_i, + bufs["q_row"], + bufs["s_row"], + bufs["ra"], + bufs["q_col"], + bufs["s_col"], + bufs["ca"], + True, + True, + with_rht=True, + random_sign_mask_t=mask, ) mapping = { - "data": ("q_row", out_g.data), - "scale": ("s_row", out_g.scale.view(torch.uint8)), - "row_amax": ("ra", out_g.row_amax), - "columnwise_data": ("q_col", out_g.columnwise_data), - "columnwise_scale": ("s_col", out_g.columnwise_scale.view(torch.uint8)), - "col_amax": ("ca", out_g.col_amax), + "data": ("q_row", out_g.data), + "scale": ("s_row", out_g.scale.view(torch.uint8)), + "row_amax": ("ra", out_g.row_amax), + "columnwise_data": ("q_col", out_g.columnwise_data), + "columnwise_scale": ("s_col", out_g.columnwise_scale.view(torch.uint8)), + "col_amax": ("ca", out_g.col_amax), } for attr, (single_key, grouped_t) in mapping.items(): single_t = bufs[single_key] @@ -507,7 +543,8 @@ def test_group_with_rht_equals_single_tensor_per_split( @_GATED_FP4 @pytest.mark.parametrize("splits,K", _RHT_GROUP_SHAPES[:2]) def test_group_k1_amax_matches_single_tensor_per_split_under_rht( - splits: List[int], K: int, + splits: List[int], + K: int, ) -> None: """Grouped K1 amax byte-equals single-tensor K1 per split. Isolates K1 via the lighter nvfp4_per_token_group_amax binding to catch K1-vs-K2 @@ -519,16 +556,17 @@ def test_group_k1_amax_matches_single_tensor_per_split_under_rht( x = torch.randn((sum_M, K), dtype=torch.bfloat16, device=device).contiguous() mask = 0xACE1 - row_amax_list = [ - torch.empty((int(s),), dtype=torch.float32, device=device) for s in splits - ] - col_amax_list = [ - torch.empty((K,), dtype=torch.float32, device=device) for _ in splits - ] + row_amax_list = [torch.empty((int(s),), dtype=torch.float32, device=device) for s in splits] + col_amax_list = [torch.empty((K,), dtype=torch.float32, device=device) for _ in splits] tex.nvfp4_per_token_group_amax( - x, [int(s) for s in splits], row_amax_list, col_amax_list, - True, True, - with_rht=True, random_sign_mask_t=mask, + x, + [int(s) for s in splits], + row_amax_list, + col_amax_list, + True, + True, + with_rht=True, + random_sign_mask_t=mask, ) x_splits = _split_views(x, splits) @@ -537,10 +575,17 @@ def test_group_k1_amax_matches_single_tensor_per_split_under_rht( ra_s = torch.empty((M_i,), dtype=torch.float32, device=device) ca_s = torch.empty((K,), dtype=torch.float32, device=device) tex.nvfp4_per_token_amax( - x_i, ra_s, ca_s, True, True, - with_rht=True, random_sign_mask_t=mask, + x_i, + ra_s, + ca_s, + True, + True, + with_rht=True, + random_sign_mask_t=mask, + ) + torch.testing.assert_close( + ra_g, ra_s, rtol=0.0, atol=0.0, msg=f"split[{i}] row_amax mismatch (K1 only)" + ) + torch.testing.assert_close( + ca_g, ca_s, rtol=0.0, atol=0.0, msg=f"split[{i}] col_amax mismatch (K1 only)" ) - torch.testing.assert_close(ra_g, ra_s, rtol=0.0, atol=0.0, - msg=f"split[{i}] row_amax mismatch (K1 only)") - torch.testing.assert_close(ca_g, ca_s, rtol=0.0, atol=0.0, - msg=f"split[{i}] col_amax mismatch (K1 only)") diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu index 03ef547834..a6f08e9f74 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu @@ -29,11 +29,11 @@ #include +#include "common/cast/core/common.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" #include "common/common.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "common/cast/core/common.cuh" -#include "common/cast/nvfp4/core_nvfp4.cuh" namespace transformer_engine { namespace nvfp4_per_token { @@ -41,73 +41,72 @@ namespace nvfp4_per_token { #if FP4_TYPE_SUPPORTED using dispatch::common::align_smem_ptr_per_TMA_requirements; +using dispatch::nvfp4::nvfp4_scale_t; using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; -using dispatch::nvfp4::nvfp4_scale_t; -constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input -constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input -constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height -constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width -constexpr int THREADS_NUM = 128; // threads per CTA -constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM -constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) -constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows of input +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols of input +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int THREADS_NUM = 128; // threads per CTA +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int PREFETCH_STAGES = 1; // 1-stage prefetch overlap constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; // = 2 ping-pong input buffers // Derived (chunk / tile / stage) constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; // 2 constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; // 2 -constexpr int STAGES = TILES_Y * TILES_X; // 4 +constexpr int STAGES = TILES_Y * TILES_X; // 4 constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 inner blocks per row of the chunk constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 inner blocks per col of the chunk -constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 -constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 // Encode helpers' thread layout (rowwise pass: 4x32 = K-dim x M-dim) -constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 -constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 -constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = + SCALE_DIM / ELTS_PER_THREAD; // 1 (each block owned by 1 thread) constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 // Buffer dimensions (input bf16 SMEM tiles + FP4 output SMEM tiles for TMA store) constexpr int BUFF_IN_DIM_Y = TILE_DIM_Y; constexpr int BUFF_IN_DIM_X = TILE_DIM_X; -constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; // elements constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; -constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) -constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (2 fp4 per byte) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; -constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 -constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; -constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) -constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 ping-pong (matches input) +constexpr int BUFFS_NUM_OUT_TR = 2; // 2 ping-pong for transpose // Manual swizzling parameters to reduce SMEM bank conflicts on rowwise loads constexpr int PACK_SIZE = 8; -constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 -constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 -using IType = bf16; +using IType = bf16; using IType2 = ptx::FPx2; // = ptx::bf16x2 -using IType3D = IType [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; -using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using IType3D = IType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; -using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; // Compute the per-block (1x16) byte-equal arithmetic and emit FP4 codes into // SMEM rowwise output buffer + e4m3 scale into SMEM rowwise scale buffer. __device__ __forceinline__ void rowwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_ptr, + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, nvfp4_scale_t* __restrict__ sSFrowwise_ptr, - const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out) { + const float* __restrict__ sRowAmax, // [CHUNK_DIM_Y], indexed by chunk-local row + const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { const auto& sIn = *reinterpret_cast(sIn_ptr); auto& sOut = *reinterpret_cast(sOut_ptr); auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -115,12 +114,14 @@ __device__ __forceinline__ void rowwise_scaling_per_token( const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 - const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) + const int thread_offset_X_rowwise = + tid_X_rowwise * ELTS_PER_THREAD; // K-elt offset in tile (0/16/32/48) - const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here + const int SF_thread_offset_rowwise_X = + tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; // = tid_X_rowwise here const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); const int stage_rowwise_scales_offset_X = @@ -152,8 +153,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); } } - const float block_amax = static_cast( - __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); // Byte-equal compute path (matches the Python reference in // ``NVFP4QuantizerPerTokenRef``): @@ -202,8 +203,7 @@ __device__ __forceinline__ void rowwise_scaling_per_token( // Output is NOT normalized; caller multiplies by k16HadamardNorm (0.25). // Sign-flip is a branchless XOR on the fp32 sign bit (bit-exact == r = -r on // finite fp32, which is all this helper sees from bf16 SMEM reads). -__device__ __forceinline__ void apply_signed_fht16_inplace( - float r[16], uint32_t random_sign_mask) { +__device__ __forceinline__ void apply_signed_fht16_inplace(float r[16], uint32_t random_sign_mask) { #pragma unroll for (int i = 0; i < 16; ++i) { const uint32_t bits = __float_as_uint(r[i]); @@ -218,8 +218,8 @@ __device__ __forceinline__ void apply_signed_fht16_inplace( for (int j = 0; j < stride; ++j) { const float a = r[g + j]; const float b = r[g + j + stride]; - r[g + j] = a + b; - r[g + j + stride] = a - b; + r[g + j] = a + b; + r[g + j + stride] = a - b; } } } @@ -242,34 +242,34 @@ constexpr float k16HadamardNorm = 0.25f; // sColAmax already reflects the rotated columns. template __device__ __forceinline__ void colwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_tr_ptr, + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, nvfp4_scale_t* __restrict__ sSFcolwise_ptr, - const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out_tr, + const float* __restrict__ sColAmax, // [CHUNK_DIM_X], indexed by chunk-local col + const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr, const uint32_t random_sign_mask_t = 0u) { const auto& sIn2x = *reinterpret_cast(sIn_ptr); auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); - const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) - const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 (M-block index in tile) + const int tid_X_colwise = thread_lane; // 0..31 (col-pair index in tile) - const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 - const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; // 0/16/32/48 + const int thread_offset_X_colwise = tid_X_colwise * 2; // 0/2/.../62 (2 cols per thread) const int in_thread_offset_Y = thread_offset_Y_colwise; - const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] + const int in_thread_offset_X = thread_offset_X_colwise / 2; // index into IType2[] - const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y - const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index + const int out_tr_thread_offset_Y = thread_offset_X_colwise; // transpose: X becomes Y + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; // /2 for fp4e2m1x2 byte index - const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) - const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index + const int scale_tr_offset_Y = + (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; // chunk-local col index (×1) + const int scale_tr_offset_X = + (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; // chunk-local M-block index __align__(8) IType rIn[2][SCALE_DIM]; // RHT staging in fp32 from FHT through mul_cvt_4x: avoids the lossy @@ -366,18 +366,15 @@ __device__ __forceinline__ void colwise_scaling_per_token( // RHT-rotated columnwise_amax. Row direction never sees RHT. // ============================================================================= template -__global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t* const scales_ptr, - nvfp4_scale_t* const scales_t_ptr, - const float* const row_amax_in, - const float* const col_amax_in, - const float* noop, - const size_t rows, const size_t cols, - const size_t scale_stride, const size_t scale_stride_t, - const uint32_t random_sign_mask_t) { +__global__ void __launch_bounds__(THREADS_NUM) + per_token_encode_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t* const scales_ptr, nvfp4_scale_t* const scales_t_ptr, + const float* const row_amax_in, const float* const col_amax_in, + const float* noop, const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t, + const uint32_t random_sign_mask_t) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -408,24 +405,26 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; constexpr int out_mem_colwise_scales = DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; extern __shared__ unsigned char dynamic_shmem[]; unsigned char* dshmem = align_smem_ptr_per_TMA_requirements(dynamic_shmem); - IType* sIn_ptr = reinterpret_cast(dshmem); - fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); - fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data); + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data - + out_mem_rowwise_scales); + nvfp4_scale_t* sSFcolwise_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + + out_mem_colwise_data + out_mem_rowwise_scales); // Per-CTA row/col amax SMEM cache (128 floats each). __shared__ float sRowAmax[CHUNK_DIM_Y]; @@ -479,8 +478,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( uint64_t* dst = reinterpret_cast(&sIn[buff_in]); const uint64_t* src = reinterpret_cast(&tensor_map_input); ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - dst, src, global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); } } @@ -508,18 +507,17 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( if (leading_thread) { uint64_t* dst = reinterpret_cast(&sIn[next_prefetch_buff]); const uint64_t* src = reinterpret_cast(&tensor_map_input); - ptx::mbarrier_arrive_expect_tx( - &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - dst, src, next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, next_global_offset_X, + next_global_offset_Y, + &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for current stage's input to land. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // Wait for any prior TMA store to have finished reading the output SMEM @@ -528,14 +526,12 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( // ----- Compute: rowwise + colwise from the same SMEM tile ----- if (DO_ROW) { - rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, - sRowAmax, stage_Y, stage_X, buff_in, buff_out); + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, + buff_in, buff_out); } if (DO_COL) { - colwise_scaling_per_token( - sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, - sColAmax, stage_Y, stage_X, buff_in, buff_out_tr, - random_sign_mask_t); + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, + stage_X, buff_in, buff_out_tr, random_sign_mask_t); } // Fence + sync so all threads' SMEM writes are visible to TMA store. @@ -552,22 +548,20 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( if (DO_ROW) { auto& sOut = *reinterpret_cast(sOut_ptr); ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), - global_offset_X, global_offset_Y, + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, reinterpret_cast(&sOut[buff_out])); } if (DO_COL) { auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_t), - global_offset_X_tr, global_offset_Y_tr, - reinterpret_cast(&sOut_tr[buff_out_tr])); + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); } ptx::cp_async_bulk_commit_group(); } - buff_in = (buff_in + 1) % BUFFS_NUM; - buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; } // end of stages @@ -583,8 +577,7 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( const size_t row_global = scales_block_offset_Y_rowwise + row; if (row_global < rows) { ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); - const size_t scale_idx_global = - row_global * scale_stride + scales_block_offset_X_rowwise; + const size_t scale_idx_global = row_global * scale_stride + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); } } @@ -599,8 +592,7 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); - const size_t scale_idx_global = - row_tr_global * scale_stride_t + scales_block_offset_X_tr; + const size_t scale_idx_global = row_tr_global * scale_stride_t + scales_block_offset_X_tr; scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); } } @@ -636,13 +628,12 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel( // FHT with random_sign_mask_t). Row direction never sees RHT. // ============================================================================= template -__global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( - const __grid_constant__ CUtensorMap tensor_map_input, - float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW - float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL - const float* noop, - const size_t rows, const size_t cols, - const uint32_t random_sign_mask_t) { // col-only; low 16 bits = signs +__global__ void __launch_bounds__(THREADS_NUM) + per_token_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ row_amax_out, // [M], nullptr if !DO_ROW + float* __restrict__ col_amax_out, // [K], nullptr if !DO_COL + const float* noop, const size_t rows, const size_t cols, + const uint32_t random_sign_mask_t) { // col-only; low 16 bits = signs #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -677,8 +668,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( // i.e., this thread contributes to row partial in stages // where stage_Y == tid / 64. // col owned: col_base + tid -> stage_X == tid / 64. - const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 - const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 + const int my_row_stage_Y = tid / TILE_DIM_Y; // 0 or 1 + const int my_col_stage_X = tid / TILE_DIM_X; // 0 or 1 const int my_row_in_subtile = tid % TILE_DIM_Y; // 0..63 const int my_col_in_subtile = tid % TILE_DIM_X; // 0..63 @@ -703,8 +694,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in]), - reinterpret_cast(&tensor_map_input), - global_offset_X, global_offset_Y, &IN_buff_readable_mbar[buff_in]); + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); } } @@ -725,20 +716,18 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx( - &IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), - next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for this stage's tile. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // ----- Row partial update: walk this thread's row across the sub-tile ----- @@ -762,8 +751,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( for (int p = 0; p < 4; ++p) { ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); } - local_max = fmaxf(local_max, - static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + local_max = + fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); } row_partial = local_max; } @@ -832,10 +821,8 @@ __global__ void __launch_bounds__(THREADS_NUM) per_token_amax_kernel( // Launch Kernel 1 (amax). Pre-zeroes the amax buffers (atomicMax identity). // with_rht=true applies a 16-pt RHT on the col direction before amax; // random_sign_mask_t carries the 16-bit sign pattern (ignored when false). -inline void launch_amax(const Tensor& input, Tensor* output, - const Tensor& noop, - const bool with_rht, - const uint32_t random_sign_mask_t, +inline void launch_amax(const Tensor& input, Tensor* output, const Tensor& noop, + const bool with_rht, const uint32_t random_sign_mask_t, cudaStream_t stream) { const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); @@ -846,50 +833,46 @@ inline void launch_amax(const Tensor& input, Tensor* output, // Pre-zero amax buffers (atomicMaxFloat identity for non-negative values). if (do_row) { - NVTE_CHECK(output->amax.numel() == M, - "Per-token amax: output->amax numel must equal M = ", M, + NVTE_CHECK(output->amax.numel() == M, "Per-token amax: output->amax numel must equal M = ", M, ", got ", output->amax.numel()); NVTE_CHECK_CUDA(cudaMemsetAsync(output->amax.dptr, 0, M * sizeof(float), stream)); } if (do_col) { NVTE_CHECK(output->columnwise_amax.numel() == K, - "Per-token amax: output->columnwise_amax numel must equal K = ", K, - ", got ", output->columnwise_amax.numel()); + "Per-token amax: output->columnwise_amax numel must equal K = ", K, ", got ", + output->columnwise_amax.numel()); NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_amax.dptr, 0, K * sizeof(float), stream)); } checkCuDriverContext(stream); alignas(64) CUtensorMap tmap_in{}; - create_2D_tensor_map(tmap_in, input.data, M, K, - TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - const float* noop_ptr = (noop.data.dptr != nullptr) - ? reinterpret_cast(noop.data.dptr) - : nullptr; + const float* noop_ptr = + (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; // RHT only matters when colwise amax is computed; collapse to the // kWithRht=false instantiation otherwise. const bool with_rht_effective = with_rht && do_col; - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, - TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { auto kernel = per_token_amax_kernel; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); kernel<<>>( - tmap_in, - do_row ? reinterpret_cast(output->amax.dptr) : nullptr, - do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, - noop_ptr, M, K, random_sign_mask_t); + tmap_in, do_row ? reinterpret_cast(output->amax.dptr) : nullptr, + do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr, noop_ptr, + M, K, random_sign_mask_t); }))); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -899,10 +882,8 @@ inline void launch_amax(const Tensor& input, Tensor* output, // output->data / scale_inv / columnwise_data / columnwise_scale_inv. // with_rht=true requires K1 amax to have been launched with the SAME mask; // the composite per_token_quantize path threads this automatically. -inline void launch_encode(const Tensor& input, Tensor* output, - const Tensor& noop, - const bool with_rht, - const uint32_t random_sign_mask_t, +inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noop, + const bool with_rht, const uint32_t random_sign_mask_t, cudaStream_t stream) { const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); @@ -934,15 +915,13 @@ inline void launch_encode(const Tensor& input, Tensor* output, alignas(64) CUtensorMap tmap_out{}; alignas(64) CUtensorMap tmap_out_t{}; - create_2D_tensor_map(tmap_in, input.data, M, K, - TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); + create_2D_tensor_map(tmap_in, input.data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(IType) * 8); if (do_row) { - create_2D_tensor_map(tmap_out, output->data, M, K, - TILE_DIM_Y, TILE_DIM_X, K, 0, 4); + create_2D_tensor_map(tmap_out, output->data, M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, 4); } if (do_col) { - create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, - TILE_DIM_X, TILE_DIM_Y, M, 0, 4); + create_2D_tensor_map(tmap_out_t, output->columnwise_data, K, M, TILE_DIM_X, TILE_DIM_Y, M, 0, + 4); } constexpr int buff_elems_total_in = BUFFS_NUM * BUFF_IN_SIZE; @@ -952,29 +931,21 @@ inline void launch_encode(const Tensor& input, Tensor* output, DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales = - DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales_t = - DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_t = DIVUP_TO_MULTIPLE( + CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); // Total dyn SMEM: input + output FP4 (row + col) + SF (row + col) + 128B align. - const int dshmem_size = - buff_size_aligned_in - + (do_row ? buff_size_aligned_out : 0) - + (do_col ? buff_size_aligned_out_t : 0) - + (do_row ? buff_size_scales : 0) - + (do_col ? buff_size_scales_t : 0) - + TMA_SHMEM_ALIGNMENT; - - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(M / CHUNK_DIM_Y), 1); + const int dshmem_size = buff_size_aligned_in + (do_row ? buff_size_aligned_out : 0) + + (do_col ? buff_size_aligned_out_t : 0) + (do_row ? buff_size_scales : 0) + + (do_col ? buff_size_scales_t : 0) + TMA_SHMEM_ALIGNMENT; + + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); - const float* noop_ptr = (noop.data.dptr != nullptr) - ? reinterpret_cast(noop.data.dptr) - : nullptr; + const float* noop_ptr = + (noop.data.dptr != nullptr) ? reinterpret_cast(noop.data.dptr) : nullptr; const size_t scale_stride = do_row ? output->scale_inv.shape[1] : 0; const size_t scale_stride_t = do_col ? output->columnwise_scale_inv.shape[1] : 0; @@ -982,25 +953,22 @@ inline void launch_encode(const Tensor& input, Tensor* output, do_row ? reinterpret_cast(output->scale_inv.dptr) : nullptr; nvfp4_scale_t* scales_t_ptr = do_col ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const float* row_amax_in = - do_row ? reinterpret_cast(output->amax.dptr) : nullptr; + const float* row_amax_in = do_row ? reinterpret_cast(output->amax.dptr) : nullptr; const float* col_amax_in = do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr; // RHT only matters when colwise FP4 is produced; collapse to the // kWithRht=false instantiation for rowwise-only callers. const bool with_rht_effective = with_rht && do_col; - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, - TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { auto kernel = per_token_encode_kernel; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); kernel<<>>( - tmap_in, tmap_out, tmap_out_t, - scales_ptr, scales_t_ptr, - row_amax_in, col_amax_in, - noop_ptr, M, K, scale_stride, scale_stride_t, - random_sign_mask_t); + tmap_in, tmap_out, tmap_out_t, scales_ptr, scales_t_ptr, row_amax_in, col_amax_in, + noop_ptr, M, K, scale_stride, scale_stride_t, random_sign_mask_t); }))); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1017,15 +985,14 @@ inline void launch_encode(const Tensor& input, Tensor* output, // Output constraints differ by entry point (see validate_*_output helpers below). inline void validate_input_shape(const Tensor& input) { NVTE_CHECK(input.has_data(), "Per-token cast: input has no data."); - NVTE_CHECK(input.dtype() == DType::kBFloat16, - "Per-token cast is bf16-only. Got dtype enum ", + NVTE_CHECK(input.dtype() == DType::kBFloat16, "Per-token cast is bf16-only. Got dtype enum ", static_cast(input.dtype())); const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); - NVTE_CHECK(M % CHUNK_DIM_Y == 0, - "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, ", got M=", M); - NVTE_CHECK(K % CHUNK_DIM_X == 0, - "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, ", got K=", K); + NVTE_CHECK(M % CHUNK_DIM_Y == 0, "Per-token cast: M must be a multiple of ", CHUNK_DIM_Y, + ", got M=", M); + NVTE_CHECK(K % CHUNK_DIM_X == 0, "Per-token cast: K must be a multiple of ", CHUNK_DIM_X, + ", got K=", K); } // K1 (amax-only) requires at least one amax buffer allocated; FP4 output is not used. @@ -1046,10 +1013,8 @@ inline void validate_encode_output(const Tensor* output) { // K1 amax with optional col-wise RHT. with_rht=false is byte-equal to the // pre-RHT per-token K1 path regardless of random_sign_mask_t. -void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, - Tensor* output, - const bool with_rht, - const uint32_t random_sign_mask_t, +void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + const bool with_rht, const uint32_t random_sign_mask_t, cudaStream_t stream) { validate_input_shape(input); validate_amax_output(output); @@ -1060,10 +1025,8 @@ void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, // K2 encode with optional col-wise RHT. Caller must have filled // output->columnwise_amax via K1 amax with the SAME with_rht/mask, else the // inner SF + FP4 codes are calibrated against mismatched data and saturate. -void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, - Tensor* output, - const bool with_rht, - const uint32_t random_sign_mask_t, +void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + const bool with_rht, const uint32_t random_sign_mask_t, cudaStream_t stream) { validate_input_shape(input); validate_encode_output(output); @@ -1073,10 +1036,8 @@ void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, // Composite K1+K2. Both launches receive the same with_rht / mask so the // colwise amax and FP4 cast see byte-identical data. -void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, - Tensor* output, - const bool with_rht, - const uint32_t random_sign_mask_t, +void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, + const bool with_rht, const uint32_t random_sign_mask_t, cudaStream_t stream) { validate_input_shape(input); validate_encode_output(output); @@ -1089,16 +1050,16 @@ bool can_use_per_token(size_t M, size_t K, DType dtype) { return (dtype == DType::kBFloat16) && (M % CHUNK_DIM_Y == 0) && (K % CHUNK_DIM_X == 0); } #else // !FP4_TYPE_SUPPORTED -void per_token_amax_blocked_impl(const Tensor&, const Tensor&, Tensor*, - bool, uint32_t, cudaStream_t) { +void per_token_amax_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, + cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } -void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, - bool, uint32_t, cudaStream_t) { +void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, + cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } -void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, - bool, uint32_t, cudaStream_t) { +void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, + cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } bool can_use_per_token(size_t, size_t, DType) { return false; } @@ -1111,10 +1072,8 @@ bool can_use_per_token(size_t, size_t, DType) { return false; } // C-API entry points // ============================================================================= -void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, - NVTETensor output, - const int with_rht, - const int random_sign_mask_t, +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, + const int with_rht, const int random_sign_mask_t, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_amax); @@ -1126,22 +1085,22 @@ void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, // C-API takes `int` to match prod's nvte_hadamard_transform_amax convention; // internally we treat the low 16 bits as a uint32_t bitmask. nvfp4_per_token::per_token_amax_blocked_impl( - *input_tensor, *noop_tensor, output_tensor, - with_rht != 0, - static_cast(random_sign_mask_t) & 0xFFFFu, - stream); + *input_tensor, *noop_tensor, output_tensor, with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, stream); #else - (void)input; (void)noop; (void)output; (void)with_rht; - (void)random_sign_mask_t; (void)stream; + (void)input; + (void)noop; + (void)output; + (void)with_rht; + (void)random_sign_mask_t; + (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } -void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, - NVTETensor output, - const int with_rht, - const int random_sign_mask_t, - cudaStream_t stream) { +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, + const int with_rht, const int random_sign_mask_t, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_encode); using namespace transformer_engine; @@ -1152,22 +1111,22 @@ void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, // C-API mirrors nvte_nvfp4_per_token_amax: `int` for cross-language ABI // safety, internal kernel arg is uint32_t with only the low 16 bits used. nvfp4_per_token::per_token_encode_blocked_impl( - *input_tensor, *noop_tensor, output_tensor, - with_rht != 0, - static_cast(random_sign_mask_t) & 0xFFFFu, - stream); + *input_tensor, *noop_tensor, output_tensor, with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, stream); #else - (void)input; (void)noop; (void)output; (void)with_rht; - (void)random_sign_mask_t; (void)stream; + (void)input; + (void)noop; + (void)output; + (void)with_rht; + (void)random_sign_mask_t; + (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } -void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, - NVTETensor output, - const int with_rht, - const int random_sign_mask_t, - cudaStream_t stream) { +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, + const int with_rht, const int random_sign_mask_t, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_quantize); using namespace transformer_engine; @@ -1176,13 +1135,15 @@ void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop Tensor dummy_noop; const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; nvfp4_per_token::per_token_quantize_blocked_impl( - *input_tensor, *noop_tensor, output_tensor, - with_rht != 0, - static_cast(random_sign_mask_t) & 0xFFFFu, - stream); + *input_tensor, *noop_tensor, output_tensor, with_rht != 0, + static_cast(random_sign_mask_t) & 0xFFFFu, stream); #else - (void)input; (void)noop; (void)output; (void)with_rht; - (void)random_sign_mask_t; (void)stream; + (void)input; + (void)noop; + (void)output; + (void)with_rht; + (void)random_sign_mask_t; + (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif } diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu index 9eb049443b..a0be8d184a 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu @@ -30,9 +30,9 @@ namespace nvfp4_per_token_group { #if FP4_TYPE_SUPPORTED +using dispatch::nvfp4::nvfp4_scale_t; using dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; using dispatch::nvfp4::quantization_SF::compute_decoding_scaling_factor; -using dispatch::nvfp4::nvfp4_scale_t; using ptx::FPx2; constexpr int kInnerK = 16; // NVFP4 inner block: 16 elements per e4m3 SF @@ -43,14 +43,14 @@ constexpr int kMaxTensorsPerKernel = 64; // Per-launch arg table; passed as __grid_constant__ for constant-cache reads. struct NVFP4PerTokenMultiArgs { // K1 outputs (per-tensor pointers; one fp32 array per tensor) - void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) - void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) + void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,) + void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,) // K2 outputs (per-tensor pointers; FP4 codes + e4m3 inner SF) - void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) - void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) - void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) - void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) + void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2) + void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16) + void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2) + void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16) // Shared layout info int split_sections_range[kMaxTensorsPerKernel + 1]; // prefix sum w/ leading 0 @@ -69,10 +69,10 @@ __device__ __forceinline__ int GetTensorId(const NVFP4PerTokenMultiArgs& args, i // per-tensor buffer via tensor_id lookup at CTA entry. namespace fused { -constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows -constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols -constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height -constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width +constexpr int CHUNK_DIM_Y = 128; // CTA covers this many rows +constexpr int CHUNK_DIM_X = 128; // CTA covers this many cols +constexpr int TILE_DIM_Y = 64; // TMA bulk-2D box height +constexpr int TILE_DIM_X = 64; // TMA bulk-2D box width constexpr int THREADS_NUM = 128; constexpr int PREFETCH_STAGES = 1; constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; @@ -92,8 +92,7 @@ using FusedIType3D = FusedIType[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; // of the single-tensor helpers in quantize_nvfp4_per_token.cu; K1 and K2 // must consume identical output for FP4 + outer SF to be self-consistent. // TODO: hoist into a shared core header. -__device__ __forceinline__ void apply_signed_fht16_inplace( - float r[16], uint32_t random_sign_mask) { +__device__ __forceinline__ void apply_signed_fht16_inplace(float r[16], uint32_t random_sign_mask) { #pragma unroll for (int i = 0; i < 16; ++i) { const uint32_t bits = __float_as_uint(r[i]); @@ -108,8 +107,8 @@ __device__ __forceinline__ void apply_signed_fht16_inplace( for (int j = 0; j < stride; ++j) { const float a = r[g + j]; const float b = r[g + j + stride]; - r[g + j] = a + b; - r[g + j + stride] = a - b; + r[g + j] = a + b; + r[g + j + stride] = a - b; } } } @@ -127,15 +126,14 @@ constexpr float k16HadamardNorm = 0.25f; // Pre-zero amax buffers (identity for atomicMax). template -__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, - int K) { +__global__ void group_per_token_fused_zero_amax_kernel(NVFP4PerTokenMultiArgs args, int K) { const int tensor_id = blockIdx.x; if (tensor_id >= args.num_tensors) return; if (DO_ROW) { float* row_amax = reinterpret_cast(args.row_amax_list[tensor_id]); if (row_amax != nullptr) { - const int M_i = args.split_sections_range[tensor_id + 1] - - args.split_sections_range[tensor_id]; + const int M_i = + args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; for (int m = threadIdx.x; m < M_i; m += blockDim.x) { row_amax[m] = 0.0f; } @@ -157,8 +155,7 @@ template __global__ void __launch_bounds__(THREADS_NUM) group_per_token_fused_amax_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ NVFP4PerTokenMultiArgs args, - const float* noop, const size_t rows, - const size_t cols, + const float* noop, const size_t rows, const size_t cols, const uint32_t random_sign_mask_t) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { @@ -173,8 +170,7 @@ __global__ void __launch_bounds__(THREADS_NUM) DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); extern __shared__ unsigned char dynamic_shmem[]; - unsigned char* dshmem = - dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); FusedIType* sIn_ptr = reinterpret_cast(dshmem); auto& sIn = *reinterpret_cast(sIn_ptr); @@ -189,10 +185,8 @@ __global__ void __launch_bounds__(THREADS_NUM) // Tile lies fully inside one tensor (split_sections[i] % 128 == 0). const int tensor_id = GetTensorId(args, block_offset_Y); const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; - float* row_amax_out = - DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; - float* col_amax_out = - DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + float* row_amax_out = DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + float* col_amax_out = DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; // Each thread owns chunk-row `tid` (for row amax) and chunk-col `tid` (for col amax). float row_partial = 0.f; @@ -223,8 +217,8 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in]), - reinterpret_cast(&tensor_map_input), global_offset_X, - global_offset_Y, &IN_buff_readable_mbar[buff_in]); + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in]); } } @@ -245,20 +239,18 @@ __global__ void __launch_bounds__(THREADS_NUM) const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], - shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), - next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for this stage's tile. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // Row partial: rotate e-iter by bank group to split warp into 8 groups. @@ -275,8 +267,8 @@ __global__ void __launch_bounds__(THREADS_NUM) for (int p = 0; p < 4; ++p) { ptx::abs_max_2x(amax_2x, amax_2x, pairs[p]); } - local_max = fmaxf(local_max, static_cast( - __hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); + local_max = + fmaxf(local_max, static_cast(__hmax(__habs(amax_2x.x), __habs(amax_2x.y)))); } row_partial = local_max; } @@ -338,52 +330,49 @@ __global__ void __launch_bounds__(THREADS_NUM) // K2 (encode) constants + helpers; byte-equal port of the single-tensor // per-token cooperative 4x32 / 32x4 threading + ld_shared_b128 + mul_cvt_4x. -constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM -constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) +constexpr int ELTS_PER_THREAD = 16; // = NVFP4 block size = SCALE_DIM +constexpr int SCALE_DIM = 16; // NVFP4 inner block (1x16) constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; // 8 constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; // 8 -constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 -constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 4 +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; // 4 // Rowwise pass: 4 (K-dim) x 32 (M-dim) -> 1 NVFP4 block per thread. -constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 -constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 -constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 -constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; // 4 +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 32 +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; // 1 +constexpr int ITERATIONS_NORMAL = TILE_DIM_Y / THREADS_Y_ROWWISE; // 2 // Output / SF SMEM buffer dims (sub-tile sized, double-buffered for ping-pong). -constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; -constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) -constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; +constexpr int BUFF_OUT_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (TILE_DIM_X * 4) / 8; // 32 (fp4e2m1x2 bytes) +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; constexpr int BUFF_OUT_TR_DIM_Y = TILE_DIM_X; -constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 -constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; -constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 -constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_OUT_TR_DIM_X = (TILE_DIM_Y * 4) / 8; // 32 +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; // 2 +constexpr int BUFFS_NUM_OUT_TR = 2; // Manual SMEM swizzling parameters (matches single-tensor encode kernel). -constexpr int PACK_SIZE = 8; -constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 -constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 -constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 - -using IType = FusedIType; -using IType2 = FusedIType2; -using IType2x3D = IType2 [BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; -using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; // 2 +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; // 16 + +using IType = FusedIType; +using IType2 = FusedIType2; +using IType2x3D = IType2[BUFFS_NUM][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; -using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; // Rowwise encode helper: reads sRowAmax (pre-populated by K1), writes FP4 + // e4m3 SFs into sOut / sSFrowwise. Byte-equal to the single-tensor version. __device__ __forceinline__ void rowwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_ptr, - nvfp4_scale_t* __restrict__ sSFrowwise_ptr, - const float* __restrict__ sRowAmax, - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out) { + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_ptr, + nvfp4_scale_t* __restrict__ sSFrowwise_ptr, const float* __restrict__ sRowAmax, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out) { const auto& sIn = *reinterpret_cast(sIn_ptr); auto& sOut = *reinterpret_cast(sOut_ptr); auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); @@ -391,8 +380,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; // 0..31 + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; // 0..3 const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; @@ -425,8 +414,8 @@ __device__ __forceinline__ void rowwise_scaling_per_token( ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); } } - const float block_amax = static_cast( - __hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); const fp8e4m3 s_dec = compute_decoding_scaling_factor(block_amax, S_enc); const float s_dec_f = static_cast(s_dec); @@ -459,22 +448,19 @@ __device__ __forceinline__ void rowwise_scaling_per_token( // mask so the per-col outer amax matches. template __device__ __forceinline__ void colwise_scaling_per_token( - const IType* __restrict__ sIn_ptr, - fp4e2m1x2* __restrict__ sOut_tr_ptr, - nvfp4_scale_t* __restrict__ sSFcolwise_ptr, - const float* __restrict__ sColAmax, - const int stage_Y, const int stage_X, - const int buff_in, const int buff_out_tr, + const IType* __restrict__ sIn_ptr, fp4e2m1x2* __restrict__ sOut_tr_ptr, + nvfp4_scale_t* __restrict__ sSFcolwise_ptr, const float* __restrict__ sColAmax, + const int stage_Y, const int stage_X, const int buff_in, const int buff_out_tr, const uint32_t random_sign_mask_t = 0u) { const auto& sIn2x = *reinterpret_cast(sIn_ptr); auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); auto& sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); - const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 + const int warp = threadIdx.x / THREADS_PER_WARP; // 0..3 const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 - const int tid_X_colwise = thread_lane; // 0..31 + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; // 0..3 + const int tid_X_colwise = thread_lane; // 0..31 const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; const int thread_offset_X_colwise = tid_X_colwise * 2; @@ -575,8 +561,7 @@ template __global__ void __launch_bounds__(THREADS_NUM) group_per_token_fused_cast_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ NVFP4PerTokenMultiArgs args, - const float* noop, const size_t rows, - const size_t cols, + const float* noop, const size_t rows, const size_t cols, const uint32_t random_sign_mask_t) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { @@ -599,25 +584,26 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr int out_mem_colwise_data = DO_COL ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; constexpr int out_mem_colwise_scales = DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) : 0; + TMA_SHMEM_ALIGNMENT) + : 0; (void)out_mem_colwise_scales; extern __shared__ unsigned char dynamic_shmem[]; - unsigned char* dshmem = - dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + unsigned char* dshmem = dispatch::common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); - IType* sIn_ptr = reinterpret_cast(dshmem); - fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); - fp4e2m1x2* sOut_tr_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data); + IType* sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2* sOut_ptr = reinterpret_cast(dshmem + buff_size_aligned_in); + fp4e2m1x2* sOut_tr_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data); nvfp4_scale_t* sSFrowwise_ptr = reinterpret_cast( dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t* sSFcolwise_ptr = reinterpret_cast( - dshmem + buff_size_aligned_in + out_mem_rowwise_data + out_mem_colwise_data - + out_mem_rowwise_scales); + nvfp4_scale_t* sSFcolwise_ptr = + reinterpret_cast(dshmem + buff_size_aligned_in + out_mem_rowwise_data + + out_mem_colwise_data + out_mem_rowwise_scales); __shared__ float sRowAmax[CHUNK_DIM_Y]; __shared__ float sColAmax[CHUNK_DIM_X]; @@ -635,22 +621,21 @@ __global__ void __launch_bounds__(THREADS_NUM) // Chunk Y stays inside one tensor (split_sections[i] % 128 == 0). const int tensor_id = GetTensorId(args, block_offset_Y); const int local_row_base = block_offset_Y - args.split_sections_range[tensor_id]; - const int M_t = args.split_sections_range[tensor_id + 1] - - args.split_sections_range[tensor_id]; + const int M_t = args.split_sections_range[tensor_id + 1] - args.split_sections_range[tensor_id]; // Per-tensor output bases (one constant-cache lookup per CTA). - uint8_t* const q_row_base = DO_ROW - ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; - uint8_t* const q_col_base = DO_COL - ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; - nvfp4_scale_t* const s_dec_row_base = DO_ROW - ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; - nvfp4_scale_t* const s_dec_col_base = DO_COL - ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; - const float* const row_amax_base = DO_ROW - ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; - const float* const col_amax_base = DO_COL - ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; + uint8_t* const q_row_base = + DO_ROW ? reinterpret_cast(args.q_row_list[tensor_id]) : nullptr; + uint8_t* const q_col_base = + DO_COL ? reinterpret_cast(args.q_col_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_row_base = + DO_ROW ? reinterpret_cast(args.s_dec_row_list[tensor_id]) : nullptr; + nvfp4_scale_t* const s_dec_col_base = + DO_COL ? reinterpret_cast(args.s_dec_col_list[tensor_id]) : nullptr; + const float* const row_amax_base = + DO_ROW ? reinterpret_cast(args.row_amax_list[tensor_id]) : nullptr; + const float* const col_amax_base = + DO_COL ? reinterpret_cast(args.col_amax_list[tensor_id]) : nullptr; const size_t data_stride_row = static_cast(cols) / 2; const size_t data_stride_col = static_cast(M_t) / 2; @@ -686,13 +671,13 @@ __global__ void __launch_bounds__(THREADS_NUM) ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[buff_in_p], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[buff_in_p]), - reinterpret_cast(&tensor_map_input), global_offset_X, - global_offset_Y, &IN_buff_readable_mbar[buff_in_p]); + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + &IN_buff_readable_mbar[buff_in_p]); } } - int buff_in = 0; - int buff_out = 0; + int buff_in = 0; + int buff_out = 0; int buff_out_tr = 0; int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; @@ -709,31 +694,28 @@ __global__ void __launch_bounds__(THREADS_NUM) const int next_global_offset_Y = block_offset_Y + next_stage_Y * TILE_DIM_Y; const int next_global_offset_X = block_offset_X + next_stage_X * TILE_DIM_X; if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], - shmem_buff_size); + ptx::mbarrier_arrive_expect_tx(&IN_buff_readable_mbar[next_prefetch_buff], shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&sIn[next_prefetch_buff]), - reinterpret_cast(&tensor_map_input), - next_global_offset_X, next_global_offset_Y, - &IN_buff_readable_mbar[next_prefetch_buff]); + reinterpret_cast(&tensor_map_input), next_global_offset_X, + next_global_offset_Y, &IN_buff_readable_mbar[next_prefetch_buff]); } ptx::fence_proxy_async_shared_cta(); } // Wait for current stage's input tile to land. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta( - &IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; // 4x32 cooperative row + col encode helpers. if (DO_ROW) { - rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, - sRowAmax, stage_Y, stage_X, buff_in, buff_out); + rowwise_scaling_per_token(sIn_ptr, sOut_ptr, sSFrowwise_ptr, sRowAmax, stage_Y, stage_X, + buff_in, buff_out); } if (DO_COL) { - colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, - sColAmax, stage_Y, stage_X, buff_in, buff_out_tr, - random_sign_mask_t); + colwise_scaling_per_token(sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, sColAmax, stage_Y, + stage_X, buff_in, buff_out_tr, random_sign_mask_t); } // Make helper SMEM writes visible before the scatter epilogue. @@ -744,13 +726,10 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_ROW) { auto& sOut = *reinterpret_cast(sOut_ptr); const int row_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 - const int half = static_cast(threadIdx.x) & 1; // 0..1 - const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; - const int byte_off_X = (block_offset_X / 2) - + stage_X * (TILE_DIM_X / 2) - + half * 16; - const uint4* src = reinterpret_cast( - &sOut[buff_out][row_in_subtile][half * 16]); + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int local_row = local_row_base + stage_Y * TILE_DIM_Y + row_in_subtile; + const int byte_off_X = (block_offset_X / 2) + stage_X * (TILE_DIM_X / 2) + half * 16; + const uint4* src = reinterpret_cast(&sOut[buff_out][row_in_subtile][half * 16]); uint4* dst = reinterpret_cast( q_row_base + static_cast(local_row) * data_stride_row + byte_off_X); *dst = *src; @@ -758,13 +737,11 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_COL) { auto& sOut_tr = *reinterpret_cast(sOut_tr_ptr); const int col_in_subtile = static_cast(threadIdx.x) >> 1; // 0..63 - const int half = static_cast(threadIdx.x) & 1; // 0..1 - const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; - const int byte_off_M = (local_row_base / 2) - + stage_Y * (TILE_DIM_Y / 2) - + half * 16; - const uint4* src = reinterpret_cast( - &sOut_tr[buff_out_tr][col_in_subtile][half * 16]); + const int half = static_cast(threadIdx.x) & 1; // 0..1 + const int global_col = block_offset_X + stage_X * TILE_DIM_X + col_in_subtile; + const int byte_off_M = (local_row_base / 2) + stage_Y * (TILE_DIM_Y / 2) + half * 16; + const uint4* src = + reinterpret_cast(&sOut_tr[buff_out_tr][col_in_subtile][half * 16]); uint4* dst = reinterpret_cast( q_col_base + static_cast(global_col) * data_stride_col + byte_off_M); *dst = *src; @@ -773,8 +750,8 @@ __global__ void __launch_bounds__(THREADS_NUM) // Sync so the scatter completes before next stage overwrites the buffer. __syncthreads(); - buff_in = (buff_in + 1) % BUFFS_NUM; - buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_in = (buff_in + 1) % BUFFS_NUM; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; } @@ -782,13 +759,11 @@ __global__ void __launch_bounds__(THREADS_NUM) if (DO_ROW) { auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); using ScalesVec = Vec; - const size_t scales_block_offset_X_rowwise = - static_cast(ctaid_X) * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_X_rowwise = static_cast(ctaid_X) * SCALES_PER_CHUNK_X; for (int row = static_cast(threadIdx.x); row < CHUNK_DIM_Y; row += THREADS_NUM) { ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); const size_t local_row = static_cast(local_row_base) + row; - const size_t scale_idx_global = - local_row * scale_stride_row + scales_block_offset_X_rowwise; + const size_t scale_idx_global = local_row * scale_stride_row + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&s_dec_row_base[scale_idx_global], 0, SCALES_PER_CHUNK_X); } } @@ -797,12 +772,10 @@ __global__ void __launch_bounds__(THREADS_NUM) using ScalesVec = Vec; // M-block offset within s_dec_col[global_col] (shape (K, M_i/16) row-major). const size_t local_block_offset_M = static_cast(local_row_base) / SCALE_DIM; - for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; - row_tr += THREADS_NUM) { + for (int row_tr = static_cast(threadIdx.x); row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { ScalesVec& scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); const size_t global_col = static_cast(block_offset_X) + row_tr; - const size_t scale_idx_global = - global_col * scale_stride_col + local_block_offset_M; + const size_t scale_idx_global = global_col * scale_stride_col + local_block_offset_M; scales_vec.store_to_elts(&s_dec_col_base[scale_idx_global], 0, SCALES_PER_CHUNK_Y); } } @@ -829,11 +802,10 @@ __global__ void __launch_bounds__(THREADS_NUM) // used the same flag + mask, else inner SF + FP4 saturate against mismatched // data. inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, - const SimpleTensor& input_data, int sum_M, - int K, bool do_row, bool do_col, - bool with_rht, - uint32_t random_sign_mask_t, - const float* noop, cudaStream_t stream) { + const SimpleTensor& input_data, int sum_M, int K, + bool do_row, bool do_col, bool with_rht, + uint32_t random_sign_mask_t, const float* noop, + cudaStream_t stream) { if (!do_row && !do_col) return; checkCuDriverContext(stream); @@ -842,40 +814,38 @@ inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, create_2D_tensor_map(tmap_in, input_data, sum_M, K, TILE_DIM_Y, TILE_DIM_X, K, 0, sizeof(FusedIType) * 8); - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); // Collapse to kWithRht=false when no colwise output is requested. const bool with_rht_effective = with_rht && do_col; - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, - TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { - constexpr int sz_in = DIVUP_TO_MULTIPLE( - BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); - constexpr int sz_out_r = DO_ROW - ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; - constexpr int sz_out_c = DO_COL - ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int sz_sf_r = DO_ROW - ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int sz_sf_c = DO_COL - ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), - TMA_SHMEM_ALIGNMENT) - : 0; - constexpr int dshmem_size = sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c - + TMA_SHMEM_ALIGNMENT; - auto kernel = group_per_token_fused_cast_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>(tmap_in, args, noop, - static_cast(sum_M), - static_cast(K), - random_sign_mask_t); - }));); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + constexpr int sz_in = DIVUP_TO_MULTIPLE(BUFFS_NUM * BUFF_IN_SIZE * sizeof(FusedIType), + TMA_SHMEM_ALIGNMENT); + constexpr int sz_out_r = + DO_ROW ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT) : 0; + constexpr int sz_out_c = + DO_COL ? DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_r = + DO_ROW ? DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int sz_sf_c = + DO_COL ? DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), + TMA_SHMEM_ALIGNMENT) + : 0; + constexpr int dshmem_size = + sz_in + sz_out_r + sz_out_c + sz_sf_r + sz_sf_c + TMA_SHMEM_ALIGNMENT; + auto kernel = group_per_token_fused_cast_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, args, noop, static_cast(sum_M), static_cast(K), + random_sign_mask_t); + }));); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -883,11 +853,10 @@ inline void launch_grouped_fused_cast_bf16(const NVFP4PerTokenMultiArgs& args, // with_rht=true applies a 16-pt RHT on the col amax (rowwise raw). The // downstream K2 cast MUST use the same flag + mask. inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, - const SimpleTensor& input_data, int sum_M, - int K, bool do_row, bool do_col, - bool with_rht, - uint32_t random_sign_mask_t, - const float* noop, cudaStream_t stream) { + const SimpleTensor& input_data, int sum_M, int K, + bool do_row, bool do_col, bool with_rht, + uint32_t random_sign_mask_t, const float* noop, + cudaStream_t stream) { if (!do_row && !do_col) return; // Pre-zero amax slots (atomicMax identity). @@ -918,23 +887,21 @@ inline void launch_grouped_fused_amax_bf16(const NVFP4PerTokenMultiArgs& args, DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(FusedIType), TMA_SHMEM_ALIGNMENT); constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; - dim3 grid(static_cast(K / CHUNK_DIM_X), - static_cast(sum_M / CHUNK_DIM_Y), 1); + dim3 grid(static_cast(K / CHUNK_DIM_X), static_cast(sum_M / CHUNK_DIM_Y), 1); dim3 block(THREADS_NUM, 1, 1); // Collapse to kWithRht=false when no colwise amax is requested. const bool with_rht_effective = with_rht && do_col; - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_row, DO_ROW, - TRANSFORMER_ENGINE_SWITCH_CONDITION(do_col, DO_COL, - TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { - auto kernel = group_per_token_fused_amax_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>(tmap_in, args, noop, - static_cast(sum_M), - static_cast(K), - random_sign_mask_t); - }));); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_row, DO_ROW, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { + auto kernel = group_per_token_fused_amax_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tmap_in, args, noop, static_cast(sum_M), static_cast(K), + random_sign_mask_t); + }));); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -957,19 +924,18 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, args->split_sections_range[0] = 0; for (size_t i = 0; i < num_tensors; ++i) { Tensor* o = outputs[i]; - NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, - "] = ", split_sections[i], " must be a multiple of 128"); + NVTE_CHECK(split_sections[i] % 128 == 0, "split_sections[", i, "] = ", split_sections[i], + " must be a multiple of 128"); args->split_sections_range[i + 1] = args->split_sections_range[i] + static_cast(split_sections[i]); if (split_sections[i] == 0) continue; if (which_buffers & kBufRowAmax) { - NVTE_CHECK(o->amax.dptr != nullptr, - "NVFP4 per-token grouped: outputs[", i, "].amax must be allocated for rowwise"); + NVTE_CHECK(o->amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, + "].amax must be allocated for rowwise"); args->row_amax_list[i] = o->amax.dptr; } if (which_buffers & kBufColAmax) { - NVTE_CHECK(o->columnwise_amax.dptr != nullptr, - "NVFP4 per-token grouped: outputs[", i, + NVTE_CHECK(o->columnwise_amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i, "].columnwise_amax must be allocated for columnwise"); args->col_amax_list[i] = o->columnwise_amax.dptr; } @@ -981,10 +947,9 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, args->s_dec_row_list[i] = o->scale_inv.dptr; } if (which_buffers & kBufColCast) { - NVTE_CHECK( - o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, - "NVFP4 per-token grouped: outputs[", i, - "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); + NVTE_CHECK(o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr, + "NVFP4 per-token grouped: outputs[", i, + "].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast"); args->q_col_list[i] = o->columnwise_data.dptr; args->s_dec_col_list[i] = o->columnwise_scale_inv.dptr; } @@ -1000,9 +965,8 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector& outputs, // same flag/mask if they invoke amax + cast separately. void quantize_per_token_grouped(const Tensor& input, std::vector& outputs, const size_t* split_sections, size_t num_tensors, bool rowwise, - bool columnwise, bool do_amax, bool do_cast, - bool with_rht, uint32_t random_sign_mask_t, - cudaStream_t stream) { + bool columnwise, bool do_amax, bool do_cast, bool with_rht, + uint32_t random_sign_mask_t, cudaStream_t stream) { NVTE_CHECK(num_tensors > 0, "NVFP4 per-token grouped: num_tensors must be > 0"); NVTE_CHECK(num_tensors <= static_cast(kMaxTensorsPerKernel), "NVFP4 per-token grouped: num_tensors (", num_tensors, @@ -1017,8 +981,7 @@ void quantize_per_token_grouped(const Tensor& input, std::vector& outpu const int sum_M = static_cast(input.flat_first_dim()); const int K = static_cast(input.flat_last_dim()); if (sum_M == 0 || K == 0) return; - NVTE_CHECK(K % 128 == 0, - "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); + NVTE_CHECK(K % 128 == 0, "NVFP4 per-token grouped: K (", K, ") must be a multiple of 128"); int which_buffers = 0; if ((do_amax || do_cast) && rowwise) which_buffers |= kBufRowAmax; @@ -1068,9 +1031,8 @@ std::vector collect_outputs(NVTETensor* outputs, si } // namespace void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, - int with_rht, int random_sign_mask_t, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_amax); @@ -1080,13 +1042,12 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs std::vector outs = collect_outputs(outputs, num_tensors); // C-API mirrors nvte_nvfp4_per_token_amax: `int` for cross-language ABI // safety; internal kernel arg is uint32_t with only the low 16 bits used. - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, - rowwise, columnwise, - /*do_amax=*/true, /*do_cast=*/false, - /*with_rht=*/with_rht != 0, - /*random_sign_mask_t=*/ - static_cast(random_sign_mask_t) & 0xFFFFu, - stream); + nvfp4_per_token_group::quantize_per_token_grouped( + *in, outs, split_sections, num_tensors, rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/false, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, stream); #else (void)input; (void)outputs; @@ -1102,9 +1063,8 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs } void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, - int with_rht, int random_sign_mask_t, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_cast); @@ -1112,13 +1072,12 @@ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, - rowwise, columnwise, - /*do_amax=*/false, /*do_cast=*/true, - /*with_rht=*/with_rht != 0, - /*random_sign_mask_t=*/ - static_cast(random_sign_mask_t) & 0xFFFFu, - stream); + nvfp4_per_token_group::quantize_per_token_grouped( + *in, outs, split_sections, num_tensors, rowwise, columnwise, + /*do_amax=*/false, /*do_cast=*/true, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, stream); #else (void)input; (void)outputs; @@ -1135,22 +1094,20 @@ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, - int with_rht, int random_sign_mask_t, - cudaStream_t stream) { + bool rowwise, bool columnwise, int with_rht, + int random_sign_mask_t, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_group_nvfp4_per_token_quantize); using namespace transformer_engine; if (num_tensors == 0) return; const Tensor* in = convertNVTETensorCheck(input); std::vector outs = collect_outputs(outputs, num_tensors); - nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, - rowwise, columnwise, - /*do_amax=*/true, /*do_cast=*/true, - /*with_rht=*/with_rht != 0, - /*random_sign_mask_t=*/ - static_cast(random_sign_mask_t) & 0xFFFFu, - stream); + nvfp4_per_token_group::quantize_per_token_grouped( + *in, outs, split_sections, num_tensors, rowwise, columnwise, + /*do_amax=*/true, /*do_cast=*/true, + /*with_rht=*/with_rht != 0, + /*random_sign_mask_t=*/ + static_cast(random_sign_mask_t) & 0xFFFFu, stream); #else (void)input; (void)outputs; diff --git a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h index 9533dddc4d..395b061b52 100644 --- a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h +++ b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h @@ -15,7 +15,6 @@ extern "C" { #endif - /*! \brief Composite K1+K2: per-row + per-col amax (K1) then FP4 + 1x16 * e4m3 SF encode (K2), back-to-back on the same stream. * @@ -27,11 +26,8 @@ extern "C" { * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared by * K1 and K2. Ignored when with_rht == 0. */ -void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, - NVTETensor output, - int with_rht, - int random_sign_mask_t, - cudaStream_t stream); +void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, + int with_rht, int random_sign_mask_t, cudaStream_t stream); /*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. * Pre-zeroes the amax buffers and merges per-CTA partials into @@ -45,11 +41,8 @@ void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop * when with_rht == 0. Type matches prod's * nvte_hadamard_transform_amax convention. */ -void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, - NVTETensor output, - int with_rht, - int random_sign_mask_t, - cudaStream_t stream); +void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output, + int with_rht, int random_sign_mask_t, cudaStream_t stream); /*! \brief Kernel 2 in isolation: FP4 + 1x16 e4m3 SF encode given a * pre-filled ``output->amax`` / ``output->columnwise_amax``. Reads @@ -62,11 +55,8 @@ void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored * when with_rht == 0. */ -void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, - NVTETensor output, - int with_rht, - int random_sign_mask_t, - cudaStream_t stream); +void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, + int with_rht, int random_sign_mask_t, cudaStream_t stream); /*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``. * @@ -86,8 +76,7 @@ int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum); * d[i, j] = d[i, j] * row_amax_a[i] * row_amax_b[j] */ void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, - const NVTETensor row_amax_b, - cudaStream_t stream); + const NVTETensor row_amax_b, cudaStream_t stream); /* ============================================================================ * Grouped (multi-tensor) per-token quantize. @@ -108,9 +97,8 @@ void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a, * \param[in] stream CUDA stream */ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, - int with_rht, int random_sign_mask_t, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, cudaStream_t stream); /*! \brief Grouped per-token encode (FP4 + 1x16 e4m3 inner SF) using the @@ -133,9 +121,8 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs * \param[in] stream CUDA stream */ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs, - const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, - int with_rht, int random_sign_mask_t, + const size_t* split_sections, size_t num_tensors, bool rowwise, + bool columnwise, int with_rht, int random_sign_mask_t, cudaStream_t stream); /*! \brief Composite K1+K2 grouped per-token quantize. Calls the amax + cast @@ -161,9 +148,8 @@ void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs */ void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs, const size_t* split_sections, size_t num_tensors, - bool rowwise, bool columnwise, - int with_rht, int random_sign_mask_t, - cudaStream_t stream); + bool rowwise, bool columnwise, int with_rht, + int random_sign_mask_t, cudaStream_t stream); #ifdef __cplusplus } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a1cee4461e..53bf118425 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -412,8 +412,7 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void compute_amax(const at::Tensor &tensor, at::Tensor &amax); void hadamard_transform_amax(const at::Tensor &tensor, at::Tensor &rowwise_amax, - at::Tensor &columnwise_amax, - int64_t rht_matrix_random_sign_mask); + at::Tensor &columnwise_amax, int64_t rht_matrix_random_sign_mask); void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, @@ -452,38 +451,34 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi const at::Tensor &scale_inv_colwise, int rows, int cols, size_t start_offset); -void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, - at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t); +void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t); -void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, - at::Tensor col_amax, bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t); +void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, at::Tensor col_amax, + bool rowwise, bool columnwise, bool with_rht, int64_t random_sign_mask_t); -void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, - at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, - at::Tensor col_amax, bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t); +void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t); void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, const at::Tensor &row_amax_b); void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta); + const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, at::Tensor d, + const at::Tensor &workspace, int64_t m, int64_t n, int64_t k, + double alpha, double beta); // Bench-only per-tensor twin of nvfp4_per_token_gemm: scalar amaxes folded // into cuBLAS LT alpha via the amax slot; no trailing post-scale. void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, - const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_amax, const at::Tensor &b_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta); + const at::Tensor &a_sf, const at::Tensor &b_sf, const at::Tensor &a_amax, + const at::Tensor &b_amax, at::Tensor d, const at::Tensor &workspace, + int64_t m, int64_t n, int64_t k, double alpha, double beta); // with_rht=true applies a 16-pt RHT on the col direction in BOTH K1 and K2; // random_sign_mask_t low 16 bits = sign pattern (ignored when with_rht=false). @@ -491,19 +486,16 @@ void nvfp4_per_token_group_quantize( const at::Tensor &input, const std::vector &split_sections, std::vector q_row_list, std::vector s_dec_row_list, std::vector row_amax_list, std::vector q_col_list, - std::vector s_dec_col_list, std::vector col_amax_list, - bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t); + std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t); // Amax-only variant of the grouped quantize. Useful for multi-rank training // where amax is allReduced before the cast pass. Caller must thread the // matching with_rht / mask into the subsequent cast launch. -void nvfp4_per_token_group_amax(const at::Tensor &input, - const std::vector &split_sections, +void nvfp4_per_token_group_amax(const at::Tensor &input, const std::vector &split_sections, std::vector row_amax_list, std::vector col_amax_list, bool rowwise, - bool columnwise, - bool with_rht, int64_t random_sign_mask_t); + bool columnwise, bool with_rht, int64_t random_sign_mask_t); // Bulk grouped quantize: allocate-view-dispatch all in one pybind hop. // Returns 6 per-split vectors (q_row, s_dec_row_fp8, row_amax, q_col, @@ -511,9 +503,8 @@ void nvfp4_per_token_group_amax(const at::Tensor &input, std::tuple, std::vector, std::vector, std::vector, std::vector, std::vector> nvfp4_per_token_group_quantize_bulk(const at::Tensor &input, - const std::vector &split_sections, - bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t); + const std::vector &split_sections, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp index 8498f2f249..4ef13afdae 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp @@ -25,11 +25,10 @@ namespace { // Validates the input and assembles ``out_te`` for all 3 modes; caller // dispatches to the right C-API entry on the caller's stream. -void assemble_per_token_tensors(const at::Tensor& input, - at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, - bool rowwise, bool columnwise, int mode, - TensorWrapper& in_te, TensorWrapper& out_te) { +void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, int mode, + TensorWrapper& in_te, TensorWrapper& out_te) { TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); @@ -53,8 +52,8 @@ void assemble_per_token_tensors(const at::Tensor& input, TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), "row_amax must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax.scalar_type() == at::ScalarType::Float, "row_amax must be float32"); - TORCH_CHECK(row_amax.numel() == M, - "row_amax numel mismatch: expected M=", M, ", got ", row_amax.numel()); + TORCH_CHECK(row_amax.numel() == M, "row_amax numel mismatch: expected M=", M, ", got ", + row_amax.numel()); out_te.set_amax(row_amax.data_ptr(), DType::kFloat32, std::vector{static_cast(M)}); @@ -63,15 +62,14 @@ void assemble_per_token_tensors(const at::Tensor& input, "q_row must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), "s_dec_row must be a contiguous CUDA tensor"); - TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, - "q_row must be uint8 (FP4 packed)"); + TORCH_CHECK(q_row.scalar_type() == at::ScalarType::Byte, "q_row must be uint8 (FP4 packed)"); TORCH_CHECK(s_dec_row.scalar_type() == at::ScalarType::Byte, "s_dec_row must be uint8 (FP8 e4m3 raw bytes)"); - TORCH_CHECK(q_row.numel() == M * K / 2, - "q_row numel mismatch: expected M*K/2=", M * K / 2, ", got ", q_row.numel()); + TORCH_CHECK(q_row.numel() == M * K / 2, "q_row numel mismatch: expected M*K/2=", M * K / 2, + ", got ", q_row.numel()); TORCH_CHECK(s_dec_row.numel() == M * K / 16, - "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, - ", got ", s_dec_row.numel()); + "s_dec_row numel mismatch: expected M*K/16=", M * K / 16, ", got ", + s_dec_row.numel()); out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, in_shape); out_te.set_rowwise_scale_inv( s_dec_row.data_ptr(), DType::kFloat8E4M3, @@ -82,8 +80,8 @@ void assemble_per_token_tensors(const at::Tensor& input, TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), "col_amax must be a contiguous CUDA tensor"); TORCH_CHECK(col_amax.scalar_type() == at::ScalarType::Float, "col_amax must be float32"); - TORCH_CHECK(col_amax.numel() == K, - "col_amax numel mismatch: expected K=", K, ", got ", col_amax.numel()); + TORCH_CHECK(col_amax.numel() == K, "col_amax numel mismatch: expected K=", K, ", got ", + col_amax.numel()); out_te.set_columnwise_amax(col_amax.data_ptr(), DType::kFloat32, std::vector{static_cast(K)}); @@ -92,15 +90,14 @@ void assemble_per_token_tensors(const at::Tensor& input, "q_col must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), "s_dec_col must be a contiguous CUDA tensor"); - TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, - "q_col must be uint8 (FP4 packed)"); + TORCH_CHECK(q_col.scalar_type() == at::ScalarType::Byte, "q_col must be uint8 (FP4 packed)"); TORCH_CHECK(s_dec_col.scalar_type() == at::ScalarType::Byte, "s_dec_col must be uint8 (FP8 e4m3 raw bytes)"); - TORCH_CHECK(q_col.numel() == K * M / 2, - "q_col numel mismatch: expected K*M/2=", K * M / 2, ", got ", q_col.numel()); + TORCH_CHECK(q_col.numel() == K * M / 2, "q_col numel mismatch: expected K*M/2=", K * M / 2, + ", got ", q_col.numel()); TORCH_CHECK(s_dec_col.numel() == K * M / 16, - "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, - ", got ", s_dec_col.numel()); + "s_dec_col numel mismatch: expected K*M/16=", K * M / 16, ", got ", + s_dec_col.numel()); out_te.set_columnwise_data( q_col.data_ptr(), DType::kFloat4E2M1, std::vector{static_cast(K), static_cast(M)}); @@ -115,69 +112,55 @@ void assemble_per_token_tensors(const at::Tensor& input, // Production composite (K1 + K2 back-to-back). with_rht=true enables the // 16-pt col-wise RHT in BOTH K1 and K2 so outer + inner SFs stay consistent. -void nvfp4_per_token_quantize(const at::Tensor& input, - at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, - bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t) { +void nvfp4_per_token_quantize(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, - q_col, s_dec_col, col_amax, - rowwise, columnwise, /*mode=*/0, in_te, out_te); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, + columnwise, /*mode=*/0, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); - nvte_nvfp4_per_token_quantize( - in_te.data(), nullptr, out_te.data(), - with_rht ? 1 : 0, - static_cast(random_sign_mask_t & 0xFFFF), - stream); + nvte_nvfp4_per_token_quantize(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), stream); } // K1-only (diagnostic / bench): populates only amax buffers. with_rht=true // applies the 16-pt col-wise RHT before amax (rowwise unaffected); // random_sign_mask_t low 16 bits = sign-flip pattern. -void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, - at::Tensor col_amax, bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t) { +void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, at::Tensor col_amax, + bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t) { at::Tensor empty_u8; // not consumed by K1 TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, - empty_u8, empty_u8, col_amax, - rowwise, columnwise, /*mode=*/1, in_te, out_te); + assemble_per_token_tensors(input, empty_u8, empty_u8, row_amax, empty_u8, empty_u8, col_amax, + rowwise, columnwise, /*mode=*/1, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); // C-API matches prod's `int` convention; only low 16 bits are consumed. - nvte_nvfp4_per_token_amax( - in_te.data(), nullptr, out_te.data(), - with_rht ? 1 : 0, - static_cast(random_sign_mask_t & 0xFFFF), - stream); + nvte_nvfp4_per_token_amax(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), stream); } // K2-only (diagnostic / bench): reads pre-filled amax buffers, emits FP4 + SFs. // with_rht=true requires col_amax to have been produced by an earlier K1 // amax call with the SAME mask, else inner SFs are miscalibrated. -void nvfp4_per_token_encode(const at::Tensor& input, - at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, - at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, - bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t) { +void nvfp4_per_token_encode(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, + at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, + at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, + int64_t random_sign_mask_t) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); - assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, - q_col, s_dec_col, col_amax, - rowwise, columnwise, /*mode=*/2, in_te, out_te); + assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, + columnwise, /*mode=*/2, in_te, out_te); const auto stream = at::cuda::getCurrentCUDAStream(); - nvte_nvfp4_per_token_encode( - in_te.data(), nullptr, out_te.data(), - with_rht ? 1 : 0, - static_cast(random_sign_mask_t & 0xFFFF), - stream); + nvte_nvfp4_per_token_encode(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, + static_cast(random_sign_mask_t & 0xFFFF), stream); } // Apply per-token post-scale to a GEMM output (see nvfp4_per_token.h for math). -void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, - const at::Tensor &row_amax_b) { +void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor& row_amax_a, + const at::Tensor& row_amax_b) { TORCH_CHECK(d.is_cuda() && d.is_contiguous(), "d must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax_a.is_cuda() && row_amax_a.is_contiguous(), "row_amax_a must be a contiguous CUDA tensor"); @@ -190,16 +173,16 @@ void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, const int64_t M = d.size(0); const int64_t N = d.size(1); - TORCH_CHECK(row_amax_a.numel() == M, - "row_amax_a numel mismatch: expected M=", M, ", got ", row_amax_a.numel()); - TORCH_CHECK(row_amax_b.numel() == N, - "row_amax_b numel mismatch: expected N=", N, ", got ", row_amax_b.numel()); + TORCH_CHECK(row_amax_a.numel() == M, "row_amax_a numel mismatch: expected M=", M, ", got ", + row_amax_a.numel()); + TORCH_CHECK(row_amax_b.numel() == N, "row_amax_b numel mismatch: expected N=", N, ", got ", + row_amax_b.numel()); const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), - std::vector{static_cast(M), static_cast(N)}, DType::kBFloat16); + d.data_ptr(), std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16); TensorWrapper ra_te = makeTransformerEngineTensor( row_amax_a.data_ptr(), std::vector{static_cast(M)}, DType::kFloat32); TensorWrapper rb_te = makeTransformerEngineTensor( @@ -211,11 +194,11 @@ void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, // End-to-end NVFP4 per-token GEMM: swizzle compact SFs -> cuBLAS LT NVFP4 // GEMM (operand amax pinned to 1.0 to cancel the 2688^2 inner-SF factor) -> // per-row post-scale. beta must be 0.0. Math in nvfp4_per_token.h. -void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, - const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta) { +void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, + const at::Tensor& a_sf, const at::Tensor& b_sf, + const at::Tensor& a_row_amax, const at::Tensor& b_row_amax, at::Tensor d, + const at::Tensor& workspace, int64_t m, int64_t n, int64_t k, + double alpha, double beta) { TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && a_row_amax.is_cuda() && b_row_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), @@ -234,14 +217,13 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); - TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, - "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, - "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", - a_data.size(0), ", ", a_data.size(1), ")"); + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), + ", ", a_data.size(1), ")"); TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, - "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", - b_data.size(0), ", ", b_data.size(1), ")"); + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), + ", ", b_data.size(1), ")"); TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, "), got (", d.size(0), ", ", d.size(1), ")"); @@ -250,10 +232,10 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, "a_sf numel mismatch: expected M*K/16=", m * k / 16, ", got ", a_sf.numel()); TORCH_CHECK(b_sf.numel() == static_cast(n * k / 16), "b_sf numel mismatch: expected N*K/16=", n * k / 16, ", got ", b_sf.numel()); - TORCH_CHECK(a_row_amax.numel() == m, - "a_row_amax numel mismatch: expected M=", m, ", got ", a_row_amax.numel()); - TORCH_CHECK(b_row_amax.numel() == n, - "b_row_amax numel mismatch: expected N=", n, ", got ", b_row_amax.numel()); + TORCH_CHECK(a_row_amax.numel() == m, "a_row_amax numel mismatch: expected M=", m, ", got ", + a_row_amax.numel()); + TORCH_CHECK(b_row_amax.numel() == n, "b_row_amax numel mismatch: expected N=", n, ", got ", + b_row_amax.numel()); TORCH_CHECK(static_cast(beta) == 0.0f, "nvfp4_per_token_gemm: beta != 0 not yet supported. Got beta=", beta); @@ -323,8 +305,7 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, b_te.set_with_gemm_swizzled_scales(true); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), - std::vector{static_cast(m), static_cast(n)}, + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, DType::kBFloat16); TensorWrapper workspace_te = makeTransformerEngineTensor( @@ -353,11 +334,10 @@ void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, // Per-tensor twin of nvfp4_per_token_gemm: scalar amax goes through cuBLAS's // own amax slot (no post-scale). Bench-only apples-to-apples baseline. -void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, - const at::Tensor &a_sf, const at::Tensor &b_sf, - const at::Tensor &a_amax, const at::Tensor &b_amax, - at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, - int64_t k, double alpha, double beta) { +void nvfp4_per_tensor_gemm(const at::Tensor& a_data, const at::Tensor& b_data, + const at::Tensor& a_sf, const at::Tensor& b_sf, const at::Tensor& a_amax, + const at::Tensor& b_amax, at::Tensor d, const at::Tensor& workspace, + int64_t m, int64_t n, int64_t k, double alpha, double beta) { TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && a_amax.is_cuda() && b_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), "All tensors must be CUDA tensors"); @@ -374,14 +354,13 @@ void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, TORCH_CHECK(d.scalar_type() == at::ScalarType::BFloat16, "d must be bfloat16"); TORCH_CHECK(workspace.scalar_type() == at::ScalarType::Byte, "workspace must be uint8"); - TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, - "a_data/b_data/d must be 2D"); + TORCH_CHECK(a_data.dim() == 2 && b_data.dim() == 2 && d.dim() == 2, "a_data/b_data/d must be 2D"); TORCH_CHECK(a_data.size(0) == m && a_data.size(1) * 2 == k, - "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", - a_data.size(0), ", ", a_data.size(1), ")"); + "a_data shape mismatch: expected (M=", m, ", K/2=", k / 2, "), got (", a_data.size(0), + ", ", a_data.size(1), ")"); TORCH_CHECK(b_data.size(0) == n && b_data.size(1) * 2 == k, - "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", - b_data.size(0), ", ", b_data.size(1), ")"); + "b_data shape mismatch: expected (N=", n, ", K/2=", k / 2, "), got (", b_data.size(0), + ", ", b_data.size(1), ")"); TORCH_CHECK(d.size(0) == m && d.size(1) == n, "d shape mismatch: expected (M=", m, ", N=", n, "), got (", d.size(0), ", ", d.size(1), ")"); @@ -446,8 +425,7 @@ void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, b_te.set_with_gemm_swizzled_scales(true); TensorWrapper d_te = makeTransformerEngineTensor( - d.data_ptr(), - std::vector{static_cast(m), static_cast(n)}, + d.data_ptr(), std::vector{static_cast(m), static_cast(n)}, DType::kBFloat16); TensorWrapper workspace_te = makeTransformerEngineTensor( @@ -470,13 +448,13 @@ void nvfp4_per_tensor_gemm(const at::Tensor &a_data, const at::Tensor &b_data, // Disabled direction's lists are ignored. namespace { -void build_per_token_output_wrapper( - TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, bool columnwise, - const at::Tensor& q_row, const at::Tensor& s_dec_row, const at::Tensor& row_amax, - const at::Tensor& q_col, const at::Tensor& s_dec_col, const at::Tensor& col_amax) { +void build_per_token_output_wrapper(TensorWrapper& out_te, int64_t M_i, int64_t K, bool rowwise, + bool columnwise, const at::Tensor& q_row, + const at::Tensor& s_dec_row, const at::Tensor& row_amax, + const at::Tensor& q_col, const at::Tensor& s_dec_col, + const at::Tensor& col_amax) { if (rowwise) { - TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), - "q_row must be a contiguous CUDA tensor"); + TORCH_CHECK(q_row.is_cuda() && q_row.is_contiguous(), "q_row must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_row.is_cuda() && s_dec_row.is_contiguous(), "s_dec_row must be a contiguous CUDA tensor"); TORCH_CHECK(row_amax.is_cuda() && row_amax.is_contiguous(), @@ -488,9 +466,8 @@ void build_per_token_output_wrapper( M_i * K / 2, ", got ", q_row.numel()); TORCH_CHECK(s_dec_row.numel() == M_i * K / 16, "s_dec_row numel mismatch for split"); TORCH_CHECK(row_amax.numel() == M_i, "row_amax numel mismatch for split"); - out_te.set_rowwise_data( - q_row.data_ptr(), DType::kFloat4E2M1, - std::vector{static_cast(M_i), static_cast(K)}); + out_te.set_rowwise_data(q_row.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(M_i), static_cast(K)}); out_te.set_rowwise_scale_inv( s_dec_row.data_ptr(), DType::kFloat8E4M3, std::vector{static_cast(M_i), static_cast(K / 16)}); @@ -498,8 +475,7 @@ void build_per_token_output_wrapper( std::vector{static_cast(M_i)}); } if (columnwise) { - TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), - "q_col must be a contiguous CUDA tensor"); + TORCH_CHECK(q_col.is_cuda() && q_col.is_contiguous(), "q_col must be a contiguous CUDA tensor"); TORCH_CHECK(s_dec_col.is_cuda() && s_dec_col.is_contiguous(), "s_dec_col must be a contiguous CUDA tensor"); TORCH_CHECK(col_amax.is_cuda() && col_amax.is_contiguous(), @@ -535,13 +511,10 @@ void nvfp4_per_token_group_quantize( const at::Tensor& input, const std::vector& split_sections, std::vector q_row_list, std::vector s_dec_row_list, std::vector row_amax_list, std::vector q_col_list, - std::vector s_dec_col_list, std::vector col_amax_list, - bool rowwise, bool columnwise, - bool with_rht, int64_t random_sign_mask_t) { - TORCH_CHECK(rowwise || columnwise, - "At least one of rowwise/columnwise must be True."); - TORCH_CHECK(input.is_cuda() && input.is_contiguous(), - "input must be a contiguous CUDA tensor"); + std::vector s_dec_col_list, std::vector col_amax_list, bool rowwise, + bool columnwise, bool with_rht, int64_t random_sign_mask_t) { + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); TORCH_CHECK(input.dim() == 2, "input must be 2D"); const int64_t sum_M = input.size(0); const int64_t K = input.size(1); @@ -552,8 +525,8 @@ void nvfp4_per_token_group_quantize( int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { TORCH_CHECK(split_sections[i] >= 0, "split_sections[", i, "] must be non-negative"); - TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, - "] = ", split_sections[i], " must be a multiple of 64"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] = ", split_sections[i], + " must be a multiple of 64"); acc += split_sections[i]; } TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); @@ -573,8 +546,8 @@ void nvfp4_per_token_group_quantize( const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + in_dtype); // One TensorWrapper per split; raw NVTETensor handles go into `handles`. std::vector wrappers; @@ -593,34 +566,25 @@ void nvfp4_per_token_group_quantize( continue; // empty split is allowed (skipped inside the kernel) } build_per_token_output_wrapper( - wrappers.back(), M_i, K, rowwise, columnwise, - rowwise ? q_row_list[i] : empty_dummy, - rowwise ? s_dec_row_list[i] : empty_dummy, - rowwise ? row_amax_list[i] : empty_dummy, - columnwise ? q_col_list[i] : empty_dummy, - columnwise ? s_dec_col_list[i] : empty_dummy, + wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_list[i] : empty_dummy, columnwise ? col_amax_list[i] : empty_dummy); handles.push_back(wrappers.back().data()); } - nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), - split_sections_sz.data(), num_tensors, rowwise, - columnwise, - static_cast(with_rht), - static_cast(random_sign_mask_t), - stream); + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, static_cast(with_rht), + static_cast(random_sign_mask_t), stream); } // Amax-only grouped variant (K1 only); for allReduce-before-cast flows. -void nvfp4_per_token_group_amax(const at::Tensor& input, - const std::vector& split_sections, +void nvfp4_per_token_group_amax(const at::Tensor& input, const std::vector& split_sections, std::vector row_amax_list, std::vector col_amax_list, bool rowwise, - bool columnwise, - bool with_rht, int64_t random_sign_mask_t) { + bool columnwise, bool with_rht, int64_t random_sign_mask_t) { TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); - TORCH_CHECK(input.is_cuda() && input.is_contiguous(), - "input must be a contiguous CUDA tensor"); + TORCH_CHECK(input.is_cuda() && input.is_contiguous(), "input must be a contiguous CUDA tensor"); TORCH_CHECK(input.dim() == 2, "input must be 2D"); const int64_t sum_M = input.size(0); const int64_t K = input.size(1); @@ -628,21 +592,19 @@ void nvfp4_per_token_group_amax(const at::Tensor& input, TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { - TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, - "] must be a multiple of 64"); + TORCH_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] must be a multiple of 64"); acc += split_sections[i]; } TORCH_CHECK(acc == sum_M, "sum(split_sections) must equal input.size(0)"); if (rowwise) TORCH_CHECK(row_amax_list.size() == num_tensors, "row_amax_list size mismatch"); - if (columnwise) - TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); + if (columnwise) TORCH_CHECK(col_amax_list.size() == num_tensors, "col_amax_list size mismatch"); const DType in_dtype = resolve_input_dtype(input); const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(sum_M), static_cast(K)}, in_dtype); + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, + in_dtype); std::vector wrappers; wrappers.reserve(num_tensors); @@ -676,10 +638,8 @@ void nvfp4_per_token_group_amax(const at::Tensor& input, } nvte_group_nvfp4_per_token_amax(in_te.data(), handles.data(), split_sections_sz.data(), - num_tensors, rowwise, columnwise, - static_cast(with_rht), - static_cast(random_sign_mask_t), - stream); + num_tensors, rowwise, columnwise, static_cast(with_rht), + static_cast(random_sign_mask_t), stream); } // BULK grouped per-token quantize: alloc + view + dispatch in ONE C++ call. @@ -689,15 +649,13 @@ std::tuple, std::vector, std::vector, std::vector, std::vector> nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, const std::vector& split_sections, bool rowwise, - bool columnwise, - bool with_rht, int64_t random_sign_mask_t) { + bool columnwise, bool with_rht, int64_t random_sign_mask_t) { // Validation mirrors _validate_per_token_group_input in Python. - TORCH_CHECK(rowwise || columnwise, - "At least one of rowwise/columnwise must be True."); + TORCH_CHECK(rowwise || columnwise, "At least one of rowwise/columnwise must be True."); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_contiguous(), "x_concat must be contiguous (row-major)"); - TORCH_CHECK(input.dim() == 2, - "nvfp4_per_token_group_quantize expects a 2D input, got ", input.dim(), "D"); + TORCH_CHECK(input.dim() == 2, "nvfp4_per_token_group_quantize expects a 2D input, got ", + input.dim(), "D"); TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, "Per-token grouped kernel is bf16-only; got dtype ", input.scalar_type()); @@ -706,13 +664,13 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, constexpr int64_t kPerTokenTile = 128; constexpr int64_t kBlockK = 16; - TORCH_CHECK(K % kPerTokenTile == 0, - "Per-token grouped kernel requires K % ", kPerTokenTile, " == 0; got K=", K); + TORCH_CHECK(K % kPerTokenTile == 0, "Per-token grouped kernel requires K % ", kPerTokenTile, + " == 0; got K=", K); const size_t num_tensors = split_sections.size(); TORCH_CHECK(num_tensors > 0, "split_sections must not be empty"); - TORCH_CHECK(num_tensors <= 64, - "num_tensors must be <= 64 (kernel arg-struct cap); got ", num_tensors); + TORCH_CHECK(num_tensors <= 64, "num_tensors must be <= 64 (kernel arg-struct cap); got ", + num_tensors); int64_t acc = 0; for (size_t i = 0; i < num_tensors; ++i) { @@ -722,8 +680,7 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, " must be a multiple of ", kPerTokenTile); acc += M_i; } - TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, - " must equal input.size(0) = ", sum_M); + TORCH_CHECK(acc == sum_M, "sum(split_sections) = ", acc, " must equal input.size(0) = ", sum_M); // Bulk allocation: one at::empty per output type, covers all splits. auto opts_u8 = input.options().dtype(at::kByte); @@ -773,8 +730,7 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, if (columnwise) { auto q_col_flat = q_col_bulk.narrow(0, K * m_off / 2, K * M_i / 2); q_col_list.emplace_back(q_col_flat.view({K, M_i / 2})); - auto s_dec_col_flat = - s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); + auto s_dec_col_flat = s_dec_col_bulk.narrow(0, K * m_off / kBlockK, K * M_i / kBlockK); s_dec_col_u8_list.emplace_back(s_dec_col_flat.view({K, M_i / kBlockK})); col_amax_list.emplace_back(col_amax_bulk.select(0, static_cast(i))); s_dec_col_fp8_list.emplace_back(s_dec_col_u8_list.back().view(at::kFloat8_e4m3fn)); @@ -785,8 +741,7 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, // Dispatch K1+K2 grouped kernel via the same C-API the thin entry uses. const auto stream = at::cuda::getCurrentCUDAStream(); TensorWrapper in_te = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(sum_M), static_cast(K)}, + input.data_ptr(), std::vector{static_cast(sum_M), static_cast(K)}, DType::kBFloat16); std::vector wrappers; @@ -801,22 +756,16 @@ nvfp4_per_token_group_quantize_bulk(const at::Tensor& input, split_sections_sz[i] = static_cast(M_i); wrappers.emplace_back(NVTE_NVFP4_1D_SCALING); build_per_token_output_wrapper( - wrappers.back(), M_i, K, rowwise, columnwise, - rowwise ? q_row_list[i] : empty_dummy, - rowwise ? s_dec_row_u8_list[i] : empty_dummy, - rowwise ? row_amax_list[i] : empty_dummy, - columnwise ? q_col_list[i] : empty_dummy, - columnwise ? s_dec_col_u8_list[i] : empty_dummy, + wrappers.back(), M_i, K, rowwise, columnwise, rowwise ? q_row_list[i] : empty_dummy, + rowwise ? s_dec_row_u8_list[i] : empty_dummy, rowwise ? row_amax_list[i] : empty_dummy, + columnwise ? q_col_list[i] : empty_dummy, columnwise ? s_dec_col_u8_list[i] : empty_dummy, columnwise ? col_amax_list[i] : empty_dummy); handles.push_back(wrappers.back().data()); } - nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), - split_sections_sz.data(), num_tensors, rowwise, - columnwise, - static_cast(with_rht), - static_cast(random_sign_mask_t), - stream); + nvte_group_nvfp4_per_token_quantize(in_te.data(), handles.data(), split_sections_sz.data(), + num_tensors, rowwise, columnwise, static_cast(with_rht), + static_cast(random_sign_mask_t), stream); return std::make_tuple(std::move(q_row_list), std::move(s_dec_row_fp8_list), std::move(row_amax_list), std::move(q_col_list), diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e97e5837a8..507595d172 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -346,8 +346,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "K1 of the NVFP4Quantizer RHT+post_rht_amax path: rowwise (pre-RHT) + " "columnwise (RHT(input.T)) amax in one launch. Bench-only entry.", py::arg("input"), py::arg("rowwise_amax"), py::arg("columnwise_amax"), - py::arg("rht_matrix_random_sign_mask"), - py::call_guard()); + py::arg("rht_matrix_random_sign_mask"), py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", @@ -402,44 +401,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Requires bf16 input, M % 128 == 0, K % 128 == 0. " "with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), - py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), - py::arg("rowwise"), py::arg("columnwise"), - py::arg("with_rht") = false, + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise"), py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); - m.def("nvfp4_per_token_amax", - &transformer_engine::pytorch::nvfp4_per_token_amax, + m.def("nvfp4_per_token_amax", &transformer_engine::pytorch::nvfp4_per_token_amax, "K1-only: per-row/per-col outer amax via TMA + atomicMax. Bench/diagnostic. " "with_rht=True applies a 16-pt col-wise RHT before amax.", - py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), - py::arg("rowwise"), py::arg("columnwise"), - py::arg("with_rht") = false, + py::arg("input"), py::arg("row_amax"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise"), py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); - m.def("nvfp4_per_token_encode", - &transformer_engine::pytorch::nvfp4_per_token_encode, + m.def("nvfp4_per_token_encode", &transformer_engine::pytorch::nvfp4_per_token_encode, "K2-only: FP4 + e4m3 SF encode given pre-filled amax buffers. Bench/diagnostic. " "with_rht=True requires col_amax produced by a K1 launch with the same mask.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), - py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), - py::arg("rowwise"), py::arg("columnwise"), - py::arg("with_rht") = false, + py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), + py::arg("columnwise"), py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("nvfp4_per_token_post_scale", &transformer_engine::pytorch::nvfp4_per_token_post_scale, - "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", - py::arg("d"), py::arg("row_amax_a"), py::arg("row_amax_b")); + "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", py::arg("d"), + py::arg("row_amax_a"), py::arg("row_amax_b")); m.def("nvfp4_per_token_gemm", &transformer_engine::pytorch::nvfp4_per_token_gemm, "End-to-end NVFP4 per-token GEMM: swizzle compact SFs, cuBLAS LT NVFP4 " "GEMM, then row*col post-scale to recover C = A @ B^T. beta must be 0.", py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), - py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), - py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), - py::arg("alpha"), py::arg("beta")); + py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), py::arg("workspace"), + py::arg("m"), py::arg("n"), py::arg("k"), py::arg("alpha"), py::arg("beta")); m.def("nvfp4_per_tensor_gemm", &transformer_engine::pytorch::nvfp4_per_tensor_gemm, "Skinny prod NVFP4 GEMM twin of nvfp4_per_token_gemm: per-tensor amaxes " "folded into cuBLAS alpha, no trailing post-scale. Bench-only.", - py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), - py::arg("a_amax"), py::arg("b_amax"), py::arg("d"), - py::arg("workspace"), py::arg("m"), py::arg("n"), py::arg("k"), - py::arg("alpha"), py::arg("beta")); + py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), py::arg("a_amax"), + py::arg("b_amax"), py::arg("d"), py::arg("workspace"), py::arg("m"), py::arg("n"), + py::arg("k"), py::arg("alpha"), py::arg("beta")); m.def("nvfp4_per_token_group_quantize", &transformer_engine::pytorch::nvfp4_per_token_group_quantize, "Grouped (multi-tensor) NVFP4 per-token cast: K1 + K2 across <= 64 splits " @@ -448,26 +440,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("input"), py::arg("split_sections"), py::arg("q_row_list"), py::arg("s_dec_row_list"), py::arg("row_amax_list"), py::arg("q_col_list"), py::arg("s_dec_col_list"), py::arg("col_amax_list"), py::arg("rowwise"), - py::arg("columnwise"), - py::arg("with_rht") = false, + py::arg("columnwise"), py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); - m.def("nvfp4_per_token_group_amax", - &transformer_engine::pytorch::nvfp4_per_token_group_amax, + m.def("nvfp4_per_token_group_amax", &transformer_engine::pytorch::nvfp4_per_token_group_amax, "K1-only variant of nvfp4_per_token_group_quantize: only fills amax slots. " "with_rht / random_sign_mask_t must match the trailing cast launch.", py::arg("input"), py::arg("split_sections"), py::arg("row_amax_list"), py::arg("col_amax_list"), py::arg("rowwise"), py::arg("columnwise"), - py::arg("with_rht") = false, - py::arg("random_sign_mask_t") = static_cast(0xACE1)); + py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("nvfp4_per_token_group_quantize_bulk", &transformer_engine::pytorch::nvfp4_per_token_group_quantize_bulk, "Bulk grouped quantize: allocates per-split buffers + view-slices inside " "the binding (one pybind hop instead of 1 + 6N), then dispatches the K1+K2 " "kernel. with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", - py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), - py::arg("columnwise"), - py::arg("with_rht") = false, - py::arg("random_sign_mask_t") = static_cast(0xACE1)); + py::arg("input"), py::arg("split_sections"), py::arg("rowwise"), py::arg("columnwise"), + py::arg("with_rht") = false, py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py index 8c86d03b98..a2be19f73c 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py @@ -146,8 +146,9 @@ def _quantize_2d(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, tor # kernel. All-zero blocks: s_dec saturates to 0, naive S_enc/0 # would NaN; short-circuit to 0 to mirror the kernel. zero_blk = decode_scale_back_fp32 == 0 - denom = torch.where(zero_blk, torch.ones_like(decode_scale_back_fp32), - decode_scale_back_fp32) + denom = torch.where( + zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32 + ) encode_scale = S_enc_per_blk / denom encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) encode_scale = torch.minimum(encode_scale, fp32_max) @@ -200,19 +201,19 @@ def _validate_per_token_input(x: torch.Tensor) -> Tuple[int, int]: ) M, K = x.shape if M % _PER_TOKEN_TILE != 0: - raise ValueError( - f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}" - ) + raise ValueError(f"Per-token kernel requires M % {_PER_TOKEN_TILE} == 0; got M={M}") if K % _PER_TOKEN_TILE != 0: - raise ValueError( - f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" - ) + raise ValueError(f"Per-token kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") return M, K def nvfp4_per_token_quantize( - x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = False, - with_rht: bool = False, random_sign_mask_t: int = 0xACE1, + x: torch.Tensor, + *, + rowwise: bool = True, + columnwise: bool = False, + with_rht: bool = False, + random_sign_mask_t: int = 0xACE1, ) -> RefNVFP4TensorPerToken: """Production NVFP4 per-token cast through ``tex.nvfp4_per_token_quantize``. @@ -265,8 +266,17 @@ def nvfp4_per_token_quantize( q_col, s_dec_col, col_amax = empty, empty, empty_f32 tex.nvfp4_per_token_quantize( - x, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, columnwise, - with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + x, + q_row, + s_dec_row, + row_amax, + q_col, + s_dec_col, + col_amax, + rowwise, + columnwise, + with_rht=with_rht, + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, ) out = RefNVFP4TensorPerToken() @@ -290,9 +300,14 @@ def nvfp4_per_token_quantize( # above; the composite handles K1 + K2 ordering on the same stream. # ============================================================================ + def nvfp4_per_token_amax( - x: torch.Tensor, *, rowwise: bool = True, columnwise: bool = True, - with_rht: bool = False, random_sign_mask_t: int = 0xACE1, + x: torch.Tensor, + *, + rowwise: bool = True, + columnwise: bool = True, + with_rht: bool = False, + random_sign_mask_t: int = 0xACE1, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. Returns ``(row_amax, col_amax)``; either may be ``None`` if the @@ -327,8 +342,13 @@ def nvfp4_per_token_amax( ) tex.nvfp4_per_token_amax( - x, row_amax, col_amax, rowwise, columnwise, - with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + x, + row_amax, + col_amax, + rowwise, + columnwise, + with_rht=with_rht, + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, ) return (row_amax if rowwise else None, col_amax if columnwise else None) @@ -390,8 +410,17 @@ def nvfp4_per_token_encode( q_col, s_dec_col, col_amax_t = empty, empty, empty_f32 tex.nvfp4_per_token_encode( - x, q_row, s_dec_row, row_amax_t, q_col, s_dec_col, col_amax_t, rowwise, columnwise, - with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + x, + q_row, + s_dec_row, + row_amax_t, + q_col, + s_dec_col, + col_amax_t, + rowwise, + columnwise, + with_rht=with_rht, + random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, ) out = RefNVFP4TensorPerToken() diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py index ba3aabeaa6..7cd48bad7b 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py @@ -29,20 +29,14 @@ def _validate_per_token_group_input( ``(sum_M, K)``. """ if x_concat.ndim != 2: - raise ValueError( - f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D" - ) + raise ValueError(f"nvfp4_per_token_group_quantize expects a 2D input, got {x_concat.ndim}D") if not x_concat.is_contiguous(): raise ValueError("x_concat must be contiguous (row-major)") if x_concat.dtype != torch.bfloat16: - raise ValueError( - f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}." - ) + raise ValueError(f"Per-token grouped kernel is bf16-only; got dtype {x_concat.dtype}.") sum_M, K = x_concat.shape if K % _PER_TOKEN_TILE != 0: - raise ValueError( - f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}" - ) + raise ValueError(f"Per-token grouped kernel requires K % {_PER_TOKEN_TILE} == 0; got K={K}") if len(split_sections) == 0: raise ValueError("split_sections must not be empty") if len(split_sections) > 64: @@ -54,14 +48,10 @@ def _validate_per_token_group_input( if M_i <= 0: raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") if M_i % _PER_TOKEN_TILE != 0: - raise ValueError( - f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}" - ) + raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}") acc += M_i if acc != sum_M: - raise ValueError( - f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}" - ) + raise ValueError(f"sum(split_sections) = {acc} must equal input.size(0) = {sum_M}") return sum_M, K @@ -109,7 +99,10 @@ def nvfp4_per_token_group_quantize( s_dec_col_list, col_amax_list, ) = tex.nvfp4_per_token_group_quantize_bulk( - x_concat, split_sections_list, rowwise, columnwise, + x_concat, + split_sections_list, + rowwise, + columnwise, with_rht=bool(with_rht), random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, ) From 2d2899a3951d564be8ce059a52c99b2fa6409670 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Fri, 29 May 2026 08:04:32 -0700 Subject: [PATCH 5/6] Fuse rowwise SF swizzle into NVFP4 per-token K2 + bench scaffolding Add an optional fused-swizzle path to the NVFP4 per-token K2 encode kernel: when with_swizzle=True the rowwise scale_inv is emitted directly in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the compact M-major layout (rowwise-only fusion for now). The new code path is gated by a kWithSwizzle template parameter on per_token_encode_kernel. The scatter epilogue uses thread mapping b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store, and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict instead of 4-way). Pre-existing code path is byte-equal. with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode}, their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled flags so the caller opts into the fast path per operand (mirrors prod NVFP4 GEMM's per-operand swizzle). Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around nvte_swizzle_scaling_factors that does one standalone per-operand swizzle launch. Bench-only; lets --qs attribute swizzle cost separately from K1+K2 and from cuBLAS LT GEMM. Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token (fuse) column for the K2-fused path). The existing --swizzle end-to-end mode also gains the fused-swizzle column. --pair / --fuse auto-imply --qs to avoid silent fall-through to the default --composite table. Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle rowwise SF vs a pure-Python permutation reference, byte-equality of all other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False, and numerical equivalence of the end-to-end GEMM via both code paths. Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster than per-tensor at small M. Co-authored-by: Zhongbo Zhu Co-authored-by: Jiaxing Qi Signed-off-by: Cael Ling --- tests/pytorch/nvfp4/bench_nvfp4_per_token.py | 597 +++++++++++++++++- tests/pytorch/nvfp4/test_nvfp4_per_token.py | 119 ++++ .../cast/nvfp4/quantize_nvfp4_per_token.cu | 138 ++-- .../transformer_engine/nvfp4_per_token.h | 12 +- transformer_engine/pytorch/csrc/extensions.h | 14 +- .../csrc/extensions/nvfp4_per_token.cpp | 95 ++- .../pytorch/csrc/extensions/pybind.cpp | 28 +- .../custom_recipes/gemm_nvfp4_per_token.py | 8 + .../quantization_nvfp4_per_token.py | 45 +- 9 files changed, 919 insertions(+), 137 deletions(-) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py index 7da8ec9519..633aefea8a 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py @@ -4,15 +4,29 @@ """Bench NVFP4 per-token K1+K2 quant vs per-tensor RHT+SR baseline. -Quant-only (no GEMM). bf16, M % 128 == 0, K % 128 == 0. +bf16, M % 128 == 0, K % 128 == 0. Modes: - * default: 2-way composite (per-token vs per-tensor). Ratio = pt / pten. - * ``--rht``: 3-way composite (adds per-token + col-wise 16-pt RHT). Ratio = - per-token (+rht) / per-tensor. - * ``--k1-only``: K1 in isolation. Without ``--rht``: pt_K1 vs prod_K1. - With ``--rht``: (A) pt_K1 vs pt_K1+RHT (apples-to-apples) and - (B) pt_K1+RHT vs prod_K1 (NOT apples-to-apples; output shapes differ). + * default: 2-way quant-only (per-token vs per-tensor). Ratio = pt / pten. + * ``--rht``: 3-way quant-only (adds per-token + col-wise 16-pt RHT). + * ``--swizzle``: 3-way END-TO-END (quant + swizzle + cuBLAS LT NVFP4 GEMM). + Compares per-token (separate swizzle launch) vs per-token (fused + swizzle in K2) vs per-tensor. Ratio = per-token (+swizzle) / per-tensor. + * ``--qs``: 2-way K1+K2 + standalone rowwise swizzle. NO GEMM. + - default (solo, 1 tensor): K1+K2(A) + swizzle(A); apples-to-apples + with --composite (which is also 1-tensor) -- the delta vs --composite + is the pure marginal swizzle cost. + - ``--pair`` (2 tensors): K1+K2(A) + K1+K2(B) + swizzle(A) + swizzle(B); + matches prod NVFP4 GEMM's per-call quant+swizzle pipeline (1 swizzle + per operand). Use this when you want "one GEMM call's worth of + non-GEMM cost". + - ``--fuse``: also bench per-token with fused-swizzle K2 (K2 directly + writes the rowwise SF in cuBLAS LT swizzled layout; no separate + swizzle launch). Prints a 3-way table: per-token / per-token(fuse) / + per-tensor. The (fuse) column saves 1 swizzle launch/operand vs the + non-fuse column. + Ratio = per-token / per-tensor (3-way mode adds a per-token(fuse) column). + * ``--k1-only``: K1 in isolation (orthogonal to --swizzle / --qs). """ from __future__ import annotations @@ -130,6 +144,35 @@ class K1ShapeBench: t_prod_g: float +@dataclass +class E2EShapeBench: + """End-to-end (quant + GEMM) timing for --swizzle mode. N is bound to M.""" + + M: int + K: int + t_pt: float # per-token (no fused swizzle): quant + ext swizzle + GEMM + t_pt_swz: float # per-token (fused swizzle): quant_with_swizzle=True + GEMM + t_pten: float # per-tensor: NVFP4Quantizer + cuBLAS LT GEMM + t_pt_g: float + t_pt_swz_g: float + t_pten_g: float + + +@dataclass +class QSShapeBench: + """K1+K2 + rowwise swizzle, no GEMM. solo=3 launches, --pair=6, + --fuse adds per-token-fused column (K2 emits swizzled SF in 1 launch).""" + + M: int + K: int + t_pt: float # per-token K1+K2 + ext swizzle (1 or 2 operands depending on pair) + t_pten: float # per-tensor K1+K2 + ext swizzle (matching operand count) + t_pt_g: float + t_pten_g: float + t_pt_swz: float = float("nan") # per-token K1+K2 with fused swizzle (no ext swz launch) + t_pt_swz_g: float = float("nan") + + # Default mask seed; matches prod's `te-nvfp4-build-overrides.mdc` convention. _RHT_MASK_DEFAULT: int = 0xACE1 @@ -137,10 +180,8 @@ class K1ShapeBench: def _bench_shape( M: int, K: int, *, device: torch.device, with_rht: bool = False, mask_t: int = _RHT_MASK_DEFAULT ) -> ShapeBench: - """Composite K1+K2 timing at one (M, K) shape. - pt = per-token (no RHT), pt_rht = per-token + col-wise 16-pt RHT - (NaN unless with_rht=True), pten = per-tensor + RHT + SR (prod baseline). - """ + """Composite K1+K2 at (M, K). pt = per-token (no RHT); pt_rht = +col-wise + 16-pt RHT (NaN unless with_rht=True); pten = per-tensor + RHT + SR.""" a = torch.randn((M, K), dtype=torch.bfloat16, device=device) quantizer = _make_baseline_quantizer() @@ -213,13 +254,450 @@ def _pt_full_quant_rht_fn(): ) +def _bench_shape_e2e_swizzle( + M: int, + K: int, + *, + device: torch.device, + with_rht: bool = False, + mask_t: int = _RHT_MASK_DEFAULT, +) -> E2EShapeBench: + """E2E (quant + cuBLAS LT NVFP4 GEMM) for --swizzle, square N=M. + pt: ext swizzle; pt_swz: fused-swizzle K2 (no internal swz launch); + pten: NVFP4Quantizer + nvfp4_per_tensor_gemm baseline.""" + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace + + N = M # square; cuBLAS LT NVFP4 is TN-only -- A: (M, K), B: (N, K) + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + d = torch.empty((M, N), dtype=torch.bfloat16, device=device) + # torch.device("cuda").index is None (no explicit device index); resolve to + # an actual GPU index via the allocated tensor so get_cublas_workspace + # creates the workspace on the right CUDA device instead of CPU. + workspace = get_cublas_workspace(a.device.index, ub=False, grouped_gemm=False) + + # Per-token quant produces row + col directions on every call (matches the + # per-tensor baseline below which does both in one kernel). GEMM consumes + # only the rowwise side; the col allocation is realistic prod overhead. + BLOCK_K = 16 + + def _alloc_pt(R, C): + return ( + torch.empty((R, C // 2), dtype=torch.uint8, device=device), + torch.empty((R, C // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((R,), dtype=torch.float32, device=device), + torch.empty((C, R // 2), dtype=torch.uint8, device=device), + torch.empty((C, R // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((C,), dtype=torch.float32, device=device), + ) + + a_qr, a_sr, a_ra, a_qc, a_sc, a_ca = _alloc_pt(M, K) + b_qr, b_sr, b_ra, b_qc, b_sc, b_ca = _alloc_pt(N, K) + + def _pt_quant(t, qr, sr, ra_buf, qc, sc, ca_buf, *, fused_swizzle: bool): + tex.nvfp4_per_token_quantize( + t, qr, sr, ra_buf, qc, sc, ca_buf, + True, True, # rowwise + columnwise (apples-to-apples vs per-tensor) + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=fused_swizzle, + ) + + def _pt_e2e_ext_swizzle(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca, fused_swizzle=False) + _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca, fused_swizzle=False) + tex.nvfp4_per_token_gemm( + a_qr, b_qr, a_sr.reshape(-1), b_sr.reshape(-1), + a_ra, b_ra, d, workspace, M, N, K, 1.0, 0.0, + a_sf_swizzled=False, b_sf_swizzled=False, + ) + + def _pt_e2e_fused_swizzle(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca, fused_swizzle=True) + _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca, fused_swizzle=True) + tex.nvfp4_per_token_gemm( + a_qr, b_qr, a_sr.reshape(-1), b_sr.reshape(-1), + a_ra, b_ra, d, workspace, M, N, K, 1.0, 0.0, + a_sf_swizzled=True, b_sf_swizzled=True, + ) + + # Per-tensor path: NVFP4Quantizer (RHT+SR) + bench-only nvfp4_per_tensor_gemm. + quantizer = _make_baseline_quantizer() + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) + dst_b = quantizer.make_empty(b.shape, dtype=torch.bfloat16, device=device) + + def _pten_e2e(): + tex.quantize(a, quantizer, dst_a, None) + tex.quantize(b, quantizer, dst_b, None) + tex.nvfp4_per_tensor_gemm( + dst_a._rowwise_data, dst_b._rowwise_data, + dst_a._rowwise_scale_inv, dst_b._rowwise_scale_inv, + dst_a._amax_rowwise, dst_b._amax_rowwise, + d, workspace, M, N, K, 1.0, 0.0, + ) + + t_pt = cuda_time_ms(_pt_e2e_ext_swizzle) + t_pt_swz = cuda_time_ms(_pt_e2e_fused_swizzle) + t_pten = cuda_time_ms(_pten_e2e) + t_pt_g = cuda_graph_time_ms(_pt_e2e_ext_swizzle) + t_pt_swz_g = cuda_graph_time_ms(_pt_e2e_fused_swizzle) + t_pten_g = cuda_graph_time_ms(_pten_e2e) + + return E2EShapeBench( + M=M, K=K, + t_pt=t_pt, t_pt_swz=t_pt_swz, t_pten=t_pten, + t_pt_g=t_pt_g, t_pt_swz_g=t_pt_swz_g, t_pten_g=t_pten_g, + ) + + +def _bench_shape_qs( + M: int, + K: int, + *, + device: torch.device, + with_rht: bool = False, + mask_t: int = _RHT_MASK_DEFAULT, + pair: bool = False, + fuse: bool = False, +) -> QSShapeBench: + """K1+K2 + standalone rowwise swizzle, no GEMM. solo=3 launches/operand, + --pair=6 (A+B). Swizzle binding identical across pt/pten -- only K1+K2 differs.""" + N = M # square; matches --swizzle's apples-to-apples convention. + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + BLOCK_K = 16 + + def _alloc_pt(R, C): + return ( + torch.empty((R, C // 2), dtype=torch.uint8, device=device), + torch.empty((R, C // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((R,), dtype=torch.float32, device=device), + torch.empty((C, R // 2), dtype=torch.uint8, device=device), + torch.empty((C, R // BLOCK_K), dtype=torch.uint8, device=device), + torch.empty((C,), dtype=torch.float32, device=device), + ) + + a_qr, a_sr, a_ra, a_qc, a_sc, a_ca = _alloc_pt(M, K) + a_sr_swz = torch.empty(a_sr.numel(), dtype=torch.uint8, device=device) + + # B-side allocation only when --pair (avoids spurious HBM pressure in solo). + if pair: + b = torch.randn((N, K), dtype=torch.bfloat16, device=device) + b_qr, b_sr, b_ra, b_qc, b_sc, b_ca = _alloc_pt(N, K) + b_sr_swz = torch.empty(b_sr.numel(), dtype=torch.uint8, device=device) + + def _pt_quant(t, qr, sr, ra_buf, qc, sc, ca_buf): + tex.nvfp4_per_token_quantize( + t, qr, sr, ra_buf, qc, sc, ca_buf, + True, True, + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=False, # explicit external swizzle, see below + ) + + if pair: + def _pt_qs(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca) + _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca) + tex.nvfp4_per_token_swizzle_rowwise_sf(a_qr, a_sr.reshape(-1), a_sr_swz) + tex.nvfp4_per_token_swizzle_rowwise_sf(b_qr, b_sr.reshape(-1), b_sr_swz) + else: + def _pt_qs(): + _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca) + tex.nvfp4_per_token_swizzle_rowwise_sf(a_qr, a_sr.reshape(-1), a_sr_swz) + + # Per-tensor baseline path: NVFP4Quantizer (RHT+SR), reuse internal storage. + quantizer = _make_baseline_quantizer() + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) + pten_a_sr_swz = torch.empty(dst_a._rowwise_scale_inv.numel(), dtype=torch.uint8, device=device) + if pair: + dst_b = quantizer.make_empty(b.shape, dtype=torch.bfloat16, device=device) + pten_b_sr_swz = torch.empty( + dst_b._rowwise_scale_inv.numel(), dtype=torch.uint8, device=device + ) + + if pair: + def _pten_qs(): + tex.quantize(a, quantizer, dst_a, None) + tex.quantize(b, quantizer, dst_b, None) + tex.nvfp4_per_token_swizzle_rowwise_sf( + dst_a._rowwise_data, dst_a._rowwise_scale_inv.reshape(-1), pten_a_sr_swz + ) + tex.nvfp4_per_token_swizzle_rowwise_sf( + dst_b._rowwise_data, dst_b._rowwise_scale_inv.reshape(-1), pten_b_sr_swz + ) + else: + def _pten_qs(): + tex.quantize(a, quantizer, dst_a, None) + tex.nvfp4_per_token_swizzle_rowwise_sf( + dst_a._rowwise_data, dst_a._rowwise_scale_inv.reshape(-1), pten_a_sr_swz + ) + + t_pt = cuda_time_ms(_pt_qs) + t_pten = cuda_time_ms(_pten_qs) + t_pt_g = cuda_graph_time_ms(_pt_qs) + t_pten_g = cuda_graph_time_ms(_pten_qs) + + t_pt_swz = float("nan") + t_pt_swz_g = float("nan") + if fuse: + # Fused-swizzle K2: writes rowwise SF directly in swizzled layout + # (same numel as compact, just byte-permuted). No external swizzle + # launch -- K1+K2 alone is the full pipeline. + a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f = _alloc_pt(M, K) + if pair: + b_qr_f, b_sr_f, b_ra_f, b_qc_f, b_sc_f, b_ca_f = _alloc_pt(N, K) + + def _pt_quant_fused(t, qr, sr, ra_buf, qc, sc, ca_buf): + tex.nvfp4_per_token_quantize( + t, qr, sr, ra_buf, qc, sc, ca_buf, + True, True, + with_rht=with_rht, + random_sign_mask_t=mask_t if with_rht else 0, + with_swizzle=True, # <-- fused: K2 emits swizzled rowwise SF + ) + + if pair: + def _pt_qs_fused(): + _pt_quant_fused(a, a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f) + _pt_quant_fused(b, b_qr_f, b_sr_f, b_ra_f, b_qc_f, b_sc_f, b_ca_f) + else: + def _pt_qs_fused(): + _pt_quant_fused(a, a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f) + + t_pt_swz = cuda_time_ms(_pt_qs_fused) + t_pt_swz_g = cuda_graph_time_ms(_pt_qs_fused) + + return QSShapeBench( + M=M, K=K, t_pt=t_pt, t_pten=t_pten, t_pt_g=t_pt_g, t_pten_g=t_pten_g, + t_pt_swz=t_pt_swz, t_pt_swz_g=t_pt_swz_g, + ) + + +def _print_qs_table(records: List[QSShapeBench], *, fuse: bool) -> None: + """K1+K2 + rowwise swizzle (no GEMM). 2-way default, 3-way w/ --fuse. + Ratio = per-token(fuse if --fuse else plain) / per-tensor.""" + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + if not fuse: + w_pt, w_pten, w_ratio = 14, 15, 8 + block_w = w_pt + 1 + w_pten + 1 + w_ratio + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + print(header1) + print(header2) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt, rec.t_pten) + ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g) + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + return + + # 3-way with fuse column + w_pt, w_swz, w_pten, w_ratio = 12, 14, 13, 8 + block_w = w_pt + 1 + w_swz + 1 + w_pten + 1 + w_ratio + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>7} {'':>6}" + " |" + f"{'':>{w_pt}} {'(fuse)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + " |" + f"{'':>{w_pt}} {'(fuse)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + # 3-way ratio uses the fused-swizzle column vs per-tensor. + ratio = _ratio(rec.t_pt_swz, rec.t_pten) + ratio_g = _ratio(rec.t_pt_swz_g, rec.t_pten_g) + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pt_swz:>{w_swz}.4f}" + f" {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pt_swz_g:>{w_swz}.4f}" + f" {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_qs_legend(*, with_rht: bool, rht_mask: int, pair: bool, fuse: bool) -> None: + print() + n_tensors = 2 if pair else 1 + n_launches_ext = 3 * n_tensors # K1+K2+swz per tensor + n_launches_fused = 2 * n_tensors # K1+K2 only per tensor (swizzle folded into K2) + mode_tag = "--pair, 2 operands" if pair else "default solo, 1 operand" + n_kernels_tag = ( + f"ext-swz pipeline {n_launches_ext} launches" + + (f" / fused pipeline {n_launches_fused} launches" if fuse else "") + ) + print(f"Legend (K1+K2 + rowwise swizzle; NO GEMM; mode = {mode_tag}; {n_kernels_tag}):") + rht_suffix = ( + f"with_rht=True + random_sign_mask_t=0x{rht_mask:04X}" if with_rht else "with_rht=False" + ) + print( + f" per-token (ms) = {n_tensors} x nvfp4_per_token_quantize({rht_suffix})" + " # K1+K2 each" + ) + print( + f" + {n_tensors} x nvfp4_per_token_swizzle_rowwise_sf" + " # 1 swz each" + ) + print(" K1 = nvfp4_per_token_amax (per-row/per-col vec amax)") + print(" K2 = nvfp4_per_token_encode (cast + e4m3 SF + optional RHT)") + if fuse: + print( + f" per-token (fuse) (ms) = {n_tensors} x nvfp4_per_token_quantize(..., " + "with_swizzle=True)" + ) + print(" # K1+K2 each; K2 directly emits the swizzled rowwise") + print(" # SF in cuBLAS LT layout (no separate swizzle launch).") + print( + f" per-tensor (ms) = {n_tensors} x tex.quantize(NVFP4Quantizer(rht+sr))" + " # K1+K2 each" + ) + print( + f" + {n_tensors} x nvfp4_per_token_swizzle_rowwise_sf" + " # 1 swz each" + ) + print(" K1 = nvte_hadamard_transform_amax (post-RHT scalar amax)") + print(" K2 = nvte_quantize_with_hadamard_transform") + print(" (RHT + SR + cast fusion, rowwise + columnwise)") + if fuse: + print(" The (fuse) column saves 1 swizzle launch/operand vs the non-fuse column;") + print(" the K2 byte-output is identical (verified by pytest byte-equality test).") + if not pair: + print( + " solo mode is apples-to-apples with --composite (also 1 operand): the delta" + ) + print( + " per-token(--qs) - per-token(--composite) ~= one nvte_swizzle launch." + ) + else: + print( + " --pair mode = one prod NVFP4 GEMM call's quant+swizzle pipeline " + "(1 swz/operand)." + ) + if fuse: + print(" ratio = per-token(fuse) / per-tensor") + else: + print(" ratio = per-token / per-tensor") + print(" ** < 1.0 = this PR wins vs prod K1+K2+swizzle path **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + + +def _print_e2e_swizzle_table(records: List[E2EShapeBench]) -> None: + """3-way end-to-end (--swizzle). ratio = per-token (+swizzle) / per-tensor.""" + w_pt, w_swz, w_pten, w_ratio = 12, 14, 13, 8 + block_w = w_pt + 1 + w_swz + 1 + w_pten + 1 + w_ratio + header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header2 = ( + f"{'M':>7} {'K':>6}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + " |" + f"{'per-token':>{w_pt}} {'per-token':>{w_swz}}" + f" {'per-tensor':>{w_pten}} {'ratio':>{w_ratio}}" + ) + header3 = ( + f"{'':>7} {'':>6}" + " |" + f"{'':>{w_pt}} {'(+swizzle)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + " |" + f"{'':>{w_pt}} {'(+swizzle)':>{w_swz}}" + f" {'':>{w_pten}} {'':>{w_ratio}}" + ) + print(header1) + print(header2) + print(header3) + print("-" * len(header2)) + prev_M = None + for rec in records: + if prev_M is not None and rec.M != prev_M: + print() + prev_M = rec.M + ratio = _ratio(rec.t_pt_swz, rec.t_pten) + ratio_g = _ratio(rec.t_pt_swz_g, rec.t_pten_g) + + def _fmt(r: float) -> str: + return "nan" if math.isnan(r) else f"{r:.2f}x" + + print( + f"{rec.M:>7} {rec.K:>6}" + " |" + f"{rec.t_pt:>{w_pt}.4f} {rec.t_pt_swz:>{w_swz}.4f}" + f" {rec.t_pten:>{w_pten}.4f} {_fmt(ratio):>{w_ratio}}" + " |" + f"{rec.t_pt_g:>{w_pt}.4f} {rec.t_pt_swz_g:>{w_swz}.4f}" + f" {rec.t_pten_g:>{w_pten}.4f} {_fmt(ratio_g):>{w_ratio}}" + ) + + +def _print_e2e_swizzle_legend(*, with_rht: bool, rht_mask: int) -> None: + print() + print("Legend (end-to-end quant + cuBLAS LT NVFP4 GEMM, square N=M):") + rht_suffix = ( + f"with_rht=True + random_sign_mask_t=0x{rht_mask:04X}" if with_rht else "with_rht=False" + ) + print(f" per-token (ms) = nvfp4_per_token_quantize({rht_suffix}) +") + print(" nvfp4_per_token_gemm(sf_swizzled=False)") + print(" -> K1 + K2 + 2 swizzle launches + cuBLAS LT GEMM") + print(" + per-token post-scale.") + print(f" per-token (+swizzle) (ms) = nvfp4_per_token_quantize({rht_suffix},") + print(" with_swizzle=True) +") + print(" nvfp4_per_token_gemm(sf_swizzled=True)") + print(" -> K1 + K2 (fused swizzle) + cuBLAS LT GEMM") + print(" + per-token post-scale. (2 launches saved.)") + print(" per-tensor (ms) = tex.quantize(a, NVFP4Quantizer(rht+sr)) +") + print(" nvfp4_per_tensor_gemm (cuBLAS LT NVFP4)") + print(" -> fused RHT+quant + 2 swizzle launches + GEMM.") + print(" ratio = per-token (+swizzle) / per-tensor") + print(" ** < 1.0 = this PR wins vs prod baseline **") + print(" (Graph) suffix = same under CUDA Graphs replay (Python + alloc elided).") + + def _bench_shape_k1_only( M: int, K: int, *, device: torch.device, with_rht: bool = False, mask_t: int = _RHT_MASK_DEFAULT ) -> K1ShapeBench: - """K1-only timing. pt = per-token (no RHT), pt_rht = per-token + col RHT - (NaN unless with_rht=True), prod = hadamard_transform_amax (scalar amax; - NOT apples-to-apples but the closest prod K1 reference). - """ + """K1-only. pt = per-token (no RHT); pt_rht = +col RHT (NaN unless --rht); + prod = hadamard_transform_amax (scalar amax; not apples-to-apples).""" a = torch.randn((M, K), dtype=torch.bfloat16, device=device) # Per-token K1 amax buffers (vectors). @@ -387,9 +865,8 @@ def _fmt(r: float) -> str: def _print_k1_2way_table(records: List[K1ShapeBench]) -> None: - """2-way K1 (default --k1-only). pt_K1 vs prod_K1; NOT apples-to-apples - (per-token K1 outputs M+K floats, prod outputs 2 scalars). - """ + """2-way K1 (default --k1-only). pt_K1 vs prod_K1 (not apples-to-apples: + per-token outputs M+K floats, prod outputs 2 scalars).""" print("K1-only: pt vs prod (NOT apples-to-apples; output shapes differ).") header = ( f"{'M':>7} {'K':>6}" @@ -449,9 +926,8 @@ def _print_k1_rht_cost_table(records: List[K1ShapeBench]) -> None: def _print_k1_vs_prod_table(records: List[K1ShapeBench]) -> None: - """Table B: pt_K1+RHT vs prod_K1 (NOT apples-to-apples; output shapes - differ -- 2 scalars vs M+K floats). Fast-floor reference only. - """ + """Table B: pt_K1+RHT vs prod_K1 (not apples-to-apples; 2 scalars + vs M+K floats). Fast-floor reference only.""" print("Table B -- K1-only vs prod (NOT apples-to-apples; output shapes differ).") header = ( f"{'M':>7} {'K':>6}" @@ -555,6 +1031,47 @@ def main() -> int: "(context only; output shapes differ)." ), ) + parser.add_argument( + "--swizzle", + action="store_true", + help=( + "End-to-end mode: quant + cuBLAS LT NVFP4 GEMM (square N=M). " + "Prints a 3-way table: per-token (external swizzle) vs per-token " + "(fused swizzle in K2, sf_swizzled=True) vs per-tensor. Ratio = " + "per-token (+swizzle) / per-tensor. --rht composes (adds 16-pt " + "col-wise RHT to the per-token paths)." + ), + ) + parser.add_argument( + "--qs", + action="store_true", + help=( + "K1+K2 + standalone rowwise swizzle. NO GEMM. 2-way table: " + "per-token vs per-tensor. Default solo (1 operand, 3 launches) is " + "apples-to-apples with --composite; add --pair for 2-operand " + "(6 launches, matches prod NVFP4 GEMM's per-call pipeline). " + "--rht composes." + ), + ) + parser.add_argument( + "--pair", + action="store_true", + help=( + "Modifier for --qs: bench the 2-operand (A + B) pipeline, matching " + "what prod NVFP4 GEMM does per call (1 K1+K2 + 1 swizzle per " + "operand). Default (no --pair) is solo (1 operand)." + ), + ) + parser.add_argument( + "--fuse", + action="store_true", + help=( + "Modifier for --qs: also bench per-token with fused-swizzle K2 " + "(K2 directly emits the rowwise SF in cuBLAS LT swizzled layout; " + "no separate swizzle launch). Adds a 'per-token(fuse)' column to " + "the table, and the ratio switches to per-token(fuse) / per-tensor." + ), + ) parser.add_argument( "--rht-mask", type=lambda s: int(s, 0), @@ -574,6 +1091,27 @@ def main() -> int: shapes = list(args.shapes) if args.shapes else list(_DEFAULT_SHAPES) mask = args.rht_mask & 0xFFFF + # --pair / --fuse are modifiers for --qs; auto-imply --qs if either is set + # alone, so we don't silently fall through to --composite default and bake + # a confusing "looks-like-the-modifier-worked-but-didnt" table. + if (args.pair or args.fuse) and not args.qs: + modifiers = [] + if args.pair: + modifiers.append("--pair") + if args.fuse: + modifiers.append("--fuse") + print( + f"INFO: {' / '.join(modifiers)} implies --qs; running --qs " + f"{' '.join(modifiers)} (K1+K2 + swizzle, no GEMM).", + file=sys.stderr, + ) + args.qs = True + + exclusive = sum(int(x) for x in (args.k1_only, args.swizzle, args.qs)) + if exclusive > 1: + print("ERROR: --k1-only, --swizzle, and --qs are mutually exclusive.", file=sys.stderr) + return 2 + if args.k1_only: records_k1: List[K1ShapeBench] = [ _bench_shape_k1_only(M, K, device=device, with_rht=args.rht, mask_t=mask) @@ -585,6 +1123,23 @@ def main() -> int: _print_k1_vs_prod_table(records_k1) else: _print_k1_2way_table(records_k1) + elif args.swizzle: + records_e2e: List[E2EShapeBench] = [ + _bench_shape_e2e_swizzle(M, K, device=device, with_rht=args.rht, mask_t=mask) + for (M, K) in shapes + ] + _print_e2e_swizzle_table(records_e2e) + _print_e2e_swizzle_legend(with_rht=args.rht, rht_mask=mask) + elif args.qs: + records_qs: List[QSShapeBench] = [ + _bench_shape_qs( + M, K, device=device, with_rht=args.rht, mask_t=mask, + pair=args.pair, fuse=args.fuse, + ) + for (M, K) in shapes + ] + _print_qs_table(records_qs, fuse=args.fuse) + _print_qs_legend(with_rht=args.rht, rht_mask=mask, pair=args.pair, fuse=args.fuse) else: records: List[ShapeBench] = [ _bench_shape(M, K, device=device, with_rht=args.rht, mask_t=mask) for (M, K) in shapes diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token.py b/tests/pytorch/nvfp4/test_nvfp4_per_token.py index cfe2a9f4aa..951b5b482e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token.py @@ -834,3 +834,122 @@ def test_per_token_composite_with_rht_col_amax_matches_k1( torch.testing.assert_close( bufs["ra"], ra_k1, rtol=0.0, atol=0.0, msg=f"composite ra != K1-only ra at ({M}, {K})" ) + + +# ============================================================================= +# (6) Fused-swizzle correctness: K2 with_swizzle=True emits rowwise SF in +# cuBLAS LT layout. Tests cover byte-equal vs Python reference, other-outputs +# identical to with_swizzle=False, and GEMM fast-path numerical equivalence. +# ============================================================================= + +_SWIZZLE_SHAPES = [ + (128, 128), + (256, 256), + (512, 512), + (256, 1024), + (1024, 256), +] + + +def _swizzle_sf_reference(sf_m_major: torch.Tensor) -> torch.Tensor: + """Reference M-major (M, K_SF) e4m3 -> cuBLAS LT swizzled flat bytes + (128Mx4K tile, 16-byte slot = 4 M-stripes x 4 K-bytes stripe-major).""" + M, K_SF = sf_m_major.shape + assert M % 128 == 0 + assert K_SF % 4 == 0 + device = sf_m_major.device + sf_u8 = sf_m_major.contiguous().view(torch.uint8) + out = torch.empty(M * K_SF, dtype=torch.uint8, device=device) + + m_idx = torch.arange(M, device=device, dtype=torch.int64).view(M, 1).expand(M, K_SF) + k_idx = torch.arange(K_SF, device=device, dtype=torch.int64).view(1, K_SF).expand(M, K_SF) + m_tile = m_idx // 128 + k_tile = k_idx // 4 + out_idx = ( + m_tile * 128 * K_SF + + k_tile * 512 + + (m_idx % 32) * 16 + + ((m_idx % 128) // 32) * 4 + + (k_idx % 4) + ) + out[out_idx.reshape(-1)] = sf_u8.reshape(-1) + return out + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _SWIZZLE_SHAPES) +def test_per_token_with_swizzle_sf_byte_equal_to_reference(M: int, K: int) -> None: + """Fused-swizzle rowwise scale_inv matches the Python byte-permutation + reference of the M-major SF (covers both rowwise-only and rowwise+colwise). + """ + device = torch.device("cuda") + torch.manual_seed(0) + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + out_plain = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=False) + out_swz = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=True) + + ref_swz_sf = _swizzle_sf_reference(out_plain.scale.view(torch.uint8)) + got_swz_sf = out_swz.scale.view(torch.uint8).reshape(-1) + + torch.testing.assert_close( + got_swz_sf, + ref_swz_sf, + rtol=0, + atol=0, + msg=f"fused-swizzle rowwise SF mismatch at ({M}, {K})", + ) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", _SWIZZLE_SHAPES) +def test_per_token_with_swizzle_other_outputs_unchanged(M: int, K: int) -> None: + """Only the rowwise scale_inv layout differs: FP4 data, row_amax, colwise + data / scale_inv / col_amax must be byte-identical between with_swizzle + True and False. + """ + device = torch.device("cuda") + torch.manual_seed(0) + x = torch.randn((M, K), dtype=torch.bfloat16, device=device) + + out_plain = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=False) + out_swz = nvfp4_per_token_quantize(x, rowwise=True, columnwise=True, with_swizzle=True) + + torch.testing.assert_close(out_swz.data, out_plain.data, rtol=0, atol=0) + torch.testing.assert_close(out_swz.row_amax, out_plain.row_amax, rtol=0, atol=0) + torch.testing.assert_close(out_swz.columnwise_data, out_plain.columnwise_data, rtol=0, atol=0) + torch.testing.assert_close(out_swz.columnwise_scale, out_plain.columnwise_scale, rtol=0, atol=0) + torch.testing.assert_close(out_swz.col_amax, out_plain.col_amax, rtol=0, atol=0) + + +@_GATED_SM100 +@pytest.mark.parametrize("M,K", [(256, 256), (512, 1024), (1024, 512)]) +def test_per_token_gemm_with_fused_swizzle_matches_unswizzled(M: int, K: int) -> None: + """E2E GEMM two paths: (A) with_swizzle=False + ext swizzle (sf_swizzled=False) + vs (B) with_swizzle=True + sf_swizzled=True. Same SF bytes to cuBLAS LT, + so C outputs must be byte-equal.""" + device = torch.device("cuda") + torch.manual_seed(0) + N = M # square; GEMM is TN with M, N free + A = torch.randn((M, K), dtype=torch.bfloat16, device=device) + B = torch.randn((N, K), dtype=torch.bfloat16, device=device) + + a_plain = nvfp4_per_token_quantize(A, rowwise=True, columnwise=False, with_swizzle=False) + b_plain = nvfp4_per_token_quantize(B, rowwise=True, columnwise=False, with_swizzle=False) + c_unswz = nvfp4_per_token_gemm( + a_plain.data, a_plain.scale, a_plain.row_amax, + b_plain.data, b_plain.scale, b_plain.row_amax, + ) + + a_swz = nvfp4_per_token_quantize(A, rowwise=True, columnwise=False, with_swizzle=True) + b_swz = nvfp4_per_token_quantize(B, rowwise=True, columnwise=False, with_swizzle=True) + c_swz = nvfp4_per_token_gemm( + a_swz.data, a_swz.scale, a_swz.row_amax, + b_swz.data, b_swz.scale, b_swz.row_amax, + a_sf_swizzled=True, b_sf_swizzled=True, + ) + + torch.testing.assert_close( + c_swz, c_unswz, rtol=0, atol=0, + msg=f"fused-swizzle GEMM output != unswizzled-input GEMM at ({M}, {K})", + ) diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu index a6f08e9f74..ae804ef147 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu @@ -362,10 +362,10 @@ __device__ __forceinline__ void colwise_scaling_per_token( // ============================================================================= // Kernel 2: per-token encode (rowwise + optional colwise transpose). -// kWithRht=true: col-wise FP4 cast over RHT-rotated strips, matching K1's -// RHT-rotated columnwise_amax. Row direction never sees RHT. +// kWithRht=true: col-wise FP4 cast over RHT-rotated strips (row never sees RHT). +// kWithSwizzle=true: rowwise SF emitted directly in cuBLAS LT 128x4 tile layout. // ============================================================================= -template +template __global__ void __launch_bounds__(THREADS_NUM) per_token_encode_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -565,20 +565,49 @@ __global__ void __launch_bounds__(THREADS_NUM) buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; } // end of stages - // Vectorized SF scatter to global (chunk-end batch). Mirrors the - // production tuned 1D scale-store epilogue. + // Vectorized SF scatter. kWithSwizzle=false: compact M-major (downstream + // nvte_swizzle_scaling_factors re-permutes). kWithSwizzle=true: emit cuBLAS + // LT 128Mx4K tile layout directly; thread mapping below is perf-critical. if (DO_ROW) { auto& sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); - using ScalesVec = Vec; - const int chunk_cols = static_cast(cols) - block_offset_X; - const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); - - for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { - const size_t row_global = scales_block_offset_Y_rowwise + row; - if (row_global < rows) { - ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); - const size_t scale_idx_global = row_global * scale_stride + scales_block_offset_X_rowwise; - scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + if constexpr (kWithSwizzle) { + // uint64_t SMEM load below assumes each sSFrowwise row is exactly 8 bytes + // (= 2 K-tiles of 4 K-bytes each); any other geometry needs a different + // pack/store split. + static_assert(SCALES_PER_CHUNK_X == 8, + "fused-swizzle rowwise scatter assumes SCALES_PER_CHUNK_X == 8"); + const int tid = threadIdx.x; + const int b = tid & 3; // M-stripe [0, 4), fast axis -> coalesced gmem + const int ty = tid >> 2; // slot index within K-tile [0, 32) + const int lm = b * 32 + ty; // [0, 128) + const size_t M_tile_idx = ctaid_Y; + const size_t K_tile_global_base = ctaid_X * (SCALES_PER_CHUNK_X / 4); // 2 + + // Single 8-byte SMEM load (vs 2 x 4-byte) halves the SMEM access count + // and degrades the bank conflict from 4-way to 2-way (each lane touches + // 2 adjacent banks at the same lm row instead of 1 bank twice). + const uint64_t packed_all = *reinterpret_cast(&sSFrowwise[lm][0]); + const uint32_t packed_lo = static_cast(packed_all); + const uint32_t packed_hi = static_cast(packed_all >> 32); + + const size_t base_byte = M_tile_idx * CHUNK_DIM_Y * scale_stride + + K_tile_global_base * 512 + static_cast(ty) * 16 + + static_cast(b) * 4; + *reinterpret_cast(&scales_ptr[base_byte]) = packed_lo; + *reinterpret_cast(&scales_ptr[base_byte + 512]) = packed_hi; + } else { + using ScalesVec = Vec; + const int chunk_cols = static_cast(cols) - block_offset_X; + const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); + + for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = + row_global * scale_stride + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + } } } } @@ -877,14 +906,12 @@ inline void launch_amax(const Tensor& input, Tensor* output, const Tensor& noop, NVTE_CHECK_CUDA(cudaGetLastError()); } -// Launch Kernel 2 (encode). Requires output->amax / columnwise_amax to be pre-filled -// (by a prior launch_amax call or by an external caller); writes -// output->data / scale_inv / columnwise_data / columnwise_scale_inv. -// with_rht=true requires K1 amax to have been launched with the SAME mask; -// the composite per_token_quantize path threads this automatically. +// Launch K2 encode. Requires pre-filled amax/columnwise_amax; writes data + +// scale_inv (both directions). with_rht requires K1 to have run with the +// SAME mask. with_swizzle: rowwise SF in cuBLAS LT layout (rowwise-only). inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noop, const bool with_rht, const uint32_t random_sign_mask_t, - cudaStream_t stream) { + const bool with_swizzle, cudaStream_t stream) { const size_t M = input.flat_first_dim(); const size_t K = input.flat_last_dim(); @@ -957,19 +984,26 @@ inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noo const float* col_amax_in = do_col ? reinterpret_cast(output->columnwise_amax.dptr) : nullptr; - // RHT only matters when colwise FP4 is produced; collapse to the - // kWithRht=false instantiation for rowwise-only callers. + // RHT only matters with colwise FP4 -> collapse to kWithRht=false for + // rowwise-only callers; swizzle only matters with rowwise FP4 -> + // collapse to kWithSwizzle=false for colwise-only callers. const bool with_rht_effective = with_rht && do_col; + const bool with_swizzle_effective = with_swizzle && do_row; TRANSFORMER_ENGINE_SWITCH_CONDITION( do_row, DO_ROW, TRANSFORMER_ENGINE_SWITCH_CONDITION( - do_col, DO_COL, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_rht_effective, kWithRht, { - auto kernel = per_token_encode_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tmap_in, tmap_out, tmap_out_t, scales_ptr, scales_t_ptr, row_amax_in, col_amax_in, - noop_ptr, M, K, scale_stride, scale_stride_t, random_sign_mask_t); - }))); + do_col, DO_COL, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_rht_effective, kWithRht, + TRANSFORMER_ENGINE_SWITCH_CONDITION(with_swizzle_effective, kWithSwizzle, { + auto kernel = per_token_encode_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tmap_in, tmap_out, tmap_out_t, scales_ptr, scales_t_ptr, row_amax_in, + col_amax_in, noop_ptr, M, K, scale_stride, scale_stride_t, + random_sign_mask_t); + })))); NVTE_CHECK_CUDA(cudaGetLastError()); } #endif // FP4_TYPE_SUPPORTED @@ -1002,13 +1036,17 @@ inline void validate_amax_output(const Tensor* output) { "must be allocated."); } -// K2 (encode) and composite require at least one FP4 output buffer allocated. -inline void validate_encode_output(const Tensor* output) { +// K2 / composite require >=1 FP4 output buffer. with_swizzle=true: rowwise +// SF emitted in cuBLAS LT swizzled layout (caller sets with_gemm_swizzled_scales). +inline void validate_encode_output(const Tensor* output, const bool with_swizzle) { NVTE_CHECK(output->has_data() || output->has_columnwise_data(), "Per-token K2 (encode): at least one of rowwise/columnwise FP4 output " "must be allocated."); - NVTE_CHECK(!output->with_gemm_swizzled_scales, - "Per-token cast emits compact (non-swizzled) inner SF."); + if (!with_swizzle) { + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "Per-token cast emits compact (non-swizzled) inner SF unless " + "with_swizzle=true is passed."); + } } // K1 amax with optional col-wise RHT. with_rht=false is byte-equal to the @@ -1022,28 +1060,28 @@ void per_token_amax_blocked_impl(const Tensor& input, const Tensor& noop, Tensor launch_amax(input, output, noop, with_rht, random_sign_mask_t, stream); } -// K2 encode with optional col-wise RHT. Caller must have filled -// output->columnwise_amax via K1 amax with the SAME with_rht/mask, else the -// inner SF + FP4 codes are calibrated against mismatched data and saturate. +// K2 encode with optional col-wise RHT + fused rowwise swizzle. with_rht +// requires K1 amax to have been launched with the SAME mask, else the inner +// SF + FP4 codes are calibrated against mismatched data and saturate. void per_token_encode_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, const bool with_rht, const uint32_t random_sign_mask_t, - cudaStream_t stream) { + const bool with_swizzle, cudaStream_t stream) { validate_input_shape(input); - validate_encode_output(output); + validate_encode_output(output, with_swizzle); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; - launch_encode(input, output, noop, with_rht, random_sign_mask_t, stream); + launch_encode(input, output, noop, with_rht, random_sign_mask_t, with_swizzle, stream); } // Composite K1+K2. Both launches receive the same with_rht / mask so the // colwise amax and FP4 cast see byte-identical data. void per_token_quantize_blocked_impl(const Tensor& input, const Tensor& noop, Tensor* output, const bool with_rht, const uint32_t random_sign_mask_t, - cudaStream_t stream) { + const bool with_swizzle, cudaStream_t stream) { validate_input_shape(input); - validate_encode_output(output); + validate_encode_output(output, with_swizzle); if (input.flat_first_dim() == 0 || input.flat_last_dim() == 0) return; launch_amax(input, output, noop, with_rht, random_sign_mask_t, stream); - launch_encode(input, output, noop, with_rht, random_sign_mask_t, stream); + launch_encode(input, output, noop, with_rht, random_sign_mask_t, with_swizzle, stream); } bool can_use_per_token(size_t M, size_t K, DType dtype) { @@ -1054,11 +1092,11 @@ void per_token_amax_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, ui cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } -void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, +void per_token_encode_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, bool, cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } -void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, +void per_token_quantize_blocked_impl(const Tensor&, const Tensor&, Tensor*, bool, uint32_t, bool, cudaStream_t) { NVTE_ERROR("NVFP4 requires SM100 (Blackwell); build with sm_100a/sm_100f."); } @@ -1100,7 +1138,7 @@ void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NV void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, const int with_rht, const int random_sign_mask_t, - cudaStream_t stream) { + const int with_swizzle, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_encode); using namespace transformer_engine; @@ -1112,13 +1150,14 @@ void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, // safety, internal kernel arg is uint32_t with only the low 16 bits used. nvfp4_per_token::per_token_encode_blocked_impl( *input_tensor, *noop_tensor, output_tensor, with_rht != 0, - static_cast(random_sign_mask_t) & 0xFFFFu, stream); + static_cast(random_sign_mask_t) & 0xFFFFu, with_swizzle != 0, stream); #else (void)input; (void)noop; (void)output; (void)with_rht; (void)random_sign_mask_t; + (void)with_swizzle; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif @@ -1126,7 +1165,7 @@ void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, const int with_rht, const int random_sign_mask_t, - cudaStream_t stream) { + const int with_swizzle, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED NVTE_API_CALL(nvte_nvfp4_per_token_quantize); using namespace transformer_engine; @@ -1136,13 +1175,14 @@ void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop const Tensor* noop_tensor = (noop != nullptr) ? convertNVTETensorCheck(noop) : &dummy_noop; nvfp4_per_token::per_token_quantize_blocked_impl( *input_tensor, *noop_tensor, output_tensor, with_rht != 0, - static_cast(random_sign_mask_t) & 0xFFFFu, stream); + static_cast(random_sign_mask_t) & 0xFFFFu, with_swizzle != 0, stream); #else (void)input; (void)noop; (void)output; (void)with_rht; (void)random_sign_mask_t; + (void)with_swizzle; (void)stream; NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif diff --git a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h index 395b061b52..c8b6b630f0 100644 --- a/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h +++ b/transformer_engine/common/include/transformer_engine/nvfp4_per_token.h @@ -25,9 +25,13 @@ extern "C" { * to the pre-RHT path. * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared by * K1 and K2. Ignored when with_rht == 0. + * \param[in] with_swizzle non-zero -> K2 emits rowwise scale_inv directly + * in the cuBLAS LT swizzled tile layout (rowwise only; + * colwise stays compact M-major). */ void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output, - int with_rht, int random_sign_mask_t, cudaStream_t stream); + int with_rht, int random_sign_mask_t, int with_swizzle, + cudaStream_t stream); /*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax. * Pre-zeroes the amax buffers and merges per-CTA partials into @@ -54,9 +58,13 @@ void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NV * to thread the same flag + mask through K1 and K2). * \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored * when with_rht == 0. + * \param[in] with_swizzle non-zero -> write rowwise scale_inv directly in + * the cuBLAS LT swizzled tile layout (rowwise only; + * colwise stays compact M-major). */ void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output, - int with_rht, int random_sign_mask_t, cudaStream_t stream); + int with_rht, int random_sign_mask_t, int with_swizzle, + cudaStream_t stream); /*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``. * diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 53bf118425..075f735e06 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -451,10 +451,13 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi const at::Tensor &scale_inv_colwise, int rows, int cols, size_t start_offset); +// with_swizzle=true makes K2 write rowwise scale_inv in the cuBLAS LT +// swizzled tile layout (skips the standalone nvte_swizzle_scaling_factors). +// Has no effect on colwise scale_inv (rowwise-only for now). void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, - int64_t random_sign_mask_t); + int64_t random_sign_mask_t, bool with_swizzle); void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, int64_t random_sign_mask_t); @@ -462,16 +465,21 @@ void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, at::Tens void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, - int64_t random_sign_mask_t); + int64_t random_sign_mask_t, bool with_swizzle); void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor &row_amax_a, const at::Tensor &row_amax_b); +// Standalone rowwise-SF swizzle for one NVFP4 operand. One launch per call, +// mirrors prod NVFP4 GEMM's per-operand swizzle. Used by --qs bench mode. +void nvfp4_per_token_swizzle_rowwise_sf(const at::Tensor &data, const at::Tensor &sf_in, + at::Tensor sf_out); + void nvfp4_per_token_gemm(const at::Tensor &a_data, const at::Tensor &b_data, const at::Tensor &a_sf, const at::Tensor &b_sf, const at::Tensor &a_row_amax, const at::Tensor &b_row_amax, at::Tensor d, const at::Tensor &workspace, int64_t m, int64_t n, int64_t k, - double alpha, double beta); + double alpha, double beta, bool a_sf_swizzled, bool b_sf_swizzled); // Bench-only per-tensor twin of nvfp4_per_token_gemm: scalar amaxes folded // into cuBLAS LT alpha via the amax slot; no trailing post-scale. diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp index 4ef13afdae..49b01566e5 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp @@ -110,19 +110,22 @@ void assemble_per_token_tensors(const at::Tensor& input, at::Tensor q_row, at::T } // namespace -// Production composite (K1 + K2 back-to-back). with_rht=true enables the -// 16-pt col-wise RHT in BOTH K1 and K2 so outer + inner SFs stay consistent. +// Composite K1 + K2 (back-to-back). with_rht: 16-pt col-wise RHT in both +// (keeps outer + inner SFs consistent). with_swizzle: K2 emits rowwise +// scale_inv in cuBLAS LT swizzled layout (skips downstream swizzle). void nvfp4_per_token_quantize(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, - int64_t random_sign_mask_t) { + int64_t random_sign_mask_t, bool with_swizzle) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, columnwise, /*mode=*/0, in_te, out_te); + if (with_swizzle) out_te.set_with_gemm_swizzled_scales(true); const auto stream = at::cuda::getCurrentCUDAStream(); nvte_nvfp4_per_token_quantize(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, - static_cast(random_sign_mask_t & 0xFFFF), stream); + static_cast(random_sign_mask_t & 0xFFFF), + with_swizzle ? 1 : 0, stream); } // K1-only (diagnostic / bench): populates only amax buffers. with_rht=true @@ -142,20 +145,22 @@ void nvfp4_per_token_amax(const at::Tensor& input, at::Tensor row_amax, at::Tens static_cast(random_sign_mask_t & 0xFFFF), stream); } -// K2-only (diagnostic / bench): reads pre-filled amax buffers, emits FP4 + SFs. -// with_rht=true requires col_amax to have been produced by an earlier K1 -// amax call with the SAME mask, else inner SFs are miscalibrated. +// K2-only (bench): reads pre-filled amax, emits FP4 + SFs. with_rht needs +// col_amax from K1 with the SAME mask (else inner SFs miscalibrate). +// with_swizzle: rowwise scale_inv in cuBLAS LT swizzled layout. void nvfp4_per_token_encode(const at::Tensor& input, at::Tensor q_row, at::Tensor s_dec_row, at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col, at::Tensor col_amax, bool rowwise, bool columnwise, bool with_rht, - int64_t random_sign_mask_t) { + int64_t random_sign_mask_t, bool with_swizzle) { TensorWrapper in_te; TensorWrapper out_te(NVTE_NVFP4_1D_SCALING); assemble_per_token_tensors(input, q_row, s_dec_row, row_amax, q_col, s_dec_col, col_amax, rowwise, columnwise, /*mode=*/2, in_te, out_te); + if (with_swizzle) out_te.set_with_gemm_swizzled_scales(true); const auto stream = at::cuda::getCurrentCUDAStream(); nvte_nvfp4_per_token_encode(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, - static_cast(random_sign_mask_t & 0xFFFF), stream); + static_cast(random_sign_mask_t & 0xFFFF), + with_swizzle ? 1 : 0, stream); } // Apply per-token post-scale to a GEMM output (see nvfp4_per_token.h for math). @@ -191,14 +196,52 @@ void nvfp4_per_token_post_scale(at::Tensor d, const at::Tensor& row_amax_a, nvte_nvfp4_per_token_post_scale(d_te.data(), ra_te.data(), rb_te.data(), stream); } -// End-to-end NVFP4 per-token GEMM: swizzle compact SFs -> cuBLAS LT NVFP4 -// GEMM (operand amax pinned to 1.0 to cancel the 2688^2 inner-SF factor) -> -// per-row post-scale. beta must be 0.0. Math in nvfp4_per_token.h. +// Standalone rowwise-SF swizzle for one NVFP4 operand: 1 launch == +// 1 nvte_swizzle_scaling_factors, mirrors prod's per-operand swizzle. +// Bench-only (--qs); sf_in M-major (M, K/16) -> sf_out swizzled. +void nvfp4_per_token_swizzle_rowwise_sf(const at::Tensor& data, const at::Tensor& sf_in, + at::Tensor sf_out) { + TORCH_CHECK(data.is_cuda() && sf_in.is_cuda() && sf_out.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(data.is_contiguous() && sf_in.is_contiguous() && sf_out.is_contiguous(), + "All tensors must be contiguous"); + TORCH_CHECK(data.scalar_type() == at::ScalarType::Byte, "data must be uint8 (FP4 packed)"); + TORCH_CHECK(sf_in.scalar_type() == at::ScalarType::Byte, "sf_in must be uint8 (FP8 e4m3)"); + TORCH_CHECK(sf_out.scalar_type() == at::ScalarType::Byte, "sf_out must be uint8 (FP8 e4m3)"); + TORCH_CHECK(data.dim() == 2, "data must be 2D (M, K/2)"); + TORCH_CHECK(sf_in.numel() == sf_out.numel(), + "sf_in/sf_out numel mismatch: ", sf_in.numel(), " vs ", sf_out.numel()); + + const int64_t m = data.size(0); + const int64_t k = data.size(1) * 2; // FP4 packed + TORCH_CHECK(k % 16 == 0, "k must be a multiple of 16 (NVFP4 inner SFVecSize), got ", k); + TORCH_CHECK(sf_in.numel() == m * k / 16, "sf_in numel mismatch: expected m*k/16=", m * k / 16, + ", got ", sf_in.numel()); + + const std::vector data_shape = {static_cast(m), static_cast(k)}; + const std::vector sf_shape = {static_cast(m), static_cast(k / 16)}; + + TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); + in_nvte.set_rowwise_data(data.data_ptr(), DType::kFloat4E2M1, data_shape); + in_nvte.set_rowwise_scale_inv(sf_in.data_ptr(), DType::kFloat8E4M3, sf_shape); + + TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); + out_nvte.set_rowwise_data(data.data_ptr(), DType::kFloat4E2M1, data_shape); + out_nvte.set_rowwise_scale_inv(sf_out.data_ptr(), DType::kFloat8E4M3, sf_shape); + out_nvte.set_with_gemm_swizzled_scales(true); + + const auto stream = at::cuda::getCurrentCUDAStream(); + nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); +} + +// E2E NVFP4 per-token GEMM: swizzle SFs -> cuBLAS LT (amax pinned to 1.0 +// to cancel 2688^2 inner-SF) -> per-row post-scale. beta must be 0. +// a_sf_swizzled/b_sf_swizzled=true skips the in-binding swizzle for that operand. void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, const at::Tensor& a_sf, const at::Tensor& b_sf, const at::Tensor& a_row_amax, const at::Tensor& b_row_amax, at::Tensor d, const at::Tensor& workspace, int64_t m, int64_t n, int64_t k, - double alpha, double beta) { + double alpha, double beta, bool a_sf_swizzled, bool b_sf_swizzled) { TORCH_CHECK(a_data.is_cuda() && b_data.is_cuda() && a_sf.is_cuda() && b_sf.is_cuda() && a_row_amax.is_cuda() && b_row_amax.is_cuda() && d.is_cuda() && workspace.is_cuda(), @@ -247,31 +290,37 @@ void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, const std::vector a_sf_shape = {static_cast(m), static_cast(k / 16)}; const std::vector b_sf_shape = {static_cast(n), static_cast(k / 16)}; - // Swizzled SF buffers (cuBLAS LT requires swizzled layout). + // SF buffers for cuBLAS LT: reuse caller's buffer if already swizzled, + // else allocate a swizzled copy. 0/1/2 swizzle launches total. auto byte_opts = a_sf.options().dtype(at::kByte); - at::Tensor a_sf_swizzled = at::empty({a_sf.numel()}, byte_opts); - at::Tensor b_sf_swizzled = at::empty({b_sf.numel()}, byte_opts); - - { + at::Tensor a_sf_buf; + at::Tensor b_sf_buf; + if (a_sf_swizzled) { + a_sf_buf = a_sf; + } else { + a_sf_buf = at::empty({a_sf.numel()}, byte_opts); TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); in_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); in_nvte.set_rowwise_scale_inv(a_sf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); out_nvte.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); - out_nvte.set_rowwise_scale_inv(a_sf_swizzled.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + out_nvte.set_rowwise_scale_inv(a_sf_buf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); out_nvte.set_with_gemm_swizzled_scales(true); nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); } - { + if (b_sf_swizzled) { + b_sf_buf = b_sf; + } else { + b_sf_buf = at::empty({b_sf.numel()}, byte_opts); TensorWrapper in_nvte(NVTE_NVFP4_1D_SCALING); in_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); in_nvte.set_rowwise_scale_inv(b_sf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); TensorWrapper out_nvte(NVTE_NVFP4_1D_SCALING); out_nvte.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); - out_nvte.set_rowwise_scale_inv(b_sf_swizzled.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + out_nvte.set_rowwise_scale_inv(b_sf_buf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); out_nvte.set_with_gemm_swizzled_scales(true); nvte_swizzle_scaling_factors(in_nvte.data(), out_nvte.data(), stream); @@ -294,13 +343,13 @@ void nvfp4_per_token_gemm(const at::Tensor& a_data, const at::Tensor& b_data, // Assemble A's NVTE tensor: NVFP4_1D_SCALING + swizzled SF + amax=1.0. TensorWrapper a_te(NVTE_NVFP4_1D_SCALING); a_te.set_rowwise_data(a_data.data_ptr(), DType::kFloat4E2M1, a_data_shape); - a_te.set_rowwise_scale_inv(a_sf_swizzled.data_ptr(), DType::kFloat8E4M3, a_sf_shape); + a_te.set_rowwise_scale_inv(a_sf_buf.data_ptr(), DType::kFloat8E4M3, a_sf_shape); a_te.set_amax(amax_one.data_ptr(), DType::kFloat32, std::vector{1}); a_te.set_with_gemm_swizzled_scales(true); TensorWrapper b_te(NVTE_NVFP4_1D_SCALING); b_te.set_rowwise_data(b_data.data_ptr(), DType::kFloat4E2M1, b_data_shape); - b_te.set_rowwise_scale_inv(b_sf_swizzled.data_ptr(), DType::kFloat8E4M3, b_sf_shape); + b_te.set_rowwise_scale_inv(b_sf_buf.data_ptr(), DType::kFloat8E4M3, b_sf_shape); b_te.set_amax(amax_one.data_ptr(), DType::kFloat32, std::vector{1}); b_te.set_with_gemm_swizzled_scales(true); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 507595d172..0bebaaae86 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -396,14 +396,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"), py::call_guard()); m.def("nvfp4_per_token_quantize", &transformer_engine::pytorch::nvfp4_per_token_quantize, - "NVFP4 per-token cast (composite K1 amax + K2 encode). Same FP4 + 1x16 " - "e4m3 SF layout as per-tensor, but outer amax is per-row/per-col. " - "Requires bf16 input, M % 128 == 0, K % 128 == 0. " - "with_rht=True applies a 16-pt col-wise RHT in both K1 and K2.", + "NVFP4 per-token cast (composite K1 amax + K2 encode). " + "with_rht=True: 16-pt col-wise RHT in K1+K2; " + "with_swizzle=True: rowwise scale_inv in cuBLAS LT swizzled layout.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), py::arg("columnwise"), py::arg("with_rht") = false, - py::arg("random_sign_mask_t") = static_cast(0xACE1)); + py::arg("random_sign_mask_t") = static_cast(0xACE1), + py::arg("with_swizzle") = false); m.def("nvfp4_per_token_amax", &transformer_engine::pytorch::nvfp4_per_token_amax, "K1-only: per-row/per-col outer amax via TMA + atomicMax. Bench/diagnostic. " "with_rht=True applies a 16-pt col-wise RHT before amax.", @@ -412,20 +412,28 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("random_sign_mask_t") = static_cast(0xACE1)); m.def("nvfp4_per_token_encode", &transformer_engine::pytorch::nvfp4_per_token_encode, "K2-only: FP4 + e4m3 SF encode given pre-filled amax buffers. Bench/diagnostic. " - "with_rht=True requires col_amax produced by a K1 launch with the same mask.", + "with_rht=True requires col_amax produced by a K1 launch with the same mask; " + "with_swizzle=True writes rowwise scale_inv directly in the swizzled layout.", py::arg("input"), py::arg("q_row"), py::arg("s_dec_row"), py::arg("row_amax"), py::arg("q_col"), py::arg("s_dec_col"), py::arg("col_amax"), py::arg("rowwise"), py::arg("columnwise"), py::arg("with_rht") = false, - py::arg("random_sign_mask_t") = static_cast(0xACE1)); + py::arg("random_sign_mask_t") = static_cast(0xACE1), + py::arg("with_swizzle") = false); m.def("nvfp4_per_token_post_scale", &transformer_engine::pytorch::nvfp4_per_token_post_scale, "Apply d[i,j] *= row_amax_a[i] * row_amax_b[j] in-place on bf16 D.", py::arg("d"), py::arg("row_amax_a"), py::arg("row_amax_b")); + m.def("nvfp4_per_token_swizzle_rowwise_sf", + &transformer_engine::pytorch::nvfp4_per_token_swizzle_rowwise_sf, + "Standalone rowwise SF swizzle (1 launch); mirrors prod's per-operand swizzle. " + "data (M, K/2) FP4; sf_in (M, K/16) M-major; sf_out (M, K/16) swizzled.", + py::arg("data"), py::arg("sf_in"), py::arg("sf_out")); m.def("nvfp4_per_token_gemm", &transformer_engine::pytorch::nvfp4_per_token_gemm, - "End-to-end NVFP4 per-token GEMM: swizzle compact SFs, cuBLAS LT NVFP4 " - "GEMM, then row*col post-scale to recover C = A @ B^T. beta must be 0.", + "E2E NVFP4 per-token GEMM: swizzle SFs -> cuBLAS LT -> row*col post-scale. " + "beta must be 0. a_sf_swizzled/b_sf_swizzled=True skips that operand's swizzle.", py::arg("a_data"), py::arg("b_data"), py::arg("a_sf"), py::arg("b_sf"), py::arg("a_row_amax"), py::arg("b_row_amax"), py::arg("d"), py::arg("workspace"), - py::arg("m"), py::arg("n"), py::arg("k"), py::arg("alpha"), py::arg("beta")); + py::arg("m"), py::arg("n"), py::arg("k"), py::arg("alpha"), py::arg("beta"), + py::arg("a_sf_swizzled") = false, py::arg("b_sf_swizzled") = false); m.def("nvfp4_per_tensor_gemm", &transformer_engine::pytorch::nvfp4_per_tensor_gemm, "Skinny prod NVFP4 GEMM twin of nvfp4_per_token_gemm: per-tensor amaxes " "folded into cuBLAS alpha, no trailing post-scale. Bench-only.", diff --git a/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py index 74fc9dc228..033c98729e 100644 --- a/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py +++ b/transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py @@ -123,12 +123,18 @@ def nvfp4_per_token_gemm( alpha: float = 1.0, beta: float = 0.0, out_dtype: torch.dtype = torch.bfloat16, + a_sf_swizzled: bool = False, + b_sf_swizzled: bool = False, ) -> torch.Tensor: """Production C = alpha * (A @ B^T) via cuBLAS LT NVFP4 + per-token post-scale. Binding swizzles compact SFs in-flight, runs cuBLAS LT NVFP4 with operand amaxes pinned to 1.0, then applies the row_amax_A * row_amax_B post-scale. Output is bf16 (cuBLAS LT NVFP4 locks D to bf16/fp32); beta != 0 unsupported. + + ``a_sf_swizzled`` / ``b_sf_swizzled = True`` skips the in-binding swizzle + for that operand (caller's SF is already in the cuBLAS LT swizzled layout + e.g. from ``nvfp4_per_token_quantize(..., with_swizzle=True)``). """ import transformer_engine_torch as tex # type: ignore @@ -204,6 +210,8 @@ def nvfp4_per_token_gemm( K, float(alpha), float(beta), + a_sf_swizzled=a_sf_swizzled, + b_sf_swizzled=b_sf_swizzled, ) return out_bf16 if out_dtype is torch.bfloat16 else out_bf16.to(out_dtype) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py index a2be19f73c..7af52c89ac 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py @@ -214,27 +214,18 @@ def nvfp4_per_token_quantize( columnwise: bool = False, with_rht: bool = False, random_sign_mask_t: int = 0xACE1, + with_swizzle: bool = False, ) -> RefNVFP4TensorPerToken: """Production NVFP4 per-token cast through ``tex.nvfp4_per_token_quantize``. - Backed by the TMA + mbarrier + 64x64 sub-tile pipeline - (``common/cast/nvfp4/quantize_nvfp4_per_token.cu``). The C-API - runs K1 (per-row + per-col amax) and K2 (FP4 + e4m3 SF encode) back- - to-back on the same stream. - - Returns a ``RefNVFP4TensorPerToken`` populated with the kernel - output (compact, non-swizzled scales). The Python-level container is - the same as the reference for symmetry; only the source of the - values differs. - - For cuBLAS LT consumption, the caller must swizzle the inner SF - before forwarding to the GEMM; ``gemm_nvfp4_per_token`` handles - this automatically. - - ``with_rht=True`` applies a 16-pt col-wise RHT in BOTH K1 and K2 so - outer + inner SF stay self-consistent (rowwise never sees RHT). + Composite K1 (per-row/per-col amax) + K2 (FP4 + e4m3 SF) on the same + stream. ``with_rht``: 16-pt col-wise RHT in K1+K2 (rowwise unaffected); ``random_sign_mask_t`` low 16 bits = sign pattern (default ``0xACE1``). + ``with_swizzle=True``: rowwise ``scale_inv`` in cuBLAS LT layout + (colwise stays compact). Downstream ``nvfp4_per_token_gemm`` must + use ``sf_swizzled=True`` to skip its built-in swizzle. + Raises ``ValueError`` on non-bf16 input or non-128-aligned shapes. """ # Import lazily so the module does not require the binary at import time. @@ -277,6 +268,7 @@ def nvfp4_per_token_quantize( columnwise, with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + with_swizzle=with_swizzle, ) out = RefNVFP4TensorPerToken() @@ -363,21 +355,15 @@ def nvfp4_per_token_encode( columnwise: bool = True, with_rht: bool = False, random_sign_mask_t: int = 0xACE1, + with_swizzle: bool = False, ) -> RefNVFP4TensorPerToken: - """Kernel 2 in isolation: FP4 + e4m3 SF encode given pre-filled - amax buffer(s). - - ``row_amax`` of shape ``(M,)`` is required when ``rowwise=True``; same - for ``col_amax`` of shape ``(K,)`` when ``columnwise=True``. The - buffers are typically produced by a prior - ``nvfp4_per_token_amax`` call. - - Lets the benchmark compare K2 wall-time against the production - per-tensor cast pass. Production callers should use the composite - ``nvfp4_per_token_quantize`` instead. + """K2 in isolation: FP4 + e4m3 SF given pre-filled amax buffer(s) + (``row_amax`` ``(M,)`` and/or ``col_amax`` ``(K,)`` from a prior + ``nvfp4_per_token_amax`` call). - ``with_rht=True`` requires ``col_amax`` produced by a prior K1 call - with the SAME mask, else inner SF / FP4 saturate. + ``with_rht=True`` requires ``col_amax`` from a K1 call with the SAME + mask. ``with_swizzle=True`` emits rowwise ``scale_inv`` in cuBLAS LT + swizzled layout (skips a downstream swizzle launch). Raises ``ValueError`` on non-bf16 input, non-128-aligned shapes, or missing / mis-shaped amax buffers. @@ -421,6 +407,7 @@ def nvfp4_per_token_encode( columnwise, with_rht=with_rht, random_sign_mask_t=int(random_sign_mask_t) & 0xFFFF, + with_swizzle=with_swizzle, ) out = RefNVFP4TensorPerToken() From 15a24ab433b87af65fd806ddcf869900bfcc84ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 May 2026 15:07:04 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/nvfp4/bench_nvfp4_per_token.py | 147 +++++++++++++----- tests/pytorch/nvfp4/test_nvfp4_per_token.py | 24 ++- .../cast/nvfp4/quantize_nvfp4_per_token.cu | 11 +- .../csrc/extensions/nvfp4_per_token.cpp | 12 +- 4 files changed, 138 insertions(+), 56 deletions(-) diff --git a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py index 633aefea8a..b4182b8efd 100644 --- a/tests/pytorch/nvfp4/bench_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/bench_nvfp4_per_token.py @@ -296,8 +296,15 @@ def _alloc_pt(R, C): def _pt_quant(t, qr, sr, ra_buf, qc, sc, ca_buf, *, fused_swizzle: bool): tex.nvfp4_per_token_quantize( - t, qr, sr, ra_buf, qc, sc, ca_buf, - True, True, # rowwise + columnwise (apples-to-apples vs per-tensor) + t, + qr, + sr, + ra_buf, + qc, + sc, + ca_buf, + True, + True, # rowwise + columnwise (apples-to-apples vs per-tensor) with_rht=with_rht, random_sign_mask_t=mask_t if with_rht else 0, with_swizzle=fused_swizzle, @@ -307,18 +314,42 @@ def _pt_e2e_ext_swizzle(): _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca, fused_swizzle=False) _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca, fused_swizzle=False) tex.nvfp4_per_token_gemm( - a_qr, b_qr, a_sr.reshape(-1), b_sr.reshape(-1), - a_ra, b_ra, d, workspace, M, N, K, 1.0, 0.0, - a_sf_swizzled=False, b_sf_swizzled=False, + a_qr, + b_qr, + a_sr.reshape(-1), + b_sr.reshape(-1), + a_ra, + b_ra, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, + a_sf_swizzled=False, + b_sf_swizzled=False, ) def _pt_e2e_fused_swizzle(): _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca, fused_swizzle=True) _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca, fused_swizzle=True) tex.nvfp4_per_token_gemm( - a_qr, b_qr, a_sr.reshape(-1), b_sr.reshape(-1), - a_ra, b_ra, d, workspace, M, N, K, 1.0, 0.0, - a_sf_swizzled=True, b_sf_swizzled=True, + a_qr, + b_qr, + a_sr.reshape(-1), + b_sr.reshape(-1), + a_ra, + b_ra, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, + a_sf_swizzled=True, + b_sf_swizzled=True, ) # Per-tensor path: NVFP4Quantizer (RHT+SR) + bench-only nvfp4_per_tensor_gemm. @@ -330,10 +361,19 @@ def _pten_e2e(): tex.quantize(a, quantizer, dst_a, None) tex.quantize(b, quantizer, dst_b, None) tex.nvfp4_per_tensor_gemm( - dst_a._rowwise_data, dst_b._rowwise_data, - dst_a._rowwise_scale_inv, dst_b._rowwise_scale_inv, - dst_a._amax_rowwise, dst_b._amax_rowwise, - d, workspace, M, N, K, 1.0, 0.0, + dst_a._rowwise_data, + dst_b._rowwise_data, + dst_a._rowwise_scale_inv, + dst_b._rowwise_scale_inv, + dst_a._amax_rowwise, + dst_b._amax_rowwise, + d, + workspace, + M, + N, + K, + 1.0, + 0.0, ) t_pt = cuda_time_ms(_pt_e2e_ext_swizzle) @@ -344,9 +384,14 @@ def _pten_e2e(): t_pten_g = cuda_graph_time_ms(_pten_e2e) return E2EShapeBench( - M=M, K=K, - t_pt=t_pt, t_pt_swz=t_pt_swz, t_pten=t_pten, - t_pt_g=t_pt_g, t_pt_swz_g=t_pt_swz_g, t_pten_g=t_pten_g, + M=M, + K=K, + t_pt=t_pt, + t_pt_swz=t_pt_swz, + t_pten=t_pten, + t_pt_g=t_pt_g, + t_pt_swz_g=t_pt_swz_g, + t_pten_g=t_pten_g, ) @@ -388,20 +433,30 @@ def _alloc_pt(R, C): def _pt_quant(t, qr, sr, ra_buf, qc, sc, ca_buf): tex.nvfp4_per_token_quantize( - t, qr, sr, ra_buf, qc, sc, ca_buf, - True, True, + t, + qr, + sr, + ra_buf, + qc, + sc, + ca_buf, + True, + True, with_rht=with_rht, random_sign_mask_t=mask_t if with_rht else 0, with_swizzle=False, # explicit external swizzle, see below ) if pair: + def _pt_qs(): _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca) _pt_quant(b, b_qr, b_sr, b_ra, b_qc, b_sc, b_ca) tex.nvfp4_per_token_swizzle_rowwise_sf(a_qr, a_sr.reshape(-1), a_sr_swz) tex.nvfp4_per_token_swizzle_rowwise_sf(b_qr, b_sr.reshape(-1), b_sr_swz) + else: + def _pt_qs(): _pt_quant(a, a_qr, a_sr, a_ra, a_qc, a_sc, a_ca) tex.nvfp4_per_token_swizzle_rowwise_sf(a_qr, a_sr.reshape(-1), a_sr_swz) @@ -417,6 +472,7 @@ def _pt_qs(): ) if pair: + def _pten_qs(): tex.quantize(a, quantizer, dst_a, None) tex.quantize(b, quantizer, dst_b, None) @@ -426,7 +482,9 @@ def _pten_qs(): tex.nvfp4_per_token_swizzle_rowwise_sf( dst_b._rowwise_data, dst_b._rowwise_scale_inv.reshape(-1), pten_b_sr_swz ) + else: + def _pten_qs(): tex.quantize(a, quantizer, dst_a, None) tex.nvfp4_per_token_swizzle_rowwise_sf( @@ -450,18 +508,28 @@ def _pten_qs(): def _pt_quant_fused(t, qr, sr, ra_buf, qc, sc, ca_buf): tex.nvfp4_per_token_quantize( - t, qr, sr, ra_buf, qc, sc, ca_buf, - True, True, + t, + qr, + sr, + ra_buf, + qc, + sc, + ca_buf, + True, + True, with_rht=with_rht, random_sign_mask_t=mask_t if with_rht else 0, with_swizzle=True, # <-- fused: K2 emits swizzled rowwise SF ) if pair: + def _pt_qs_fused(): _pt_quant_fused(a, a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f) _pt_quant_fused(b, b_qr_f, b_sr_f, b_ra_f, b_qc_f, b_sc_f, b_ca_f) + else: + def _pt_qs_fused(): _pt_quant_fused(a, a_qr_f, a_sr_f, a_ra_f, a_qc_f, a_sc_f, a_ca_f) @@ -469,8 +537,14 @@ def _pt_qs_fused(): t_pt_swz_g = cuda_graph_time_ms(_pt_qs_fused) return QSShapeBench( - M=M, K=K, t_pt=t_pt, t_pten=t_pten, t_pt_g=t_pt_g, t_pten_g=t_pten_g, - t_pt_swz=t_pt_swz, t_pt_swz_g=t_pt_swz_g, + M=M, + K=K, + t_pt=t_pt, + t_pten=t_pten, + t_pt_g=t_pt_g, + t_pten_g=t_pten_g, + t_pt_swz=t_pt_swz, + t_pt_swz_g=t_pt_swz_g, ) @@ -484,7 +558,9 @@ def _fmt(r: float) -> str: if not fuse: w_pt, w_pten, w_ratio = 14, 15, 8 block_w = w_pt + 1 + w_pten + 1 + w_ratio - header1 = f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + header1 = ( + f"{'':>7} {'':>6} |{'Eager, unit (ms)':^{block_w}} |{'Graph, unit (ms)':^{block_w}}" + ) header2 = ( f"{'M':>7} {'K':>6}" " |" @@ -562,9 +638,8 @@ def _print_qs_legend(*, with_rht: bool, rht_mask: int, pair: bool, fuse: bool) - n_launches_ext = 3 * n_tensors # K1+K2+swz per tensor n_launches_fused = 2 * n_tensors # K1+K2 only per tensor (swizzle folded into K2) mode_tag = "--pair, 2 operands" if pair else "default solo, 1 operand" - n_kernels_tag = ( - f"ext-swz pipeline {n_launches_ext} launches" - + (f" / fused pipeline {n_launches_fused} launches" if fuse else "") + n_kernels_tag = f"ext-swz pipeline {n_launches_ext} launches" + ( + f" / fused pipeline {n_launches_fused} launches" if fuse else "" ) print(f"Legend (K1+K2 + rowwise swizzle; NO GEMM; mode = {mode_tag}; {n_kernels_tag}):") rht_suffix = ( @@ -602,17 +677,10 @@ def _print_qs_legend(*, with_rht: bool, rht_mask: int, pair: bool, fuse: bool) - print(" The (fuse) column saves 1 swizzle launch/operand vs the non-fuse column;") print(" the K2 byte-output is identical (verified by pytest byte-equality test).") if not pair: - print( - " solo mode is apples-to-apples with --composite (also 1 operand): the delta" - ) - print( - " per-token(--qs) - per-token(--composite) ~= one nvte_swizzle launch." - ) + print(" solo mode is apples-to-apples with --composite (also 1 operand): the delta") + print(" per-token(--qs) - per-token(--composite) ~= one nvte_swizzle launch.") else: - print( - " --pair mode = one prod NVFP4 GEMM call's quant+swizzle pipeline " - "(1 swz/operand)." - ) + print(" --pair mode = one prod NVFP4 GEMM call's quant+swizzle pipeline (1 swz/operand).") if fuse: print(" ratio = per-token(fuse) / per-tensor") else: @@ -1133,8 +1201,13 @@ def main() -> int: elif args.qs: records_qs: List[QSShapeBench] = [ _bench_shape_qs( - M, K, device=device, with_rht=args.rht, mask_t=mask, - pair=args.pair, fuse=args.fuse, + M, + K, + device=device, + with_rht=args.rht, + mask_t=mask, + pair=args.pair, + fuse=args.fuse, ) for (M, K) in shapes ] diff --git a/tests/pytorch/nvfp4/test_nvfp4_per_token.py b/tests/pytorch/nvfp4/test_nvfp4_per_token.py index 951b5b482e..5d99d2919b 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_per_token.py +++ b/tests/pytorch/nvfp4/test_nvfp4_per_token.py @@ -937,19 +937,31 @@ def test_per_token_gemm_with_fused_swizzle_matches_unswizzled(M: int, K: int) -> a_plain = nvfp4_per_token_quantize(A, rowwise=True, columnwise=False, with_swizzle=False) b_plain = nvfp4_per_token_quantize(B, rowwise=True, columnwise=False, with_swizzle=False) c_unswz = nvfp4_per_token_gemm( - a_plain.data, a_plain.scale, a_plain.row_amax, - b_plain.data, b_plain.scale, b_plain.row_amax, + a_plain.data, + a_plain.scale, + a_plain.row_amax, + b_plain.data, + b_plain.scale, + b_plain.row_amax, ) a_swz = nvfp4_per_token_quantize(A, rowwise=True, columnwise=False, with_swizzle=True) b_swz = nvfp4_per_token_quantize(B, rowwise=True, columnwise=False, with_swizzle=True) c_swz = nvfp4_per_token_gemm( - a_swz.data, a_swz.scale, a_swz.row_amax, - b_swz.data, b_swz.scale, b_swz.row_amax, - a_sf_swizzled=True, b_sf_swizzled=True, + a_swz.data, + a_swz.scale, + a_swz.row_amax, + b_swz.data, + b_swz.scale, + b_swz.row_amax, + a_sf_swizzled=True, + b_sf_swizzled=True, ) torch.testing.assert_close( - c_swz, c_unswz, rtol=0, atol=0, + c_swz, + c_unswz, + rtol=0, + atol=0, msg=f"fused-swizzle GEMM output != unswizzled-input GEMM at ({M}, {K})", ) diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu index ae804ef147..4efad9237a 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu @@ -590,9 +590,8 @@ __global__ void __launch_bounds__(THREADS_NUM) const uint32_t packed_lo = static_cast(packed_all); const uint32_t packed_hi = static_cast(packed_all >> 32); - const size_t base_byte = M_tile_idx * CHUNK_DIM_Y * scale_stride + - K_tile_global_base * 512 + static_cast(ty) * 16 + - static_cast(b) * 4; + const size_t base_byte = M_tile_idx * CHUNK_DIM_Y * scale_stride + K_tile_global_base * 512 + + static_cast(ty) * 16 + static_cast(b) * 4; *reinterpret_cast(&scales_ptr[base_byte]) = packed_lo; *reinterpret_cast(&scales_ptr[base_byte + 512]) = packed_hi; } else { @@ -604,8 +603,7 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t row_global = scales_block_offset_Y_rowwise + row; if (row_global < rows) { ScalesVec& scales_vec = *reinterpret_cast(sSFrowwise[row]); - const size_t scale_idx_global = - row_global * scale_stride + scales_block_offset_X_rowwise; + const size_t scale_idx_global = row_global * scale_stride + scales_block_offset_X_rowwise; scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); } } @@ -1001,8 +999,7 @@ inline void launch_encode(const Tensor& input, Tensor* output, const Tensor& noo dshmem_size); kernel<<>>( tmap_in, tmap_out, tmap_out_t, scales_ptr, scales_t_ptr, row_amax_in, - col_amax_in, noop_ptr, M, K, scale_stride, scale_stride_t, - random_sign_mask_t); + col_amax_in, noop_ptr, M, K, scale_stride, scale_stride_t, random_sign_mask_t); })))); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp index 49b01566e5..bdcf12efbb 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp @@ -124,8 +124,8 @@ void nvfp4_per_token_quantize(const at::Tensor& input, at::Tensor q_row, at::Ten if (with_swizzle) out_te.set_with_gemm_swizzled_scales(true); const auto stream = at::cuda::getCurrentCUDAStream(); nvte_nvfp4_per_token_quantize(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, - static_cast(random_sign_mask_t & 0xFFFF), - with_swizzle ? 1 : 0, stream); + static_cast(random_sign_mask_t & 0xFFFF), with_swizzle ? 1 : 0, + stream); } // K1-only (diagnostic / bench): populates only amax buffers. with_rht=true @@ -159,8 +159,8 @@ void nvfp4_per_token_encode(const at::Tensor& input, at::Tensor q_row, at::Tenso if (with_swizzle) out_te.set_with_gemm_swizzled_scales(true); const auto stream = at::cuda::getCurrentCUDAStream(); nvte_nvfp4_per_token_encode(in_te.data(), nullptr, out_te.data(), with_rht ? 1 : 0, - static_cast(random_sign_mask_t & 0xFFFF), - with_swizzle ? 1 : 0, stream); + static_cast(random_sign_mask_t & 0xFFFF), with_swizzle ? 1 : 0, + stream); } // Apply per-token post-scale to a GEMM output (see nvfp4_per_token.h for math). @@ -209,8 +209,8 @@ void nvfp4_per_token_swizzle_rowwise_sf(const at::Tensor& data, const at::Tensor TORCH_CHECK(sf_in.scalar_type() == at::ScalarType::Byte, "sf_in must be uint8 (FP8 e4m3)"); TORCH_CHECK(sf_out.scalar_type() == at::ScalarType::Byte, "sf_out must be uint8 (FP8 e4m3)"); TORCH_CHECK(data.dim() == 2, "data must be 2D (M, K/2)"); - TORCH_CHECK(sf_in.numel() == sf_out.numel(), - "sf_in/sf_out numel mismatch: ", sf_in.numel(), " vs ", sf_out.numel()); + TORCH_CHECK(sf_in.numel() == sf_out.numel(), "sf_in/sf_out numel mismatch: ", sf_in.numel(), + " vs ", sf_out.numel()); const int64_t m = data.size(0); const int64_t k = data.size(1) * 2; // FP4 packed