Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,230 changes: 1,230 additions & 0 deletions tests/pytorch/nvfp4/bench_nvfp4_per_token.py

Large diffs are not rendered by default.

460 changes: 460 additions & 0 deletions tests/pytorch/nvfp4/bench_nvfp4_per_token_group.py

Large diffs are not rendered by default.

967 changes: 967 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_per_token.py

Large diffs are not rendered by default.

591 changes: 591 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_per_token_group.py

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,11 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
recipe/nvfp4.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)
cast/nvfp4/quantize_nvfp4_per_token.cu
cast/nvfp4/quantize_nvfp4_per_token_group.cu
gemm/nvfp4_per_token_post_scale.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)

# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
Expand Down
1,192 changes: 1,192 additions & 0 deletions transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu

Large diffs are not rendered by default.

1,123 changes: 1,123 additions & 0 deletions transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu

Large diffs are not rendered by default.

140 changes: 140 additions & 0 deletions transformer_engine/common/gemm/nvfp4_per_token_post_scale.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*************************************************************************
* Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file nvfp4_per_token_post_scale.cu
* \brief NVFP4 per-token GEMM-output post-scale: d[i,j] *= r_A[i] * r_B[j].
*
* Standalone bf16 epilogue applied after cuBLAS LT NVFP4 GEMM with the
* operand amaxes pinned to 1.0. See nvfp4_per_token.h for the math chain.
*/

#include <transformer_engine/nvfp4_per_token.h>

#include "../common.h"
#include "../util/logging.h"
#include "../util/ptx.cuh"

namespace transformer_engine {
namespace nvfp4_per_token {

namespace {

// Each block tiles 16 rows x 256 cols of the output: amaxes are loaded
// once into SMEM, then each thread handles 8 cols via a 16-byte int4 LD/ST
// for peak HBM coalescing on SM100. Wrapper enforces M, N % 128 alignment.
constexpr int kTileCols = 256;
constexpr int kTileRows = 16;
constexpr int kElemsPerThread = 8; // bf16x8 = 16-byte vector
constexpr int kThreadsX = kTileCols / kElemsPerThread;
constexpr int kThreadsY = kTileRows;
constexpr int kThreadsPerBlock = kThreadsX * kThreadsY;
static_assert(kTileCols % kElemsPerThread == 0, "kTileCols must be a multiple of kElemsPerThread");
static_assert(kElemsPerThread * sizeof(__nv_bfloat16) == sizeof(int4),
"kElemsPerThread bf16 must pack into a single int4 (16 bytes)");

__global__ void __launch_bounds__(kThreadsPerBlock)
per_token_post_scale_kernel(__nv_bfloat16* __restrict__ d, const float* __restrict__ row_amax_a,
const float* __restrict__ row_amax_b, const int M, const int N) {
__shared__ float s_row_amax[kTileRows];
__shared__ float s_col_amax[kTileCols];

const int row_tile = blockIdx.y * kTileRows;
const int col_tile = blockIdx.x * kTileCols;

// Cooperatively load row + col amaxes into SMEM (272 floats / 512 threads).
const int tid = threadIdx.y * kThreadsX + threadIdx.x;
if (tid < kTileRows) {
const int gi = row_tile + tid;
s_row_amax[tid] = (gi < M) ? row_amax_a[gi] : 0.0f;
}
if (tid < kTileCols) {
const int gj = col_tile + tid;
s_col_amax[tid] = (gj < N) ? row_amax_b[gj] : 0.0f;
}
__syncthreads();

const int i = row_tile + threadIdx.y;
const int j0 = col_tile + threadIdx.x * kElemsPerThread;
if (i >= M || j0 >= N) return;

const float a = s_row_amax[threadIdx.y];
const size_t base = static_cast<size_t>(i) * N + j0;

// Fast path = 16-byte aligned LD/ST; slow path = boundary tile fallback.
if (j0 + kElemsPerThread <= N) {
// __align__(16) is required for the int4 reinterpret_cast to be defined.
__nv_bfloat16 __align__(16) chunk[kElemsPerThread];
*reinterpret_cast<int4*>(chunk) = *reinterpret_cast<const int4*>(&d[base]);
#pragma unroll
for (int e = 0; e < kElemsPerThread; ++e) {
const float b = s_col_amax[threadIdx.x * kElemsPerThread + e];
const float current = static_cast<float>(chunk[e]);
chunk[e] = static_cast<__nv_bfloat16>(current * a * b);
}
*reinterpret_cast<int4*>(&d[base]) = *reinterpret_cast<const int4*>(chunk);
} else {
#pragma unroll
for (int e = 0; e < kElemsPerThread; ++e) {
const int j = j0 + e;
if (j >= N) break;
const float b = s_col_amax[threadIdx.x * kElemsPerThread + e];
const size_t idx = base + e;
const float current = static_cast<float>(d[idx]);
d[idx] = static_cast<__nv_bfloat16>(current * a * b);
}
}
}

} // namespace

void per_token_post_scale(Tensor* d, const Tensor& row_amax_a, const Tensor& row_amax_b,
cudaStream_t stream) {
NVTE_CHECK(d->has_data(), "NVFP4 per-token post-scale: d has no data.");
NVTE_CHECK(d->data.dtype == DType::kBFloat16,
"NVFP4 per-token post-scale: d must be BF16 (got non-BF16 dtype).");
NVTE_CHECK(row_amax_a.data.dtype == DType::kFloat32,
"NVFP4 per-token post-scale: row_amax_a must be FP32.");
NVTE_CHECK(row_amax_b.data.dtype == DType::kFloat32,
"NVFP4 per-token post-scale: row_amax_b must be FP32.");

const auto& d_shape = d->data.shape;
NVTE_CHECK(d_shape.size() == 2,
"NVFP4 per-token post-scale: d must be 2D, got rank=", d_shape.size());
const int M = static_cast<int>(d_shape[0]);
const int N = static_cast<int>(d_shape[1]);
NVTE_CHECK(row_amax_a.data.numel() == static_cast<size_t>(M),
"NVFP4 per-token post-scale: row_amax_a numel must equal M=", M, ", got ",
row_amax_a.data.numel());
NVTE_CHECK(row_amax_b.data.numel() == static_cast<size_t>(N),
"NVFP4 per-token post-scale: row_amax_b numel must equal N=", N, ", got ",
row_amax_b.data.numel());

if (M == 0 || N == 0) {
return;
}

// 32 x 16 threads = 512/block; covers 256 cols x 16 rows = 4096 elems/block.
dim3 block(kThreadsX, kThreadsY, 1);
dim3 grid((N + kTileCols - 1) / kTileCols, (M + kTileRows - 1) / kTileRows, 1);
per_token_post_scale_kernel<<<grid, block, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(d->data.dptr),
reinterpret_cast<const float*>(row_amax_a.data.dptr),
reinterpret_cast<const float*>(row_amax_b.data.dptr), M, N);
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace nvfp4_per_token
} // namespace transformer_engine

void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a,
const NVTETensor row_amax_b, cudaStream_t stream) {
NVTE_API_CALL(nvte_nvfp4_per_token_post_scale);
using namespace transformer_engine;

transformer_engine::nvfp4_per_token::per_token_post_scale(
convertNVTETensorCheck(d), *convertNVTETensorCheck(row_amax_a),
*convertNVTETensorCheck(row_amax_b), stream);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*************************************************************************
* Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_
#define TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_

#include <cuda_runtime_api.h>

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

/*! \brief Composite K1+K2: per-row + per-col amax (K1) then FP4 + 1x16
* e4m3 SF encode (K2), back-to-back on the same stream.
*
* Production entry point for the per-token cast on bf16 + 128-aligned shapes.
*
* \param[in] with_rht non-zero -> apply 16-pt RHT on the col direction in
* both K1 and K2. Rowwise stays raw; zero is byte-equal
* to the pre-RHT path.
* \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared by
* K1 and K2. Ignored when with_rht == 0.
* \param[in] with_swizzle non-zero -> K2 emits rowwise scale_inv directly
* in the cuBLAS LT swizzled tile layout (rowwise only;
* colwise stays compact M-major).
*/
void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output,
int with_rht, int random_sign_mask_t, int with_swizzle,
cudaStream_t stream);

/*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax.
* Pre-zeroes the amax buffers and merges per-CTA partials into
* ``output->amax`` (size [M]) / ``output->columnwise_amax``
* (size [K]). Does NOT touch FP4 data / scale_inv slots.
*
* \param[in] with_rht non-zero -> apply 16-pt RHT on the col direction
* before columnwise_amax (rowwise stays raw); zero is
* byte-equal to the pre-RHT K1.
* \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored
* when with_rht == 0. Type matches prod's
* nvte_hadamard_transform_amax convention.
*/
void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output,
int with_rht, int random_sign_mask_t, cudaStream_t stream);

/*! \brief Kernel 2 in isolation: FP4 + 1x16 e4m3 SF encode given a
* pre-filled ``output->amax`` / ``output->columnwise_amax``. Reads
* the outer amax buffer(s) and writes the FP4 data / scale_inv
* tensors only.
*
* \param[in] with_rht non-zero -> col-wise cast applies the same 16-pt RHT
* that K1 amax must have used (caller's responsibility
* to thread the same flag + mask through K1 and K2).
* \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; ignored
* when with_rht == 0.
* \param[in] with_swizzle non-zero -> write rowwise scale_inv directly in
* the cuBLAS LT swizzled tile layout (rowwise only;
* colwise stays compact M-major).
*/
void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output,
int with_rht, int random_sign_mask_t, int with_swizzle,
cudaStream_t stream);

/*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``.
*
* Currently returns 1 iff ``dtype`` is bf16 AND ``M % 128 == 0`` AND
* ``K % 128 == 0``. Cheap host-side query (no CUDA call).
*
* \param[in] M first-dim (rows).
* \param[in] K last-dim (cols).
* \param[in] input_dtype_enum NVTE_DType cast to int.
*/
int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum);

/*! \brief Apply per-row * per-col outer-scale to a (M, N) bf16 GEMM output.
*
* Computes:
*
* d[i, j] = d[i, j] * row_amax_a[i] * row_amax_b[j]
*/
void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a,
const NVTETensor row_amax_b, cudaStream_t stream);

/* ============================================================================
* Grouped (multi-tensor) per-token quantize.
*
* \param[in] input (sum_M, K) bf16/fp32, row-major contiguous
* \param[in,out] outputs array of `num_tensors` NVTETensors; on
* return, amax/columnwise_amax slots are filled.
* \param[in] split_sections array of `num_tensors` size_t values,
* each a multiple of 64; sum must equal sum_M.
* \param[in] num_tensors <= 64
* \param[in] rowwise emit per-row amax in `outputs[i].amax`
* \param[in] columnwise emit per-col amax in `outputs[i].columnwise_amax`
* \param[in] with_rht non-zero -> 16-pt RHT on the col direction
* (rowwise stays raw).
* \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; must
* match the value passed to the matching cast
* if amax + cast are launched separately.
* \param[in] stream CUDA stream
*/
void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs,
const size_t* split_sections, size_t num_tensors, bool rowwise,
bool columnwise, int with_rht, int random_sign_mask_t,
cudaStream_t stream);

/*! \brief Grouped per-token encode (FP4 + 1x16 e4m3 inner SF) using the
* row_amax / col_amax values already populated by
* `nvte_group_nvfp4_per_token_amax`.
*
* \param[in] input same as `nvte_group_nvfp4_per_token_amax`
* \param[in,out] outputs on entry: amax/columnwise_amax populated;
* on return: data/scale_inv + columnwise_data/
* columnwise_scale_inv populated.
* \param[in] split_sections same as `nvte_group_nvfp4_per_token_amax`
* \param[in] num_tensors <= 64
* \param[in] rowwise emit per-row FP4 + inner SF
* \param[in] columnwise emit per-col FP4 + inner SF
* \param[in] with_rht must match the preceding amax call's
* with_rht; applies the same 16-pt RHT on the
* colwise cast.
* \param[in] random_sign_mask_t low 16 bits = sign-flip pattern; must
* match K1.
* \param[in] stream CUDA stream
*/
void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs,
const size_t* split_sections, size_t num_tensors, bool rowwise,
bool columnwise, int with_rht, int random_sign_mask_t,
cudaStream_t stream);

/*! \brief Composite K1+K2 grouped per-token quantize. Calls the amax + cast
* kernels on the same stream. This is the external API
* `tex.split_quantize(per_token=True)` should call.
*
* \param[in] input (sum_M, K) bf16/fp32, row-major contiguous
* \param[in,out] outputs on entry: amax / columnwise_amax / data /
* scale_inv / columnwise_data /
* columnwise_scale_inv slots allocated;
* on return: all populated.
* \param[in] split_sections array of `num_tensors` size_t values,
* each a multiple of 64; sum must equal sum_M.
* \param[in] num_tensors <= 64
* \param[in] rowwise emit rowwise output
* \param[in] columnwise emit columnwise output
* \param[in] with_rht non-zero -> 16-pt RHT on the col direction
* in BOTH K1 and K2; zero is byte-equal to the
* pre-RHT path.
* \param[in] random_sign_mask_t low 16 bits = sign-flip pattern shared
* between K1 and K2; ignored when with_rht==0.
* \param[in] stream CUDA stream
*/
void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* outputs,
const size_t* split_sections, size_t num_tensors,
bool rowwise, bool columnwise, int with_rht,
int random_sign_mask_t, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif // TRANSFORMER_ENGINE_NVFP4_PER_TOKEN_H_
Loading