Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions transformer_engine/pytorch/triton_kernels/norms_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
te_dtype_to_triton_dtype,
)
from ..quantized_tensor import Quantizer
from .utils import num_programs, block_size, use_blocked, make_ln_out
from .utils import num_programs, block_size, use_blocked, make_ln_out, get_num_sms
from .common import get_fp8_max
from .rmsnorm import (
_rmsnorm_fwd_triton,
Expand Down Expand Up @@ -122,6 +122,44 @@ def _get_fp8_transpose_configs():
True: _rmsnorm_bwd_dg_reduce_triton,
False: _rmsnorm_bwd_dg_reduce_triton_impl,
}


_TARGET_PRGMS_PER_CU = 8
_MAX_PRGMS_PER_CU = 16
_I32_OFFSET_LIMIT = 1 << 31

_FWD_FLAT_GRID_MAX_H = 2048

_FWD_CAP_LADDER = ((128, 16), (256, 4), (512, 2))
_BWD_CAP_LADDER = ((128, 4), (512, 2))


def _rows_per_pid(rows, hidden, num_sms, cap_ladder):
"""Largest power-of-two rows-per-program for a narrow-H RMSNorm launch.

Picks the biggest ``rpp`` that (a) does not exceed the hidden-size cap,
(b) evenly divides ``rows`` -- so the kernel needs no row-tail mask -- and
(c) still leaves at least ``_TARGET_PRGMS_PER_CU * num_sms`` programs. If
the program-count target cannot be met we still pack a little (4 or 2) to
amortise the gamma load, provided divisibility holds.
"""
target = _TARGET_PRGMS_PER_CU * num_sms
cap = 1
for threshold, c in cap_ladder:
if hidden <= threshold:
cap = c
break
rpp = 1
while rpp * 2 <= cap and rows % (rpp * 2) == 0 and rows // (rpp * 2) >= target:
rpp *= 2
if rpp == 1:
if hidden <= 128 and rows % 4 == 0:
rpp = 4
elif hidden <= 512 and rows % 2 == 0:
rpp = 2
return rpp


# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd
def te_rmsnorm_fwd_triton(
input: torch.Tensor,
Expand Down Expand Up @@ -208,7 +246,27 @@ def _te_norm_fwd_triton(
IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer)
BLOCK_SIZE = block_size(input_tensor, norm=kernel)
USE_BLOCKED = use_blocked(input_tensor)
NUM_PRGMS = N if kernel=='layer' else num_programs(input_tensor, sm_margin)

# Row-packing (RMSNorm, non-blocked, non-FP8 only)
ROWS_PER_PID_FWD = 1
if kernel == 'rms' and not USE_BLOCKED and not IS_FP8:
ROWS_PER_PID_FWD = _rows_per_pid(N, H, get_num_sms(sm_margin), _FWD_CAP_LADDER)

# HOIST_GAMMA: load gamma once per program when the program processes more than one row
HOIST_GAMMA_FWD = True
if kernel == 'layer':
# LayerNorm requires one program per row on both paths.
NUM_PRGMS = N
elif kernel == 'rms' and not USE_BLOCKED:
use_flat = (not IS_FP8) and (H <= _FWD_FLAT_GRID_MAX_H)
if use_flat:
NUM_PRGMS = N // ROWS_PER_PID_FWD
Comment thread
AllenFarcas marked this conversation as resolved.
HOIST_GAMMA_FWD = ROWS_PER_PID_FWD > 1
else:
NUM_PRGMS = num_programs(input_tensor, sm_margin)
HOIST_GAMMA_FWD = True # persistent loop reuses gamma across rows
else:
NUM_PRGMS = num_programs(input_tensor, sm_margin)
MAKE_TRANSPOSE = False
APPLY_ATOMIC = N < 512 or kernel == 'rms'
ATOMIC_REDUCTION_BLOCK_SIZE=256
Expand Down Expand Up @@ -301,6 +359,10 @@ def _te_norm_fwd_triton(
out_ptr.data_ptr() % 16 == 0 and
output_row_stride * getattr(out_ptr.dtype, 'itemsize', 1) % 16 == 0
)
kwargs["ROWS_PER_PID"] = ROWS_PER_PID_FWD
kwargs["HOIST_GAMMA"] = HOIST_GAMMA_FWD
max_off = max(input_row_stride, output_row_stride) * N + H
kwargs["NEEDS_I64_OFFSETS"] = max_off >= _I32_OFFSET_LIMIT

kernel_func[grid_fwd](**kwargs)

Expand Down Expand Up @@ -358,9 +420,23 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma,
M, N = x_.shape
blk_size = block_size(x_)
USE_BLOCKED = use_blocked(x_)
NUM_PRGMS = num_programs(x_, sm_margin)

# Multi-row-per-program for narrow N
ROWS_PER_PID_BWD = 1
if not USE_BLOCKED:
ROWS_PER_PID_BWD = _rows_per_pid(M, N, get_num_sms(sm_margin), _BWD_CAP_LADDER)

base_prgms = num_programs(x_, sm_margin)
if USE_BLOCKED:
NUM_PRGMS = base_prgms
dg_tmp_rows = M
elif ROWS_PER_PID_BWD > 1:
NUM_PRGMS = min(M // ROWS_PER_PID_BWD, _MAX_PRGMS_PER_CU * base_prgms)
dg_tmp_rows = NUM_PRGMS
else:
NUM_PRGMS = base_prgms
dg_tmp_rows = NUM_PRGMS
need_reduction = NUM_PRGMS > 1
dg_tmp_rows = M if USE_BLOCKED else NUM_PRGMS
dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None

input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0)
Expand All @@ -370,6 +446,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma,
dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0)

grid_bwd = lambda meta: (NUM_PRGMS, )
# See forward: gate i64 offsets on actual overflow to stay on the fast i32 path.
needs_i64_offsets_bwd = max(x_.stride(0), dz_.stride(0), dx.stride(0)) * M + N >= _I32_OFFSET_LIMIT
bwd_kernel = _rmsnorm_bwd_kernels[autotune]
bwd_kwargs = dict(
n_rows=M, n_cols=N,
Expand All @@ -380,6 +458,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma,
GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16,
DX_ALIGNED_16=dx_aligned_16,
DG_ALIGNED_16=dg_aligned_16,
ROWS_PER_PID=ROWS_PER_PID_BWD,
NEEDS_I64_OFFSETS=needs_i64_offsets_bwd,
)
if not autotune:
bwd_kwargs["num_warps"] = 8
Expand Down
162 changes: 109 additions & 53 deletions transformer_engine/pytorch/triton_kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,32 @@
import torch
import triton
import triton.language as tl
from itertools import product

def get_autotune_config():
return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])]
return [triton.Config({'waves_per_eu': 0}, num_warps=nw) for nw in (1, 4, 8, 16)]


def _prune_rms_configs(configs, named_args, **kwargs):
"""Prune autotune configs whose ``num_warps`` is a poor fit for ``n_cols``.

Bounds threads-per-row to ``[n_cols/16, n_cols*2]`` so the autotuner never
wastes a measurement on a config that is obviously too serial (too few
lanes) or mostly idle (too many lanes). Shared by the fwd and bwd kernels.
"""
n_cols = named_args.get('n_cols')
if n_cols is None:
n_cols = kwargs.get('n_cols')
if n_cols is None:
return configs
out = []
for cfg in configs:
threads = cfg.num_warps * 64 # AMD wavefront = 64
if threads < n_cols / 16: # too serial
continue
if threads > n_cols * 2: # too many idle lanes
continue
out.append(cfg)
return out if out else configs


# TODO(micky774) Implement fused MXFP8 quantization within the kernel
Expand All @@ -32,6 +54,9 @@ def _rmsnorm_fwd_triton_impl(
FP8_MAX: tl.constexpr,
INPUT_ALIGNED_16: tl.constexpr,
OUTPUT_ALIGNED_16: tl.constexpr,
ROWS_PER_PID: tl.constexpr = 1,
NEEDS_I64_OFFSETS: tl.constexpr = False,
HOIST_GAMMA: tl.constexpr = True,
):

row_start = tl.program_id(0)
Expand Down Expand Up @@ -128,51 +153,71 @@ def _rmsnorm_fwd_triton_impl(

else:
mask = col_offsets < n_cols
# gamma is invariant across rows -- load + ZERO_CENTERED adjustment once per program.
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
if (ZERO_CENTERED_GAMMA):
g += 1
inv_n_cols = 1.0 / n_cols
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2):
input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
if INPUT_ALIGNED_16:
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
row_norm = row * row
row_norm = tl.sum(row_norm, axis=-1)
norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon)

# Store rsigma (norm_factor)
rsigma_output_ptr = rsigma_ptr + row_idx
tl.store(rsigma_output_ptr, norm_factor)

rms_norm = row * norm_factor * g

output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
if OUTPUT_ALIGNED_16:
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
if IS_FP8:
amax_temp = tl.max(tl.abs(rms_norm), axis=-1)
amax = tl.maximum(amax, amax_temp)
rms_norm = rms_norm * scale
rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX)
tl.store(output_ptrs, rms_norm.to(output_type), mask=mask)
if HOIST_GAMMA:
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
if (ZERO_CENTERED_GAMMA):
g += 1
n_chunks = tl.cdiv(n_rows, ROWS_PER_PID)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) rounds up, and the inner tl.static_range(ROWS_PER_PID) writes row_idx = base_row + i to rsigma_ptr + row_idx, output_ptrs, etc. with only a column mask. Correctness depends entirely on the dispatcher in norms_common.py returning ROWS_PER_PID > 1 only when n_rows % ROWS_PER_PID == 0 (which _rows_per_pid does enforce today).

Since the kernel default is ROWS_PER_PID: tl.constexpr = 1, a future caller passing a non-divisor value would silently corrupt memory. Worth a short comment near n_chunks documenting the divisibility invariant (or, more defensively, a row_idx < n_rows mask on the inner stores).

for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2):
base_row = chunk_idx * ROWS_PER_PID
for i in tl.static_range(ROWS_PER_PID):
row_idx = base_row + i
if NEEDS_I64_OFFSETS:
row_off = row_idx.to(tl.int64)
else:
row_off = row_idx
input_ptrs = input_ptr + row_off * input_row_stride + col_offsets
if INPUT_ALIGNED_16:
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
if HOIST_GAMMA and ROWS_PER_PID == 1:
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
else:
row = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32)
row_norm = tl.sum(row * row, axis=-1)
norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon)

# Store rsigma (norm_factor)
tl.store(rsigma_ptr + row_idx, norm_factor)

if not HOIST_GAMMA:
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
if (ZERO_CENTERED_GAMMA):
g += 1
rms_norm = row * norm_factor * g

output_ptrs = output_ptr + row_off * output_row_stride + col_offsets
if OUTPUT_ALIGNED_16:
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
if IS_FP8:
amax_temp = tl.max(tl.abs(rms_norm), axis=-1)
amax = tl.maximum(amax, amax_temp)
rms_norm = rms_norm * scale
rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX)
tl.store(output_ptrs, rms_norm.to(output_type), mask=mask)
if IS_FP8:
tl.atomic_max(q_amax_ptr, amax, sem="relaxed")
if row_start == 0:
scale = tl.load(q_scale_ptr)
scale_inv = tl.fdiv(1.0, scale)
tl.store(scale_inv_ptr, scale_inv)

autotune_dec = triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
autotune_dec = triton.autotune(
configs=get_autotune_config(),
key=['n_rows', 'n_cols'],
prune_configs_by={'early_config_prune': _prune_rms_configs},
use_cuda_graph=True,
)
_rmsnorm_fwd_triton = autotune_dec(_rmsnorm_fwd_triton_impl)

@triton.jit
def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride,
n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr,
USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr,
INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr,
DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr):
DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr,
ROWS_PER_PID: tl.constexpr = 1,
NEEDS_I64_OFFSETS: tl.constexpr = False):
row_start = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
inv_n_cols = 1.0 / n_cols
Expand Down Expand Up @@ -290,39 +335,50 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p
if (ZERO_CENTERED_GAMMA):
g += 1.

for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2):
input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets
dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets
# Persistent outer loop with multi-row-per-chunk.
n_chunks = tl.cdiv(n_rows, ROWS_PER_PID)
for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2):
base_row = chunk_idx * ROWS_PER_PID
for i in tl.static_range(ROWS_PER_PID):
row_idx = base_row + i
if NEEDS_I64_OFFSETS:
row_off = row_idx.to(tl.int64)
else:
row_off = row_idx
input_ptrs = input_ptr + row_off * input_row_stride + col_offsets
grad_output_ptrs = grad_output_ptr + row_off * output_row_stride + col_offsets
dx_ptrs = dx_ptr + row_off * input_row_stride + col_offsets

if INPUT_ALIGNED_16:
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
if GRAD_OUTPUT_ALIGNED_16:
grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, ))
if DX_ALIGNED_16:
dx_ptrs = tl.multiple_of(dx_ptrs, (16, ))
if INPUT_ALIGNED_16:
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
if GRAD_OUTPUT_ALIGNED_16:
grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, ))
if DX_ALIGNED_16:
dx_ptrs = tl.multiple_of(dx_ptrs, (16, ))

x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32)
grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32)
x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32)
grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32)

norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32)
grad_sum = tl.sum(grad_output * x * g, axis=0)
c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols
norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32)
grad_sum = tl.sum(grad_output * x * g, axis=0)
c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols

grad_input = norm_factor * (grad_output * g - c_scalar * x)
tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask)
grad_input = norm_factor * (grad_output * g - c_scalar * x)
tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask)

dg = grad_output * x * norm_factor
dg_col_redux += dg.to(tl.float32)
dg = grad_output * x * norm_factor
dg_col_redux += dg.to(tl.float32)

# Each program owns exactly one dg_tmp partial row (index == row_start).
tl.store(dg_ptr + row_start * n_cols + col_offsets, dg_col_redux, mask=mask)


# Autotune wrapper. Mirrors the fwd autotune layout so callers can toggle
# autotune via the same flag.
# Autotune wrapper. Mirrors the fwd autotune layout (same config set + prune)
# so callers can toggle autotune via the same flag.
_rmsnorm_bwd_triton = triton.autotune(
configs=get_autotune_config(),
key=['n_rows', 'n_cols'],
prune_configs_by={'early_config_prune': _prune_rms_configs},
use_cuda_graph=True,
)(_rmsnorm_bwd_triton_impl)

Expand Down
9 changes: 4 additions & 5 deletions transformer_engine/pytorch/triton_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.utils import get_sm_count
from .common import te_dtype_to_torch_dtype

def get_ln_sm_margin(sm_margin_type):
Expand Down Expand Up @@ -34,12 +35,10 @@ def get_inf_ln_sm_margin():


def get_num_sms(sm_margin=None):
num_sms = torch.cuda.get_device_properties(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use get_sm_count from transformer_engine.pytorch.utils

torch.cuda.current_device()
).multi_processor_count
n = get_sm_count()
if sm_margin is not None and sm_margin > 0:
num_sms = max(num_sms - int(sm_margin), 1)
return num_sms
n = max(n - int(sm_margin), 1)
return n


def num_programs(x, sm_margin=None):
Expand Down