From e588542b0680193a83701fc72079ad3f483b2c6c Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 23 Feb 2026 04:13:45 -0800 Subject: [PATCH 01/63] update --- gemlite/quant_utils.py | 260 +++++++++++++++---- gemlite/triton_kernels/gemm_kernels.py | 338 ++++++++++++++++++++++--- 2 files changed, 505 insertions(+), 93 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 96ec3ad..b0409bc 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -265,83 +265,241 @@ def round_triton_amd(tensor): else: round_triton = round_triton_nvidia +# @triton.jit +# def scale_activations_per_token_kernel( +# tensor_ptr, scale_ptr, y_ptr, +# M, K, +# stride_m, stride_k, stride_sm, +# ROUND: tl.constexpr, +# UNROLL: tl.constexpr, +# min_val: tl.constexpr, +# max_val: tl.constexpr, +# fp32_scale: tl.constexpr, +# BLOCK_M: tl.constexpr, +# BLOCK_K: tl.constexpr, +# ): +# pid_m = tl.program_id(0) * UNROLL +# pid_k = tl.program_id(1) + +# offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) +# offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + +# for m in range(UNROLL): +# mask = ((offs_m < M)[:, None] & (offs_k < K)[None, :]).to(tl.int1) +# in_ptrs = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k +# tensor = tl.load(tensor_ptr + in_ptrs, mask=mask, other=0.0) +# if fp32_scale: +# tensor = tensor.to(tl.float32) + +# scales_x = tl.max(tl.abs(tensor), axis=1, keep_dims=True) +# scales_x /= max_val +# scales_x = tl.maximum(scales_x, 1e-6) +# tensor /= scales_x +# tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) + +# if ROUND: +# tensor = round_triton(tensor) + +# tl.store(scale_ptr + offs_m[:, None] * stride_sm, scales_x) +# tl.store(y_ptr + in_ptrs, tensor, mask=mask) +# offs_m += BLOCK_M + +# def scale_activations_per_token_triton( +# tensor: Tensor, w_dtype: torch.dtype, fp32_scale: bool = True +# ) -> Tuple[Tensor, Tensor]: +# min_val, max_val = get_dtype_range(w_dtype) +# x_shape = tensor.shape +# tensor = tensor.view(-1, tensor.shape[-1]) +# M, K = tensor.shape +# scales = torch.empty( +# (M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device +# ) +# y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + +# UNROLL = 1 # max(1, M // 128) +# BLOCK_M = 1 +# BLOCK_K = triton.next_power_of_2(K) +# grid = (triton.cdiv(M, BLOCK_M * UNROLL), triton.cdiv(K, BLOCK_K)) + +# ROUND = not w_dtype.is_floating_point + +# scale_activations_per_token_kernel[grid]( +# tensor, +# scales, +# y, +# M, +# K, +# tensor.stride(0), +# tensor.stride(1), +# scales.stride(0), +# min_val=min_val, +# max_val=max_val, +# fp32_scale=fp32_scale, +# ROUND=ROUND, +# UNROLL=UNROLL, +# BLOCK_M=BLOCK_M, +# BLOCK_K=BLOCK_K, +# num_stages=1, +# num_warps=4, +# ) + +# return y.view(x_shape), scales + + +# from typing import Tuple + +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_M': 1}, num_warps=8, num_stages=1), +# triton.Config({'BLOCK_M': 2}, num_warps=8, num_stages=1), +# triton.Config({'BLOCK_M': 4}, num_warps=8, num_stages=1), +# ], +# key=['M', 'K'] +# ) +# @triton.jit +# def scale_activations_single_pass_kernel( +# tensor_ptr, scale_ptr, y_ptr, +# M, K, +# stride_m, stride_k, stride_sm, +# min_val: tl.constexpr, max_val: tl.constexpr, +# fp32_scale: tl.constexpr, ROUND: tl.constexpr, +# BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, +# ): +# pid_m = tl.program_id(0) + +# offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) +# offs_k = tl.arange(0, BLOCK_K) +# m_mask = offs_m < M +# k_mask = offs_k < K +# mask = m_mask[:, None] & k_mask[None, :] + +# offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + +# tensor = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) + +# if fp32_scale: +# tensor = tensor.to(tl.float32) + +# scales_x = tl.max(tl.abs(tensor), axis=1) / max_val +# scales_x = tl.maximum(scales_x, 1e-6) +# tensor = tensor / scales_x[:, None] +# tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) + +# if ROUND: +# tensor = round_triton(tensor) + +# tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) +# tl.store(y_ptr + offsets, tensor, mask=mask) + +# def scale_activations_per_token_triton( +# tensor: torch.Tensor, w_dtype: torch.dtype, fp32_scale: bool = True +# ) -> Tuple[torch.Tensor, torch.Tensor]: + +# min_val, max_val = get_dtype_range(w_dtype) + +# x_shape = tensor.shape +# tensor = tensor.view(-1, tensor.shape[-1]) +# M, K = tensor.shape + +# scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) +# y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + +# grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), ) + +# BLOCK_K = triton.next_power_of_2(K) +# ROUND = not w_dtype.is_floating_point + +# scale_activations_single_pass_kernel[grid]( +# tensor, scales, y, +# M, K, +# tensor.stride(0), tensor.stride(1), +# scales.stride(0), +# min_val=min_val, max_val=max_val, +# fp32_scale=fp32_scale, ROUND=ROUND, +# BLOCK_K=BLOCK_K +# ) + +# return y.view(x_shape), scales + + + +from typing import Tuple + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 1}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_M': 2}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_M': 4}, num_warps=8, num_stages=1), + ], + key=['M', 'K'] +) @triton.jit -def scale_activations_per_token_kernel( +def scale_activations_persistent_kernel( tensor_ptr, scale_ptr, y_ptr, M, K, stride_m, stride_k, stride_sm, - ROUND: tl.constexpr, - UNROLL: tl.constexpr, - min_val: tl.constexpr, - max_val: tl.constexpr, - fp32_scale: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, + min_val: tl.constexpr, max_val: tl.constexpr, + fp32_scale: tl.constexpr, ROUND: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, ): - pid_m = tl.program_id(0) * UNROLL - pid_k = tl.program_id(1) + start_pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_tiles = tl.cdiv(M, BLOCK_M) + + offs_k = tl.arange(0, BLOCK_K) + k_mask = offs_k < K - offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + for pid_m in range(start_pid, num_tiles, num_programs): + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + m_mask = offs_m < M + mask = m_mask[:, None] & k_mask[None, :] - for m in range(UNROLL): - mask = ((offs_m < M)[:, None] & (offs_k < K)[None, :]).to(tl.int1) - in_ptrs = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k - tensor = tl.load(tensor_ptr + in_ptrs, mask=mask, other=0.0) + offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + tensor = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) + if fp32_scale: tensor = tensor.to(tl.float32) - - scales_x = tl.max(tl.abs(tensor), axis=1, keep_dims=True) - scales_x /= max_val + + scales_x = tl.max(tl.abs(tensor), axis=1) / max_val scales_x = tl.maximum(scales_x, 1e-6) - tensor /= scales_x + tensor = tensor / scales_x[:, None] tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) if ROUND: tensor = round_triton(tensor) + + tl.store(y_ptr + offsets, tensor, mask=mask) + tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) + + - tl.store(scale_ptr + offs_m[:, None] * stride_sm, scales_x) - tl.store(y_ptr + in_ptrs, tensor, mask=mask) - offs_m += BLOCK_M - - +NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count def scale_activations_per_token_triton( - tensor: Tensor, w_dtype: torch.dtype, fp32_scale: bool = True -) -> Tuple[Tensor, Tensor]: + tensor: torch.Tensor, w_dtype: torch.dtype, fp32_scale: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + min_val, max_val = get_dtype_range(w_dtype) + x_shape = tensor.shape tensor = tensor.view(-1, tensor.shape[-1]) M, K = tensor.shape - scales = torch.empty( - (M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device - ) + + scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) - UNROLL = 1 # max(1, M // 128) - BLOCK_M = 1 - BLOCK_K = triton.next_power_of_2(K) - grid = (triton.cdiv(M, BLOCK_M * UNROLL), triton.cdiv(K, BLOCK_K)) + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META['BLOCK_M'])), ) + BLOCK_K = triton.next_power_of_2(K) ROUND = not w_dtype.is_floating_point - scale_activations_per_token_kernel[grid]( - tensor, - scales, - y, - M, - K, - tensor.stride(0), - tensor.stride(1), + scale_activations_persistent_kernel[grid]( + tensor, scales, y, + M, K, + tensor.stride(0), tensor.stride(1), scales.stride(0), - min_val=min_val, - max_val=max_val, - fp32_scale=fp32_scale, - ROUND=ROUND, - UNROLL=UNROLL, - BLOCK_M=BLOCK_M, - BLOCK_K=BLOCK_K, - num_stages=1, - num_warps=4, + min_val=min_val, max_val=max_val, + fp32_scale=fp32_scale, ROUND=ROUND, + BLOCK_K=BLOCK_K ) return y.view(x_shape), scales diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 2afafa5..5b19df3 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -63,54 +63,61 @@ def kernel_config_pruner(configs, nargs, **kwargs): elif m <= 256: block_size_m = min(max(block_size_m, 64), 256) #m: [128...256] elif m > 256: block_size_m = min(max(block_size_m, 64), 256) #m > 256 - #Constraint: BLOCK_SIZE_K >= group_size, only for load_as_block = False - if(load_scales_as_block): - num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads - if(e > 1): - block_size_k = max(block_size_k, 64) #m16n8k64 - else: - block_size_k = max(block_size_k, 32) #m16n8k32 - else: - block_size_k = min(block_size_k, g) + # #Constraint: BLOCK_SIZE_K >= group_size, only for load_as_block = False + # if(load_scales_as_block): + # num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads + # if(e > 1): + # block_size_k = max(block_size_k, 64) #m16n8k64 + # else: + # block_size_k = max(block_size_k, 32) #m16n8k32 + # else: + # block_size_k = min(block_size_k, g) block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) #Hint: skip block_size_n > block_size_k for col-major non-packed data. - #Nvidia - if not IS_HIP: - if e > 1 and not load_scales_as_block: - #Limit num stages when data is packed - num_stages = min(num_stages, 4) - if(e == 1 and num_stages == 1): - #skip num_stages=1 for non-packed weights - continue - - #Avoid OOM - while num_stages > 0 and not load_scales_as_block: #TODO: revisit MXFP case - shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) - if(e > 1): - shared_mem += block_size_k * block_size_n * a_sizeof - shared_mem *= num_stages - if int(shared_mem) <= gpu_shared_memory: - break - num_stages -= 1 + # #Nvidia + # if not IS_HIP: + # if e > 1 and not load_scales_as_block: + # #Limit num stages when data is packed + # num_stages = min(num_stages, 4) + # if(e == 1 and num_stages == 1): + # #skip num_stages=1 for non-packed weights + # continue + + # #Avoid OOM + # while num_stages > 0 and not load_scales_as_block: #TODO: revisit MXFP case + # shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) + # if(e > 1): + # shared_mem += block_size_k * block_size_n * a_sizeof + # shared_mem *= num_stages + # if int(shared_mem) <= gpu_shared_memory: + # break + # num_stages -= 1 if(num_stages == 0): continue #config too large ########################################### - if(load_scales_as_block):#tmp MXFP fix - block_size_k = min(block_size_k, 256) + # if(load_scales_as_block):#tmp MXFP fix + # block_size_k = min(block_size_k, 256) ########################################### key = (block_size_m, block_size_n, block_size_k, group_size_m, A_load_order, num_stages, num_warps) + + EVEN_M = (m % block_size_m == 0) + EVEN_N = (n % block_size_n == 0) + EVEN_K = (k % block_size_k == 0) new_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, + "EVEN_M": EVEN_M, + "EVEN_N": EVEN_N, + "EVEN_K": EVEN_K, "A_load_order": A_load_order, "NUM_STAGES": num_stages, } @@ -167,6 +174,17 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) + + # + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + + configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) return configs def get_default_config_nvidia(): @@ -278,12 +296,13 @@ def gemm_INT_kernel( ######### tuning params ######### BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_K: tl.constexpr, EVEN_N: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# - meta_evict_policy: tl.constexpr = '', - a_evict: tl.constexpr = '', - b_evict: tl.constexpr = '', + meta_evict_policy: tl.constexpr = "evict_last", + a_evict: tl.constexpr = "", + b_evict: tl.constexpr = "evict_first", ): """ Based on https://github.com/fpgaminer/GPTQ-triton @@ -313,7 +332,7 @@ def gemm_INT_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - + #Offsets ############################################################################################################# if data_contiguous: @@ -347,12 +366,18 @@ def gemm_INT_kernel( for k in range(num_pid_k): if(A_load_order == 0): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) b = tl.load(b_ptrs, eviction_policy=b_evict) if(A_load_order == 1): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Meta-data loading policy if(W_group_mode > 0): @@ -372,13 +397,19 @@ def gemm_INT_kernel( zeros = None if(A_load_order == 2): #Mid load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): #Late load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Dot acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) @@ -386,6 +417,10 @@ def gemm_INT_kernel( #Advance a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K_P * stride_bk + + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) + ############################################################################################################# #Channel-wise scaling @@ -413,6 +448,164 @@ def gemm_INT_kernel( tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + +# TMA descriptors require a global memory allocation +from typing import Optional +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) +triton.set_allocator(alloc_fn) + +@triton.autotune( + configs = get_autotune_config(), + key = KEYS, + prune_configs_by = {'early_config_prune': kernel_config_pruner}, + use_cuda_graph = AUTOTUNE.USE_CUDA_GRAPH, +) +@triton.jit +def gemm_INT_kernel_persistent_tma( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, scales_a_ptr, + M, N, K, M_CLOSEST, + ######### Quant parms ######### + W_nbits: tl.constexpr, + group_size: tl.constexpr, + unpack_mask: tl.constexpr, + elements_per_sample: tl.constexpr, + ################################# + type_id: tl.constexpr, + a_sizeof: tl.constexpr, + b_sizeof: tl.constexpr, + ######### Strides ######### + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bk: tl.constexpr, stride_bn: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, + stride_meta_g: tl.constexpr, stride_meta_n: tl.constexpr, + ######### Dtypes ######### + load_scales_as_block: tl.constexpr, #False + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + acc_dtype: tl.constexpr, + meta_dtype: tl.constexpr, + ######### Meta-data mode ######### + channel_scale_mode: tl.constexpr, + W_group_mode: tl.constexpr, + zero_is_scalar: tl.constexpr, + ######### tuning params ######### + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_K: tl.constexpr, EVEN_N: tl.constexpr, + ################################# + A_load_order: tl.constexpr = 0, + data_contiguous: tl.constexpr = True, + ################################# + meta_evict_policy: tl.constexpr = '', + a_evict: tl.constexpr = '', + b_evict: tl.constexpr = '', + NUM_SMS: tl.constexpr = 8, +): + """ + Persistent + TMA version. + A: (M, K) fp16/bf16 + B_packed: (K//elements_per_sample, N) int32 + scales/zeros: (num_groups, N) or other depending on W_group_mode + """ + + # --------------------------- + # Persistent tiling setup + # --------------------------- + start_pid = tl.program_id(0).to(tl.int32) + + grid_m = tl.cdiv(M, BLOCK_SIZE_M) + grid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_tiles = grid_m * grid_n + width = GROUP_SIZE_M * grid_n # tiles per "group stripe" + + a_desc = tl.make_tensor_descriptor( + a_ptr, + [M, K], + [stride_am, stride_ak], + [BLOCK_SIZE_M, BLOCK_SIZE_K] + ) + + # b_desc = tl.make_tensor_descriptor( + # b_ptr, + # [K, N], + # [stride_bk, stride_bn], + # [BLOCK_SIZE_K, BLOCK_SIZE_N] + # ) + + #transposed : use self.W_q = self.W_q.contiguous().t() + b_desc = tl.make_tensor_descriptor( + b_ptr, + [N, K], + [stride_bn, stride_bk], + [BLOCK_SIZE_N, BLOCK_SIZE_K] + ) + + # # Precompute unpack shifts (vector length = elements_per_sample) + # # shifts = [0, W_nbits, 2*W_nbits, ...] + # shifts = (tl.arange(0, elements_per_sample) * W_nbits).to(tl.int32) + + # # Optional scalar zero + # if zero_is_scalar: + # zero_scalar = tl.load(zeros_ptr, eviction_policy="evict_last") + + ############################################################################################################# + # Main loop + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): + group_id = tile_id // width + first_m = group_id * GROUP_SIZE_M + gs = tl.minimum(grid_m - first_m, GROUP_SIZE_M) + + pid_m = first_m + (tile_id % gs) + pid_n = (tile_id % width) // gs + + rm = pid_m * BLOCK_SIZE_M + rn = pid_n * BLOCK_SIZE_N + + # Accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + # Column indices for this tile (used for metadata + store) + offs_n = rn + tl.arange(0, BLOCK_SIZE_N) + n_mask = offs_n < N + + # K loop + for k in tl.range(0, K, BLOCK_SIZE_K): + a = tl.load_tensor_descriptor(a_desc, [rm, k]) + + k_packed = k // elements_per_sample + #b = tl.load_tensor_descriptor(b_desc, [k_packed, rn]) + b = tl.load_tensor_descriptor(b_desc, [rn, k_packed]).T #Transposed + + acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) + + ############################################################################################################# + # Channel-wise scaling + offs_m = rm + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + if channel_scale_mode == 1: # weight-only + # expects a 1D per-N scale at scales_ptr (same as your original) + scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) + acc = acc.to(meta_dtype) * scales_b[None, :] + + elif channel_scale_mode == 2: # activation-only + scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) + acc = acc.to(meta_dtype) * scales_a[:, None] + + elif channel_scale_mode == 3: # weight + activation + scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) + scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) + acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + + acc = acc.to(output_dtype) + + ############################################################################################################# + # Store + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :]) + @triton.autotune( configs = get_autotune_config(), key = KEYS, @@ -452,12 +645,13 @@ def gemm_MX_kernel( ######### tuning params ######### BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_K: tl.constexpr, EVEN_N: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# - meta_evict_policy: tl.constexpr = '', - a_evict: tl.constexpr = '', - b_evict: tl.constexpr = '', + meta_evict_policy: tl.constexpr = "evict_last", + a_evict: tl.constexpr = "", + b_evict: tl.constexpr = "", meta_scale_norm: tl.constexpr = (0.05 ** 2), ): @@ -510,7 +704,11 @@ def gemm_MX_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + b = tl.load(b_ptrs, eviction_policy=b_evict) k_m = k * BLOCK_SIZE_K_S @@ -525,6 +723,9 @@ def gemm_MX_kernel( a_ptrs += BLOCK_SIZE_K_A * stride_ak b_ptrs += BLOCK_SIZE_K_B * stride_bk + + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) #NVFP4 meta-scale if(group_size == 16): @@ -554,7 +755,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] M_CLOSEST = get_closest_m(M) - + #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) @@ -600,6 +801,59 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x return output +# # Persistent version +# NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count +# def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +# W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, +# input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, +# channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, +# ) -> Tensor: + +# M = x.shape[0] +# K = W_q.shape[0] * elements_per_sample +# N = W_q.shape[1] +# M_CLOSEST = get_closest_m(M) +# load_scales_as_block = False + +# output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) + +# if scales_x is not None: +# stride_meta_a_m, stride_meta_a_g = scales_x.stride(0), scales_x.stride(1) +# else: +# stride_meta_a_m, stride_meta_a_g = 0, 0 + +# grid = (NUM_SMS,) + +# gemm_INT_kernel_persistent_tma[grid]( +# x, W_q, output, +# scales, zeros, scales_x, +# M, N, K, M_CLOSEST, +# ############################################# +# W_nbits, group_size, unpack_mask, elements_per_sample, +# type_id, x.dtype.itemsize, W_q.dtype.itemsize, +# ############################################### +# x.stride(0), x.stride(1), +# W_q.stride(0), W_q.stride(1), +# output.stride(0), output.stride(1), +# stride_meta_a_m, stride_meta_a_g, +# scales.stride(0), scales.stride(1), +# ################################################ +# load_scales_as_block = load_scales_as_block, +# input_dtype = DTYPE_TO_TRITON[input_dtype], +# output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], +# acc_dtype = DTYPE_TO_TRITON[acc_dtype], +# meta_dtype = DTYPE_TO_TRITON[meta_dtype], +# ################################################ +# channel_scale_mode = channel_scale_mode, +# W_group_mode = W_group_mode, +# zero_is_scalar = zeros.numel() == 1, +# data_contiguous = data_contiguous, +# NUM_SMS = NUM_SMS, +# ) + + +# return output + class gemm: kernel = [gemm_INT_kernel, gemm_MX_kernel] From 2b9cd5960355459c0cc748515fda9db891b3bdf8 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 23 Feb 2026 07:05:53 -0800 Subject: [PATCH 02/63] update --- gemlite/core.py | 17 +++- gemlite/quant_utils.py | 41 ++++---- gemlite/triton_kernels/gemm_kernels.py | 130 +++++++++++++++++++++---- 3 files changed, 149 insertions(+), 39 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 30fa424..192dbf7 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -399,7 +399,7 @@ def pack( if(self.W_q is None): raise Exception('Weights were not packed, please check your W_q.dtype') - + #Bias / device self.device = self.W_q.device self.bias = None if (bias is None) else bias.to(device=self.device) @@ -495,6 +495,21 @@ def pack( self.scales = self.scales.T self.W_group_mode = 2 self.channel_scale_mode = 0 + + ################ + # TMA + K, N = self.W_q.shape + + if(self.input_dtype in [DType.MXFP4, DType.NVFP4]): + K *= 2 + group_size = 2 * self.W_q.numel() // self.scales.numel() + else: + group_size = self.W_q.numel() // self.scales.numel() + self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() + self.W_q = self.W_q.contiguous().T + + print(self.scales.stride(), self.scales.shape) + ################ if(self.scales is not None): self.meta_dtype = TORCH_TO_DTYPE[self.scales.dtype] diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index b0409bc..3abdc40 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -716,31 +716,34 @@ def scale_activations_mxfp8_triton_v2( pad_m = (group_size - M % group_size) % group_size M_padded = M + pad_m - out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + # out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) + # scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + out = tensor.to(w_dtype) + scales = torch.full((M_padded, K // group_size), 120, device=tensor.device, dtype=torch.uint8) #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) BLOCK_SIZE_M = group_size grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) device_index = tensor.device.index - scale_activations_mxfp8_triton_kernel_v2[grid]( - tensor, - out, - scales, - M, K, - tensor.stride(0), tensor.stride(1), - scales.stride(0), scales.stride(1), - out.stride(0), out.stride(1), - ######################### - min_val=min_val, - max_val=max_val, - eps_exp=eps_exp, - GROUP_SIZE=group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=2, - num_warps=4, - ) + # scale_activations_mxfp8_triton_kernel_v2[grid]( + # tensor, + # out, + # scales, + # M, K, + # tensor.stride(0), tensor.stride(1), + # scales.stride(0), scales.stride(1), + # out.stride(0), out.stride(1), + # ######################### + # min_val=min_val, + # max_val=max_val, + # eps_exp=eps_exp, + # GROUP_SIZE=group_size, + # BLOCK_SIZE_M=BLOCK_SIZE_M, + # num_stages=1, + # num_warps=4, + # ) return out, scales diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 5b19df3..47500d0 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -75,6 +75,14 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) + + + ###################################################### + if block_size_n % 128 > 0: + block_size_n = 128 + if block_size_k % 128 > 0: + block_size_k = 128 + ###################################################### #Hint: skip block_size_n > block_size_k for col-major non-packed data. @@ -616,7 +624,8 @@ def gemm_INT_kernel_persistent_tma( def gemm_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + #M, N, K, M_CLOSEST, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, M_CLOSEST: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -627,9 +636,9 @@ def gemm_MX_kernel( a_sizeof: tl.constexpr, b_sizeof: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bk: tl.constexpr, stride_bn: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, stride_meta_n: tl.constexpr, stride_meta_g: tl.constexpr, ######### Dtypes ######### @@ -697,27 +706,92 @@ def gemm_MX_kernel( offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] + + + a_desc = tl.make_tensor_descriptor( + a_ptr, + [M, K // elements_per_sample_a], + [stride_am, stride_ak], + [BLOCK_SIZE_M, BLOCK_SIZE_K_A] + ) + + # Transposed + b_desc = tl.make_tensor_descriptor( + b_ptr, + [N, K // elements_per_sample], + [stride_bn, stride_bk], + [BLOCK_SIZE_N, BLOCK_SIZE_K_B] + ) + + # 2. 5D TMA Descriptors for Scales + rep_m: tl.constexpr = BLOCK_SIZE_M // 128 + rep_n: tl.constexpr = BLOCK_SIZE_N // 128 + rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 + + # shape_b1: tl.constexpr = N // 128 + # shape_b2: tl.constexpr = K // group_size // 4 + # stride_b4: tl.constexpr = 1 + # stride_b3: tl.constexpr = 256 + # stride_b2: tl.constexpr = 512 + # stride_b1: tl.constexpr = 512 * shape_b2 + # stride_b0: tl.constexpr = stride_b1 * shape_b1 + + #(8388608, 65536, 512, 256, 1) torch.Size([1, 128, 128, 2, 256]) + shape_b1: tl.constexpr = 128 + shape_b2: tl.constexpr = 128 + + stride_b0: tl.constexpr = 8388608 + stride_b1: tl.constexpr = 65536 + stride_b2: tl.constexpr = 512 + stride_b3: tl.constexpr = 256 + stride_b4: tl.constexpr = 1 + + # REQUIRES BLOCK_SIZE_K / BLOCK_SIZE_N to be multiples of 128 + scales_b_desc = tl.make_tensor_descriptor( + scales_ptr, + [1, shape_b1, shape_b2, 2, 256], + [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], + [1, rep_n, rep_k, 2, 256] + ) + #B-scales if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g + + # Used in channel-wise MXPF8 version + #scales_b = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): - if EVEN_M and EVEN_K: - a = tl.load(a_ptrs, eviction_policy=a_evict) - else: - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + #for k in tl.range(num_pid_k): + # if EVEN_M and EVEN_K: + # a = tl.load(a_ptrs, eviction_policy=a_evict) + # else: + # a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + # b = tl.load(b_ptrs, eviction_policy=b_evict) + + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) + b = tl.load_tensor_descriptor(b_desc, [k * BLOCK_SIZE_K_B, pid_n * BLOCK_SIZE_N]).T k_m = k * BLOCK_SIZE_K_S - scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) - + #scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + + # 5D Scale Loads and Unpacking + offs_scale_m = pid_m * rep_m + offs_scale_n = pid_n * rep_n + offs_scale_k = k * rep_k + + scale_b = tl.load_tensor_descriptor(scales_b_desc, [0, offs_scale_n, offs_scale_k, 0, 0]) + scales_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) + #https://github.com/triton-lang/triton/blob/main/python/tutorials/10-block-scaled-matmul.py#L220C1-L221C117 + if(channel_scale_mode == 4): scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) else: - scales_a = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_a = scales_a_1s acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) @@ -740,19 +814,32 @@ def gemm_MX_kernel( acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) ############################################################################################################# - #Output - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) - tl.store(c_ptrs, acc, mask=mask) + # #Output + # offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + # offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) + # mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) + # tl.store(c_ptrs, acc, mask=mask) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + [M, N], + [stride_cm, stride_cn], + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) + + +PRINTED = False def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, ) -> Tensor: + + global PRINTED M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] M_CLOSEST = get_closest_m(M) @@ -773,7 +860,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x gemm_kernel = gemm_INT_kernel load_scales_as_block = False - gemm_kernel[grid]( + compiled_kernel = gemm_kernel[grid]( x, W_q, output, scales, zeros, scales_x, M, N, K, M_CLOSEST, @@ -798,6 +885,11 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, ) + + if PRINTED == False: + with open('kernel.ptx', 'w') as f: + f.write(compiled_kernel.asm['ptx']) + PRINTED = True return output From baeda2365a646c29f56930a19dd7ff2079900f3c Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 23 Feb 2026 08:13:20 -0800 Subject: [PATCH 03/63] update --- gemlite/core.py | 27 +++--- gemlite/quant_utils.py | 80 ++++++++-------- gemlite/triton_kernels/gemm_kernels.py | 127 +++++++++++++------------ 3 files changed, 123 insertions(+), 111 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 192dbf7..afff0f7 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -496,20 +496,25 @@ def pack( self.W_group_mode = 2 self.channel_scale_mode = 0 - ################ + ################################ # TMA - K, N = self.W_q.shape + # K, N = self.W_q.shape - if(self.input_dtype in [DType.MXFP4, DType.NVFP4]): - K *= 2 - group_size = 2 * self.W_q.numel() // self.scales.numel() - else: - group_size = self.W_q.numel() // self.scales.numel() - self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() - self.W_q = self.W_q.contiguous().T + # if(self.input_dtype in [DType.MXFP4, DType.NVFP4]): + # K *= 2 + # group_size = 2 * self.W_q.numel() // self.scales.numel() + # else: + # group_size = self.W_q.numel() // self.scales.numel() - print(self.scales.stride(), self.scales.shape) - ################ + # #self.scales = self.scales.contiguous().T # Transposed 2D TMA layout + # #self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() # 5D TMA layout for the scales: + + #self.W_q = self.W_q.contiguous().T #Transposed for tma + + #self.W_q = self.W_q.contiguous() + + #print(self.scales.stride(), self.scales.shape) + ################################ if(self.scales is not None): self.meta_dtype = TORCH_TO_DTYPE[self.scales.dtype] diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 3abdc40..ed8c618 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -716,34 +716,31 @@ def scale_activations_mxfp8_triton_v2( pad_m = (group_size - M % group_size) % group_size M_padded = M + pad_m - # out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - # scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - out = tensor.to(w_dtype) - scales = torch.full((M_padded, K // group_size), 120, device=tensor.device, dtype=torch.uint8) - #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) BLOCK_SIZE_M = group_size grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) device_index = tensor.device.index - # scale_activations_mxfp8_triton_kernel_v2[grid]( - # tensor, - # out, - # scales, - # M, K, - # tensor.stride(0), tensor.stride(1), - # scales.stride(0), scales.stride(1), - # out.stride(0), out.stride(1), - # ######################### - # min_val=min_val, - # max_val=max_val, - # eps_exp=eps_exp, - # GROUP_SIZE=group_size, - # BLOCK_SIZE_M=BLOCK_SIZE_M, - # num_stages=1, - # num_warps=4, - # ) + scale_activations_mxfp8_triton_kernel_v2[grid]( + tensor, + out, + scales, + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + min_val=min_val, + max_val=max_val, + eps_exp=eps_exp, + GROUP_SIZE=group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + num_stages=1, + num_warps=4, + ) return out, scales @@ -989,8 +986,11 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: pad_m = (group_size - M % group_size) % group_size M_padded = M + pad_m - out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + # out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + # scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + out = tensor.to(torch.uint8)[:M, :K // 2] + scales = torch.full((M_padded, K // group_size), 120, device=tensor.device, dtype=torch.uint8) #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) BLOCK_SIZE_M = group_size @@ -998,22 +998,22 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) device_index = tensor.device.index - scale_activations_mxfp4_triton_kernel_v2[grid]( - tensor, - out, - scales, - thr_pos[device_index], - M, K, - tensor.stride(0), tensor.stride(1), - scales.stride(0), scales.stride(1), - out.stride(0), out.stride(1), - ######################### - eps_exp=eps_exp, - GROUP_SIZE=group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=2, - num_warps=4, - ) + # scale_activations_mxfp4_triton_kernel_v2[grid]( + # tensor, + # out, + # scales, + # thr_pos[device_index], + # M, K, + # tensor.stride(0), tensor.stride(1), + # scales.stride(0), scales.stride(1), + # out.stride(0), out.stride(1), + # ######################### + # eps_exp=eps_exp, + # GROUP_SIZE=group_size, + # BLOCK_SIZE_M=BLOCK_SIZE_M, + # num_stages=2, + # num_warps=4, + # ) return out, scales diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 47500d0..1238c1a 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -78,10 +78,11 @@ def kernel_config_pruner(configs, nargs, **kwargs): ###################################################### - if block_size_n % 128 > 0: - block_size_n = 128 - if block_size_k % 128 > 0: - block_size_k = 128 + # # FOR TMA + # if block_size_n % 128 > 0: + # block_size_n = 128 + # if block_size_k % 128 > 0: + # block_size_k = 128 ###################################################### #Hint: skip block_size_n > block_size_k for col-major non-packed data. @@ -108,8 +109,9 @@ def kernel_config_pruner(configs, nargs, **kwargs): if(num_stages == 0): continue #config too large ########################################### - # if(load_scales_as_block):#tmp MXFP fix - # block_size_k = min(block_size_k, 256) + if(load_scales_as_block):#tmp MXFP fix + block_size_k = max(block_size_k, 64) + block_size_n = max(block_size_n, 64) ########################################### key = (block_size_m, block_size_n, block_size_k, group_size_m, A_load_order, num_stages, num_warps) @@ -188,7 +190,9 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) @@ -708,52 +712,47 @@ def gemm_MX_kernel( scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] - a_desc = tl.make_tensor_descriptor( - a_ptr, - [M, K // elements_per_sample_a], - [stride_am, stride_ak], - [BLOCK_SIZE_M, BLOCK_SIZE_K_A] - ) + # a_desc = tl.make_tensor_descriptor( + # a_ptr, + # [M, K // elements_per_sample_a], + # [stride_am, stride_ak], + # [BLOCK_SIZE_M, BLOCK_SIZE_K_A] + # ) - # Transposed - b_desc = tl.make_tensor_descriptor( - b_ptr, - [N, K // elements_per_sample], - [stride_bn, stride_bk], - [BLOCK_SIZE_N, BLOCK_SIZE_K_B] - ) + # # Transposed + # b_desc = tl.make_tensor_descriptor( + # b_ptr, + # [N, K // elements_per_sample], + # [stride_bn, stride_bk], + # [BLOCK_SIZE_N, BLOCK_SIZE_K_B] + # ) - # 2. 5D TMA Descriptors for Scales - rep_m: tl.constexpr = BLOCK_SIZE_M // 128 - rep_n: tl.constexpr = BLOCK_SIZE_N // 128 - rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 - - # shape_b1: tl.constexpr = N // 128 - # shape_b2: tl.constexpr = K // group_size // 4 + # # 2. 5D TMA Descriptors for Scales: #(8388608, 65536, 512, 256, 1) torch.Size([1, 128, 128, 2, 256]) + # rep_m: tl.constexpr = BLOCK_SIZE_M // 128 + # rep_n: tl.constexpr = BLOCK_SIZE_N // 128 + # rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 + # scales_b_shape1: tl.constexpr = N // 128 + # scales_b_shape2: tl.constexpr = K // group_size // 4 # stride_b4: tl.constexpr = 1 # stride_b3: tl.constexpr = 256 # stride_b2: tl.constexpr = 512 - # stride_b1: tl.constexpr = 512 * shape_b2 - # stride_b0: tl.constexpr = stride_b1 * shape_b1 - - #(8388608, 65536, 512, 256, 1) torch.Size([1, 128, 128, 2, 256]) - shape_b1: tl.constexpr = 128 - shape_b2: tl.constexpr = 128 - - stride_b0: tl.constexpr = 8388608 - stride_b1: tl.constexpr = 65536 - stride_b2: tl.constexpr = 512 - stride_b3: tl.constexpr = 256 - stride_b4: tl.constexpr = 1 - - # REQUIRES BLOCK_SIZE_K / BLOCK_SIZE_N to be multiples of 128 - scales_b_desc = tl.make_tensor_descriptor( - scales_ptr, - [1, shape_b1, shape_b2, 2, 256], - [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], - [1, rep_n, rep_k, 2, 256] - ) + # stride_b1: tl.constexpr = 512 * scales_b_shape2 + # stride_b0: tl.constexpr = stride_b1 * scales_b_shape1 + # # REQUIRES BLOCK_SIZE_K / BLOCK_SIZE_N to be multiples of 128 + # scales_b_desc = tl.make_tensor_descriptor( + # scales_ptr, + # [1, scales_b_shape1, scales_b_shape2, 2, 256], + # [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], + # [1, rep_n, rep_k, 2, 256] + # ) + # # 2D TMA - transposed + # scales_b_desc = tl.make_tensor_descriptor( + # scales_ptr, + # [K // group_size, N], + # [stride_meta_g, stride_meta_n], + # [BLOCK_SIZE_K_S, BLOCK_SIZE_N], + # ) #B-scales if(channel_scale_mode == 4): @@ -766,32 +765,40 @@ def gemm_MX_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): #for k in tl.range(num_pid_k): - # if EVEN_M and EVEN_K: - # a = tl.load(a_ptrs, eviction_policy=a_evict) - # else: - # a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - # b = tl.load(b_ptrs, eviction_policy=b_evict) + b = tl.load(b_ptrs, eviction_policy=b_evict) - a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) - b = tl.load_tensor_descriptor(b_desc, [k * BLOCK_SIZE_K_B, pid_n * BLOCK_SIZE_N]).T + #a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) + #b = tl.load_tensor_descriptor(b_desc, [k * BLOCK_SIZE_K_B, pid_n * BLOCK_SIZE_N]).T k_m = k * BLOCK_SIZE_K_S - #scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + + #################################################################################### + # NO TMA + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + + # # 2D TMA + # scales_b = tl.load_tensor_descriptor(scales_b_desc, [k * BLOCK_SIZE_K_S, pid_n * BLOCK_SIZE_N]).T + # 5D Scale Loads and Unpacking - offs_scale_m = pid_m * rep_m - offs_scale_n = pid_n * rep_n - offs_scale_k = k * rep_k + # offs_scale_m = pid_m * rep_m + # offs_scale_n = pid_n * rep_n + # offs_scale_k = k * rep_k - scale_b = tl.load_tensor_descriptor(scales_b_desc, [0, offs_scale_n, offs_scale_k, 0, 0]) - scales_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) + #scale_b = tl.load_tensor_descriptor(scales_b_desc, [0, offs_scale_n, offs_scale_k, 0, 0]) + #scales_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) #https://github.com/triton-lang/triton/blob/main/python/tutorials/10-block-scaled-matmul.py#L220C1-L221C117 if(channel_scale_mode == 4): scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) else: scales_a = scales_a_1s + #################################################################################### acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) From 0dc7d0e2253c5c13533adf573203c72381fb26e5 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 23 Feb 2026 08:22:39 -0800 Subject: [PATCH 04/63] update --- gemlite/quant_utils.py | 59 +++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index ed8c618..63376bc 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -952,6 +952,25 @@ def scale_activations_mxfp4_triton_kernel_v2( #Load mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + + ################################################# + # 1. Device-Side TMA Descriptors + # tensor_desc = tl.make_tensor_descriptor( + # tensor_ptr, + # [M, K], + # [stride_m_t, stride_k_t], + # [BLOCK_SIZE_M, GROUP_SIZE] + # ) + + # out_desc = tl.make_tensor_descriptor( + # out_ptr, + # [M, K // 2], + # [stride_m_o, stride_k_o], + # [BLOCK_SIZE_M, HALF_GROUP_SIZE] + # ) + + #tensor = tl.load_tensor_descriptor(tensor_desc, [pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE]).to(tl.float32) + ################################################# tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) #next power of 2 via log @@ -970,6 +989,7 @@ def scale_activations_mxfp4_triton_kernel_v2( offs_k = pid_k * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) + #tl.store_tensor_descriptor(out_desc, [pid_m * BLOCK_SIZE_M, pid_k * HALF_GROUP_SIZE], out) offs_k = pid_k * 1 + tl.arange(0, 1) tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales_log2) @@ -986,34 +1006,31 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: pad_m = (group_size - M % group_size) % group_size M_padded = M + pad_m - # out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - # scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - out = tensor.to(torch.uint8)[:M, :K // 2] - scales = torch.full((M_padded, K // group_size), 120, device=tensor.device, dtype=torch.uint8) - #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) BLOCK_SIZE_M = group_size grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) device_index = tensor.device.index - # scale_activations_mxfp4_triton_kernel_v2[grid]( - # tensor, - # out, - # scales, - # thr_pos[device_index], - # M, K, - # tensor.stride(0), tensor.stride(1), - # scales.stride(0), scales.stride(1), - # out.stride(0), out.stride(1), - # ######################### - # eps_exp=eps_exp, - # GROUP_SIZE=group_size, - # BLOCK_SIZE_M=BLOCK_SIZE_M, - # num_stages=2, - # num_warps=4, - # ) + scale_activations_mxfp4_triton_kernel_v2[grid]( + tensor, + out, + scales, + thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + eps_exp=eps_exp, + GROUP_SIZE=group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + num_stages=1, + num_warps=4, + ) return out, scales From da98055cb1850f343a3efdf1b4109b24e31a2f0a Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 24 Feb 2026 08:00:18 -0800 Subject: [PATCH 05/63] update --- gemlite/core.py | 12 +- gemlite/quant_utils.py | 293 ++++++++++++------------- gemlite/triton_kernels/gemm_kernels.py | 136 ++++++------ tests/test_gemlitelineartriton.py | 2 +- tests/test_mxfp.py | 29 ++- 5 files changed, 240 insertions(+), 232 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index afff0f7..6977fa5 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -497,7 +497,7 @@ def pack( self.channel_scale_mode = 0 ################################ - # TMA + # # TMA # K, N = self.W_q.shape # if(self.input_dtype in [DType.MXFP4, DType.NVFP4]): @@ -506,14 +506,12 @@ def pack( # else: # group_size = self.W_q.numel() // self.scales.numel() - # #self.scales = self.scales.contiguous().T # Transposed 2D TMA layout - # #self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() # 5D TMA layout for the scales: - - #self.W_q = self.W_q.contiguous().T #Transposed for tma + # self.W_q = self.W_q.contiguous().T #Transposed for tma - #self.W_q = self.W_q.contiguous() + # #self.scales = self.scales.contiguous().T # Transposed 2D TMA layout + # #self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() # 5D TMA layout for the scales: - #print(self.scales.stride(), self.scales.shape) + # #print(self.scales.stride(), self.scales.shape) ################################ if(self.scales is not None): diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 63376bc..c9586cf 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -18,6 +18,7 @@ def get_dtype_range(compute_dtype: torch.dtype) -> float: dtype_info = torch.iinfo(compute_dtype) return dtype_info.min, dtype_info.max +NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count NVFP4_META_SCALE = 0.05 #Temporary NVFP logic #################################################################################################################### #MXFP4 / NVFP4 weight quantizer @@ -265,165 +266,158 @@ def round_triton_amd(tensor): else: round_triton = round_triton_nvidia -# @triton.jit -# def scale_activations_per_token_kernel( -# tensor_ptr, scale_ptr, y_ptr, -# M, K, -# stride_m, stride_k, stride_sm, -# ROUND: tl.constexpr, -# UNROLL: tl.constexpr, -# min_val: tl.constexpr, -# max_val: tl.constexpr, -# fp32_scale: tl.constexpr, -# BLOCK_M: tl.constexpr, -# BLOCK_K: tl.constexpr, -# ): -# pid_m = tl.program_id(0) * UNROLL -# pid_k = tl.program_id(1) - -# offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) -# offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - -# for m in range(UNROLL): -# mask = ((offs_m < M)[:, None] & (offs_k < K)[None, :]).to(tl.int1) -# in_ptrs = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k -# tensor = tl.load(tensor_ptr + in_ptrs, mask=mask, other=0.0) -# if fp32_scale: -# tensor = tensor.to(tl.float32) - -# scales_x = tl.max(tl.abs(tensor), axis=1, keep_dims=True) -# scales_x /= max_val -# scales_x = tl.maximum(scales_x, 1e-6) -# tensor /= scales_x -# tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) - -# if ROUND: -# tensor = round_triton(tensor) - -# tl.store(scale_ptr + offs_m[:, None] * stride_sm, scales_x) -# tl.store(y_ptr + in_ptrs, tensor, mask=mask) -# offs_m += BLOCK_M - -# def scale_activations_per_token_triton( -# tensor: Tensor, w_dtype: torch.dtype, fp32_scale: bool = True -# ) -> Tuple[Tensor, Tensor]: -# min_val, max_val = get_dtype_range(w_dtype) -# x_shape = tensor.shape -# tensor = tensor.view(-1, tensor.shape[-1]) -# M, K = tensor.shape -# scales = torch.empty( -# (M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device -# ) -# y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) - -# UNROLL = 1 # max(1, M // 128) -# BLOCK_M = 1 -# BLOCK_K = triton.next_power_of_2(K) -# grid = (triton.cdiv(M, BLOCK_M * UNROLL), triton.cdiv(K, BLOCK_K)) - -# ROUND = not w_dtype.is_floating_point - -# scale_activations_per_token_kernel[grid]( -# tensor, -# scales, -# y, -# M, -# K, -# tensor.stride(0), -# tensor.stride(1), -# scales.stride(0), -# min_val=min_val, -# max_val=max_val, -# fp32_scale=fp32_scale, -# ROUND=ROUND, -# UNROLL=UNROLL, -# BLOCK_M=BLOCK_M, -# BLOCK_K=BLOCK_K, -# num_stages=1, -# num_warps=4, -# ) - -# return y.view(x_shape), scales - - -# from typing import Tuple - -# @triton.autotune( -# configs=[ -# triton.Config({'BLOCK_M': 1}, num_warps=8, num_stages=1), -# triton.Config({'BLOCK_M': 2}, num_warps=8, num_stages=1), -# triton.Config({'BLOCK_M': 4}, num_warps=8, num_stages=1), -# ], -# key=['M', 'K'] -# ) -# @triton.jit -# def scale_activations_single_pass_kernel( -# tensor_ptr, scale_ptr, y_ptr, -# M, K, -# stride_m, stride_k, stride_sm, -# min_val: tl.constexpr, max_val: tl.constexpr, -# fp32_scale: tl.constexpr, ROUND: tl.constexpr, -# BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, -# ): -# pid_m = tl.program_id(0) - -# offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) -# offs_k = tl.arange(0, BLOCK_K) -# m_mask = offs_m < M -# k_mask = offs_k < K -# mask = m_mask[:, None] & k_mask[None, :] +@triton.jit +def scale_activations_per_token_triton_v1_kernel( + tensor_ptr, scale_ptr, y_ptr, + M, K, + stride_m, stride_k, stride_sm, + ROUND: tl.constexpr, + UNROLL: tl.constexpr, + min_val: tl.constexpr, + max_val: tl.constexpr, + fp32_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) * UNROLL + pid_k = tl.program_id(1) -# offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) -# tensor = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) - -# if fp32_scale: -# tensor = tensor.to(tl.float32) + for m in range(UNROLL): + mask = ((offs_m < M)[:, None] & (offs_k < K)[None, :]).to(tl.int1) + in_ptrs = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + tensor = tl.load(tensor_ptr + in_ptrs, mask=mask, other=0.0) + if fp32_scale: + tensor = tensor.to(tl.float32) -# scales_x = tl.max(tl.abs(tensor), axis=1) / max_val -# scales_x = tl.maximum(scales_x, 1e-6) -# tensor = tensor / scales_x[:, None] -# tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) + scales_x = tl.max(tl.abs(tensor), axis=1, keep_dims=True) + scales_x /= max_val + scales_x = tl.maximum(scales_x, 1e-6) + tensor /= scales_x + tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) -# if ROUND: -# tensor = round_triton(tensor) + if ROUND: + tensor = round_triton(tensor) -# tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) -# tl.store(y_ptr + offsets, tensor, mask=mask) + tl.store(scale_ptr + offs_m[:, None] * stride_sm, scales_x) + tl.store(y_ptr + in_ptrs, tensor, mask=mask) + offs_m += BLOCK_M -# def scale_activations_per_token_triton( -# tensor: torch.Tensor, w_dtype: torch.dtype, fp32_scale: bool = True -# ) -> Tuple[torch.Tensor, torch.Tensor]: - -# min_val, max_val = get_dtype_range(w_dtype) +def scale_activations_per_token_triton_v1( + tensor: Tensor, w_dtype: torch.dtype, fp32_scale: bool = True +) -> Tuple[Tensor, Tensor]: + min_val, max_val = get_dtype_range(w_dtype) + x_shape = tensor.shape + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + scales = torch.empty( + (M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device + ) + y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + + UNROLL = 1 # max(1, M // 128) + BLOCK_M = 1 + BLOCK_K = triton.next_power_of_2(K) + grid = (triton.cdiv(M, BLOCK_M * UNROLL), triton.cdiv(K, BLOCK_K)) + + ROUND = not w_dtype.is_floating_point + + scale_activations_per_token_triton_v1_kernel[grid]( + tensor, + scales, + y, + M, + K, + tensor.stride(0), + tensor.stride(1), + scales.stride(0), + min_val=min_val, + max_val=max_val, + fp32_scale=fp32_scale, + ROUND=ROUND, + UNROLL=UNROLL, + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + num_stages=1, + num_warps=4, + ) + + return y.view(x_shape), scales + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 1}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_M': 2}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_M': 4}, num_warps=8, num_stages=1), + ], + key=['M', 'K'] +) +@triton.jit +def scale_activations_per_token_triton_v2_kernel( + tensor_ptr, scale_ptr, y_ptr, + M, K, + stride_m, stride_k, stride_sm, + min_val: tl.constexpr, max_val: tl.constexpr, + fp32_scale: tl.constexpr, ROUND: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) -# x_shape = tensor.shape -# tensor = tensor.view(-1, tensor.shape[-1]) -# M, K = tensor.shape + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_K) + m_mask = offs_m < M + k_mask = offs_k < K + mask = m_mask[:, None] & k_mask[None, :] + + offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + + tensor = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) -# scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) -# y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + if fp32_scale: + tensor = tensor.to(tl.float32) -# grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), ) + scales_x = tl.max(tl.abs(tensor), axis=1) / max_val + scales_x = tl.maximum(scales_x, 1e-6) + tensor = tensor / scales_x[:, None] + tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) -# BLOCK_K = triton.next_power_of_2(K) -# ROUND = not w_dtype.is_floating_point + if ROUND: + tensor = round_triton(tensor) -# scale_activations_single_pass_kernel[grid]( -# tensor, scales, y, -# M, K, -# tensor.stride(0), tensor.stride(1), -# scales.stride(0), -# min_val=min_val, max_val=max_val, -# fp32_scale=fp32_scale, ROUND=ROUND, -# BLOCK_K=BLOCK_K -# ) + tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) + tl.store(y_ptr + offsets, tensor, mask=mask) -# return y.view(x_shape), scales +def scale_activations_per_token_triton_v2( + tensor: torch.Tensor, w_dtype: torch.dtype, fp32_scale: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + + min_val, max_val = get_dtype_range(w_dtype) + + x_shape = tensor.shape + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) + y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), ) + BLOCK_K = triton.next_power_of_2(K) + ROUND = not w_dtype.is_floating_point -from typing import Tuple + scale_activations_per_token_triton_v2_kernel[grid]( + tensor, scales, y, + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), + min_val=min_val, max_val=max_val, + fp32_scale=fp32_scale, ROUND=ROUND, + BLOCK_K=BLOCK_K + ) + + return y.view(x_shape), scales @triton.autotune( configs=[ @@ -434,7 +428,7 @@ def round_triton_amd(tensor): key=['M', 'K'] ) @triton.jit -def scale_activations_persistent_kernel( +def scale_activations_per_token_triton_v3_kernel( tensor_ptr, scale_ptr, y_ptr, M, K, stride_m, stride_k, stride_sm, @@ -471,10 +465,7 @@ def scale_activations_persistent_kernel( tl.store(y_ptr + offsets, tensor, mask=mask) tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) - - -NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count -def scale_activations_per_token_triton( +def scale_activations_per_token_triton_v3( tensor: torch.Tensor, w_dtype: torch.dtype, fp32_scale: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -492,7 +483,7 @@ def scale_activations_per_token_triton( BLOCK_K = triton.next_power_of_2(K) ROUND = not w_dtype.is_floating_point - scale_activations_persistent_kernel[grid]( + scale_activations_per_token_triton_v3_kernel[grid]( tensor, scales, y, M, K, tensor.stride(0), tensor.stride(1), @@ -1132,7 +1123,7 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso return out, scales #################################################################################################################### -scale_activations_per_token = scale_activations_per_token_triton +scale_activations_per_token = scale_activations_per_token_triton_v3 scale_activations_mxfp8 = scale_activations_mxfp8_triton_v2 scale_activations_mxfp4 = scale_activations_mxfp4_triton_v2 scale_activations_nvfp4 = scale_activations_nvfp4_triton_v2 diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 1238c1a..e172d51 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -23,7 +23,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): t = nargs['type_id'] a_sizeof = nargs['a_sizeof'] b_sizeof = nargs['b_sizeof'] - + #Check cache if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE): signature = str(tuple([get_closest_m(m), n, k, g, e, t])) @@ -38,6 +38,10 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) config["NUM_STAGES"] = num_stages + + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) yield triton.Config(config, num_stages=num_stages, num_warps=num_warps) return @@ -76,7 +80,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - ###################################################### # # FOR TMA # if block_size_n % 128 > 0: @@ -109,25 +112,25 @@ def kernel_config_pruner(configs, nargs, **kwargs): if(num_stages == 0): continue #config too large ########################################### - if(load_scales_as_block):#tmp MXFP fix + if(load_scales_as_block):#tmp MXFP fix with TMA block_size_k = max(block_size_k, 64) block_size_n = max(block_size_n, 64) ########################################### key = (block_size_m, block_size_n, block_size_k, group_size_m, A_load_order, num_stages, num_warps) - EVEN_M = (m % block_size_m == 0) - EVEN_N = (n % block_size_n == 0) - EVEN_K = (k % block_size_k == 0) + even_m = (m % block_size_m == 0) + even_n = (n % block_size_n == 0) + even_k = (k % block_size_k == 0) new_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, - "EVEN_M": EVEN_M, - "EVEN_N": EVEN_N, - "EVEN_K": EVEN_K, + "EVEN_M": even_m, + "EVEN_N": even_n, + "EVEN_K": even_k, "A_load_order": A_load_order, "NUM_STAGES": num_stages, } @@ -200,7 +203,7 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':4}, num_warps=4, num_stages=4),] + return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':4}, num_warps=4, num_stages=4),] ######################################################################################################################################################################## #AMD - Instinct MI300X @@ -249,7 +252,7 @@ def get_fast_autotune_config_amd(): return configs def get_default_config_amd(): - return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] + return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] ######################################################################################################################################################################## if IS_HIP: @@ -308,10 +311,13 @@ def gemm_INT_kernel( ######### tuning params ######### BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_K: tl.constexpr, EVEN_N: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = "evict_last", a_evict: tl.constexpr = "", b_evict: tl.constexpr = "evict_first", @@ -506,7 +512,10 @@ def gemm_INT_kernel_persistent_tma( ######### tuning params ######### BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_K: tl.constexpr, EVEN_N: tl.constexpr, + ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, ################################# A_load_order: tl.constexpr = 0, data_contiguous: tl.constexpr = True, @@ -658,14 +667,19 @@ def gemm_MX_kernel( ######### tuning params ######### BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_K: tl.constexpr, EVEN_N: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = "evict_last", a_evict: tl.constexpr = "", b_evict: tl.constexpr = "", meta_scale_norm: tl.constexpr = (0.05 ** 2), + ################################# + use_tma: tl.constexpr = False, ): pid = tl.program_id(axis=0) @@ -712,20 +726,21 @@ def gemm_MX_kernel( scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] - # a_desc = tl.make_tensor_descriptor( - # a_ptr, - # [M, K // elements_per_sample_a], - # [stride_am, stride_ak], - # [BLOCK_SIZE_M, BLOCK_SIZE_K_A] - # ) - - # # Transposed - # b_desc = tl.make_tensor_descriptor( - # b_ptr, - # [N, K // elements_per_sample], - # [stride_bn, stride_bk], - # [BLOCK_SIZE_N, BLOCK_SIZE_K_B] - # ) + if use_tma: + a_desc = tl.make_tensor_descriptor( + a_ptr, + [M, K // elements_per_sample_a], + [stride_am, stride_ak], + [BLOCK_SIZE_M, BLOCK_SIZE_K_A] + ) + + # Transposed + b_desc = tl.make_tensor_descriptor( + b_ptr, + [N, K // elements_per_sample], + [stride_bn, stride_bk], + [BLOCK_SIZE_N, BLOCK_SIZE_K_B] + ) # # 2. 5D TMA Descriptors for Scales: #(8388608, 65536, 512, 256, 1) torch.Size([1, 128, 128, 2, 256]) # rep_m: tl.constexpr = BLOCK_SIZE_M // 128 @@ -759,29 +774,26 @@ def gemm_MX_kernel( scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g # Used in channel-wise MXPF8 version - #scales_b = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): - #for k in tl.range(num_pid_k): - if EVEN_M and EVEN_K: - a = tl.load(a_ptrs, eviction_policy=a_evict) + # Load A and B tiles + if use_tma: + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) + b = tl.load_tensor_descriptor(b_desc, [k * BLOCK_SIZE_K_B, pid_n * BLOCK_SIZE_N]).T else: - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - - b = tl.load(b_ptrs, eviction_policy=b_evict) - - #a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) - #b = tl.load_tensor_descriptor(b_desc, [k * BLOCK_SIZE_K_B, pid_n * BLOCK_SIZE_N]).T + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - k_m = k * BLOCK_SIZE_K_S - + b = tl.load(b_ptrs, eviction_policy=b_evict) #################################################################################### + k_m = k * BLOCK_SIZE_K_S # NO TMA scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) - # # 2D TMA # scales_b = tl.load_tensor_descriptor(scales_b_desc, [k * BLOCK_SIZE_K_S, pid_n * BLOCK_SIZE_N]).T @@ -802,11 +814,11 @@ def gemm_MX_kernel( acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) - a_ptrs += BLOCK_SIZE_K_A * stride_ak - b_ptrs += BLOCK_SIZE_K_B * stride_bk - - if not EVEN_K: - a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) + if not use_tma: + a_ptrs += BLOCK_SIZE_K_A * stride_ak + b_ptrs += BLOCK_SIZE_K_B * stride_bk + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) #NVFP4 meta-scale if(group_size == 16): @@ -819,25 +831,25 @@ def gemm_MX_kernel( scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=dtype) acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) - + ############################################################################################################# - # #Output - # offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - # offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - # c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - # mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) - # tl.store(c_ptrs, acc, mask=mask) - - c_desc = tl.make_tensor_descriptor( - c_ptr, - [M, N], - [stride_cm, stride_cn], - [BLOCK_SIZE_M, BLOCK_SIZE_N] - ) - tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) + #Output + if use_tma: + c_desc = tl.make_tensor_descriptor( + c_ptr, + [M, N], + [stride_cm, stride_cn], + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) + mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) + tl.store(c_ptrs, acc, mask=mask) - PRINTED = False def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index bc0d942..bc2a4fc 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -19,7 +19,7 @@ def is_fp8_supported(): gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] reset_config() -set_autotune(False) +#set_autotune(False) KERNEL.ENABLE_CACHING = False diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 950009d..7efc966 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -14,14 +14,14 @@ def is_fp8_supported(device_index=0): device = 'cuda:0' compute_dtype = torch.bfloat16 #float16, bfloat16 -matmul_types = ['GEMM_SPLITK', 'GEMM'] #TODO: add GEMV use-cases +matmul_types = ['GEMM'] #GEMM_SPLITK #TODO: add GEMV use-cases reset_config() set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) in_features, out_features = 4096, 2048 -batch_sizes = [1, 4, 16] +batch_sizes = [16, 64, 512] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False @@ -42,13 +42,20 @@ def eval(self, gemlite_linear, tol: float = 1e-3): err = (y_ref - y_gem).abs().mean().item() self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + # @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + # def test_A16W8_MXFP(self): + # gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) + # self.assertTrue(not gemlite_linear.scaled_activations) + # self.eval(gemlite_linear, tol = 2e-4) + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_A16W8_MXFP(self): - gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + def test_A8W8_MXFP_post_scale_dynamic(self): + gemlite_linear = A8W8_MXFP_dynamic(device=device, dtype=compute_dtype, post_scale=True).from_linear(linear_layer, del_orig=False) self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) - self.assertTrue(not gemlite_linear.scaled_activations) + self.assertTrue(gemlite_linear.scaled_activations) self.eval(gemlite_linear, tol = 2e-4) - + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W8_MXFP_dynamic(self): gemlite_linear = A8W8_MXFP_dynamic(device=device, dtype=compute_dtype, post_scale=False).from_linear(linear_layer, del_orig=False) @@ -56,11 +63,11 @@ def test_A8W8_MXFP_dynamic(self): self.assertTrue(gemlite_linear.scaled_activations) self.eval(gemlite_linear, tol = 2e-4) - def test_A16W4_MXFP(self): - gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) - self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) - self.assertTrue(not gemlite_linear.scaled_activations) - self.eval(gemlite_linear, tol = 7e-4) + # def test_A16W4_MXFP(self): + # gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) + # self.assertTrue(not gemlite_linear.scaled_activations) + # self.eval(gemlite_linear, tol = 7e-4) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W4_MXFP_dynamic(self): From fc181613fccca17109474453c5bd95676461d8c5 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 24 Feb 2026 08:25:57 -0800 Subject: [PATCH 06/63] update --- gemlite/quant_utils.py | 98 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 6 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index c9586cf..106eea0 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -499,14 +499,14 @@ def scale_activations_per_token_triton_v3( #MXFP8 #################################################################################################################### @triton.jit -def next_power_of_2_log_triton(val, eps: tl.constexpr): +def next_power_of_2_log_triton(val, eps_exp: tl.constexpr): exp = tl.ceil(tl.log2(val)).to(tl.int32) exp = tl.maximum(tl.minimum(exp, 254), 127 + eps_exp) - scales = tl.where(exp >= 0, 1 << scales_log2, 1.0 / (1 << (-exp))) + scales = tl.where(exp >= 0, 1 << exp, 1.0 / (1 << (-exp))) return scales, exp @triton.jit -def next_power_of_2_logapprox_triton(val, eps_exp: tl.constexpr): +def next_power_of_2_ptx_triton(val, eps_exp: tl.constexpr): exp = tl.inline_asm_elementwise( """ { @@ -537,7 +537,7 @@ def next_power_of_2_bitwise_triton(val, eps_exp: tl.constexpr): scales = tl.cast(yi, tl.float32, bitcast=True) return scales, exp -next_power_of_2_triton = next_power_of_2_bitwise_triton +next_power_of_2_triton = next_power_of_2_ptx_triton @torch.compile(fullgraph=True) def scale_activations_mxfp8_torch( @@ -697,7 +697,7 @@ def scale_activations_mxfp8_triton_v2( ) -> Tuple[torch.Tensor, torch.Tensor]: group_size: int = 32 eps_exp: int = -30 - eps: float = 2 ** -30 + eps: float = 2 ** eps_exp min_val, max_val = get_dtype_range(w_dtype) tensor = tensor.contiguous() @@ -735,6 +735,92 @@ def scale_activations_mxfp8_triton_v2( return out, scales +@triton.jit +def scale_activations_mxfp8_triton_kernel_v3( + tensor_ptr, + out_ptr, + scales_ptr, + M, K, + stride_m_t, stride_k_t, + stride_m_s, stride_k_s, + stride_m_o, stride_k_o, + ######################### + min_val: tl.constexpr, + max_val: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + + #Load + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + #next power of 2 via log + scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor), axis=1, keep_dims=True) / max_val, eps_exp) + + #Map to index + out = tensor / scales + out = tl.clamp(out, min=min_val, max=max_val) + out = out.to(out_dtype) + + #Store + out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) + + offs_k = pid_k * 1 + tl.arange(0, 1) + tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales_log2) + + +def scale_activations_mxfp8_triton_v3( + tensor: torch.Tensor, w_dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 32 + eps_exp: int = -30 + eps: float = 2 ** -30 + min_val, max_val = get_dtype_range(w_dtype) + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) + BLOCK_SIZE_M = group_size + grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) + device_index = tensor.device.index + + scale_activations_mxfp8_triton_kernel_v3[grid]( + tensor, + out, + scales, + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + min_val=min_val, + max_val=max_val, + eps_exp=eps_exp, + GROUP_SIZE=group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + num_stages=1, + num_warps=4, + ) + + return out, scales #################################################################################################################### #MXPF4 / NVFP4 @@ -1124,6 +1210,6 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso #################################################################################################################### scale_activations_per_token = scale_activations_per_token_triton_v3 -scale_activations_mxfp8 = scale_activations_mxfp8_triton_v2 +scale_activations_mxfp8 = scale_activations_mxfp8_triton_v3 scale_activations_mxfp4 = scale_activations_mxfp4_triton_v2 scale_activations_nvfp4 = scale_activations_nvfp4_triton_v2 From 0d02f97f37ced13103457bfad8a0ea8f0ccb63fc Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 24 Feb 2026 09:18:23 -0800 Subject: [PATCH 07/63] update --- gemlite/quant_utils.py | 47 +++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 106eea0..e859bf6 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -501,30 +501,41 @@ def scale_activations_per_token_triton_v3( @triton.jit def next_power_of_2_log_triton(val, eps_exp: tl.constexpr): exp = tl.ceil(tl.log2(val)).to(tl.int32) + exp = exp + 127 exp = tl.maximum(tl.minimum(exp, 254), 127 + eps_exp) - scales = tl.where(exp >= 0, 1 << exp, 1.0 / (1 << (-exp))) + scales = tl.cast(exp << 23, tl.float32, bitcast=True) return scales, exp @triton.jit def next_power_of_2_ptx_triton(val, eps_exp: tl.constexpr): - exp = tl.inline_asm_elementwise( - """ - { - lg2.approx.f32 $1, $1; - cvt.rpi.f32.f32 $1, $1; - cvt.rzi.s32.f32 $0, $1; - } + scales, biased_exp = tl.inline_asm_elementwise( + f""" + {{ + .reg .f32 f_log; + .reg .f32 f_ceil; + .reg .s32 r_exp; + .reg .f32 f_clamped; + + lg2.approx.f32 f_log, $2; + cvt.rpi.f32.f32 f_ceil, f_log; + cvt.rzi.s32.f32 r_exp, f_ceil; + + max.s32 r_exp, r_exp, {eps_exp}; + min.s32 r_exp, r_exp, 127; + + add.s32 $1, r_exp, 127; + cvt.rn.f32.s32 f_clamped, r_exp; + ex2.approx.f32 $0, f_clamped; + }} """, - "=r,r", + "=f,=r,f", [val], - dtype=tl.int32, + dtype=(tl.float32, tl.int32), is_pure=True, pack=1 ) - - exp = tl.maximum(tl.minimum(exp, 254), 127 + eps_exp) - scales = tl.where(exp >= 0, 1 << exp, 1.0 / (1 << (-exp))) - return scales, exp + + return scales, biased_exp @triton.jit def next_power_of_2_bitwise_triton(val, eps_exp: tl.constexpr): @@ -533,12 +544,14 @@ def next_power_of_2_bitwise_triton(val, eps_exp: tl.constexpr): mant = xi & 0x7FFFFF exp += tl.where(mant != 0, 1, 0) exp = tl.maximum(tl.minimum(exp, 254), 127 + eps_exp) - yi = exp << 23 - scales = tl.cast(yi, tl.float32, bitcast=True) + scales = tl.cast(exp << 23, tl.float32, bitcast=True) return scales, exp -next_power_of_2_triton = next_power_of_2_ptx_triton +next_power_of_2_triton = next_power_of_2_bitwise_triton +#################################################################################################################### +#MXFP8 +#################################################################################################################### @torch.compile(fullgraph=True) def scale_activations_mxfp8_torch( tensor: Tensor, w_dtype: torch.dtype = torch.float8_e4m3fn From 1a66408e9a2f454fb04d535386d1a221cf8642cc Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 24 Feb 2026 10:37:28 -0800 Subject: [PATCH 08/63] update --- gemlite/quant_utils.py | 433 ++++++++++++++++++++++------------------- tests/test_mxfp.py | 9 +- 2 files changed, 234 insertions(+), 208 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index e859bf6..b876849 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -228,6 +228,29 @@ def dequantize(self, W_q, scales, shape = None, dtype = None): #################################################################################################################### #INT8 / FP8 activations #################################################################################################################### +def prune_large_blocks(configs, named_args, **kwargs): + M = named_args['M'] + + pruned = [] + for config in configs: + if config.kwargs['BLOCK_SIZE_M'] <= M: + pruned.append(config) + + if not pruned: + for config in configs: + new_kwargs = config.kwargs.copy() + new_kwargs['BLOCK_SIZE_M'] = 16 + + pruned.append( + triton.Config( + new_kwargs, + num_warps=config.num_warps, + num_stages=config.num_stages + ) + ) + + return pruned + # Main activation scaling functions @torch.compile(fullgraph=True) def scale_activations_per_token_torch( @@ -250,7 +273,7 @@ def scale_activations_per_token_torch( if not w_dtype.is_floating_point: out.round_() - out = out.to(dtype=w_dtype) + out = out.to(dtype=w_dtype) return out.view(out_shape), scales @triton.jit @@ -270,20 +293,22 @@ def round_triton_amd(tensor): def scale_activations_per_token_triton_v1_kernel( tensor_ptr, scale_ptr, y_ptr, M, K, - stride_m, stride_k, stride_sm, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, ROUND: tl.constexpr, UNROLL: tl.constexpr, min_val: tl.constexpr, max_val: tl.constexpr, fp32_scale: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(0) * UNROLL pid_k = tl.program_id(1) - offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) for m in range(UNROLL): mask = ((offs_m < M)[:, None] & (offs_k < K)[None, :]).to(tl.int1) @@ -303,7 +328,7 @@ def scale_activations_per_token_triton_v1_kernel( tl.store(scale_ptr + offs_m[:, None] * stride_sm, scales_x) tl.store(y_ptr + in_ptrs, tensor, mask=mask) - offs_m += BLOCK_M + offs_m += BLOCK_SIZE_M def scale_activations_per_token_triton_v1( tensor: Tensor, w_dtype: torch.dtype, fp32_scale: bool = True @@ -318,9 +343,9 @@ def scale_activations_per_token_triton_v1( y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) UNROLL = 1 # max(1, M // 128) - BLOCK_M = 1 - BLOCK_K = triton.next_power_of_2(K) - grid = (triton.cdiv(M, BLOCK_M * UNROLL), triton.cdiv(K, BLOCK_K)) + BLOCK_SIZE_M = 1 + BLOCK_SIZE_K = triton.next_power_of_2(K) + grid = (triton.cdiv(M, BLOCK_SIZE_M * UNROLL), triton.cdiv(K, BLOCK_SIZE_K)) ROUND = not w_dtype.is_floating_point @@ -338,8 +363,8 @@ def scale_activations_per_token_triton_v1( fp32_scale=fp32_scale, ROUND=ROUND, UNROLL=UNROLL, - BLOCK_M=BLOCK_M, - BLOCK_K=BLOCK_K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, num_stages=1, num_warps=4, ) @@ -348,9 +373,9 @@ def scale_activations_per_token_triton_v1( @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 1}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_M': 2}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_M': 4}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 1}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 2}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 4}, num_warps=8, num_stages=1), ], key=['M', 'K'] ) @@ -358,15 +383,20 @@ def scale_activations_per_token_triton_v1( def scale_activations_per_token_triton_v2_kernel( tensor_ptr, scale_ptr, y_ptr, M, K, - stride_m, stride_k, stride_sm, - min_val: tl.constexpr, max_val: tl.constexpr, - fp32_scale: tl.constexpr, ROUND: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, + min_val: tl.constexpr, + max_val: tl.constexpr, + fp32_scale: tl.constexpr, + ROUND: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(0) - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_K) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) m_mask = offs_m < M k_mask = offs_k < K mask = m_mask[:, None] & k_mask[None, :] @@ -402,9 +432,9 @@ def scale_activations_per_token_triton_v2( scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), ) - BLOCK_K = triton.next_power_of_2(K) + BLOCK_SIZE_K = triton.next_power_of_2(K) ROUND = not w_dtype.is_floating_point scale_activations_per_token_triton_v2_kernel[grid]( @@ -414,16 +444,16 @@ def scale_activations_per_token_triton_v2( scales.stride(0), min_val=min_val, max_val=max_val, fp32_scale=fp32_scale, ROUND=ROUND, - BLOCK_K=BLOCK_K + BLOCK_SIZE_K=BLOCK_SIZE_K ) return y.view(x_shape), scales @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 1}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_M': 2}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_M': 4}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 1}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 2}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 4}, num_warps=8, num_stages=1), ], key=['M', 'K'] ) @@ -431,20 +461,25 @@ def scale_activations_per_token_triton_v2( def scale_activations_per_token_triton_v3_kernel( tensor_ptr, scale_ptr, y_ptr, M, K, - stride_m, stride_k, stride_sm, - min_val: tl.constexpr, max_val: tl.constexpr, - fp32_scale: tl.constexpr, ROUND: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, + min_val: tl.constexpr, + max_val: tl.constexpr, + fp32_scale: tl.constexpr, + ROUND: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): start_pid = tl.program_id(0) num_programs = tl.num_programs(0) - num_tiles = tl.cdiv(M, BLOCK_M) + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) - offs_k = tl.arange(0, BLOCK_K) + offs_k = tl.arange(0, BLOCK_SIZE_K) k_mask = offs_k < K for pid_m in range(start_pid, num_tiles, num_programs): - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) m_mask = offs_m < M mask = m_mask[:, None] & k_mask[None, :] @@ -478,9 +513,9 @@ def scale_activations_per_token_triton_v3( scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) - grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META['BLOCK_M'])), ) + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META['BLOCK_SIZE_M'])), ) - BLOCK_K = triton.next_power_of_2(K) + BLOCK_SIZE_K = triton.next_power_of_2(K) ROUND = not w_dtype.is_floating_point scale_activations_per_token_triton_v3_kernel[grid]( @@ -490,7 +525,7 @@ def scale_activations_per_token_triton_v3( scales.stride(0), min_val=min_val, max_val=max_val, fp32_scale=fp32_scale, ROUND=ROUND, - BLOCK_K=BLOCK_K + BLOCK_SIZE_K=BLOCK_SIZE_K ) return y.view(x_shape), scales @@ -667,9 +702,13 @@ def scale_activations_mxfp8_triton_kernel_v2( out_ptr, scales_ptr, M, K, - stride_m_t, stride_k_t, - stride_m_s, stride_k_s, - stride_m_o, stride_k_o, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, ######################### min_val: tl.constexpr, max_val: tl.constexpr, @@ -748,15 +787,33 @@ def scale_activations_mxfp8_triton_v2( return out, scales +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) @triton.jit def scale_activations_mxfp8_triton_kernel_v3( tensor_ptr, out_ptr, scales_ptr, M, K, - stride_m_t, stride_k_t, - stride_m_s, stride_k_s, - stride_m_o, stride_k_o, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, ######################### min_val: tl.constexpr, max_val: tl.constexpr, @@ -764,40 +821,43 @@ def scale_activations_mxfp8_triton_kernel_v3( GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, ): - pid_m = tl.program_id(axis=0) - pid_k = tl.program_id(axis=1) - out_dtype: tl.constexpr = out_ptr.dtype.element_ty - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + + tensor_block_ptr = tl.make_block_ptr( + base=tensor_ptr, shape=(M, K), strides=(stride_m_t, stride_k_t), + offsets=(pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE), + block_shape=(BLOCK_SIZE_M, GROUP_SIZE), order=(1, 0) + ) + + out_block_ptr = tl.make_block_ptr( + base=out_ptr, shape=(M, K), strides=(stride_m_o, stride_k_o), + offsets=(pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE), + block_shape=(BLOCK_SIZE_M, GROUP_SIZE), order=(1, 0) + ) - #Load - mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) - tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) - tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + tensor = tl.load(tensor_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) - #next power of 2 via log - scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor), axis=1, keep_dims=True) / max_val, eps_exp) + abs_max = tl.max(tl.abs(tensor), axis=1, keep_dims=True) + scales, scales_log2 = next_power_of_2_triton(abs_max / max_val, eps_exp) - #Map to index - out = tensor / scales + out = tensor * (1.0 / scales) out = tl.clamp(out, min=min_val, max=max_val) - out = out.to(out_dtype) + out = out.to(out_ptr.dtype.element_ty) - #Store - out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) - tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) - - offs_k = pid_k * 1 + tl.arange(0, 1) - tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales_log2) + tl.store(out_block_ptr, out, boundary_check=(0, 1)) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask_m = offs_m < M + scales_ptrs = scales_ptr + (offs_m * stride_m_s + pid_k * stride_k_s) + tl.store(scales_ptrs, tl.reshape(scales_log2, (BLOCK_SIZE_M, )), mask=mask_m) def scale_activations_mxfp8_triton_v3( tensor: torch.Tensor, w_dtype: torch.dtype = torch.float8_e4m3fn ) -> Tuple[torch.Tensor, torch.Tensor]: group_size: int = 32 eps_exp: int = -30 - eps: float = 2 ** -30 + eps: float = 2 ** eps_exp min_val, max_val = get_dtype_range(w_dtype) tensor = tensor.contiguous() @@ -810,9 +870,7 @@ def scale_activations_mxfp8_triton_v3( out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) - BLOCK_SIZE_M = group_size - grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index scale_activations_mxfp8_triton_kernel_v3[grid]( @@ -828,9 +886,6 @@ def scale_activations_mxfp8_triton_v3( max_val=max_val, eps_exp=eps_exp, GROUP_SIZE=group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=1, - num_warps=4, ) return out, scales @@ -933,104 +988,56 @@ def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: ) return W_q, scales +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) @triton.jit -def scale_activations_mxfp4_triton_kernel_v1( - tensor_ptr, - out_ptr, - scales_ptr, - thr_pos_ptr, - E, - eps_exp: tl.constexpr, - UNROLL: tl.constexpr, - GROUP_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) * UNROLL - - HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 - out_dtype: tl.constexpr = out_ptr.dtype.element_ty - thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] - - for m in range(UNROLL): - #Load - offs = pid * GROUP_SIZE + tl.arange(0, GROUP_SIZE) - mask = (offs < E).to(tl.int1) - tensor = tl.load(tensor_ptr + offs, mask=mask, other=0.0).to(tl.float32) - - scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor)) / 6., eps_exp) - - #Map to index - wq = tensor / scales - idx_abs = tl.sum(tl.abs(wq[:, None]) > thr_pos, axis=1) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - - #Pack - lo, hi = tl.split(out.reshape((HALF_GROUP_SIZE, 2), can_reorder=False)) - out = lo | (hi << 4) - - #Store - offs_out = pid * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) - tl.store(out_ptr + offs_out, out) - tl.store(scales_ptr + pid, scales_log2) - - pid += 1 - -def scale_activations_mxfp4_triton_v1(tensor: Tensor) -> Tuple[Tensor, Tensor]: - group_size: int = 32 - eps_exp: int = -30 - eps = 2 ** eps_exp - tensor = tensor.contiguous() - - orig_shape = tensor.shape - tensor = tensor.view(-1, tensor.shape[-1]) - inter_shape = (tensor.shape[0], tensor.shape[1] // 2) - pad_rows = (group_size - inter_shape[0] % group_size) % group_size - post_pad_shape = (inter_shape[0] + pad_rows, inter_shape[1]) - E = tensor.numel() - - UNROLL = min(triton.cdiv(triton.cdiv(E, group_size), get_num_SMs(tensor.device)), 1) - - out = torch.empty(inter_shape, device=tensor.device, dtype=torch.uint8) - scales = torch.empty( - (post_pad_shape[0], post_pad_shape[1] * 2 // group_size), - device=tensor.device, - dtype=torch.uint8, - ) - device_index = tensor.device.index - - grid = lambda meta: (triton.cdiv(E // UNROLL, group_size), ) - scale_activations_mxfp4_triton_kernel_v1[grid]( - tensor, - out, - scales, - thr_pos[device_index], - E, - eps_exp=eps_exp, - UNROLL=UNROLL, - GROUP_SIZE=group_size, - num_stages=1, - num_warps=4, - ) - - return out, scales - - -@triton.jit -def scale_activations_mxfp4_triton_kernel_v2( +def scale_activations_mxfp4_triton_kernel( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, M, K, - stride_m_t, stride_k_t, - stride_m_s, stride_k_s, - stride_m_o, stride_k_o, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, ######################### eps_exp: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + use_tma: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) + + if use_tma: + tensor_desc = tl.make_tensor_descriptor( + tensor_ptr, + [M, K], + [stride_m_t, stride_k_t], + [BLOCK_SIZE_M, GROUP_SIZE] + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + [M, K // 2], + [stride_m_o, stride_k_o], + [BLOCK_SIZE_M, HALF_GROUP_SIZE] + ) HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 out_dtype: tl.constexpr = out_ptr.dtype.element_ty @@ -1043,25 +1050,10 @@ def scale_activations_mxfp4_triton_kernel_v2( mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) - ################################################# - # 1. Device-Side TMA Descriptors - # tensor_desc = tl.make_tensor_descriptor( - # tensor_ptr, - # [M, K], - # [stride_m_t, stride_k_t], - # [BLOCK_SIZE_M, GROUP_SIZE] - # ) - - # out_desc = tl.make_tensor_descriptor( - # out_ptr, - # [M, K // 2], - # [stride_m_o, stride_k_o], - # [BLOCK_SIZE_M, HALF_GROUP_SIZE] - # ) - - #tensor = tl.load_tensor_descriptor(tensor_desc, [pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE]).to(tl.float32) - ################################################# - tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + if use_tma: + tensor = tl.load_tensor_descriptor(tensor_desc, [pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE]).to(tl.float32) + else: + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) #next power of 2 via log scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor), axis=1, keep_dims=True) / 6., eps_exp) @@ -1078,13 +1070,15 @@ def scale_activations_mxfp4_triton_kernel_v2( #Store offs_k = pid_k * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) - tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) - #tl.store_tensor_descriptor(out_desc, [pid_m * BLOCK_SIZE_M, pid_k * HALF_GROUP_SIZE], out) + if use_tma: + tl.store_tensor_descriptor(out_desc, [pid_m * BLOCK_SIZE_M, pid_k * HALF_GROUP_SIZE], out) + else: + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) offs_k = pid_k * 1 + tl.arange(0, 1) tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales_log2) -def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: +def scale_activations_mxfp4_triton(tensor: Tensor) -> Tuple[Tensor, Tensor]: group_size: int = 32 eps_exp: int = -30 eps: float = 2 ** eps_exp @@ -1098,14 +1092,11 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - - #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) - BLOCK_SIZE_M = group_size - - grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index - scale_activations_mxfp4_triton_kernel_v2[grid]( + scale_activations_mxfp4_triton_kernel[grid]( tensor, out, scales, @@ -1117,32 +1108,62 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: ######################### eps_exp=eps_exp, GROUP_SIZE=group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=1, - num_warps=4, ) return out, scales + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) @triton.jit -def scale_activations_nvfp4_triton_kernel_v2( +def scale_activations_nvfp4_triton_kernel( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, M, K, - stride_m_t, stride_k_t, - stride_m_s, stride_k_s, - stride_m_o, stride_k_o, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, ######################### eps: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, meta_scales: tl.constexpr = NVFP4_META_SCALE, + use_tma: tl.constexpr = False, ): - pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) + + if use_tma: + tensor_desc = tl.make_tensor_descriptor( + tensor_ptr, + [M, K], + [stride_m_t, stride_k_t], + [BLOCK_SIZE_M, GROUP_SIZE] + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + [M, K // 2], + [stride_m_o, stride_k_o], + [BLOCK_SIZE_M, HALF_GROUP_SIZE] + ) fp8_dtype: tl.constexpr = tl.float8e4nv max_fp8: tl.constexpr = 448. @@ -1157,8 +1178,12 @@ def scale_activations_nvfp4_triton_kernel_v2( #Load mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) - tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + if use_tma: + tensor = tl.load_tensor_descriptor(tensor_desc, [pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE]).to(tl.float32) + else: + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + #FP8 scales scales = tl.max(tl.abs(tensor), axis=1, keep_dims=True) / (6. * meta_scales) scales = tl.minimum(scales, max_fp8).to(fp8_dtype) @@ -1175,14 +1200,17 @@ def scale_activations_nvfp4_triton_kernel_v2( #Store offs_k = pid_k * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) - out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) - tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) + out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) + if use_tma: + tl.store_tensor_descriptor(out_desc, [pid_m * BLOCK_SIZE_M, pid_k * HALF_GROUP_SIZE], out) + else: + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) offs_k = pid_k + tl.arange(0, 1) tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) -def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn #Nvidia only @@ -1197,12 +1225,10 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) - #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) - BLOCK_SIZE_M = group_size - grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index - scale_activations_nvfp4_triton_kernel_v2[grid]( + scale_activations_nvfp4_triton_kernel[grid]( tensor, out, scales, @@ -1214,9 +1240,6 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso ######################### eps=eps, GROUP_SIZE=group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=2, - num_warps=4, ) return out, scales @@ -1224,5 +1247,5 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso #################################################################################################################### scale_activations_per_token = scale_activations_per_token_triton_v3 scale_activations_mxfp8 = scale_activations_mxfp8_triton_v3 -scale_activations_mxfp4 = scale_activations_mxfp4_triton_v2 -scale_activations_nvfp4 = scale_activations_nvfp4_triton_v2 +scale_activations_mxfp4 = scale_activations_mxfp4_triton +scale_activations_nvfp4 = scale_activations_nvfp4_triton diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 7efc966..264cb8d 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -20,12 +20,15 @@ def is_fp8_supported(device_index=0): KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) -in_features, out_features = 4096, 2048 -batch_sizes = [16, 64, 512] +in_features, out_features = 4032, 2000 +batch_sizes = [50] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False + +assert in_features % 32 == 0, "in_features must be divisible by 32 for the current implementation" + #Pre-cache data for faster processing input_data = {} for batch_size in batch_sizes: @@ -83,7 +86,7 @@ def test_A4W4_MXFP_dynamic(self): self.eval(gemlite_linear, tol = 1e-3) def test_A4W4_NVFP_dynamic(self): - gemlite_linear = A4W4_MXFP_dynamic(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + gemlite_linear = A4W4_NVFP_dynamic(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) self.assertTrue(gemlite_linear.scaled_activations) self.eval(gemlite_linear, tol = 1e-3) From 590ee0a2162d2697d0063a0bb16ef052f4aa6103 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 24 Feb 2026 11:20:14 -0800 Subject: [PATCH 09/63] update --- gemlite/triton_kernels/gemm_kernels.py | 7 +- gemlite/triton_kernels/gemm_splitK_kernels.py | 125 +++++++++++++----- gemlite/triton_kernels/utils.py | 5 +- tests/test_mxfp.py | 6 +- 4 files changed, 101 insertions(+), 42 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index e172d51..2ffd7df 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -309,8 +309,11 @@ def gemm_INT_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_STAGES: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 1576b92..7c8adf3 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -38,7 +38,11 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) config["NUM_STAGES"] = num_stages - + + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, @@ -85,7 +89,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Constraint: K needs to be divisible by BLOCK_SIZE_K * SPLIT_K while split_k > 1 and not is_divisible(k, block_size_k * split_k): - #while split_k > 1 and k > block_size_k * split_k: split_k //= 2 #Nvidia @@ -97,18 +100,19 @@ def kernel_config_pruner(configs, nargs, **kwargs): #skip num_stages=1 for non-packed weights continue - #Avoid OOM - while num_stages > 0: #TODO: revisit MXFP case - shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) - if(e > 1 and not load_scales_as_block): - shared_mem += block_size_k * block_size_n * a_sizeof - shared_mem *= num_stages - if int(shared_mem) <= gpu_shared_memory: - break - num_stages -= 1 + # #Avoid OOM + # while num_stages > 0: #TODO: revisit MXFP case + # shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) + # if(e > 1 and not load_scales_as_block): + # shared_mem += block_size_k * block_size_n * a_sizeof + # shared_mem *= num_stages + # if int(shared_mem) <= gpu_shared_memory: + # break + # num_stages -= 1 - if(num_stages == 0): continue #config too large + # if(num_stages == 0): continue #config too large + split_k = max(split_k, 1) ########################################### if(load_scales_as_block):#tmp MXFP fix block_size_k = min(block_size_k, 256) @@ -116,12 +120,20 @@ def kernel_config_pruner(configs, nargs, **kwargs): key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k, A_load_order, num_stages, num_warps) + even_m = (m % block_size_m == 0) + even_n = (n % block_size_n == 0) + even_k = (k % block_size_k == 0) + + new_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, "SPLIT_K": split_k, + "EVEN_M": even_m, + "EVEN_N": even_n, + "EVEN_K": even_k, "A_load_order": A_load_order, "NUM_STAGES": num_stages, } @@ -195,7 +207,7 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] + return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] ######################################################################################################################################################################## #AMD - Instinct MI300X @@ -247,7 +259,7 @@ def get_fast_autotune_config_amd(): return configs def get_default_config_amd(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] + return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] ######################################################################################################################################################################## if IS_HIP: @@ -305,11 +317,20 @@ def gemm_splitK_INT_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + ################################# + NUM_STAGES: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = '', atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', @@ -379,13 +400,19 @@ def gemm_splitK_INT_kernel( for k in range(num_pid_k): - if(A_load_order == 0): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if(A_load_order == 0): #Early load + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) b = tl.load(b_ptrs, eviction_policy=b_evict) if(A_load_order == 1): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Meta-data loading policy if(W_group_mode > 0): @@ -405,13 +432,19 @@ def gemm_splitK_INT_kernel( zeros = None if(A_load_order == 2): #Mid load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): #Late load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Dot acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) @@ -419,22 +452,25 @@ def gemm_splitK_INT_kernel( #Advance a_ptrs += BLOCK_SIZE_K_U * stride_ak b_ptrs += BLOCK_SIZE_K_P * stride_bk + + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) ############################################################################################################# #Channel-wise scaling if(channel_scale_mode == 1): #weight-only scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) - acc = acc.to(meta_dtype) * scales_b[None, :] + acc = acc.to(meta_dtype) * scales_b[None, :] if(channel_scale_mode == 2): #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) if(channel_scale_mode == 3): #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) ############################################################################################################# #Output @@ -473,8 +509,10 @@ def gemm_splitK_MX_kernel( stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, - stride_meta_n: tl.constexpr, stride_meta_g: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, + stride_meta_n: tl.constexpr, + stride_meta_g: tl.constexpr, ######### Dtypes ######### load_scales_as_block, #True | IF FALSE, RESTRICT BLOCK_SIZE_K <= 32 input_dtype: tl.constexpr, @@ -486,11 +524,20 @@ def gemm_splitK_MX_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + NUM_STAGES: tl.constexpr, + ################################# A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = 'evict_first', atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', @@ -542,16 +589,23 @@ def gemm_splitK_MX_kernel( BLOCK_SIZE_K_S: tl.constexpr = BLOCK_SIZE_K // group_size offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - #B scales - scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] + #B scales: [BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] + scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #A scales if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g + # Used in channel-wise MXPF8 version + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for k in tl.range(num_pid_k, num_stages=NUM_STAGES): - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + for k in tl.range(num_pid_k): + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + b = tl.load(b_ptrs, eviction_policy=b_evict) #k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) @@ -561,12 +615,15 @@ def gemm_splitK_MX_kernel( if(channel_scale_mode == 4): scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) else: - scales_a = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_a = scales_a_1s acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) a_ptrs += BLOCK_SIZE_K_A * stride_ak b_ptrs += BLOCK_SIZE_K_B * stride_bk + + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) #NVFP4 meta-scale if(group_size == 16): @@ -578,7 +635,7 @@ def gemm_splitK_MX_kernel( dtype: tl.constexpr = c_ptr.dtype.element_ty scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=dtype) - acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) ############################################################################################################# #Output diff --git a/gemlite/triton_kernels/utils.py b/gemlite/triton_kernels/utils.py index 16b51ab..77a4c31 100755 --- a/gemlite/triton_kernels/utils.py +++ b/gemlite/triton_kernels/utils.py @@ -123,9 +123,8 @@ def gpu_supports_float16_acc( def gpu_supports_bfloat16_atomicadd(): - #Triton tl.atomic_add doens't support bfloat16 even for Hopper and above. - #return torch.cuda.get_device_capability()[0] >= 9 #Hopper and above - return False + #Triton tl.atomic_add doens't support bfloat16 on older GPUs. + return torch.cuda.get_device_capability()[0] >= 9 #Hopper and above NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count def get_num_SMs(device): diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 264cb8d..70c96f8 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -14,14 +14,14 @@ def is_fp8_supported(device_index=0): device = 'cuda:0' compute_dtype = torch.bfloat16 #float16, bfloat16 -matmul_types = ['GEMM'] #GEMM_SPLITK #TODO: add GEMV use-cases +matmul_types = ['GEMM_SPLITK', 'GEMM'] #GEMM_SPLITK #TODO: add GEMV use-cases reset_config() set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) -in_features, out_features = 4032, 2000 -batch_sizes = [50] +in_features, out_features = 4032, 2048 +batch_sizes = [30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False From cf124c61964ad5e50bd4ac8837ffec94f6461eb5 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 24 Feb 2026 11:43:57 -0800 Subject: [PATCH 10/63] update --- gemlite/triton_kernels/gemm_kernels.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 2ffd7df..bd0aa40 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -882,7 +882,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x gemm_kernel = gemm_INT_kernel load_scales_as_block = False - compiled_kernel = gemm_kernel[grid]( + gemm_kernel[grid]( x, W_q, output, scales, zeros, scales_x, M, N, K, M_CLOSEST, @@ -908,11 +908,6 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x data_contiguous = data_contiguous, ) - if PRINTED == False: - with open('kernel.ptx', 'w') as f: - f.write(compiled_kernel.asm['ptx']) - PRINTED = True - return output # # Persistent version From 0715390464770883775c501bb60ef6214eaf3efe Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 2 Mar 2026 04:07:34 -0800 Subject: [PATCH 11/63] update --- tests/test_mxfp.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 70c96f8..e82ac7c 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -14,14 +14,14 @@ def is_fp8_supported(device_index=0): device = 'cuda:0' compute_dtype = torch.bfloat16 #float16, bfloat16 -matmul_types = ['GEMM_SPLITK', 'GEMM'] #GEMM_SPLITK #TODO: add GEMV use-cases +matmul_types = ['GEMM_SPLITK', 'GEMM'] #TODO: fix gemv mxfp bugs and enable GEMV testing reset_config() set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) in_features, out_features = 4032, 2048 -batch_sizes = [30, 32, 60, 100, 128] +batch_sizes = [1, 30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False @@ -45,12 +45,12 @@ def eval(self, gemlite_linear, tol: float = 1e-3): err = (y_ref - y_gem).abs().mean().item() self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - # @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - # def test_A16W8_MXFP(self): - # gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) - # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) - # self.assertTrue(not gemlite_linear.scaled_activations) - # self.eval(gemlite_linear, tol = 2e-4) + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_A16W8_MXFP(self): + gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) + self.assertTrue(not gemlite_linear.scaled_activations) + self.eval(gemlite_linear, tol = 2e-4) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W8_MXFP_post_scale_dynamic(self): @@ -66,11 +66,11 @@ def test_A8W8_MXFP_dynamic(self): self.assertTrue(gemlite_linear.scaled_activations) self.eval(gemlite_linear, tol = 2e-4) - # def test_A16W4_MXFP(self): - # gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) - # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) - # self.assertTrue(not gemlite_linear.scaled_activations) - # self.eval(gemlite_linear, tol = 7e-4) + def test_A16W4_MXFP(self): + gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) + self.assertTrue(not gemlite_linear.scaled_activations) + self.eval(gemlite_linear, tol = 7e-4) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W4_MXFP_dynamic(self): From 5a6ce1df4d732e435186de30b43c4150d5238382 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 2 Mar 2026 10:38:49 -0800 Subject: [PATCH 12/63] update --- gemlite/core.py | 4 ++-- tests/test_gemlitelineartriton.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 6977fa5..2160117 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -100,7 +100,7 @@ def set_acc_dtype(dtype): def get_default_gemv(W_nbits: int, mx_dtype: bool = False) -> str: #TODO: adapt mx for IS_HIP = True if mx_dtype: - return 'GEMM_SPLITK' #TODO: fix mxf bugs in GEMV outputs garbage. + return 'GEMM_SPLITK' #TODO:'GEMV' if (W_nbits < 8) else 'GEMM_SPLITK' -> Revisit NVFP4 failing test. else: return 'GEMV_REVSPLITK' if (W_nbits < 8) else 'GEMV_SPLITK' @@ -121,7 +121,7 @@ def enable_activation_scaling(batch_size): Only works with the MXFP format - use with A8W4_MXFP/A4W4_MXFP. """ return True - #return batch_size >= 32 + #return batch_size >= 2 #TODO: Needs Triton fix https://github.com/triton-lang/triton/pull/9577 #Main functional forward call diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index bc2a4fc..2f236f5 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -14,7 +14,7 @@ def is_fp8_supported(): return capability >= (8, 9) device = 'cuda:0' -compute_dtype = torch.float16 #float16, bfloat16 +compute_dtype = torch.bfloat16 #float16, bfloat16 fp8_dtype = torch.float8_e4m3fn #float8_e4m3fn / torch.float8_e5m2 (Nvidia) gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] From 3095cdf608669ae33477e4fc82e76325895e8e87 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 2 Mar 2026 14:25:38 -0800 Subject: [PATCH 13/63] Fix tests --- gemlite/core.py | 22 +++---- gemlite/triton_kernels/gemv_kernels.py | 66 +++++++++++++------ .../triton_kernels/gemv_revsplitK_kernels.py | 28 +++++--- gemlite/triton_kernels/gemv_splitK_kernels.py | 40 ++++++----- tests/test_gemlitelineartriton.py | 38 ++++++----- tests/test_mxfp.py | 8 +-- 6 files changed, 124 insertions(+), 78 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 2160117..2d3ee3e 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -250,17 +250,17 @@ def __init__( + " W_nbits are supported." ) - if in_features is not None and out_features is not None: - if (in_features % GemLiteLinearTriton.MIN_SIZE != 0) or ( - in_features % group_size != 0 if (group_size is not None) else False - ): - raise NotImplementedError( - "Invalid input shapes: " - + str(in_features) - + " , " - + str(out_features) - + ". in_features should be divisible by 32 or the group_size" - ) + # if in_features is not None and out_features is not None: + # if (in_features % GemLiteLinearTriton.MIN_SIZE != 0) or ( + # in_features % group_size != 0 if (group_size is not None or W_nbits < 16) else False + # ): + # raise NotImplementedError( + # "Invalid input shapes: " + # + str(in_features) + # + " , " + # + str(out_features) + # + ". in_features should be divisible by 32 or the group_size" + # ) #Warning: Input dtype should be the same as dequantize() weights dtype. if input_dtype not in GemLiteLinearTriton.SUPPORTED_DTYPES: diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index ccbf306..4bcf5fa 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -62,7 +62,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Constraints: BLOCK_SIZE_K <= group_size -> load_scales_as_block is always False for gemvs block_size_k = min(g, block_size_k) #Makes BLOCK_SIZE_K compatible with the group_size - block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) @@ -241,11 +240,16 @@ def gemv_INT_kernel( type_id: tl.constexpr, use_prehook: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_a_m, stride_meta_a_g, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -256,11 +260,14 @@ def gemv_INT_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - A_load_order: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + A_load_order: tl.constexpr, + NUM_STAGES: tl.constexpr, dot_prod_mode:tl.constexpr, data_contiguous: tl.constexpr, - dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post looop scaling + dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post loop scaling ##################################### meta_evict_policy: tl.constexpr = '', atomic_mode: tl.constexpr = 'relaxed', @@ -307,12 +314,17 @@ def gemv_INT_kernel( #orig version b_ptrs = b_ptr + (offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn + a_mask = (offs_am[:, None] < M) & (offs_ak[None, :] < K).to(tl.int1) + b_mask = (offs_bk[:, None] < K) & (offs_bn[None, :] < N).to(tl.int1) + #TODO: add EVEN_K / EVEN_N check ################################################################### #Load if(A_load_order == 0): a = tl.load(a_ptrs, eviction_policy=a_evict) + #a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) b = tl.load(b_ptrs, eviction_policy=b_evict) + #b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): a = tl.load(a_ptrs, eviction_policy=a_evict) @@ -354,10 +366,10 @@ def gemv_INT_kernel( a = tl.load(a_ptrs, eviction_policy=a_evict) if(dump_b_val > 0): b = b.to(tl.float32) * dump_b_val - + #Dot product if(dot_prod_mode == 0): - acc = tl.sum(a.reshape((BLOCK_SIZE_K, 1), can_reorder=False).to(acc_dtype) * b.to(acc_dtype), axis=0, keep_dims=True) + acc = tl.sum((a.reshape((BLOCK_SIZE_K, 1), can_reorder=False).to(acc_dtype)) * b.to(acc_dtype), axis=0, keep_dims=True) if(dot_prod_mode == 1): acc = tl.sum(a.reshape((BLOCK_SIZE_K, 1), can_reorder=False) * b.to(input_dtype), axis=0, keep_dims=True) @@ -380,12 +392,14 @@ def gemv_INT_kernel( acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) #################################################################### - #Output: tl.atomic_add only supports 1D fp16 arrays, bfp16 would crash + #Output offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + + @triton.autotune( configs=get_autotune_config(), key = KEYS, @@ -408,11 +422,16 @@ def gemv_MX_kernel( type_id: tl.constexpr, use_prehook: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_a_m, stride_meta_a_g, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -423,8 +442,11 @@ def gemv_MX_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - A_load_order: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + A_load_order: tl.constexpr, + NUM_STAGES: tl.constexpr, dot_prod_mode:tl.constexpr, data_contiguous: tl.constexpr, dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post looop scaling @@ -474,6 +496,7 @@ def gemv_MX_kernel( a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak #[1, BLOCK_SIZE_K] b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn #[BLOCK_SIZE_K, BLOCK_SIZE_N] a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < (K // elements_per_sample_a))).to(tl.int1) + b_mask = ((offs_bk[:, None] < (K // elements_per_sample)) & (offs_bn[None, :] < N)).to(tl.int1) if(W_nbits == 4): #mxpf4 mapping mapping = tl.load(mapping_ptr + tl.arange(0, 16), eviction_policy='evict_last')[None, :].broadcast_to((BLOCK_SIZE_K, 16)) @@ -484,6 +507,7 @@ def gemv_MX_kernel( a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) b = tl.load(b_ptrs, eviction_policy=b_evict) + #b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) @@ -591,8 +615,8 @@ def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x stride_meta_a_m, stride_meta_a_g = scales_x.stride(0), scales_x.stride(1) else: stride_meta_a_m, stride_meta_a_g = None, None - channel_scale_mode = 0 - + #channel_scale_mode = 0 + dtype = DTYPE_TO_TRITON[input_dtype] if(dtype in [tl.float16, tl.bfloat16, tl.float32]): acc_dtype = dtype diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index bad0d28..f085687 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -59,7 +59,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - #tmp fix autotune getting stuck on the MI300X if IS_HIP: if block_size_n * block_size_k >= 65536: @@ -68,9 +67,10 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Since we load the scales / zeros once per split_k pass, we need this while block_size_k >= 8 and (block_size_k * split_k > g): block_size_k //= 2 + block_size_k = max(block_size_k, 8) - if(not (block_size_k * split_k <= g)): - continue + # if(not (block_size_k * split_k <= g)): + # continue #Block size should be compatible with minimum-packing if(block_size_k < e): @@ -92,7 +92,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): if key in used: continue - + used.add(key) yield triton.Config(new_config, num_stages=num_stages, num_warps=num_warps, pre_hook=pre_hook) @@ -120,6 +120,10 @@ def get_max_autotune_config_nvidia(): #~20 sec/shape def get_fast_autotune_config_nvidia(): configs = [] + #Default + configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=1, num_stages=1)) + + #Extra configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=1, num_stages=1)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=2, num_stages=2)) @@ -236,10 +240,14 @@ def gemv_INT_revsplitK_kernel( type_id: tl.constexpr, use_prehook: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -250,7 +258,9 @@ def gemv_INT_revsplitK_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, A_load_order: tl.constexpr, dot_prod_mode: tl.constexpr, data_contiguous: tl.constexpr, diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index 9e5ed40..ee71ea5 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -63,14 +63,15 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - #K needs to be divisible by BLOCK_SIZE_K * SPLIT_K: TODO: without this, cuda-graphs breaks. - while block_size_k > 16 and not is_divisible(k, block_size_k * split_k): - block_size_k //=2 + # #K needs to be divisible by BLOCK_SIZE_K * SPLIT_K: TODO: without this, cuda-graphs breaks. + # while block_size_k > 16 and not is_divisible(k, block_size_k * split_k): + # block_size_k //=2 + # block_size_k = min(block_size_k, 16) - #Skip blocks that are either too large or too small - block_area = (block_size_k // split_k) * block_size_n - if(block_area < 1024 or block_area > 4096 * 8): #128 * 8 * num_warps - continue + # #Skip blocks that are either too large or too small + # block_area = (block_size_k // split_k) * block_size_n + # if(block_area < 1024 or block_area > 4096 * 8): #128 * 8 * num_warps + # continue #Block size should be compatible with minimum-packing if(block_size_k < e): @@ -149,9 +150,7 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':2, 'BLOCK_SIZE_K':2048, 'GROUP_SIZE_M':8, 'SPLIT_K': 1, - 'A_load_order':1, 'dot_prod_mode':0}, num_warps=4, num_stages=2) - + config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':2, 'BLOCK_SIZE_K':2048, 'GROUP_SIZE_M':8, 'SPLIT_K': 1, 'A_load_order':1, 'dot_prod_mode':0}, num_warps=4, num_stages=2) return [config] ######################################################################################################################################################################## @@ -249,10 +248,14 @@ def gemv_INT_splitK_kernel( elements_per_sample: tl.constexpr, type_id: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -263,8 +266,11 @@ def gemv_INT_splitK_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, A_load_order: tl.constexpr, dot_prod_mode: tl.constexpr, data_contiguous: tl.constexpr, @@ -336,6 +342,8 @@ def gemv_INT_splitK_kernel( acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=acc_dtype) if(dot_prod_mode == 1): acc = tl.zeros((1, BLOCK_SIZE_N), dtype=acc_dtype) + + #TODO: EVEN_K / EVEN_N use-case #for k in tl.range(0, num_pid_k, 1, num_stages=1): for k in range(num_pid_k): diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index 2f236f5..d027026 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -17,11 +17,19 @@ def is_fp8_supported(): compute_dtype = torch.bfloat16 #float16, bfloat16 fp8_dtype = torch.float8_e4m3fn #float8_e4m3fn / torch.float8_e5m2 (Nvidia) gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] -matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] +matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] #TODO: Investigate GEMV_SPLITK errors when in_features are not powers of 2 reset_config() -#set_autotune(False) +set_autotune(False) KERNEL.ENABLE_CACHING = False +in_features, out_features = 4096, 2032 +batch_sizes = [1, 5, 100] +W_nbits, group_size = 4, 128 #128 / in_features + +if group_size is None: + group_size = in_features +if group_size < in_features: + in_features = (in_features // group_size) * group_size #ensure divisibility for current implementation def gen_data(in_features, out_features, W_nbits, group_size, dtype=compute_dtype): @@ -39,11 +47,7 @@ def gen_data(in_features, out_features, W_nbits, group_size, dtype=compute_dtype return W, W_q, scales, zeros - -in_features, out_features = 4096, 1024 -batch_sizes = [1, 4] -W_nbits, group_size = 4, 128 #128 / in_features -W, W_q, scales, zeros = gen_data(in_features, out_features, W_nbits=W_nbits, group_size=group_size) +W, W_q, scales, zeros = gen_data(in_features, out_features, W_nbits=W_nbits, group_size=group_size) class TestGemLiteLinearTriton(unittest.TestCase): @@ -83,7 +87,7 @@ def test_serialization(self): y_gem = gemlite_linear_loaded.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) def test_fp16xfp16(self): gemlite_linear = GemLiteLinearTriton(W_nbits=16, @@ -101,7 +105,7 @@ def test_fp16xfp16(self): #Use non-contiguous when data is not packed self.assertTrue(gemlite_linear.data_contiguous == False) - tol = 1e-3 + tol = 2.5e-3 #higher tol for gemv kernels, otherwise 1e-3 is fine for batch_size in batch_sizes: x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.) @@ -110,7 +114,7 @@ def test_fp16xfp16(self): if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) def test_fp16xWn_asymmetric(self): @@ -144,7 +148,7 @@ def test_fp16xWn_asymmetric(self): if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) def test_int8xWn_symmetric_no_activation_scaling(self): @@ -160,7 +164,7 @@ def test_int8xWn_symmetric_no_activation_scaling(self): _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 - gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None); + gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None) #Weights are unpacked() then shifted by 7 self.assertTrue(gemlite_linear.W_group_mode == 1) @@ -176,7 +180,7 @@ def test_int8xWn_symmetric_no_activation_scaling(self): if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) def test_int8xWn_scaled_activations(self): @@ -274,7 +278,7 @@ def test_fp8xfp8(self): if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_fp8xfp8_scaled_weights_scaled_activations(self): @@ -309,7 +313,7 @@ def test_fp8xfp8_scaled_weights_scaled_activations(self): if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_fp8xWn_scaled_activations(self): @@ -346,7 +350,7 @@ def test_fp8xWn_scaled_activations(self): if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_fp8xWn_no_activation_scaling(self): @@ -380,4 +384,4 @@ def test_fp8xWn_no_activation_scaling(self): if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index e82ac7c..8dd4e33 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -14,19 +14,18 @@ def is_fp8_supported(device_index=0): device = 'cuda:0' compute_dtype = torch.bfloat16 #float16, bfloat16 -matmul_types = ['GEMM_SPLITK', 'GEMM'] #TODO: fix gemv mxfp bugs and enable GEMV testing +matmul_types = ['GEMM_SPLITK', 'GEMM'] #TODO: improve GEMV mxfp accuracy. reset_config() set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) in_features, out_features = 4032, 2048 -batch_sizes = [1, 30, 32, 60, 100, 128] +batch_sizes = [1]#[1, 30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False - assert in_features % 32 == 0, "in_features must be divisible by 32 for the current implementation" #Pre-cache data for faster processing @@ -41,9 +40,10 @@ def eval(self, gemlite_linear, tol: float = 1e-3): x = input_data[batch_size] y_ref = linear_layer(x) for matmul_type in matmul_types: + if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A16W8_MXFP(self): From f625b980adbcb2ce25d8f4e62149ab4bf50af903 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 3 Mar 2026 01:15:49 -0800 Subject: [PATCH 14/63] fix more test --- gemlite/core.py | 22 +- gemlite/triton_kernels/gemv_splitK_kernels.py | 56 +- tests/test_gemlitelineartriton.py | 665 +++++++++--------- 3 files changed, 377 insertions(+), 366 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 2d3ee3e..2b559c5 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -250,17 +250,17 @@ def __init__( + " W_nbits are supported." ) - # if in_features is not None and out_features is not None: - # if (in_features % GemLiteLinearTriton.MIN_SIZE != 0) or ( - # in_features % group_size != 0 if (group_size is not None or W_nbits < 16) else False - # ): - # raise NotImplementedError( - # "Invalid input shapes: " - # + str(in_features) - # + " , " - # + str(out_features) - # + ". in_features should be divisible by 32 or the group_size" - # ) + if in_features is not None and out_features is not None: + if (in_features % GemLiteLinearTriton.MIN_SIZE != 0) or ( + (in_features % group_size != 0) if (group_size is not None and W_nbits < 16) else False + ): + raise NotImplementedError( + "Invalid input shapes: " + + str(in_features) + + " , " + + str(out_features) + + ". in_features should be divisible by 32 or the group_size" + ) #Warning: Input dtype should be the same as dequantize() weights dtype. if input_dtype not in GemLiteLinearTriton.SUPPORTED_DTYPES: diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index ee71ea5..ff89b3d 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -37,6 +37,10 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, @@ -77,6 +81,10 @@ def kernel_config_pruner(configs, nargs, **kwargs): if(block_size_k < e): continue + even_m = (m % block_size_m == 0) + even_n = (n % block_size_n == 0) + even_k = (k % block_size_k == 0) + key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k, A_load_order, dot_prod_mode, num_stages, num_warps) new_config = { @@ -87,6 +95,9 @@ def kernel_config_pruner(configs, nargs, **kwargs): 'SPLIT_K' : split_k, 'A_load_order' : A_load_order, 'dot_prod_mode' : dot_prod_mode, + 'EVEN_M': even_m, + 'EVEN_N': even_n, + 'EVEN_K': even_k, } if IS_HIP: @@ -254,6 +265,8 @@ def gemv_INT_splitK_kernel( stride_bn: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, stride_meta_g: tl.constexpr, stride_meta_n: tl.constexpr, ######### Dtypes ######### @@ -276,6 +289,10 @@ def gemv_INT_splitK_kernel( data_contiguous: tl.constexpr, dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post looop scaling ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = '', atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', @@ -321,9 +338,9 @@ def gemv_INT_splitK_kernel( #Inputs a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) - a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K)).to(tl.int1) b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) - + a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K)).to(tl.int1) + #Meta data stuff q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] @@ -342,19 +359,23 @@ def gemv_INT_splitK_kernel( acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=acc_dtype) if(dot_prod_mode == 1): acc = tl.zeros((1, BLOCK_SIZE_N), dtype=acc_dtype) - - #TODO: EVEN_K / EVEN_N use-case #for k in tl.range(0, num_pid_k, 1, num_stages=1): for k in range(num_pid_k): if(A_load_order == 0): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) b = tl.load(b_ptrs, eviction_policy=b_evict) if(A_load_order == 1): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(W_group_mode > 0): k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) @@ -373,13 +394,19 @@ def gemv_INT_splitK_kernel( zeros = None if(A_load_order == 2): #Mid load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): #Late load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(dump_b_val > 0): b = b.to(tl.float32) * dump_b_val @@ -391,6 +418,10 @@ def gemv_INT_splitK_kernel( #Advance a_ptrs += BLOCK_SIZE_K_U * stride_ak b_ptrs += BLOCK_SIZE_K_P * stride_bk + + #Update mask + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K_U) < K)).to(tl.int1) if(dot_prod_mode == 0): acc = tl.sum(acc, axis=0, keep_dims=True) @@ -445,6 +476,13 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K']) + device_index = W_q.device.index + + if(scales_x is not None): + stride_meta_a_m, stride_meta_a_g = scales_x.stride(0), scales_x.stride(1) + else: + stride_meta_a_m, stride_meta_a_g = None, None + dtype = DTYPE_TO_TRITON[input_dtype] if(dtype in [tl.float16, tl.bfloat16, tl.float32]): acc_dtype = dtype @@ -464,6 +502,7 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s x.stride(0), x.stride(1), W_q.stride(0), W_q.stride(1), output.stride(0), output.stride(1), + stride_meta_a_m, stride_meta_a_g, scales.stride(0), scales.stride(1), ################################################ input_dtype = DTYPE_TO_TRITON[input_dtype], @@ -489,4 +528,3 @@ class gemv_splitK: matmul_type = MATMUL_TYPE __all__ = ["gemv_splitK"] - diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index d027026..f95013d 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -17,371 +17,344 @@ def is_fp8_supported(): compute_dtype = torch.bfloat16 #float16, bfloat16 fp8_dtype = torch.float8_e4m3fn #float8_e4m3fn / torch.float8_e5m2 (Nvidia) gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] -matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] #TODO: Investigate GEMV_SPLITK errors when in_features are not powers of 2 +matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] reset_config() set_autotune(False) KERNEL.ENABLE_CACHING = False -in_features, out_features = 4096, 2032 +in_features, out_features = 4032, 2032 batch_sizes = [1, 5, 100] W_nbits, group_size = 4, 128 #128 / in_features if group_size is None: group_size = in_features if group_size < in_features: - in_features = (in_features // group_size) * group_size #ensure divisibility for current implementation + in_features = (in_features // group_size) * group_size #ensure divisibility for current implementation def gen_data(in_features, out_features, W_nbits, group_size, dtype=compute_dtype): - W_q = torch.randint(0, 2**W_nbits - 1, (out_features, in_features), device=device).to(torch.uint8) + W_q = torch.randint(0, 2**W_nbits - 1, (out_features, in_features), device=device).to(torch.uint8) - shape = (out_features, in_features) - gs = W_q.numel() // group_size - scales = torch.ones((gs, 1), device=device, dtype=dtype) * 0.001 - zeros = torch.zeros((gs, 1), device=device, dtype=dtype) * ((2**W_nbits - 1)//2) - W = ((W_q.reshape([-1, group_size]) - zeros) * scales).to(fp8_dtype).to(dtype) + shape = (out_features, in_features) + gs = W_q.numel() // group_size + scales = torch.ones((gs, 1), device=device, dtype=dtype) * 0.001 + zeros = torch.zeros((gs, 1), device=device, dtype=dtype) * ((2**W_nbits - 1)//2) + W = ((W_q.reshape([-1, group_size]) - zeros) * scales).to(fp8_dtype).to(dtype) - zeros = torch.mean(W_q.reshape([-1, group_size]).float() - (W / scales).float(), axis=1, keepdim=True).to(dtype) - W = ((W_q.reshape([-1, group_size]).to(dtype) - zeros) * scales) - W = W.reshape(shape) + zeros = torch.mean(W_q.reshape([-1, group_size]).float() - (W / scales).float(), axis=1, keepdim=True).to(dtype) + W = ((W_q.reshape([-1, group_size]).to(dtype) - zeros) * scales) + W = W.reshape(shape) - return W, W_q, scales, zeros + return W, W_q, scales, zeros W, W_q, scales, zeros = gen_data(in_features, out_features, W_nbits=W_nbits, group_size=group_size) +#Pre-cache data for faster processing +input_data = {} +for batch_size in batch_sizes: + torch.random.manual_seed(0) + input_data[batch_size] = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. + class TestGemLiteLinearTriton(unittest.TestCase): - def test_serialization(self): - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=gemlite_dtype, - output_dtype=gemlite_dtype) - - - gemlite_linear.pack(W_q, scales, zeros, None) - - torch.save(gemlite_linear.state_dict(), 'tmp.pt') - - gemlite_linear_loaded = GemLiteLinearTriton() - gemlite_linear_loaded.load_state_dict(torch.load('tmp.pt')) - - ref_args = gemlite_linear.get_meta_args() - loaded_args = gemlite_linear_loaded.get_meta_args() - for i in range(len(ref_args)): - assert ref_args[i] == loaded_args[i], "meta_args mismatch at " + str(i) - - ref_args = gemlite_linear.get_tensor_args() - loaded_args = gemlite_linear_loaded.get_tensor_args() - for i in range(len(ref_args)): - assert (ref_args[i] - loaded_args[i]).float().abs().mean() == 0, "tensor_args mismatch at " + str(i) - - tol = 1e-7 - for batch_size in batch_sizes: - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - for matmul_type in ['GEMM']: - if(batch_size>1 and 'GEMV' in matmul_type): continue - - y_ref = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - y_gem = gemlite_linear_loaded.forward_manual(x, matmul_type=matmul_type) - - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - - def test_fp16xfp16(self): - gemlite_linear = GemLiteLinearTriton(W_nbits=16, - group_size=None, - in_features=in_features, - out_features=out_features, - input_dtype=gemlite_dtype, - output_dtype=gemlite_dtype, - scaled_activations=False) - - gemlite_linear.pack(W, None, None, None); - - #No weight unpacking / dequant - self.assertTrue(gemlite_linear.W_group_mode == 0 and gemlite_linear.channel_scale_mode == 0) - #Use non-contiguous when data is not packed - self.assertTrue(gemlite_linear.data_contiguous == False) + def eval(self, gemlite_linear, ref_fn, tol: float = 1e-3, input_fn=None, _matmul_types=None): + """ + Shared evaluation method. + Args: + gemlite_linear: the quantized linear layer to test + ref_fn: callable(x) -> y_ref, computes the reference output + tol: error tolerance + input_fn: optional callable(batch_size) -> x, custom input generator. + If None, uses pre-cached input_data. + _matmul_types: optional list of matmul types to test. If None, uses global matmul_types. + """ + if _matmul_types is None: + _matmul_types = matmul_types + + for batch_size in batch_sizes: + if input_fn is not None: + x = input_fn(batch_size) + else: + x = input_data[batch_size] + + y_ref = ref_fn(x) + + for matmul_type in _matmul_types: + if batch_size > 1 and 'GEMV' in matmul_type: + continue + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) + + def test_serialization(self): + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype) + + gemlite_linear.pack(W_q, scales, zeros, None) + + torch.save(gemlite_linear.state_dict(), 'tmp.pt') + + gemlite_linear_loaded = GemLiteLinearTriton() + gemlite_linear_loaded.load_state_dict(torch.load('tmp.pt')) + + ref_args = gemlite_linear.get_meta_args() + loaded_args = gemlite_linear_loaded.get_meta_args() + for i in range(len(ref_args)): + assert ref_args[i] == loaded_args[i], "meta_args mismatch at " + str(i) + + ref_args = gemlite_linear.get_tensor_args() + loaded_args = gemlite_linear_loaded.get_tensor_args() + for i in range(len(ref_args)): + assert (ref_args[i] - loaded_args[i]).float().abs().mean() == 0, "tensor_args mismatch at " + str(i) + + def ref_fn(x): + return gemlite_linear.forward_manual(x, matmul_type='GEMM') + + self.eval(gemlite_linear_loaded, ref_fn, tol=1e-7, _matmul_types=['GEMM']) + + def test_fp16xfp16(self): + gemlite_linear = GemLiteLinearTriton(W_nbits=16, + group_size=None, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype, + scaled_activations=False) + + gemlite_linear.pack(W, None, None, None) + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0 and gemlite_linear.channel_scale_mode == 0) + #Use non-contiguous when data is not packed + self.assertTrue(gemlite_linear.data_contiguous == False) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=2.5e-3) #higher tol for gemv kernels, otherwise 1e-3 is fine + + def test_fp16xWn_asymmetric(self): + #FP16 x Wn / asymmetric + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype) + + gemlite_linear.pack(W_q, scales, zeros, None) + + if(group_size == in_features): + #Weights are unpacked() then shift only if group_size == in_features (1) otherwise (3) + self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or + (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) + else: + self.assertTrue(gemlite_linear.W_group_mode in [3, 4] and gemlite_linear.channel_scale_mode == 0) + + #Use-contiguous when data is packed + self.assertTrue(gemlite_linear.data_contiguous == True) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=1e-3) + + def test_int8xWn_symmetric_no_activation_scaling(self): + #INT8 x Wn - symmetric / no scaling activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, #only channelwise is supported + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=False) + + _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 + gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None) + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Since the scales are channel-wise, we perform scaling post K-sum + self.assertTrue(gemlite_linear.channel_scale_mode == 1) + + def input_fn(batch_size): + return (torch.randint(-10, 10, (batch_size, in_features), device=device)).to(torch.int8) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) + + self.eval(gemlite_linear, ref_fn, tol=1e-3, input_fn=input_fn) + + def test_int8xWn_scaled_activations(self): + #INT8 x Wn - activation scaling only + + gemlite_linear = GemLiteLinearTriton(W_nbits=W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=True) + + gemlite_linear.pack(W_q, scales=None, zeros=7, bias=None) + gemlite_linear.meta_dtype = DType.FP32 + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Activations only are scaled + self.assertTrue(gemlite_linear.channel_scale_mode == 2) + + def input_fn(batch_size): + return torch.randn((batch_size, in_features), dtype=torch.float16, device=device) / 20. + + def ref_fn(x): + _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) + return torch.matmul(_x.to(torch.float16), (W_q.to(torch.float16) - 7).T) * _x_scaled + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) + + def test_int8Wn_scaled_weights_scaled_activations(self): + #INT8 x Wn - activation scaling only + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, #only channel-wise supported + in_features=in_features, + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=True) + + _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 + gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None) - tol = 2.5e-3 #higher tol for gemv kernels, otherwise 1e-3 is fine - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.) - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - - - def test_fp16xWn_asymmetric(self): - #FP16 x Wn / asymmetric - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=gemlite_dtype, - output_dtype=gemlite_dtype) - - - gemlite_linear.pack(W_q, scales, zeros, None); - - if(group_size == in_features): - #Weights are unpacked() then shift only if group_size == in_features (1) otherwise (3) - self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or - (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) - else: - self.assertTrue(gemlite_linear.W_group_mode in [3, 4] and gemlite_linear.channel_scale_mode == 0) - - #Use-contiguous when data is packed - self.assertTrue(gemlite_linear.data_contiguous == True) - - tol = 1e-3 - - for batch_size in batch_sizes: - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - - - def test_int8xWn_symmetric_no_activation_scaling(self): - #INT8 x Wn - symmetric / no scaling activation scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, #only channelwise is supported - out_features=out_features, - input_dtype=DType.INT8, - output_dtype=DType.FP32, - scaled_activations=False) - - - _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 - gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None) - - #Weights are unpacked() then shifted by 7 - self.assertTrue(gemlite_linear.W_group_mode == 1) - #Since the scales are channel-wise, we perform scaling post K-sum - self.assertTrue(gemlite_linear.channel_scale_mode == 1) - - tol = 1e-3 - - for batch_size in batch_sizes: - x = (torch.randint(-10, 10, (batch_size, in_features), device=device)).to(torch.int8) - y_ref = torch.matmul(x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - - - def test_int8xWn_scaled_activations(self): - #INT8 x Wn - activation scaling only - - gemlite_linear = GemLiteLinearTriton(W_nbits=W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=DType.INT8, - output_dtype=DType.FP32, - scaled_activations=True) - - - gemlite_linear.pack(W_q, scales=None, zeros=7, bias=None) - gemlite_linear.meta_dtype = DType.FP32 - - #Weights are unpacked() then shifted by 7 - self.assertTrue(gemlite_linear.W_group_mode == 1) - #Activations only are scaled - self.assertTrue(gemlite_linear.channel_scale_mode == 2) - - tol = 5e-3 - - for batch_size in batch_sizes: - x = torch.randn((batch_size, in_features), dtype=torch.float16, device=device) / 20. - _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) - y_ref = torch.matmul(_x.to(torch.float16), (W_q.to(torch.float16) - 7).T) * _x_scaled - - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type) - - def test_int8Wn_scaled_weights_scaled_activations(self): - #INT8 x Wn - activation scaling only - - gemlite_linear = GemLiteLinearTriton(W_nbits=8, - group_size=in_features, #only channel-wise supported - in_features=in_features, - out_features=out_features, - input_dtype=DType.INT8, - output_dtype=DType.FP32, - scaled_activations=True) - - _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 - gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None); - - #Weights are unpacked() then shifted by 7 if group_size == in_features (1), otherwise (3) - self.assertTrue(gemlite_linear.W_group_mode == 1) - #Activations only are scaled if group_size != in_features (2) otherwise bot are scales merged (3) - self.assertTrue(gemlite_linear.channel_scale_mode == 3) - - tol = 1e-3 - - for batch_size in batch_sizes: - shape = W_q.shape - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) - y_ref = torch.matmul(_x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) * _x_scaled - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type) - - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xfp8(self): - #FP8 x FP8 - no scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits=8, - group_size=None, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=False) - - - gemlite_linear.pack(W.to(fp8_dtype), None, None, None) - - #No weight unpacking / dequant - self.assertTrue(gemlite_linear.W_group_mode == 0) - #No channel-wise scaling - self.assertTrue(gemlite_linear.channel_scale_mode == 0) - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xfp8_scaled_weights_scaled_activations(self): - #FP8 x FP8 - both activations and weights are scaled - - gemlite_linear = GemLiteLinearTriton(W_nbits=8, - group_size=in_features, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=True) - - - _scales = torch.randn((1, out_features), dtype=compute_dtype, device=device) * 1e-4 - gemlite_linear.pack(W.to(fp8_dtype), scales=_scales, zeros=None, bias=None); - - #No weight unpacking / dequant - self.assertTrue(gemlite_linear.W_group_mode == 0) - #Both activations and weights are scales - self.assertTrue(gemlite_linear.channel_scale_mode == 3) - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - shape = W.shape - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - _x, scales_x = scale_activations(x, w_dtype=fp8_dtype) - - y_ref = torch.matmul(_x.to(compute_dtype), W.T) * (_scales * scales_x) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xWn_scaled_activations(self): - #FP8 x Wn - asymmetric, with activation scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=True) - - - gemlite_linear.pack(W_q, scales, zeros, None); - - if(group_size == in_features): - #weight unpacking and shift if group_size == in_features else (3) - self.assertTrue((gemlite_linear.W_group_mode == 1) and (gemlite_linear.channel_scale_mode == 3) or - (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 2)) - else: - #activations and weights are scaled psot accumulation if group_size==in_features else (2) - self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) - self.assertTrue(gemlite_linear.channel_scale_mode == 2) - - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype).to(compute_dtype) - _x, _scaled_x = scale_activations(x, w_dtype=fp8_dtype) - y_ref = torch.matmul(_x.to(compute_dtype), W.T) * _scaled_x - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xWn_no_activation_scaling(self): - #FP8 x Wn - asymmetric, no activation scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=False) - - gemlite_linear.pack(W_q, scales, zeros, None) - - if(group_size == in_features): - #Weight shift only if group_size==in_features else (3) - self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or - (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) - else: - #weight scaling only - post accumulator if group_size==in_features else (0) - self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) - self.assertTrue(gemlite_linear.channel_scale_mode == 0) - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) + #Weights are unpacked() then shifted by 7 if group_size == in_features (1), otherwise (3) + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Activations only are scaled if group_size != in_features (2) otherwise bot are scales merged (3) + self.assertTrue(gemlite_linear.channel_scale_mode == 3) + + def ref_fn(x): + _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) + return torch.matmul(_x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) * _x_scaled + + self.eval(gemlite_linear, ref_fn, tol=1e-3) + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xfp8(self): + #FP8 x FP8 - no scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=None, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=False) + + gemlite_linear.pack(W.to(fp8_dtype), None, None, None) + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0) + #No channel-wise scaling + self.assertTrue(gemlite_linear.channel_scale_mode == 0) + + def input_fn(batch_size): + return (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xfp8_scaled_weights_scaled_activations(self): + #FP8 x FP8 - both activations and weights are scaled + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=True) + + _scales = torch.randn((1, out_features), dtype=compute_dtype, device=device) * 1e-4 + gemlite_linear.pack(W.to(fp8_dtype), scales=_scales, zeros=None, bias=None) + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0) + #Both activations and weights are scales + self.assertTrue(gemlite_linear.channel_scale_mode == 3) + + def ref_fn(x): + _x, scales_x = scale_activations(x, w_dtype=fp8_dtype) + return torch.matmul(_x.to(compute_dtype), W.T) * (_scales * scales_x) + + self.eval(gemlite_linear, ref_fn, tol=5e-3) #needs higher tolerance with fp8 + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xWn_scaled_activations(self): + #FP8 x Wn - asymmetric, with activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=True) + + gemlite_linear.pack(W_q, scales, zeros, None) + + if(group_size == in_features): + #weight unpacking and shift if group_size == in_features else (3) + self.assertTrue((gemlite_linear.W_group_mode == 1) and (gemlite_linear.channel_scale_mode == 3) or + (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 2)) + else: + #activations and weights are scaled psot accumulation if group_size==in_features else (2) + self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) + self.assertTrue(gemlite_linear.channel_scale_mode == 2) + + def input_fn(batch_size): + return (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype).to(compute_dtype) + + def ref_fn(x): + _x, _scaled_x = scale_activations(x, w_dtype=fp8_dtype) + return torch.matmul(_x.to(compute_dtype), W.T) * _scaled_x + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xWn_no_activation_scaling(self): + #FP8 x Wn - asymmetric, no activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=False) + + gemlite_linear.pack(W_q, scales, zeros, None) + + if(group_size == in_features): + #Weight shift only if group_size==in_features else (3) + self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or + (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) + else: + #weight scaling only - post accumulator if group_size==in_features else (0) + self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) + self.assertTrue(gemlite_linear.channel_scale_mode == 0) + + def input_fn(batch_size): + return (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 \ No newline at end of file From 23d88d91f395ba7edd103e5d157b4c389a0f25b2 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 3 Mar 2026 02:37:50 -0800 Subject: [PATCH 15/63] add eval flops script --- examples/eval_flops.py | 266 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 266 insertions(+) create mode 100644 examples/eval_flops.py diff --git a/examples/eval_flops.py b/examples/eval_flops.py new file mode 100644 index 0000000..d28bacd --- /dev/null +++ b/examples/eval_flops.py @@ -0,0 +1,266 @@ +import torch +import time, gc +import gemlite +from gemlite.helper import * +import argparse +import torch._dynamo +torch._dynamo.config.recompile_limit = 256 + +device, dtype = 'cuda:0', torch.bfloat16 +repeat = 32 + +#gemlite.reset_cache() +#gemlite.set_autotune("max") +#gemlite.core.enable_activation_scaling(2) + +def get_model(K, N, repeat=repeat): + torch.manual_seed(0) + model = torch.nn.Sequential(*[ + torch.nn.Linear(N, K, dtype=dtype, device=device, bias=False) + for _ in range(repeat) + ]) + model.requires_grad_(False) + return model + + +@torch.no_grad() +def eval_model(model, M, K, iters=50, verbose=False): + torch.manual_seed(0) + t = [] + for i in range(iters): + x = torch.randn(M, K, dtype=dtype, device=device) + torch.cuda.synchronize() + t1 = time.perf_counter() + out = model(x) + torch.cuda.synchronize() + t2 = time.perf_counter() + _time = (t2 - t1) * 1000 + t.append(_time) + if verbose: + print(f"Took: {_time} ms") + t = t[-(iters // 2):] + time_torch = (sum(t) / len(t)) + return time_torch + + +def get_flops(M, K, N, perf_time_ms): + flops_per_linear = 2 * M * N * K + tflops = flops_per_linear / (perf_time_ms * 1e-3) / 1e12 + return tflops + + +def cleanup(model): + del model + torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() + + +########################################################################################################################### +# Pytorch INT8 dynamic reference +########################################################################################################################### +class NativePyTorchINT8Dynamic(torch.nn.Module): + def __init__(self, linear_layer): + super().__init__() + w_fp16 = linear_layer.weight.data + self.w_scales = w_fp16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-5) / 127.0 + w_int8 = torch.round(w_fp16 / self.w_scales).to(torch.int8) + self.w_int8 = w_int8.contiguous() + self.w_scales = self.w_scales.view(1, -1) + + def forward(self, x): + x_scales = x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5) / 127.0 + x_int8 = torch.round(x / x_scales).to(torch.int8) + out_int32 = torch._int_mm(x_int8, self.w_int8.t()) + return out_int32.to(x.dtype) * (x_scales * self.w_scales) + + +def patch_model_native_int8(model): + for i, layer in enumerate(model): + if isinstance(layer, torch.nn.Linear): + model[i] = NativePyTorchINT8Dynamic(layer) + + +########################################################################################################################### +# Pytorch FP8 dynamic reference +########################################################################################################################### +def _to_fp8_and_inv_scale( + x: torch.Tensor, + fp8_dtype: torch.dtype, + dim: int | tuple[int, ...] | None, + keepdim: bool, + clamp_min: float = 1e-12, +): + finfo = torch.finfo(fp8_dtype) + x_fp32 = x.float() + if dim is None: + amax = x_fp32.abs().amax().clamp(min=clamp_min) + else: + amax = x_fp32.abs().amax(dim=dim, keepdim=keepdim).clamp(min=clamp_min) + + scale_gain = (finfo.max / amax) + x_scaled_sat = (x_fp32 * scale_gain).clamp(min=finfo.min, max=finfo.max) + x_fp8 = x_scaled_sat.to(fp8_dtype) + inv_scale = scale_gain.reciprocal().to(torch.float32) + return x_fp8, inv_scale + + +class NativePyTorchFP8Dynamic(torch.nn.Module): + def __init__( + self, + linear_layer: torch.nn.Linear, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + use_fast_accum: bool = False, + ): + super().__init__() + self.fp8_dtype = fp8_dtype + self.use_fast_accum = use_fast_accum + + w_hp = linear_layer.weight.data + w_fp8, w_inv_scale_row = _to_fp8_and_inv_scale(w_hp, fp8_dtype=fp8_dtype, dim=1, keepdim=True) + self.register_buffer("w_fp8", w_fp8.contiguous().t()) + self.register_buffer("w_inv_scale", w_inv_scale_row.view(1, -1).contiguous()) + + if linear_layer.bias is not None: + self.register_buffer("bias", linear_layer.bias.data.contiguous()) + else: + self.bias = None + + def forward(self, x: torch.Tensor): + x_fp8, x_inv_scale = _to_fp8_and_inv_scale(x, fp8_dtype=self.fp8_dtype, dim=-1, keepdim=True) + out = torch._scaled_mm( + x_fp8, + self.w_fp8, + scale_a=x_inv_scale, + scale_b=self.w_inv_scale, + bias=self.bias, + out_dtype=x.dtype, + use_fast_accum=self.use_fast_accum, + ) + if isinstance(out, tuple): + out = out[0] + return out + + +def patch_model_native_fp8(model, fp8_dtype=torch.float8_e4m3fn, use_fast_accum=False): + for i, layer in enumerate(model): + if isinstance(layer, torch.nn.Linear): + model[i] = NativePyTorchFP8Dynamic( + layer, fp8_dtype=fp8_dtype, use_fast_accum=use_fast_accum + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate TFLOP/s for various quantized matmul processors.", + epilog=""" +Examples: + # Run with default parameters + python eval_flops.py + + # Run with specific dimensions: + python eval_flops.py --M 128 --K 4096 --N 4096 + + # Run only specific processors (comma-separated): + python eval_flops.py --processor A16W8_INT8,A8W8_FP8_dynamic + + # Run only BF16 baseline (no quantization): + python eval_flops.py --processor none + + # Available processors: + # A16W8_INT8, A16W8_FP8, A8W8_INT8_dynamic, A8W8_FP8_dynamic, + # A8W8_MXFP_dynamic_post_scale, A8W8_MXFP_dynamic_no_post_scale, + # A4W4_MXFP_dynamic, A4W4_NVFP_dynamic, none (BF16 baseline) + # Use "all" to run every processor. + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--M", type=int, default=8192, help="Batch/sequence dimension") + parser.add_argument("--K", type=int, default=16384, help="Input feature dimension") + parser.add_argument("--N", type=int, default=16384, help="Output feature dimension") + parser.add_argument("--processor", type=str, default="all", + help='Comma-separated processor names or "all" (default: all)') + args = parser.parse_args() + + M, K, N = args.M, args.K, args.N + + PROCESSOR_MAP = { + "A16W8_INT8": lambda: A16W8_INT8(), + "A16W8_FP8": lambda: A16W8_FP8(), + "A8W8_INT8_dynamic": lambda: A8W8_INT8_dynamic(), + "A8W8_FP8_dynamic": lambda: A8W8_FP8_dynamic(), + "A8W8_MXFP_dynamic_post_scale": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=True), + "A8W8_MXFP_dynamic": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=False), + "A4W4_MXFP_dynamic": lambda: A4W4_MXFP_dynamic(dtype=dtype), + "A4W4_NVFP_dynamic": lambda: A4W4_NVFP_dynamic(dtype=dtype), + "none": lambda: None, + } + + if args.processor == "all": + processor_names = list(PROCESSOR_MAP.keys()) + else: + processor_names = [p.strip() for p in args.processor.split(",")] + + results = [] + + # ---- GemLite processors ---- + for proc_name in processor_names: + if proc_name not in PROCESSOR_MAP: + print(f"Unknown processor: {proc_name}, skipping.") + continue + + procesor = PROCESSOR_MAP[proc_name]() + + model = get_model(K, N, repeat=repeat) + if procesor is not None: + patch_model(model, device=device, processor=procesor) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + label = proc_name if procesor is not None else "FP16 (no processor)" + tflops = get_flops(M, K, N, perf_time_ms) + print(f"Processor: {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + results.append((label, M, K, N, tflops)) + + cleanup(model) + + # ---- PyTorch Native INT8 dynamic reference ---- + if M >= 16: + model = get_model(K, N, repeat=repeat) + patch_model_native_int8(model) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + print(f"PyTorch Native INT8 | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + results.append(("PyTorch Native INT8", M, K, N, tflops)) + + cleanup(model) + else: + print(f"Skipping PyTorch Native INT8 for M={M} (requires M >= 16).") + + # ---- PyTorch Native FP8 dynamic reference ---- + model = get_model(K, N, repeat=repeat) + patch_model_native_fp8(model, fp8_dtype=torch.float8_e4m3fn, use_fast_accum=False) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + print(f"PyTorch Native FP8 | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + results.append(("PyTorch Native FP8", M, K, N, tflops)) + + cleanup(model) + + # ---- Summary ---- + print("\n" + "=" * 70) + gpu_name = torch.cuda.get_device_name(device) + print(f"SUMMARY (GPU: {gpu_name})") + print("=" * 70) + max_label_len = max(len(r[0]) for r in results) if results else 0 + for label, m, k, n, tflops in results: + print(f" {label:<{max_label_len}} | {m}, {k}, {n} | {tflops:.2f} TFLOP/s") + print("=" * 70) + + +if __name__ == "__main__": + main() From 31c8c40e2d89bb928e2c4cb0437d294532a851e8 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 3 Mar 2026 02:38:11 -0800 Subject: [PATCH 16/63] update store masking --- gemlite/triton_kernels/gemm_kernels.py | 11 ++++-- gemlite/triton_kernels/gemm_splitK_kernels.py | 35 ++++++++++++------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index bd0aa40..09f88e7 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -628,7 +628,11 @@ def gemm_INT_kernel_persistent_tma( ############################################################################################################# # Store c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn - tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :]) + mask = (m_mask[:, None] & n_mask[None, :]).to(tl.int1) + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) @triton.autotune( configs = get_autotune_config(), @@ -850,7 +854,10 @@ def gemm_MX_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) - tl.store(c_ptrs, acc, mask=mask) + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) PRINTED = False diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 7c8adf3..081e7eb 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -87,9 +87,9 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - #Constraint: K needs to be divisible by BLOCK_SIZE_K * SPLIT_K - while split_k > 1 and not is_divisible(k, block_size_k * split_k): - split_k //= 2 + # #Constraint: K needs to be divisible by BLOCK_SIZE_K * SPLIT_K + # while split_k > 1 and not is_divisible(k, block_size_k * split_k): + # split_k //= 2 #Nvidia if not IS_HIP: @@ -100,8 +100,8 @@ def kernel_config_pruner(configs, nargs, **kwargs): #skip num_stages=1 for non-packed weights continue - # #Avoid OOM - # while num_stages > 0: #TODO: revisit MXFP case + # #Avoid OOM: TODO: come up with a better logic, this is too conservative. + # while num_stages > 0: # shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) # if(e > 1 and not load_scales_as_block): # shared_mem += block_size_k * block_size_n * a_sizeof @@ -114,7 +114,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): split_k = max(split_k, 1) ########################################### - if(load_scales_as_block):#tmp MXFP fix + if(load_scales_as_block): #TODO: tmp MXFP TMA fix block_size_k = min(block_size_k, 256) ########################################### @@ -124,7 +124,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): even_n = (n % block_size_n == 0) even_k = (k % block_size_k == 0) - new_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, @@ -478,12 +477,18 @@ def gemm_splitK_INT_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) if(SPLIT_K > 1): - tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) + if EVEN_M and EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) else: - tl.store(c_ptrs, acc, mask=mask) + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) @triton.autotune( configs=get_autotune_config(), @@ -645,9 +650,15 @@ def gemm_splitK_MX_kernel( mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) if(SPLIT_K > 1): - tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) + if EVEN_M and EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) else: - tl.store(c_ptrs, acc, mask=mask) + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, From 488e4945e45190b92b14b275579c6a3e342f324d Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 03:13:12 -0800 Subject: [PATCH 17/63] update configs --- examples/eval_flops.py | 2 + gemlite/triton_kernels/gemm_kernels.py | 53 ++++++++----------- gemlite/triton_kernels/gemm_splitK_kernels.py | 51 +++++++----------- gemlite/triton_kernels/utils.py | 20 ++++++- tests/test_mxfp.py | 24 ++++----- 5 files changed, 73 insertions(+), 77 deletions(-) diff --git a/examples/eval_flops.py b/examples/eval_flops.py index d28bacd..6a12286 100644 --- a/examples/eval_flops.py +++ b/examples/eval_flops.py @@ -187,6 +187,7 @@ def main(): PROCESSOR_MAP = { "A16W8_INT8": lambda: A16W8_INT8(), "A16W8_FP8": lambda: A16W8_FP8(), + "A16W4_HQQ_INT": lambda: A16W4_HQQ_INT(), "A8W8_INT8_dynamic": lambda: A8W8_INT8_dynamic(), "A8W8_FP8_dynamic": lambda: A8W8_FP8_dynamic(), "A8W8_MXFP_dynamic_post_scale": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=True), @@ -194,6 +195,7 @@ def main(): "A4W4_MXFP_dynamic": lambda: A4W4_MXFP_dynamic(dtype=dtype), "A4W4_NVFP_dynamic": lambda: A4W4_NVFP_dynamic(dtype=dtype), "none": lambda: None, + "fp16": lambda: None, } if args.processor == "all": diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 09f88e7..cb0bbb3 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -79,43 +79,32 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - ###################################################### # # FOR TMA # if block_size_n % 128 > 0: # block_size_n = 128 # if block_size_k % 128 > 0: # block_size_k = 128 + + if(load_scales_as_block): #tmp MXFP fix with TMA + block_size_k = max(block_size_k, 64) + block_size_n = max(block_size_n, 64) ###################################################### #Hint: skip block_size_n > block_size_k for col-major non-packed data. - # #Nvidia - # if not IS_HIP: - # if e > 1 and not load_scales_as_block: - # #Limit num stages when data is packed - # num_stages = min(num_stages, 4) - # if(e == 1 and num_stages == 1): - # #skip num_stages=1 for non-packed weights - # continue - - # #Avoid OOM - # while num_stages > 0 and not load_scales_as_block: #TODO: revisit MXFP case - # shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) - # if(e > 1): - # shared_mem += block_size_k * block_size_n * a_sizeof - # shared_mem *= num_stages - # if int(shared_mem) <= gpu_shared_memory: - # break - # num_stages -= 1 - - if(num_stages == 0): continue #config too large - - ########################################### - if(load_scales_as_block):#tmp MXFP fix with TMA - block_size_k = max(block_size_k, 64) - block_size_n = max(block_size_n, 64) - ########################################### + if not IS_HIP: + if e == 1 and num_stages == 1: + continue + + # Prune configs that exceed shared memory + estimated_smem = estimate_shared_memory_per_block( + block_size_m, block_size_n, block_size_k, + a_sizeof, b_sizeof, num_stages, e, g, + load_scales_as_block + ) + if estimated_smem > gpu_shared_memory: + continue key = (block_size_m, block_size_n, block_size_k, group_size_m, A_load_order, num_stages, num_warps) @@ -185,10 +174,8 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) - # configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) @@ -197,9 +184,14 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + #MXFP/NVFP configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) return configs def get_default_config_nvidia(): @@ -451,8 +443,7 @@ def gemm_INT_kernel( if(channel_scale_mode == 2): #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * scales_a[:, None] if(channel_scale_mode == 3): #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 081e7eb..79301ae 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -76,7 +76,8 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Constraint: BLOCK_SIZE_K >= group_size, only for load_as_block = False if(load_scales_as_block): - num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads + #num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads + block_size_k = min(block_size_k, 256) #TODO: tmp MXFP TMA fix if(e > 1): block_size_k = max(block_size_k, 64) #m16n8k64 else: @@ -86,37 +87,20 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - - # #Constraint: K needs to be divisible by BLOCK_SIZE_K * SPLIT_K - # while split_k > 1 and not is_divisible(k, block_size_k * split_k): - # split_k //= 2 - - #Nvidia + split_k = max(split_k, 1) + if not IS_HIP: - if e > 1 and not load_scales_as_block: - #Limit num stages when data is packed - num_stages = min(num_stages, 4) - if(e == 1 and num_stages == 1): - #skip num_stages=1 for non-packed weights + if e == 1 and num_stages == 1: continue - - # #Avoid OOM: TODO: come up with a better logic, this is too conservative. - # while num_stages > 0: - # shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) - # if(e > 1 and not load_scales_as_block): - # shared_mem += block_size_k * block_size_n * a_sizeof - # shared_mem *= num_stages - # if int(shared_mem) <= gpu_shared_memory: - # break - # num_stages -= 1 - - # if(num_stages == 0): continue #config too large - - split_k = max(split_k, 1) - ########################################### - if(load_scales_as_block): #TODO: tmp MXFP TMA fix - block_size_k = min(block_size_k, 256) - ########################################### + + # Prune configs that exceed shared memory + estimated_smem = estimate_shared_memory_per_block( + block_size_m, block_size_n, block_size_k, + a_sizeof, b_sizeof, num_stages, e, g, + load_scales_as_block + ) + if estimated_smem > gpu_shared_memory: + continue key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k, A_load_order, num_stages, num_warps) @@ -203,6 +187,8 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':512, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':32, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=2)) return configs def get_default_config_nvidia(): @@ -453,7 +439,7 @@ def gemm_splitK_INT_kernel( b_ptrs += BLOCK_SIZE_K_P * stride_bk if not EVEN_K: - a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K * SPLIT_K) < K)).to(tl.int1) ############################################################################################################# #Channel-wise scaling @@ -463,8 +449,7 @@ def gemm_splitK_INT_kernel( if(channel_scale_mode == 2): #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * scales_a[:, None] if(channel_scale_mode == 3): #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) diff --git a/gemlite/triton_kernels/utils.py b/gemlite/triton_kernels/utils.py index 77a4c31..65a98b1 100755 --- a/gemlite/triton_kernels/utils.py +++ b/gemlite/triton_kernels/utils.py @@ -108,7 +108,7 @@ def is_divisible(dividend, divisor): def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" -def gpu_has_more_shared_memory(ref_gpus = ["a100", "h100", "h200", "h20", "h800", "b100", "b200"]): +def gpu_has_more_shared_memory(ref_gpus = ["a100", "h100", "h200", "h20", "h800", "b100", "b200", "b300", "6000"]): gpu_name = torch.cuda.get_device_properties(0).name.lower() return True in [g in gpu_name for g in ref_gpus] @@ -121,6 +121,24 @@ def gpu_supports_float16_acc( gpu_name = torch.cuda.get_device_properties(0).name.lower() return True in [g in gpu_name for g in ref_gpus] +def estimate_shared_memory_per_block(block_size_m, block_size_n, block_size_k, a_sizeof, b_sizeof, num_stages, e, g, load_scales_as_block): + a_smem = block_size_m * block_size_k * a_sizeof + if load_scales_as_block: + # MX kernels: dot_scaled handles scaling natively, no dequant buffer + b_smem = (block_size_k // e) * block_size_n * b_sizeof + # scales: (BLOCK_N, BLOCK_K // group_size) × meta_sizeof + s_smem = block_size_n * (block_size_k // g) * 1 # uint8 or e4m3 = 1 byte + estimated_smem = (a_smem + b_smem + s_smem) * max(num_stages - 1, 1) + elif e > 1: + # INT packed: need packed B + dequantized B for MMA + b_smem = (block_size_k // e) * block_size_n * b_sizeof + b_smem += block_size_k * block_size_n * a_sizeof + estimated_smem = int((a_smem + b_smem) * num_stages * 1.20) + else: + # INT unpacked (8-bit): exact formula + b_smem = block_size_k * block_size_n * b_sizeof + estimated_smem = (a_smem + b_smem) * max(num_stages - 1, 1) + return estimated_smem def gpu_supports_bfloat16_atomicadd(): #Triton tl.atomic_add doens't support bfloat16 on older GPUs. diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 8dd4e33..cdec61c 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -21,7 +21,7 @@ def is_fp8_supported(device_index=0): torch.random.manual_seed(0) in_features, out_features = 4032, 2048 -batch_sizes = [1]#[1, 30, 32, 60, 100, 128] +batch_sizes = [1, 30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False @@ -45,12 +45,12 @@ def eval(self, gemlite_linear, tol: float = 1e-3): err = (y_ref - y_gem).abs().mean().item() self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_A16W8_MXFP(self): - gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) - self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) - self.assertTrue(not gemlite_linear.scaled_activations) - self.eval(gemlite_linear, tol = 2e-4) + # @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + # def test_A16W8_MXFP(self): + # gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) + # self.assertTrue(not gemlite_linear.scaled_activations) + # self.eval(gemlite_linear, tol = 2e-4) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W8_MXFP_post_scale_dynamic(self): @@ -66,11 +66,11 @@ def test_A8W8_MXFP_dynamic(self): self.assertTrue(gemlite_linear.scaled_activations) self.eval(gemlite_linear, tol = 2e-4) - def test_A16W4_MXFP(self): - gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) - self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) - self.assertTrue(not gemlite_linear.scaled_activations) - self.eval(gemlite_linear, tol = 7e-4) + # def test_A16W4_MXFP(self): + # gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) + # self.assertTrue(not gemlite_linear.scaled_activations) + # self.eval(gemlite_linear, tol = 7e-4) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W4_MXFP_dynamic(self): From afe4105b11de8bc3b731d7101f537e8bcb418573 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 03:19:56 -0800 Subject: [PATCH 18/63] update configs --- gemlite/triton_kernels/gemm_kernels.py | 35 ++++++++----------- gemlite/triton_kernels/gemm_splitK_kernels.py | 12 ++++--- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index cb0bbb3..78eff89 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -66,30 +66,25 @@ def kernel_config_pruner(configs, nargs, **kwargs): elif m <= 128: block_size_m = min(max(block_size_m, 64), 128) #m: [64...128] elif m <= 256: block_size_m = min(max(block_size_m, 64), 256) #m: [128...256] elif m > 256: block_size_m = min(max(block_size_m, 64), 256) #m > 256 - - # #Constraint: BLOCK_SIZE_K >= group_size, only for load_as_block = False - # if(load_scales_as_block): - # num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads - # if(e > 1): - # block_size_k = max(block_size_k, 64) #m16n8k64 - # else: - # block_size_k = max(block_size_k, 32) #m16n8k32 - # else: - # block_size_k = min(block_size_k, g) block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - ###################################################### - # # FOR TMA - # if block_size_n % 128 > 0: - # block_size_n = 128 - # if block_size_k % 128 > 0: - # block_size_k = 128 - if(load_scales_as_block): #tmp MXFP fix with TMA - block_size_k = max(block_size_k, 64) - block_size_n = max(block_size_n, 64) - ###################################################### + #Constraints + if(load_scales_as_block): + # FOR TMA + # block_size_k = min(block_size_k, 256) #TODO: tmp MXFP TMA fix + # if block_size_n % 128 > 0: + # block_size_n = 128 + # if block_size_k % 128 > 0: + # block_size_k = 128 + if(e > 1): + block_size_k = max(block_size_k, 64) #m16n8k64 + else: + block_size_k = max(block_size_k, 32) #m16n8k32 + else: + block_size_k = min(block_size_k, g) + #Hint: skip block_size_n > block_size_k for col-major non-packed data. diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 79301ae..daed70a 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -74,10 +74,14 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Only use higher split_k values for smaller m if(m >= 32): split_k = min(split_k, 8) - #Constraint: BLOCK_SIZE_K >= group_size, only for load_as_block = False - if(load_scales_as_block): - #num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads - block_size_k = min(block_size_k, 256) #TODO: tmp MXFP TMA fix + #Constraints + if(load_scales_as_block): + # FOR TMA + # block_size_k = min(block_size_k, 256) #TODO: tmp MXFP TMA fix + # if block_size_n % 128 > 0: + # block_size_n = 128 + # if block_size_k % 128 > 0: + # block_size_k = 128 if(e > 1): block_size_k = max(block_size_k, 64) #m16n8k64 else: From 7aee3611f132a19ee41d06d52482e9d0a2eef1b3 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 04:00:55 -0800 Subject: [PATCH 19/63] update configs --- gemlite/core.py | 3 ++- gemlite/triton_kernels/gemm_kernels.py | 5 ++--- gemlite/triton_kernels/gemm_splitK_kernels.py | 6 +++--- gemlite/triton_kernels/gemv_kernels.py | 13 +++++++------ gemlite/triton_kernels/gemv_revsplitK_kernels.py | 15 ++++++++------- tests/test_gemlitelineartriton.py | 2 +- tests/test_mxfp.py | 2 +- 7 files changed, 24 insertions(+), 22 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 2b559c5..13cdd0d 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -108,7 +108,8 @@ def get_default_gemv(W_nbits: int, mx_dtype: bool = False) -> str: def get_matmul_type(batch_size: int, W_nbits: int, mx_dtype: bool = False): if batch_size > 64: return "GEMM" - if batch_size > 1: + gemv_limit = 4 if (W_nbits < 8 and not mx_dtype) else 2 # previous 1 + if batch_size > gemv_limit: return "GEMM_SPLITK" else: return get_default_gemv(W_nbits, mx_dtype) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 78eff89..f04ead4 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -85,7 +85,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): else: block_size_k = min(block_size_k, g) - #Hint: skip block_size_n > block_size_k for col-major non-packed data. if not IS_HIP: @@ -600,11 +599,11 @@ def gemm_INT_kernel_persistent_tma( scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_b[None, :] - elif channel_scale_mode == 2: # activation-only + if channel_scale_mode == 2: # activation-only scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_a[:, None] - elif channel_scale_mode == 3: # weight + activation + if channel_scale_mode == 3: # weight + activation scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index daed70a..7425906 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -447,15 +447,15 @@ def gemm_splitK_INT_kernel( ############################################################################################################# #Channel-wise scaling - if(channel_scale_mode == 1): #weight-only + if channel_scale_mode == 1: #weight-only scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_b[None, :] - if(channel_scale_mode == 2): #activation-only + if channel_scale_mode == 2: #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_a[:, None] - if(channel_scale_mode == 3): #weight + activation + if channel_scale_mode == 3: #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index 4bcf5fa..538b4aa 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -49,7 +49,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('num_consumer_groups', None) config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) - configs['NUM_STAGES'] = num_stages + config['NUM_STAGES'] = num_stages yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, pre_hook=pre_hook) return @@ -140,6 +140,8 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=2)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':64, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=2, num_stages=1)) + + configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':1024,'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=1)) return configs @@ -377,16 +379,15 @@ def gemv_INT_kernel( ################################################################## #Channel-wise scaling - if(channel_scale_mode == 1): #weight-only + if channel_scale_mode == 1: #weight-only scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_b[None, :] - if(channel_scale_mode == 2): #activation-only + if channel_scale_mode == 2: #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * scales_a[:, None] - if(channel_scale_mode == 3): #weight + activation + if channel_scale_mode == 3: #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index f085687..a667ff9 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -44,7 +44,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): used = set() for config in configs: - block_size_m = 1 #Only 1 allowed here + block_size_m = 1 #next_power_of_2(m) #Only 1 allowed here block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) split_k = 2 @@ -142,6 +142,8 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=2)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=1)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':64, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=2)) + + configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':1024, 'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=1)) return configs @@ -374,16 +376,15 @@ def gemv_INT_revsplitK_kernel( if(dump_b_val > 0): acc /= dump_b_val ############################################################################################################ #Channel-wise scaling - if(channel_scale_mode == 1): #weight-only + if channel_scale_mode == 1: #weight-only scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_b[None, :] - if(channel_scale_mode == 2): #activation-only + if channel_scale_mode == 2: #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) - - if(channel_scale_mode == 3): #weight + activation + acc = acc.to(meta_dtype) * scales_a[:, None] + + if channel_scale_mode == 3: #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index f95013d..d8c4936 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -23,7 +23,7 @@ def is_fp8_supported(): KERNEL.ENABLE_CACHING = False in_features, out_features = 4032, 2032 -batch_sizes = [1, 5, 100] +batch_sizes = [1, 3, 5, 100] W_nbits, group_size = 4, 128 #128 / in_features if group_size is None: diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index cdec61c..cace230 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -21,7 +21,7 @@ def is_fp8_supported(device_index=0): torch.random.manual_seed(0) in_features, out_features = 4032, 2048 -batch_sizes = [1, 30, 32, 60, 100, 128] +batch_sizes = [1, 3, 30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False From 1869027126ae82253264d02cdbc820748eb8c8ba Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 04:18:55 -0800 Subject: [PATCH 20/63] fix --- examples/eval_flops.py | 2 +- tests/test_gemlitelineartriton.py | 2 +- tests/test_mxfp.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/eval_flops.py b/examples/eval_flops.py index 6a12286..acf49df 100644 --- a/examples/eval_flops.py +++ b/examples/eval_flops.py @@ -227,7 +227,7 @@ def main(): cleanup(model) # ---- PyTorch Native INT8 dynamic reference ---- - if M >= 16: + if M > 16: model = get_model(K, N, repeat=repeat) patch_model_native_int8(model) model = torch.compile(model, mode="reduce-overhead", fullgraph=True) diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index d8c4936..0076e61 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -23,7 +23,7 @@ def is_fp8_supported(): KERNEL.ENABLE_CACHING = False in_features, out_features = 4032, 2032 -batch_sizes = [1, 3, 5, 100] +batch_sizes = [1, 3, 5, 16, 30, 65, 100, 250] W_nbits, group_size = 4, 128 #128 / in_features if group_size is None: diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index cace230..0d440cb 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -21,7 +21,7 @@ def is_fp8_supported(device_index=0): torch.random.manual_seed(0) in_features, out_features = 4032, 2048 -batch_sizes = [1, 3, 30, 32, 60, 100, 128] +batch_sizes = [1, 3, 16, 30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False From 48d03e8c39e498ffb17e821119699f41a3f1313f Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 04:36:05 -0800 Subject: [PATCH 21/63] fix --- gemlite/quant_utils.py | 116 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index b876849..c0bb26c 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -530,6 +530,120 @@ def scale_activations_per_token_triton_v3( return y.view(x_shape), scales + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), + ], + key=["M", "K"], +) +@triton.jit +def scale_activations_per_token_triton_v4_kernel( + tensor_ptr, + scale_ptr, + y_ptr, + M, + K, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, + min_val: tl.constexpr, + max_val: tl.constexpr, + fp32_scale: tl.constexpr, + ROUND: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + start_pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + for pid_m in range(start_pid, num_tiles, num_programs): + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + # Pass 1: streaming amax over K chunks + row_max = tl.zeros([BLOCK_SIZE_M], dtype=tl.float32) + for k_start in range(0, K, BLOCK_SIZE_K): + offs_k = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = offs_k < K + mask = m_mask[:, None] & k_mask[None, :] + chunk = tl.load( + tensor_ptr + offs_m[:, None] * stride_m + offs_k[None, :] * stride_k, + mask=mask, + other=0.0, + ) + if fp32_scale: + chunk = chunk.to(tl.float32) + row_max = tl.maximum(row_max, tl.max(tl.abs(chunk), axis=1)) + + scales_x = row_max / max_val + scales_x = tl.maximum(scales_x, 1e-6) + tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) + + # Pass 2: scale, clamp, store + inv_scales = 1.0 / scales_x + for k_start in range(0, K, BLOCK_SIZE_K): + offs_k = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = offs_k < K + mask = m_mask[:, None] & k_mask[None, :] + offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + chunk = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) + if fp32_scale: + chunk = chunk.to(tl.float32) + chunk = chunk * inv_scales[:, None] + chunk = tl.minimum(tl.maximum(chunk, min_val), max_val) + if ROUND: + chunk = round_triton(chunk) + tl.store(y_ptr + offsets, chunk, mask=mask) + + +def scale_activations_per_token_triton_v4( + tensor: torch.Tensor, + w_dtype: torch.dtype, + fp32_scale: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + min_val, max_val = get_dtype_range(w_dtype) + + x_shape = tensor.shape + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + scales = torch.empty( + (M, 1), + dtype=torch.float32 if fp32_scale else tensor.dtype, + device=tensor.device, + ) + y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"])),) + + ROUND = not w_dtype.is_floating_point + + scale_activations_per_token_triton_v4_kernel[grid]( + tensor, + scales, + y, + M, + K, + tensor.stride(0), + tensor.stride(1), + scales.stride(0), + min_val=min_val, + max_val=max_val, + fp32_scale=fp32_scale, + ROUND=ROUND, + ) + + return y.view(x_shape), scales #################################################################################################################### #MXFP8 #################################################################################################################### @@ -1245,7 +1359,7 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, return out, scales #################################################################################################################### -scale_activations_per_token = scale_activations_per_token_triton_v3 +scale_activations_per_token = scale_activations_per_token_triton_v4 scale_activations_mxfp8 = scale_activations_mxfp8_triton_v3 scale_activations_mxfp4 = scale_activations_mxfp4_triton scale_activations_nvfp4 = scale_activations_nvfp4_triton From dcd6c79a7bf2e39dd2c2f2afccd202e7dde7297d Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 04:36:22 -0800 Subject: [PATCH 22/63] fix --- gemlite/quant_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index c0bb26c..05747fb 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1359,7 +1359,7 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, return out, scales #################################################################################################################### -scale_activations_per_token = scale_activations_per_token_triton_v4 +scale_activations_per_token = scale_activations_per_token_triton_v3 scale_activations_mxfp8 = scale_activations_mxfp8_triton_v3 scale_activations_mxfp4 = scale_activations_mxfp4_triton scale_activations_nvfp4 = scale_activations_nvfp4_triton From 16a507858b9f5323592c5c823c92189a5831fdd6 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 05:31:27 -0800 Subject: [PATCH 23/63] fix --- gemlite/triton_kernels/gemm_kernels.py | 9 ++++----- gemlite/triton_kernels/gemm_splitK_kernels.py | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index f04ead4..6c01bf5 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -817,12 +817,11 @@ def gemm_MX_kernel( acc *= meta_scale_norm ############################################################################################################# - #Channel-wise scaling - if(channel_scale_mode == 2): #activation-only + #Channel-wise scaling + if channel_scale_mode == 2: # activation-only dtype: tl.constexpr = c_ptr.dtype.element_ty - scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=dtype) - acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) + scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1.0, eviction_policy=meta_evict_policy) + acc = acc.to(dtype) * scales_a[:, None] ############################################################################################################# #Output diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 7425906..eadbedb 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -624,12 +624,11 @@ def gemm_splitK_MX_kernel( acc *= meta_scale_norm ############################################################################################################# - #Channel-wise scaling - if(channel_scale_mode == 2): #activation-only + #Channel-wise scaling + if channel_scale_mode == 2: # activation-only dtype: tl.constexpr = c_ptr.dtype.element_ty - scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=dtype) - acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) + scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1.0, eviction_policy=meta_evict_policy) + acc = acc.to(dtype) * scales_a[:, None] ############################################################################################################# #Output From ffbb01befec5ada2d89d8346fa1407945816e468 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 05:35:15 -0800 Subject: [PATCH 24/63] use tma for a,b,c mx --- gemlite/core.py | 16 ++++++++-------- gemlite/triton_kernels/gemm_kernels.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 13cdd0d..03c6e83 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -498,16 +498,16 @@ def pack( self.channel_scale_mode = 0 ################################ - # # TMA - # K, N = self.W_q.shape + # TMA + K, N = self.W_q.shape - # if(self.input_dtype in [DType.MXFP4, DType.NVFP4]): - # K *= 2 - # group_size = 2 * self.W_q.numel() // self.scales.numel() - # else: - # group_size = self.W_q.numel() // self.scales.numel() + if(self.input_dtype in [DType.MXFP4, DType.NVFP4]): + K *= 2 + group_size = 2 * self.W_q.numel() // self.scales.numel() + else: + group_size = self.W_q.numel() // self.scales.numel() - # self.W_q = self.W_q.contiguous().T #Transposed for tma + self.W_q = self.W_q.contiguous().T #Transposed for tma # #self.scales = self.scales.contiguous().T # Transposed 2D TMA layout # #self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() # 5D TMA layout for the scales: diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 6c01bf5..eddc147 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -671,7 +671,7 @@ def gemm_MX_kernel( b_evict: tl.constexpr = "", meta_scale_norm: tl.constexpr = (0.05 ** 2), ################################# - use_tma: tl.constexpr = False, + use_tma: tl.constexpr = True, ): pid = tl.program_id(axis=0) From a4a753a7387c8eb9fde45fa95d09a8b9ac808177 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 05:44:03 -0800 Subject: [PATCH 25/63] use tma for a,b,c mx --- gemlite/triton_kernels/gemm_kernels.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index eddc147..a0ca2da 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -733,6 +733,7 @@ def gemm_MX_kernel( [stride_bn, stride_bk], [BLOCK_SIZE_N, BLOCK_SIZE_K_B] ) + # # 2. 5D TMA Descriptors for Scales: #(8388608, 65536, 512, 256, 1) torch.Size([1, 128, 128, 2, 256]) # rep_m: tl.constexpr = BLOCK_SIZE_M // 128 @@ -773,7 +774,7 @@ def gemm_MX_kernel( # Load A and B tiles if use_tma: a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) - b = tl.load_tensor_descriptor(b_desc, [k * BLOCK_SIZE_K_B, pid_n * BLOCK_SIZE_N]).T + b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_SIZE_N, k * BLOCK_SIZE_K_B]).T else: if EVEN_M and EVEN_K: a = tl.load(a_ptrs, eviction_policy=a_evict) From 2b6ce21abcbacd1969e13fd6367bb341995a0d1e Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 06:57:53 -0800 Subject: [PATCH 26/63] update scales --- gemlite/core.py | 4 +- gemlite/triton_kernels/gemm_kernels.py | 38 ++++++++----- gemlite/triton_kernels/gemm_splitK_kernels.py | 55 ++++++++++++++++--- tests/test_mxfp.py | 2 +- 4 files changed, 73 insertions(+), 26 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 03c6e83..0a917d3 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -507,9 +507,9 @@ def pack( else: group_size = self.W_q.numel() // self.scales.numel() - self.W_q = self.W_q.contiguous().T #Transposed for tma + #self.W_q = self.W_q.contiguous().T #Transposed for tma - # #self.scales = self.scales.contiguous().T # Transposed 2D TMA layout + #self.scales = self.scales.contiguous().T # Transposed 2D TMA layout # #self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() # 5D TMA layout for the scales: # #print(self.scales.stride(), self.scales.shape) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index a0ca2da..13bd6eb 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -77,11 +77,12 @@ def kernel_config_pruner(configs, nargs, **kwargs): # if block_size_n % 128 > 0: # block_size_n = 128 # if block_size_k % 128 > 0: - # block_size_k = 128 + # block_size_k = 128 if(e > 1): block_size_k = max(block_size_k, 64) #m16n8k64 else: block_size_k = max(block_size_k, 32) #m16n8k32 + #block_size_k = max(block_size_k, 128) #TMA else: block_size_k = min(block_size_k, g) @@ -717,7 +718,6 @@ def gemm_MX_kernel( offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] - if use_tma: a_desc = tl.make_tensor_descriptor( a_ptr, @@ -725,8 +725,7 @@ def gemm_MX_kernel( [stride_am, stride_ak], [BLOCK_SIZE_M, BLOCK_SIZE_K_A] ) - - # Transposed + b_desc = tl.make_tensor_descriptor( b_ptr, [N, K // elements_per_sample], @@ -734,7 +733,21 @@ def gemm_MX_kernel( [BLOCK_SIZE_N, BLOCK_SIZE_K_B] ) - + # 2D TMA - transposed + # scales_a_desc = tl.make_tensor_descriptor( + # scales_a_ptr, + # [M, K // group_size], + # [stride_meta_a_m, stride_meta_a_g], + # [BLOCK_SIZE_M, BLOCK_SIZE_K_S], + # ) + + scales_b_desc = tl.make_tensor_descriptor( + scales_ptr, + [K // group_size, N], + [stride_meta_g, stride_meta_n], + [BLOCK_SIZE_K_S, BLOCK_SIZE_N], + ) + # # 2. 5D TMA Descriptors for Scales: #(8388608, 65536, 512, 256, 1) torch.Size([1, 128, 128, 2, 256]) # rep_m: tl.constexpr = BLOCK_SIZE_M // 128 # rep_n: tl.constexpr = BLOCK_SIZE_N // 128 @@ -754,20 +767,13 @@ def gemm_MX_kernel( # [1, rep_n, rep_k, 2, 256] # ) - # # 2D TMA - transposed - # scales_b_desc = tl.make_tensor_descriptor( - # scales_ptr, - # [K // group_size, N], - # [stride_meta_g, stride_meta_n], - # [BLOCK_SIZE_K_S, BLOCK_SIZE_N], - # ) - #B-scales if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g # Used in channel-wise MXPF8 version scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): @@ -788,7 +794,7 @@ def gemm_MX_kernel( scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) # # 2D TMA - # scales_b = tl.load_tensor_descriptor(scales_b_desc, [k * BLOCK_SIZE_K_S, pid_n * BLOCK_SIZE_N]).T + #scales_b = tl.load_tensor_descriptor(scales_b_desc, [k * BLOCK_SIZE_K_S, pid_n * BLOCK_SIZE_N]).T # 5D Scale Loads and Unpacking # offs_scale_m = pid_m * rep_m @@ -801,6 +807,7 @@ def gemm_MX_kernel( if(channel_scale_mode == 4): scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + #scales_a = tl.load_tensor_descriptor(scales_a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_S]) else: scales_a = scales_a_1s #################################################################################### @@ -854,7 +861,8 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x global PRINTED - M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] + M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W + #M, K, N = x.shape[0], W_q.shape[1] * elements_per_sample, W_q.shape[0] #W.T M_CLOSEST = get_closest_m(M) #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index eadbedb..d6eaaad 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -538,6 +538,7 @@ def gemm_splitK_MX_kernel( b_evict: tl.constexpr = 'evict_first', meta_scale_norm: tl.constexpr = (0.05 ** 2), ################################# + use_tma: tl.constexpr = False, ): pid = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -590,17 +591,53 @@ def gemm_splitK_MX_kernel( if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g + + if use_tma: + a_desc = tl.make_tensor_descriptor( + a_ptr, + [M, K // elements_per_sample_a], + [stride_am, stride_ak], + [BLOCK_SIZE_M, BLOCK_SIZE_K_A] + ) + + b_desc = tl.make_tensor_descriptor( + b_ptr, + [N, K // elements_per_sample], + [stride_bn, stride_bk], + [BLOCK_SIZE_N, BLOCK_SIZE_K_B] + ) + + # 2D TMA - transposed + # scales_a_desc = tl.make_tensor_descriptor( + # scales_a_ptr, + # [M, K // group_size], + # [stride_meta_a_m, stride_meta_a_g], + # [BLOCK_SIZE_M, BLOCK_SIZE_K_S], + # ) + + scales_b_desc = tl.make_tensor_descriptor( + scales_ptr, + [K // group_size, N], + [stride_meta_g, stride_meta_n], + [BLOCK_SIZE_K_S, BLOCK_SIZE_N], + ) + # Used in channel-wise MXPF8 version scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k): - if EVEN_M and EVEN_K: - a = tl.load(a_ptrs, eviction_policy=a_evict) + if use_tma: + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) + b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_SIZE_N, k * BLOCK_SIZE_K_B]).T else: - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - - b = tl.load(b_ptrs, eviction_policy=b_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + + b = tl.load(b_ptrs, eviction_policy=b_evict) #k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) k_m = (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_S #OK for BLOCK_SIZE_K >=group_size @@ -616,8 +653,9 @@ def gemm_splitK_MX_kernel( a_ptrs += BLOCK_SIZE_K_A * stride_ak b_ptrs += BLOCK_SIZE_K_B * stride_bk - if not EVEN_K: - a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) + if not use_tma: + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) #NVFP4 meta-scale if(group_size == 16): @@ -654,7 +692,8 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, ) -> Tensor: - M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] + M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W + #M, K, N = x.shape[0], W_q.shape[1] * elements_per_sample, W_q.shape[0] #W.T #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" M_CLOSEST = get_closest_m(M) diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 0d440cb..36bf55f 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -14,7 +14,7 @@ def is_fp8_supported(device_index=0): device = 'cuda:0' compute_dtype = torch.bfloat16 #float16, bfloat16 -matmul_types = ['GEMM_SPLITK', 'GEMM'] #TODO: improve GEMV mxfp accuracy. +matmul_types = ['GEMM', 'GEMM_SPLITK'] #TODO: improve GEMV mxfp accuracy. reset_config() set_autotune(False) KERNEL.ENABLE_CACHING = False From edb6fc3b2c75933d0316bd5d9f2f8098ca04e786 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 07:34:57 -0800 Subject: [PATCH 27/63] tma stable --- gemlite/triton_kernels/gemm_kernels.py | 13 +++++---- gemlite/triton_kernels/gemm_splitK_kernels.py | 28 +++++++++++++------ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 13bd6eb..6b25b4d 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -733,6 +733,13 @@ def gemm_MX_kernel( [BLOCK_SIZE_N, BLOCK_SIZE_K_B] ) + c_desc = tl.make_tensor_descriptor( + c_ptr, + [M, N], + [stride_cm, stride_cn], + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + # 2D TMA - transposed # scales_a_desc = tl.make_tensor_descriptor( # scales_a_ptr, @@ -834,12 +841,6 @@ def gemm_MX_kernel( ############################################################################################################# #Output if use_tma: - c_desc = tl.make_tensor_descriptor( - c_ptr, - [M, N], - [stride_cm, stride_cn], - [BLOCK_SIZE_M, BLOCK_SIZE_N] - ) tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) else: offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index d6eaaad..43846e4 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -597,14 +597,21 @@ def gemm_splitK_MX_kernel( a_ptr, [M, K // elements_per_sample_a], [stride_am, stride_ak], - [BLOCK_SIZE_M, BLOCK_SIZE_K_A] + [BLOCK_SIZE_M, BLOCK_SIZE_K_A_E] ) b_desc = tl.make_tensor_descriptor( b_ptr, [N, K // elements_per_sample], [stride_bn, stride_bk], - [BLOCK_SIZE_N, BLOCK_SIZE_K_B] + [BLOCK_SIZE_N, BLOCK_SIZE_K_B_E] + ) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + [M, N], + [stride_cm, stride_cn], + [BLOCK_SIZE_M, BLOCK_SIZE_N] ) # 2D TMA - transposed @@ -629,8 +636,8 @@ def gemm_splitK_MX_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k): if use_tma: - a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) - b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_SIZE_N, k * BLOCK_SIZE_K_B]).T + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_A_E]) + b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_SIZE_N, (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_B_E]).T else: if EVEN_M and EVEN_K: a = tl.load(a_ptrs, eviction_policy=a_evict) @@ -667,7 +674,7 @@ def gemm_splitK_MX_kernel( dtype: tl.constexpr = c_ptr.dtype.element_ty scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1.0, eviction_policy=meta_evict_policy) acc = acc.to(dtype) * scales_a[:, None] - + ############################################################################################################# #Output offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) @@ -679,12 +686,15 @@ def gemm_splitK_MX_kernel( if EVEN_M and EVEN_N: tl.atomic_add(c_ptrs, acc, sem=atomic_mode) else: - tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) else: - if EVEN_M and EVEN_N: - tl.store(c_ptrs, acc) + if use_tma: + tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) else: - tl.store(c_ptrs, acc, mask=mask) + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, From 6dfcd1b5edee82578a490be521f3f1118d713435 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 09:46:32 -0800 Subject: [PATCH 28/63] add mxfp8 v4 activation quant --- gemlite/quant_utils.py | 119 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 05747fb..119ed07 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1004,6 +1004,123 @@ def scale_activations_mxfp8_triton_v3( return out, scales +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 512}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp8_triton_kernel_v4( + tensor_ptr, out_ptr, scales_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + min_val: tl.constexpr, max_val: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + # Load [BLOCK_M, BLOCK_K] + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + # Reshape to [BLOCK_M * GROUPS_PER_BLOCK, GROUP_SIZE] for group-wise reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # Per-group abs_max → power-of-2 scale + abs_max = tl.max(tl.abs(tensor_flat), axis=1) + scales, scales_log2 = next_power_of_2_bitwise_triton(abs_max / max_val, eps_exp) + + # Quantize: multiply by reciprocal, clamp, cast + out = tensor_flat * (1.0 / scales[:, None]) + out = tl.clamp(out, min=min_val, max=max_val) + out = out.to(out_dtype) + + # Reshape back to [BLOCK_M, BLOCK_K] and store + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + tl.store(out_bp, out, boundary_check=(0, 1)) + + # Store scales: [FLAT_M] → [BLOCK_M, GROUPS_PER_BLOCK] + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=m_mask[:, None] & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, BLOCK_SIZE_K)) + +# ersistent 1D grid, processes multiple K-groups per iteration via reshape +def scale_activations_mxfp8_triton_v4( + tensor: torch.Tensor, w_dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 32 + eps_exp: int = -30 + min_val, max_val = get_dtype_range(w_dtype) + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) + + scale_activations_mxfp8_triton_kernel_v4[grid]( + tensor, out, scales, + M, K, + tensor.stride(0), tensor.stride(1), + out.stride(0), out.stride(1), + scales.stride(0), scales.stride(1), + min_val=min_val, max_val=max_val, + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + + return out, scales #################################################################################################################### #MXPF4 / NVFP4 #################################################################################################################### @@ -1360,6 +1477,6 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, #################################################################################################################### scale_activations_per_token = scale_activations_per_token_triton_v3 -scale_activations_mxfp8 = scale_activations_mxfp8_triton_v3 +scale_activations_mxfp8 = scale_activations_mxfp8_triton_v4 scale_activations_mxfp4 = scale_activations_mxfp4_triton scale_activations_nvfp4 = scale_activations_nvfp4_triton From 22b074b8c9c8c5285950d7ae8ef77c2db8e15762 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Mar 2026 10:33:58 -0800 Subject: [PATCH 29/63] add mxfp4/nvfp4 v3 activation quant --- gemlite/quant_utils.py | 498 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 496 insertions(+), 2 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 119ed07..2f0c97a 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1475,8 +1475,502 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, return out, scales +#################################################################################################################### +# MXFP4 v2: persistent 1D grid, processes multiple K-groups per iteration +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp4_triton_kernel_v2( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] + + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K // 2), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + # Reshape to [FLAT_M, GROUP_SIZE] for group-wise reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # Per-group power-of-2 scale + scales, scales_log2 = next_power_of_2_bitwise_triton( + tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp + ) + + # Map to FP4 index via threshold comparison + wq = tensor_flat / scales + idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + # Reshape to [BLOCK_M, BLOCK_K] then pack pairs + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + tl.store(out_bp, out, boundary_check=(0, 1)) + + # Store scales: [FLAT_M, 1] → [BLOCK_M, GROUPS_PER_BLOCK] + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=m_mask[:, None] & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) + + +def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: + group_size: int = 32 + eps_exp: int = -30 + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) + device_index = tensor.device.index + + scale_activations_mxfp4_triton_kernel_v2[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + + return out, scales + + +#################################################################################################################### +# NVFP4 v2: persistent 1D grid, processes multiple K-groups per iteration +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_nvfp4_triton_kernel_v2( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + meta_scales: tl.constexpr = NVFP4_META_SCALE, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] + + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K // 2), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + # Reshape to [FLAT_M, GROUP_SIZE] for group-wise reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # Per-group FP8 scale + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + # Map to FP4 index via threshold comparison + wq = tensor_flat / scales_full + idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + # Reshape to [BLOCK_M, BLOCK_K] then pack pairs + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + tl.store(out_bp, out, boundary_check=(0, 1)) + + # Store scales: [FLAT_M, 1] → [BLOCK_M, GROUPS_PER_BLOCK] + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=m_mask[:, None] & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) + + +def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) + device_index = tensor.device.index + + scale_activations_nvfp4_triton_kernel_v2[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + ) + + return out, scales + + +#################################################################################################################### +# MXFP4 v3: 2D grid like v1, but scalar threshold loop to avoid 3D tensor +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 256}, num_warps=8, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp4_triton_kernel_v3( + tensor_ptr, + out_ptr, + scales_ptr, + thr_pos_ptr, + M, K, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, + ######################### + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + use_tma: tl.constexpr = False, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + # Load 8 thresholds as individual scalars + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + + #Load + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + #next power of 2 via log + scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor), axis=1, keep_dims=True) / 6., eps_exp) + + #Map to index via scalar threshold comparisons (avoids 3D intermediate) + wq = tensor / scales + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + #Pack + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_GROUP_SIZE, 2), can_reorder=False)) + out = lo | (hi << 4) + + #Store + offs_k = pid_k * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) + out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) + + offs_k = pid_k * 1 + tl.arange(0, 1) + tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales_log2) + +def scale_activations_mxfp4_triton_v3(tensor: Tensor) -> Tuple[Tensor, Tensor]: + group_size: int = 32 + eps_exp: int = -30 + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) + device_index = tensor.device.index + + scale_activations_mxfp4_triton_kernel_v3[grid]( + tensor, + out, + scales, + thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + + return out, scales + + +#################################################################################################################### +# NVFP4 v3: 2D grid like v1, but scalar threshold loop to avoid 3D tensor +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 256}, num_warps=8, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_nvfp4_triton_kernel_v3( + tensor_ptr, + out_ptr, + scales_ptr, + thr_pos_ptr, + M, K, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, + ######################### + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + meta_scales: tl.constexpr = NVFP4_META_SCALE, + use_tma: tl.constexpr = False, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + # Load 8 thresholds as individual scalars + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + + #Load + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + #FP8 scales + scales = tl.max(tl.abs(tensor), axis=1, keep_dims=True) / (6. * meta_scales) + scales = tl.minimum(scales, max_fp8).to(fp8_dtype) + + #Map to index via scalar threshold comparisons (avoids 3D intermediate) + scales_full = tl.maximum(scales.to(tl.float32) * meta_scales, eps) + wq = tensor / scales_full + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + #Pack + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_GROUP_SIZE, 2), can_reorder=False)) + out = lo | (hi << 4) + + #Store + offs_k = pid_k * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) + out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) + + offs_k = pid_k + tl.arange(0, 1) + tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) + + +def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) + device_index = tensor.device.index + + scale_activations_nvfp4_triton_kernel_v3[grid]( + tensor, + out, + scales, + thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + eps=eps, + GROUP_SIZE=group_size, + ) + + return out, scales + + #################################################################################################################### scale_activations_per_token = scale_activations_per_token_triton_v3 scale_activations_mxfp8 = scale_activations_mxfp8_triton_v4 -scale_activations_mxfp4 = scale_activations_mxfp4_triton -scale_activations_nvfp4 = scale_activations_nvfp4_triton +scale_activations_mxfp4 = scale_activations_mxfp4_triton_v3 +scale_activations_nvfp4 = scale_activations_nvfp4_triton_v3 From f686e4e9582b44de5491da9e0edf3a0e227579df Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 7 Mar 2026 08:18:03 -0800 Subject: [PATCH 30/63] add flashinfer nvfp4 benchmark --- examples/eval_flops.py | 349 +++++++++++++++++++++++++++++++++-------- 1 file changed, 281 insertions(+), 68 deletions(-) diff --git a/examples/eval_flops.py b/examples/eval_flops.py index acf49df..a62e71a 100644 --- a/examples/eval_flops.py +++ b/examples/eval_flops.py @@ -5,6 +5,8 @@ import argparse import torch._dynamo torch._dynamo.config.recompile_limit = 256 +import torch._inductor.config as _inductor_config +import triton device, dtype = 'cuda:0', torch.bfloat16 repeat = 32 @@ -146,31 +148,300 @@ def patch_model_native_fp8(model, fp8_dtype=torch.float8_e4m3fn, use_fast_accum= for i, layer in enumerate(model): if isinstance(layer, torch.nn.Linear): model[i] = NativePyTorchFP8Dynamic( - layer, fp8_dtype=fp8_dtype, use_fast_accum=use_fast_accum + layer, fp8_dtype=fp8_dtype, use_fast_accum=use_fast_accum, ) +########################################################################################################################### +# flashinfer NVFP4 reference (CUTLASS-based, supports sm_120) +########################################################################################################################### +def _get_flashinfer(): + """Check if flashinfer with NVFP4 support is available.""" + try: + from flashinfer import nvfp4_quantize, mm_fp4, SfLayout + return True, None + except ImportError: + return False, "flashinfer not installed (pip install flashinfer)" + + +# ---- custom_op wrappers for torch.compile compatibility ---- +@torch.library.custom_op("flashinfer_bench::nvfp4_quantize", mutates_args=()) +def _nvfp4_quantize_op( + a: torch.Tensor, a_global_sf: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + from flashinfer import nvfp4_quantize, SfLayout + a_fp4, a_sf = nvfp4_quantize(a, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + return a_fp4, a_sf + + +@torch.library.register_fake("flashinfer_bench::nvfp4_quantize") +def _nvfp4_quantize_fake( + a: torch.Tensor, a_global_sf: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + M, K = a.shape + a_fp4 = torch.empty((M, K // 2), dtype=torch.uint8, device=a.device) + a_sf = torch.empty((M, K // 16), dtype=torch.uint8, device=a.device) + return a_fp4, a_sf + + +@torch.library.custom_op("flashinfer_bench::mm_fp4", mutates_args=()) +def _mm_fp4_op( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: torch.Tensor, + out_N: int, +) -> torch.Tensor: + from flashinfer import mm_fp4 + return mm_fp4(a, b, a_descale, b_descale, alpha, torch.bfloat16, backend="cutlass") + + +@torch.library.register_fake("flashinfer_bench::mm_fp4") +def _mm_fp4_fake( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: torch.Tensor, + out_N: int, +) -> torch.Tensor: + M = a.shape[0] + return torch.empty((M, out_N), dtype=torch.bfloat16, device=a.device) + + +class FlashinferNVFP4Dynamic(torch.nn.Module): + """ + NVFP4 dynamic quantization using flashinfer CUTLASS backend. + Weights quantized offline in __init__; activations quantized on-the-fly in forward. + Compatible with torch.compile via custom_op wrappers. + """ + + def __init__(self, linear_layer: torch.nn.Linear): + super().__init__() + from flashinfer import nvfp4_quantize, SfLayout + + w_bf16 = linear_layer.weight.data # [N, K] + N, K = w_bf16.shape + + # Quantize weights offline + w_global_sf = (448.0 * 6.0) / w_bf16.float().abs().nan_to_num().amax().clamp(min=1e-12) + w_fp4, w_sf = nvfp4_quantize( + w_bf16, w_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False + ) + + # Store pre-transposed for mm_fp4: b=[K//2, N], b_descale=[K//16, N] + self.register_buffer("w_fp4_t", w_fp4.T.contiguous()) + self.register_buffer("w_sf_t", w_sf.T.contiguous()) + self.register_buffer( + "w_global_sf_inv", + (1.0 / w_global_sf).to(torch.float32).contiguous(), + ) + self.N = N + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Activation quantization: compute global scale with pytorch ops + x_global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) + + # Quantize activation via custom_op (flashinfer CUDA kernel) + x_fp4, x_sf = torch.ops.flashinfer_bench.nvfp4_quantize(x, x_global_sf) + + # alpha = 1 / (x_global_sf * w_global_sf) + alpha = self.w_global_sf_inv / x_global_sf + + # CUTLASS FP4 matmul via custom_op + return torch.ops.flashinfer_bench.mm_fp4( + x_fp4, self.w_fp4_t, x_sf, self.w_sf_t, alpha, self.N + ) + + +def patch_model_flashinfer_nvfp4(model): + for i, layer in enumerate(model): + if isinstance(layer, torch.nn.Linear): + model[i] = FlashinferNVFP4Dynamic(layer) + + +def bench_flashinfer_nvfp4(M, N, K): + """ + Benchmark flashinfer NVFP4 matmul (CUTLASS backend) - raw single matmul, no activation quant. + """ + from flashinfer import nvfp4_quantize, mm_fp4, SfLayout + + a_bf16 = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b_bf16 = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) + + a_global_sf = (448.0 * 6.0) / a_bf16.float().abs().nan_to_num().max() + b_global_sf = (448.0 * 6.0) / b_bf16.float().abs().nan_to_num().max() + + a_fp4, a_sf = nvfp4_quantize(a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + b_fp4, b_sf = nvfp4_quantize(b_bf16, b_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=True) + + alpha = 1.0 / (a_global_sf * b_global_sf) + + ms = triton.testing.do_bench( + lambda: mm_fp4(a_fp4, b_fp4.T, a_sf, b_sf.T, alpha, torch.bfloat16), + warmup=500, rep=500, + ) + return ms + + +########################################################################################################################### +def run_benchmark(proc_name, M, K, N): + """ + Unified benchmark runner. Returns (label, M, K, N, tflops) or None on skip. + Handles gemlite processors, native PyTorch INT8/FP8, and flashinfer NVFP4. + """ + has_flashinfer, fi_err = _get_flashinfer() + + # ---- flashinfer NVFP4 raw (single matmul, no activation quant, triton.do_bench) ---- + if proc_name == "flashinfer_nvfp4_raw": + if not has_flashinfer: + print(f" Skipping {proc_name}: {fi_err}") + return None + M_a = ((M + 127) // 128) * 128 + N_a = ((N + 127) // 128) * 128 + K_a = ((K + 127) // 128) * 128 + try: + ms = bench_flashinfer_nvfp4(M_a, N_a, K_a) + tflops = get_flops(M_a, K_a, N_a, ms) + label = "flashinfer NVFP4 (raw)" + print(f" {label} | {M_a}, {K_a}, {N_a} | {tflops:.2f} TFLOP/s ({ms:.3f} ms)") + return (label, M_a, K_a, N_a, tflops) + except Exception as e: + print(f" flashinfer NVFP4 raw failed: {e}") + return None + + # ---- flashinfer NVFP4 dynamic (torch.compile + activation quant) ---- + if proc_name == "flashinfer_nvfp4_dynamic": + if not has_flashinfer: + print(f" Skipping {proc_name}: {fi_err}") + return None + # Disable cudagraph trees: flashinfer CUTLASS does internal workspace allocs + old_cudagraph = _inductor_config.triton.cudagraph_trees + _inductor_config.triton.cudagraph_trees = False + + model = get_model(K, N, repeat=repeat) + patch_model_flashinfer_nvfp4(model) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + label = "flashinfer NVFP4 (dynamic)" + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + _inductor_config.triton.cudagraph_trees = old_cudagraph + return (label, M, K, N, tflops) + + # ---- Native PyTorch INT8 dynamic ---- + if proc_name == "native_int8": + if M <= 16: + print(f" Skipping native_int8 for M={M} (requires M > 16)") + return None + model = get_model(K, N, repeat=repeat) + patch_model_native_int8(model) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + label = "PyTorch Native INT8" + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + return (label, M, K, N, tflops) + + # ---- Native PyTorch FP8 dynamic ---- + if proc_name == "native_fp8": + model = get_model(K, N, repeat=repeat) + patch_model_native_fp8(model, fp8_dtype=torch.float8_e4m3fn, use_fast_accum=False) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + label = "PyTorch Native FP8" + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + return (label, M, K, N, tflops) + + # ---- GemLite processors + BF16 baseline ---- + GEMLITE_MAP = { + "A16W8_INT8": lambda: A16W8_INT8(), + "A16W8_FP8": lambda: A16W8_FP8(), + "A16W4_HQQ_INT": lambda: A16W4_HQQ_INT(), + "A8W8_INT8_dynamic": lambda: A8W8_INT8_dynamic(), + "A8W8_FP8_dynamic": lambda: A8W8_FP8_dynamic(), + "A8W8_MXFP_dynamic_post_scale": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=True), + "A8W8_MXFP_dynamic": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=False), + "A4W4_MXFP_dynamic": lambda: A4W4_MXFP_dynamic(dtype=dtype), + "A4W4_NVFP_dynamic": lambda: A4W4_NVFP_dynamic(dtype=dtype), + "none": lambda: None, + "fp16": lambda: None, + } + + if proc_name not in GEMLITE_MAP: + print(f" Unknown processor: {proc_name}, skipping.") + return None + + procesor = GEMLITE_MAP[proc_name]() + + model = get_model(K, N, repeat=repeat) + if procesor is not None: + patch_model(model, device=device, processor=procesor) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + label = proc_name if procesor is not None else "BF16 (no processor)" + tflops = get_flops(M, K, N, perf_time_ms) + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + return (label, M, K, N, tflops) + + +ALL_PROCESSORS = [ + "none", + "A16W8_INT8", + "A16W8_FP8", + "A16W4_HQQ_INT", + "A8W8_INT8_dynamic", + "A8W8_FP8_dynamic", + "A8W8_MXFP_dynamic_post_scale", + "A8W8_MXFP_dynamic", + "A4W4_MXFP_dynamic", + "A4W4_NVFP_dynamic", + "native_int8", + "native_fp8", + "flashinfer_nvfp4_dynamic", + "flashinfer_nvfp4_raw", +] + + def main(): parser = argparse.ArgumentParser( description="Evaluate TFLOP/s for various quantized matmul processors.", epilog=""" Examples: - # Run with default parameters + # Run with default parameters (all processors) python eval_flops.py # Run with specific dimensions: python eval_flops.py --M 128 --K 4096 --N 4096 # Run only specific processors (comma-separated): - python eval_flops.py --processor A16W8_INT8,A8W8_FP8_dynamic + python eval_flops.py --processor A4W4_MXFP_dynamic,flashinfer_nvfp4_dynamic,native_fp8 # Run only BF16 baseline (no quantization): python eval_flops.py --processor none # Available processors: - # A16W8_INT8, A16W8_FP8, A8W8_INT8_dynamic, A8W8_FP8_dynamic, - # A8W8_MXFP_dynamic_post_scale, A8W8_MXFP_dynamic_no_post_scale, - # A4W4_MXFP_dynamic, A4W4_NVFP_dynamic, none (BF16 baseline) + # GemLite: A16W8_INT8, A16W8_FP8, A16W4_HQQ_INT, + # A8W8_INT8_dynamic, A8W8_FP8_dynamic, + # A8W8_MXFP_dynamic_post_scale, A8W8_MXFP_dynamic, + # A4W4_MXFP_dynamic, A4W4_NVFP_dynamic + # PyTorch: native_int8, native_fp8 + # flashinfer: flashinfer_nvfp4_dynamic, flashinfer_nvfp4_raw + # Baseline: none / fp16 (BF16, no quantization) # Use "all" to run every processor. """, formatter_class=argparse.RawDescriptionHelpFormatter, @@ -184,74 +455,16 @@ def main(): M, K, N = args.M, args.K, args.N - PROCESSOR_MAP = { - "A16W8_INT8": lambda: A16W8_INT8(), - "A16W8_FP8": lambda: A16W8_FP8(), - "A16W4_HQQ_INT": lambda: A16W4_HQQ_INT(), - "A8W8_INT8_dynamic": lambda: A8W8_INT8_dynamic(), - "A8W8_FP8_dynamic": lambda: A8W8_FP8_dynamic(), - "A8W8_MXFP_dynamic_post_scale": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=True), - "A8W8_MXFP_dynamic": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=False), - "A4W4_MXFP_dynamic": lambda: A4W4_MXFP_dynamic(dtype=dtype), - "A4W4_NVFP_dynamic": lambda: A4W4_NVFP_dynamic(dtype=dtype), - "none": lambda: None, - "fp16": lambda: None, - } - if args.processor == "all": - processor_names = list(PROCESSOR_MAP.keys()) + processor_names = list(ALL_PROCESSORS) else: processor_names = [p.strip() for p in args.processor.split(",")] results = [] - - # ---- GemLite processors ---- for proc_name in processor_names: - if proc_name not in PROCESSOR_MAP: - print(f"Unknown processor: {proc_name}, skipping.") - continue - - procesor = PROCESSOR_MAP[proc_name]() - - model = get_model(K, N, repeat=repeat) - if procesor is not None: - patch_model(model, device=device, processor=procesor) - model = torch.compile(model, mode="reduce-overhead", fullgraph=True) - - perf_time_ms = eval_model(model, M, K) / repeat - label = proc_name if procesor is not None else "FP16 (no processor)" - tflops = get_flops(M, K, N, perf_time_ms) - print(f"Processor: {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") - results.append((label, M, K, N, tflops)) - - cleanup(model) - - # ---- PyTorch Native INT8 dynamic reference ---- - if M > 16: - model = get_model(K, N, repeat=repeat) - patch_model_native_int8(model) - model = torch.compile(model, mode="reduce-overhead", fullgraph=True) - - perf_time_ms = eval_model(model, M, K) / repeat - tflops = get_flops(M, K, N, perf_time_ms) - print(f"PyTorch Native INT8 | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") - results.append(("PyTorch Native INT8", M, K, N, tflops)) - - cleanup(model) - else: - print(f"Skipping PyTorch Native INT8 for M={M} (requires M >= 16).") - - # ---- PyTorch Native FP8 dynamic reference ---- - model = get_model(K, N, repeat=repeat) - patch_model_native_fp8(model, fp8_dtype=torch.float8_e4m3fn, use_fast_accum=False) - model = torch.compile(model, mode="reduce-overhead", fullgraph=True) - - perf_time_ms = eval_model(model, M, K) / repeat - tflops = get_flops(M, K, N, perf_time_ms) - print(f"PyTorch Native FP8 | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") - results.append(("PyTorch Native FP8", M, K, N, tflops)) - - cleanup(model) + result = run_benchmark(proc_name, M, K, N) + if result is not None: + results.append(result) # ---- Summary ---- print("\n" + "=" * 70) From dd3c4f75ac2b8ae8234a475fd39136fc252c2b4c Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 7 Mar 2026 08:53:44 -0800 Subject: [PATCH 31/63] update mxfp/nvfp activation quant kernels --- examples/bench_act_quant.py | 82 +++++++ examples/bench_act_quant_final.py | 62 ++++++ examples/bench_act_quant_v4.py | 343 +++++++++++++++++++++++++++++ examples/bench_act_quant_v5.py | 353 ++++++++++++++++++++++++++++++ gemlite/quant_utils.py | 289 +++++++++++++++++++++++- 5 files changed, 1127 insertions(+), 2 deletions(-) create mode 100644 examples/bench_act_quant.py create mode 100644 examples/bench_act_quant_final.py create mode 100644 examples/bench_act_quant_v4.py create mode 100644 examples/bench_act_quant_v5.py diff --git a/examples/bench_act_quant.py b/examples/bench_act_quant.py new file mode 100644 index 0000000..084c3a6 --- /dev/null +++ b/examples/bench_act_quant.py @@ -0,0 +1,82 @@ +""" +Benchmark activation quantization kernels: + - gemlite MXFP4 (v1, v2, v3) + - gemlite NVFP4 (v1, v2, v3) + - flashinfer nvfp4_quantize +""" +import torch +import triton + +torch.manual_seed(0) +device = "cuda:0" +dtype = torch.bfloat16 + +# ---- gemlite quant kernels ---- +from gemlite.quant_utils import ( + scale_activations_mxfp4_triton as mxfp4_v1, + scale_activations_mxfp4_triton_v2 as mxfp4_v2, + scale_activations_mxfp4_triton_v3 as mxfp4_v3, + scale_activations_nvfp4_triton as nvfp4_v1, + scale_activations_nvfp4_triton_v2 as nvfp4_v2, + scale_activations_nvfp4_triton_v3 as nvfp4_v3, +) + +# ---- flashinfer ---- +from flashinfer import nvfp4_quantize, SfLayout + +def flashinfer_nvfp4_quant(x): + global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) + return nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + +def flashinfer_nvfp4_quant_no_scale(x): + """Just the quantize kernel, pre-computed global scale.""" + return nvfp4_quantize(x, x._global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + +# ---- benchmark ---- +KERNELS = { + "gemlite mxfp4 v1": mxfp4_v1, + "gemlite mxfp4 v2": mxfp4_v2, + "gemlite mxfp4 v3": mxfp4_v3, + "gemlite nvfp4 v1": nvfp4_v1, + "gemlite nvfp4 v2": nvfp4_v2, + "gemlite nvfp4 v3": nvfp4_v3, + "flashinfer nvfp4 (with global_sf)": flashinfer_nvfp4_quant, + "flashinfer nvfp4 (kernel only)": None, # special case +} + +shapes = [ + (1024, 4096), + (1024, 16384), + (4096, 4096), + (4096, 16384), + (8192, 4096), + (8192, 16384), + (16384, 16384), +] + +print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") +print("=" * 76) + +for M, K in shapes: + x = torch.randn(M, K, device=device, dtype=dtype) + # Pre-compute for flashinfer kernel-only variant + global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) + x._global_sf = global_sf + + bytes_read = M * K * x.element_size() # input bytes + + for name, fn in KERNELS.items(): + if name == "flashinfer nvfp4 (kernel only)": + fn_bench = lambda: nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + else: + fn_bench = lambda fn=fn: fn(x) + + try: + ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) + us = ms * 1000 + gbps = bytes_read / (ms * 1e-3) / 1e9 + print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") + except Exception as e: + print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:30]}") + + print() diff --git a/examples/bench_act_quant_final.py b/examples/bench_act_quant_final.py new file mode 100644 index 0000000..8cf6e57 --- /dev/null +++ b/examples/bench_act_quant_final.py @@ -0,0 +1,62 @@ +""" +Benchmark activation quantization kernels from quant_utils.py (v5 integrated) +vs flashinfer nvfp4_quantize. +""" +import torch +import triton + +torch.manual_seed(0) +device = "cuda:0" +dtype = torch.bfloat16 + +# Import directly from quant_utils (now v5 by default) +from gemlite.quant_utils import ( + scale_activations_mxfp4, # v5 + scale_activations_nvfp4, # v5 + scale_activations_mxfp4_triton_v3 as mxfp4_v3, + scale_activations_nvfp4_triton_v3 as nvfp4_v3, +) + +from flashinfer import nvfp4_quantize, SfLayout + +shapes = [ + (1024, 4096), + (1024, 16384), + (4096, 4096), + (4096, 16384), + (8192, 4096), + (8192, 16384), + (16384, 16384), +] + +KERNELS = { + "gemlite mxfp4 (default=v5)": scale_activations_mxfp4, + "gemlite mxfp4 v3 (old)": mxfp4_v3, + "gemlite nvfp4 (default=v5)": scale_activations_nvfp4, + "gemlite nvfp4 v3 (old)": nvfp4_v3, + "flashinfer nvfp4 (kernel)": None, +} + +print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") +print("=" * 76) + +for M, K in shapes: + x = torch.randn(M, K, device=device, dtype=dtype) + global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) + bytes_read = M * K * x.element_size() + + for name, fn in KERNELS.items(): + if fn is None: + fn_bench = lambda: nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + else: + fn_bench = lambda fn=fn: fn(x) + + try: + ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) + us = ms * 1000 + gbps = bytes_read / (ms * 1e-3) / 1e9 + print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") + except Exception as e: + print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:40]}") + + print() diff --git a/examples/bench_act_quant_v4.py b/examples/bench_act_quant_v4.py new file mode 100644 index 0000000..f7c0679 --- /dev/null +++ b/examples/bench_act_quant_v4.py @@ -0,0 +1,343 @@ +""" +Test a v4 NVFP4 activation quant kernel: + - persistent 1D grid with K-loop (like v2) for better SM utilization + - scalar threshold comparisons (like v3) to avoid 3D intermediate + - block_ptr for coalesced loads with multi-stage pipelining +""" +import torch +import triton +import triton.language as tl + +torch.manual_seed(0) +device = "cuda:0" +dtype = torch.bfloat16 + +# Import gemlite quant utils for comparison + thr_pos +from gemlite.quant_utils import ( + scale_activations_nvfp4_triton_v3 as nvfp4_v3, + scale_activations_mxfp4_triton_v3 as mxfp4_v3, + thr_pos, + NVFP4_META_SCALE, + get_num_SMs, +) + +NUM_SMS = get_num_SMs(0) + +# flashinfer for comparison +from flashinfer import nvfp4_quantize, SfLayout + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=3), + ], + key=['M', 'K'], +) +@triton.jit +def scale_activations_nvfp4_kernel_v4( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + meta_scales: tl.constexpr = NVFP4_META_SCALE, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + # Load thresholds as scalars (like v3) + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K // 2), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + # Reshape to [FLAT_M, GROUP_SIZE] for group-wise reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # Per-group FP8 scale + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + # Scalar threshold comparisons (v3 approach, no 3D intermediate) + wq = tensor_flat / scales_full + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + # Reshape back and pack + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + tl.store(out_bp, out, boundary_check=(0, 1)) + + # Store scales + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=m_mask[:, None] & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) + + +def scale_activations_nvfp4_v4(tensor: torch.Tensor): + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) + device_index = tensor.device.index + + scale_activations_nvfp4_kernel_v4[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + ) + return out, scales + + +# Also write MXFP4 v4 with same approach +from gemlite.quant_utils import next_power_of_2_triton + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=3), + ], + key=['M', 'K'], +) +@triton.jit +def scale_activations_mxfp4_kernel_v4( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K // 2), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # MXFP4 scales: next power of 2 + scales, scales_log2 = next_power_of_2_triton( + tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp + ) + + wq = tensor_flat / scales + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + tl.store(out_bp, out, boundary_check=(0, 1)) + + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=m_mask[:, None] & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) + + +def scale_activations_mxfp4_v4(tensor: torch.Tensor): + group_size: int = 32 + eps_exp: int = -30 + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) + device_index = tensor.device.index + + scale_activations_mxfp4_kernel_v4[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + return out, scales + + +# ---- Benchmark ---- +def flashinfer_nvfp4_kernel_only(x, global_sf): + return nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + + +shapes = [ + (1024, 4096), + (1024, 16384), + (4096, 4096), + (4096, 16384), + (8192, 4096), + (8192, 16384), + (16384, 16384), +] + +KERNELS = { + "gemlite mxfp4 v3": mxfp4_v3, + "gemlite mxfp4 v4": scale_activations_mxfp4_v4, + "gemlite nvfp4 v3": nvfp4_v3, + "gemlite nvfp4 v4": scale_activations_nvfp4_v4, + "flashinfer nvfp4 (kernel only)": None, +} + +print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") +print("=" * 76) + +for M, K in shapes: + x = torch.randn(M, K, device=device, dtype=dtype) + global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) + bytes_read = M * K * x.element_size() + + for name, fn in KERNELS.items(): + if name == "flashinfer nvfp4 (kernel only)": + fn_bench = lambda: flashinfer_nvfp4_kernel_only(x, global_sf) + else: + fn_bench = lambda fn=fn: fn(x) + + try: + ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) + us = ms * 1000 + gbps = bytes_read / (ms * 1e-3) / 1e9 + print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") + except Exception as e: + print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:40]}") + + print() diff --git a/examples/bench_act_quant_v5.py b/examples/bench_act_quant_v5.py new file mode 100644 index 0000000..7884f65 --- /dev/null +++ b/examples/bench_act_quant_v5.py @@ -0,0 +1,353 @@ +""" +v5 NVFP4 activation quant: 2D grid like v3, but with BLOCK_SIZE_K processing +multiple groups per block. Keeps the simplicity of v3 while reducing block count. +""" +import torch +import triton +import triton.language as tl +from typing import Tuple + +torch.manual_seed(0) +device = "cuda:0" +dtype = torch.bfloat16 + +from gemlite.quant_utils import ( + scale_activations_nvfp4_triton_v3 as nvfp4_v3, + scale_activations_mxfp4_triton_v3 as mxfp4_v3, + thr_pos, + NVFP4_META_SCALE, + next_power_of_2_triton, +) +from flashinfer import nvfp4_quantize, SfLayout + + +def prune_large_blocks(configs, nargs, **kwargs): + M = nargs['M'] + K = nargs['K'] + for config in configs: + bm = config.kwargs['BLOCK_SIZE_M'] + bk = config.kwargs['BLOCK_SIZE_K'] + if bm > M or bk > K: + continue + yield config + + +# ---- NVFP4 v5: 2D grid, multi-group per block ---- +@triton.autotune( + configs=[ + # BLOCK_SIZE_K must be multiple of GROUP_SIZE=16 + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + # Multi-stage + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_nvfp4_kernel_v5( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + meta_scales: tl.constexpr = NVFP4_META_SCALE, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + # Load BLOCK_SIZE_K elements (multiple groups) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Reshape to [FLAT_M, GROUP_SIZE] for per-group reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # FP8 scales per group + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + # Scalar threshold comparisons + wq = tensor_flat / scales_full + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + # Reshape back to [BLOCK_SIZE_M, BLOCK_SIZE_K] and pack pairs + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + # Store packed output + offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + # Store scales [BLOCK_SIZE_M, GROUPS_PER_BLOCK] + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + base_group = pid_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] + ) + + +def scale_activations_nvfp4_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + device_index = tensor.device.index + + scale_activations_nvfp4_kernel_v5[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + ) + return out, scales + + +# Same for MXFP4 +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + # Multi-stage + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64}, num_warps=8, num_stages=2), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp4_kernel_v5( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + scales, scales_log2 = next_power_of_2_triton( + tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp + ) + + wq = tensor_flat / scales + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + base_group = pid_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] + ) + + +def scale_activations_mxfp4_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 32 + eps_exp: int = -30 + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + device_index = tensor.device.index + + scale_activations_mxfp4_kernel_v5[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + return out, scales + + +# ---- Benchmark ---- +def flashinfer_nvfp4_kernel_only(x, global_sf): + return nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + + +shapes = [ + (1024, 4096), + (1024, 16384), + (4096, 4096), + (4096, 16384), + (8192, 4096), + (8192, 16384), + (16384, 16384), +] + +KERNELS = { + "gemlite mxfp4 v3": mxfp4_v3, + "gemlite mxfp4 v5": scale_activations_mxfp4_v5, + "gemlite nvfp4 v3": nvfp4_v3, + "gemlite nvfp4 v5": scale_activations_nvfp4_v5, + "flashinfer nvfp4 (kernel only)": None, +} + +print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") +print("=" * 76) + +for M, K in shapes: + x = torch.randn(M, K, device=device, dtype=dtype) + global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) + bytes_read = M * K * x.element_size() + + for name, fn in KERNELS.items(): + if name == "flashinfer nvfp4 (kernel only)": + fn_bench = lambda: flashinfer_nvfp4_kernel_only(x, global_sf) + else: + fn_bench = lambda fn=fn: fn(x) + + try: + ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) + us = ms * 1000 + gbps = bytes_read / (ms * 1e-3) / 1e9 + print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") + except Exception as e: + print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:40]}") + + print() diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 2f0c97a..2eacde9 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1969,8 +1969,293 @@ def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tenso return out, scales + +#################################################################################################################### +# MXFP4 v5: 2D grid with multi-group BLOCK_SIZE_K (fewer blocks, better bandwidth) +#################################################################################################################### +def prune_large_blocks_2d(configs, named_args, **kwargs): + M = named_args['M'] + K = named_args['K'] + + pruned = [] + for config in configs: + bm = config.kwargs['BLOCK_SIZE_M'] + bk = config.kwargs['BLOCK_SIZE_K'] + if bm <= M and bk <= K: + pruned.append(config) + + if not pruned: + pruned.append(configs[0]) + + return pruned + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64}, num_warps=8, num_stages=2), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks_2d}, +) +@triton.jit +def scale_activations_mxfp4_triton_kernel_v5( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + scales, scales_log2 = next_power_of_2_triton( + tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp + ) + + wq = tensor_flat / scales + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + base_group = pid_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] + ) + + +def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: + group_size: int = 32 + eps_exp: int = -30 + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + device_index = tensor.device.index + + scale_activations_mxfp4_triton_kernel_v5[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + return out, scales + + +#################################################################################################################### +# NVFP4 v5: 2D grid with multi-group BLOCK_SIZE_K (fewer blocks, better bandwidth) +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks_2d}, +) +@triton.jit +def scale_activations_nvfp4_triton_kernel_v5( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + meta_scales: tl.constexpr = NVFP4_META_SCALE, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + wq = tensor_flat / scales_full + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + base_group = pid_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] + ) + + +def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + device_index = tensor.device.index + + scale_activations_nvfp4_triton_kernel_v5[grid]( + tensor, out, scales, thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + ) + return out, scales + + + #################################################################################################################### scale_activations_per_token = scale_activations_per_token_triton_v3 scale_activations_mxfp8 = scale_activations_mxfp8_triton_v4 -scale_activations_mxfp4 = scale_activations_mxfp4_triton_v3 -scale_activations_nvfp4 = scale_activations_nvfp4_triton_v3 +scale_activations_mxfp4 = scale_activations_mxfp4_triton_v5 +scale_activations_nvfp4 = scale_activations_nvfp4_triton_v5 From a0bdcdc32a91bc44e19836c9b215105cf5ef5d73 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 7 Mar 2026 09:59:52 -0800 Subject: [PATCH 32/63] add 5d tma attempt --- examples/bench_5d_tma.py | 267 ++++++++++++++++++ gemlite/core.py | 42 ++- gemlite/triton_kernels/gemm_kernels.py | 157 +++++----- gemlite/triton_kernels/gemm_splitK_kernels.py | 2 +- .../gemm_splitK_persistent_kernels.py | 2 +- gemlite/triton_kernels/gemv_kernels.py | 2 +- .../triton_kernels/gemv_revsplitK_kernels.py | 2 +- gemlite/triton_kernels/gemv_splitK_kernels.py | 2 +- tests/test_gemlitelineartriton.py | 2 +- 9 files changed, 392 insertions(+), 86 deletions(-) create mode 100644 examples/bench_5d_tma.py diff --git a/examples/bench_5d_tma.py b/examples/bench_5d_tma.py new file mode 100644 index 0000000..8271b92 --- /dev/null +++ b/examples/bench_5d_tma.py @@ -0,0 +1,267 @@ +""" +Standalone benchmark: compare pointer-based vs 5D TMA scale loading in block-scaled GEMM. +Tests the NVFP4 case (group_size=16, e4m3 scales). +""" +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +device = "cuda:0" +dtype = torch.bfloat16 + +# Required for TMA tensor descriptors +from typing import Optional +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) +triton.set_allocator(alloc_fn) + + +def preshuffle_scales(scales_2d, N, K_S): + """Convert [N, K_S] scales to 5D preshuffled layout for TMA. + + Follows the Triton tutorial layout: [1, N//128, K_S//4, 2, 256] + Preserves dtype (fp8_e4m3fn for NVFP4, uint8 for MXFP4). + """ + return ( + scales_2d + .reshape(N // 128, 4, 32, K_S // 4, 4) + .permute(0, 3, 2, 1, 4) + .reshape(1, N // 128, K_S // 4, 2, 256) + .contiguous() + ) + + +# Kernel with pointer-based scale loading (current gemlite approach) +@triton.jit +def gemm_fp4_pointer_scales( + a_ptr, b_ptr, c_ptr, + scales_b_ptr, scales_a_ptr, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + group_size: tl.constexpr, + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bn: tl.constexpr, stride_bk: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + stride_sb_n: tl.constexpr, stride_sb_g: tl.constexpr, + stride_sa_m: tl.constexpr, stride_sa_g: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NUM_STAGES: tl.constexpr, + meta_scale_norm: tl.constexpr = 0.0025, +): + pid = tl.program_id(0) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + BLOCK_K_A: tl.constexpr = BLOCK_K // 2 # packed FP4 + BLOCK_K_B: tl.constexpr = BLOCK_K // 2 + BLOCK_K_S: tl.constexpr = BLOCK_K // group_size + + # TMA for data + a_desc = tl.make_tensor_descriptor(a_ptr, [M, K // 2], [stride_am, stride_ak], [BLOCK_M, BLOCK_K_A]) + b_desc = tl.make_tensor_descriptor(b_ptr, [N, K // 2], [stride_bn, stride_bk], [BLOCK_N, BLOCK_K_B]) + c_desc = tl.make_tensor_descriptor(c_ptr, [M, N], [stride_cm, stride_cn], [BLOCK_M, BLOCK_N]) + + # Pointer-based scales + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k_s = tl.arange(0, BLOCK_K_S) + scales_b_ptrs = scales_b_ptr + offs_n[:, None] * stride_sb_n + offs_k_s[None, :] * stride_sb_g + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + scales_a_ptrs = scales_a_ptr + offs_m[:, None] * stride_sa_m + offs_k_s[None, :] * stride_sa_g + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + num_k = tl.cdiv(K, BLOCK_K) + for k in tl.range(num_k, num_stages=NUM_STAGES): + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_M, k * BLOCK_K_A]) + b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_N, k * BLOCK_K_B]).T + + k_m = k * BLOCK_K_S + scales_b = tl.load(scales_b_ptrs + k_m * stride_sb_g) + scales_a = tl.load(scales_a_ptrs + k_m * stride_sa_g) + + acc = tl.dot_scaled(a, scales_a, "e2m1", b, scales_b, "e2m1", acc) + + if group_size == 16: + acc *= meta_scale_norm + + tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_M, pid_n * BLOCK_N], value=acc) + + +# Kernel with 5D TMA scale loading (tutorial approach) +@triton.jit +def gemm_fp4_5d_tma_scales( + a_ptr, b_ptr, c_ptr, + scales_b_5d_ptr, scales_a_5d_ptr, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + group_size: tl.constexpr, + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bn: tl.constexpr, stride_bk: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + sb_s0: tl.constexpr, sb_s1: tl.constexpr, sb_s2: tl.constexpr, sb_s3: tl.constexpr, sb_s4: tl.constexpr, + sa_s0: tl.constexpr, sa_s1: tl.constexpr, sa_s2: tl.constexpr, sa_s3: tl.constexpr, sa_s4: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NUM_STAGES: tl.constexpr, + meta_scale_norm: tl.constexpr = 0.0025, +): + pid = tl.program_id(0) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + BLOCK_K_A: tl.constexpr = BLOCK_K // 2 + BLOCK_K_B: tl.constexpr = BLOCK_K // 2 + BLOCK_K_S: tl.constexpr = BLOCK_K // group_size + VEC_SIZE: tl.constexpr = group_size + + rep_m: tl.constexpr = BLOCK_M // 128 + rep_n: tl.constexpr = BLOCK_N // 128 + rep_k: tl.constexpr = BLOCK_K // VEC_SIZE // 4 + + # TMA for data + a_desc = tl.make_tensor_descriptor(a_ptr, [M, K // 2], [stride_am, stride_ak], [BLOCK_M, BLOCK_K_A]) + b_desc = tl.make_tensor_descriptor(b_ptr, [N, K // 2], [stride_bn, stride_bk], [BLOCK_N, BLOCK_K_B]) + c_desc = tl.make_tensor_descriptor(c_ptr, [M, N], [stride_cm, stride_cn], [BLOCK_M, BLOCK_N]) + + # 5D TMA for scales + scales_b_shape1: tl.constexpr = N // 128 + scales_b_shape2: tl.constexpr = K // VEC_SIZE // 4 + scales_b_desc = tl.make_tensor_descriptor( + scales_b_5d_ptr, + [1, scales_b_shape1, scales_b_shape2, 2, 256], + [sb_s0, sb_s1, sb_s2, sb_s3, sb_s4], + [1, rep_n, rep_k, 2, 256], + ) + + scales_a_shape1: tl.constexpr = M // 128 + scales_a_shape2: tl.constexpr = K // VEC_SIZE // 4 + scales_a_desc = tl.make_tensor_descriptor( + scales_a_5d_ptr, + [1, scales_a_shape1, scales_a_shape2, 2, 256], + [sa_s0, sa_s1, sa_s2, sa_s3, sa_s4], + [1, rep_m, rep_k, 2, 256], + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + num_k = tl.cdiv(K, BLOCK_K) + for k in tl.range(num_k, num_stages=NUM_STAGES): + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_M, k * BLOCK_K_A]) + b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_N, k * BLOCK_K_B]).T + + # 5D TMA scale loads + scale_b_raw = tl.load_tensor_descriptor(scales_b_desc, [0, pid_n * rep_n, k * rep_k, 0, 0]) + scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K_S) + + scale_a_raw = tl.load_tensor_descriptor(scales_a_desc, [0, pid_m * rep_m, k * rep_k, 0, 0]) + scales_a = scale_a_raw.reshape(rep_m, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K_S) + + acc = tl.dot_scaled(a, scales_a, "e2m1", b, scales_b, "e2m1", acc) + + if group_size == 16: + acc *= meta_scale_norm + + tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_M, pid_n * BLOCK_N], value=acc) + + +def bench(M, N, K, group_size=16, BLOCK_M=128, BLOCK_N=128, BLOCK_K=128, NUM_STAGES=4, num_warps=4): + VEC_SIZE = group_size + K_S = K // group_size + + # Create random FP4 data (packed as uint8) + a = torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device) + b = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device) + + # 2D scales for pointer-based kernel + scales_b_2d = torch.randn(N, K_S, device=device).to(torch.float8_e4m3fn) # [N, K_S] + scales_a_2d = torch.randn(M, K_S, device=device).to(torch.float8_e4m3fn) # [M, K_S] + + # Transposed view (matching gemlite's current layout) + scales_b_T = scales_b_2d.T # [K_S, N] with strides (1, K_S) + scales_a_T = scales_a_2d.T # not used directly, pointer from original + + # 5D preshuffled scales (keep fp8_e4m3fn dtype for NVFP4) + scales_b_5d = preshuffle_scales(scales_b_2d, N, K_S) + scales_a_5d = preshuffle_scales(scales_a_2d, M, K_S) + + c_ptr = torch.empty((M, N), dtype=torch.bfloat16, device=device) + c_5d = torch.empty((M, N), dtype=torch.bfloat16, device=device) + + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + # Pointer-based kernel + def run_pointer(): + gemm_fp4_pointer_scales[grid]( + a, b, c_ptr, + scales_b_T, scales_a_2d, # scales_b is transposed, scales_a is row-major + M, N, K, group_size, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c_ptr.stride(0), c_ptr.stride(1), + scales_b_T.stride(0), scales_b_T.stride(1), # stride_sb_n=1, stride_sb_g=K_S + scales_a_2d.stride(0), scales_a_2d.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + NUM_STAGES=NUM_STAGES, + num_warps=num_warps, + ) + + # 5D TMA kernel + def run_5d_tma(): + gemm_fp4_5d_tma_scales[grid]( + a, b, c_5d, + scales_b_5d, scales_a_5d, + M, N, K, group_size, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c_5d.stride(0), c_5d.stride(1), + scales_b_5d.stride(0), scales_b_5d.stride(1), scales_b_5d.stride(2), scales_b_5d.stride(3), scales_b_5d.stride(4), + scales_a_5d.stride(0), scales_a_5d.stride(1), scales_a_5d.stride(2), scales_a_5d.stride(3), scales_a_5d.stride(4), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + NUM_STAGES=NUM_STAGES, + num_warps=num_warps, + ) + + ms_ptr = triton.testing.do_bench(run_pointer, warmup=200, rep=500) + ms_5d = triton.testing.do_bench(run_5d_tma, warmup=200, rep=500) + + flops = 2.0 * M * N * K + tflops_ptr = flops / (ms_ptr * 1e-3) / 1e12 + tflops_5d = flops / (ms_5d * 1e-3) / 1e12 + + print(f" Pointer scales: {ms_ptr:.3f} ms, {tflops_ptr:.1f} TFLOP/s") + print(f" 5D TMA scales: {ms_5d:.3f} ms, {tflops_5d:.1f} TFLOP/s") + print(f" Speedup: {ms_ptr / ms_5d:.3f}x") + return ms_ptr, ms_5d + + +if __name__ == "__main__": + M, N, K = 8192, 16384, 16384 + group_size = 16 # NVFP4 + + configs = [ + # (BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, num_warps) + # Best so far: 128x128x128, 3 stages, 4 warps = 1217.5 TFLOP/s + (128, 128, 128, 2, 4), + (128, 128, 128, 3, 4), + (128, 128, 128, 3, 8), + (128, 128, 128, 4, 4), + (128, 128, 128, 5, 4), + (128, 128, 256, 2, 4), + (128, 128, 256, 2, 8), + (128, 256, 128, 2, 4), + (128, 256, 128, 2, 8), + (128, 256, 128, 3, 4), + (128, 256, 256, 2, 4), + (128, 256, 256, 2, 8), + ] + + print(f"M={M}, N={N}, K={K}, group_size={group_size}") + for bm, bn, bk, ns, nw in configs: + rep_k = bk // group_size // 4 + if rep_k < 1: + print(f"\n Skipping BLOCK_M={bm}, BLOCK_N={bn}, BLOCK_K={bk} (rep_k < 1)") + continue + print(f"\n BLOCK_M={bm}, BLOCK_N={bn}, BLOCK_K={bk}, stages={ns}, warps={nw}") + try: + bench(M, N, K, group_size, bm, bn, bk, ns, nw) + except Exception as e: + print(f" FAILED: {e}") diff --git a/gemlite/core.py b/gemlite/core.py index 0a917d3..e8ea178 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -332,6 +332,24 @@ def load_state_dict(self, state_dict, strict=True, assign=False): self.compute_dtype = DTYPE_TO_TORCH[self.input_dtype.value] self.scaled_activations = bool(self.scaled_activations) self.data_contiguous = bool(self.data_contiguous) + + # Regenerate scales_5d if not in saved state (backward compat) + if getattr(self, "scales_5d", None) is None: + if is_mx_dtype(self.input_dtype) and self.scales is not None: + # scales is in 2D transposed layout [K//gs, N] as a Parameter + # We need the original [N, K//gs] for preshuffling + s = self.scales.data if isinstance(self.scales, torch.nn.Parameter) else self.scales + # s is transposed view: shape [K//gs, N], original data is [N, K//gs] + s_orig = s.T.contiguous() # [N, K//gs] + N_dim = s_orig.shape[0] + K_S = s_orig.shape[1] + if N_dim % 128 == 0 and K_S % 4 == 0: + self.scales_5d = s_orig.reshape(N_dim // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N_dim // 128, K_S // 4, 2, 256).contiguous() + else: + self.scales_5d = torch.tensor([[]], dtype=torch.int32, device=s.device) + else: + device = self.W_q.device if self.W_q is not None else "cuda" + self.scales_5d = torch.tensor([[]], dtype=torch.int32, device=device) #Make sure to feed UINT8 W_q for packing def pack( @@ -493,7 +511,6 @@ def pack( if(self.input_dtype in [DType.NVFP4]): self.scales = self.scales.to(torch.float8_e4m3fn) if(is_mx_dtype(self.input_dtype)): - self.scales = self.scales.T self.W_group_mode = 2 self.channel_scale_mode = 0 @@ -507,22 +524,33 @@ def pack( else: group_size = self.W_q.numel() // self.scales.numel() - #self.W_q = self.W_q.contiguous().T #Transposed for tma - - #self.scales = self.scales.contiguous().T # Transposed 2D TMA layout - # #self.scales = self.scales.reshape(1, N // 128, K // group_size // 4, 2, 256).contiguous() # 5D TMA layout for the scales: + # Preshuffle weight scales to 5D TMA layout for fast loading + # Original: [N, K//group_size] -> 5D: [1, N//128, K//group_size//4, 2, 256] + K_S = K // group_size + if N % 128 == 0 and K_S % 4 == 0: + self.scales_5d = self.scales.reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() + self.use_5d_scales = True + else: + self.scales_5d = None + self.use_5d_scales = False - # #print(self.scales.stride(), self.scales.shape) + # Keep 2D transposed layout for the kernel (pointer-based fallback) + self.scales = self.scales.T ################################ if(self.scales is not None): self.meta_dtype = TORCH_TO_DTYPE[self.scales.dtype] + # Default scales_5d for non-FP4 dtypes + if not hasattr(self, 'scales_5d') or self.scales_5d is None: + self.scales_5d = torch.tensor([[]], dtype=torch.int32, device=self.device) + #Register tensors as buffers self.W_q = torch.nn.Parameter(self.W_q, requires_grad=False) self.bias = torch.nn.Parameter(self.bias, requires_grad=False) if self.bias is not None else None self.scales = torch.nn.Parameter(self.scales,requires_grad=False) self.zeros = torch.nn.Parameter(self.zeros, requires_grad=False) + self.scales_5d = torch.nn.Parameter(self.scales_5d, requires_grad=False) #Register metadata self.metadata = torch.nn.Parameter( @@ -539,7 +567,7 @@ def pack( #Return the main arguments def get_tensor_args(self): - return [self.W_q, self.scales, self.zeros] + return [self.W_q, self.scales, self.zeros, self.scales_5d] def get_meta_args(self): return [int(self.scaled_activations), diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 6b25b4d..9445696 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -187,6 +187,12 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + + #MXFP/NVFP 5D TMA optimized (num_stages=3) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) return configs def get_default_config_nvidia(): @@ -629,7 +635,7 @@ def gemm_INT_kernel_persistent_tma( @triton.jit def gemm_MX_kernel( a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, scales_a_ptr, + scales_ptr, zeros_ptr, scales_a_ptr, scales_5d_ptr, #M, N, K, M_CLOSEST, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, M_CLOSEST: tl.constexpr, ######### Quant parms ######### @@ -740,13 +746,14 @@ def gemm_MX_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N] ) - # 2D TMA - transposed - # scales_a_desc = tl.make_tensor_descriptor( - # scales_a_ptr, - # [M, K // group_size], - # [stride_meta_a_m, stride_meta_a_g], - # [BLOCK_SIZE_M, BLOCK_SIZE_K_S], - # ) + # 2D TMA for activation scales (disabled: last dim < 16 bytes) + # if(channel_scale_mode == 4): + # scales_a_desc = tl.make_tensor_descriptor( + # scales_a_ptr, + # [M, K // group_size], + # [stride_meta_a_m, stride_meta_a_g], + # [BLOCK_SIZE_M, BLOCK_SIZE_K_S], + # ) scales_b_desc = tl.make_tensor_descriptor( scales_ptr, @@ -755,24 +762,25 @@ def gemm_MX_kernel( [BLOCK_SIZE_K_S, BLOCK_SIZE_N], ) - # # 2. 5D TMA Descriptors for Scales: #(8388608, 65536, 512, 256, 1) torch.Size([1, 128, 128, 2, 256]) - # rep_m: tl.constexpr = BLOCK_SIZE_M // 128 - # rep_n: tl.constexpr = BLOCK_SIZE_N // 128 - # rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 - # scales_b_shape1: tl.constexpr = N // 128 - # scales_b_shape2: tl.constexpr = K // group_size // 4 - # stride_b4: tl.constexpr = 1 - # stride_b3: tl.constexpr = 256 - # stride_b2: tl.constexpr = 512 - # stride_b1: tl.constexpr = 512 * scales_b_shape2 - # stride_b0: tl.constexpr = stride_b1 * scales_b_shape1 - # # REQUIRES BLOCK_SIZE_K / BLOCK_SIZE_N to be multiples of 128 - # scales_b_desc = tl.make_tensor_descriptor( - # scales_ptr, - # [1, scales_b_shape1, scales_b_shape2, 2, 256], - # [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], - # [1, rep_n, rep_k, 2, 256] - # ) + # 5D TMA Descriptors for Scales (preshuffled layout) + USE_5D_SCALES: tl.constexpr = use_tma and (N % 128 == 0) and (BLOCK_SIZE_K // group_size % 4 == 0) and (BLOCK_SIZE_M % 128 == 0) and (BLOCK_SIZE_N % 128 == 0) + if USE_5D_SCALES: + rep_m: tl.constexpr = BLOCK_SIZE_M // 128 + rep_n: tl.constexpr = BLOCK_SIZE_N // 128 + rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 + scales_b_shape1: tl.constexpr = N // 128 + scales_b_shape2: tl.constexpr = K // group_size // 4 + stride_b4: tl.constexpr = 1 + stride_b3: tl.constexpr = 256 + stride_b2: tl.constexpr = 512 + stride_b1: tl.constexpr = 512 * scales_b_shape2 + stride_b0: tl.constexpr = stride_b1 * scales_b_shape1 + scales_b_5d_desc = tl.make_tensor_descriptor( + scales_5d_ptr, + [1, scales_b_shape1, scales_b_shape2, 2, 256], + [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], + [1, rep_n, rep_k, 2, 256] + ) #B-scales if(channel_scale_mode == 4): @@ -797,24 +805,15 @@ def gemm_MX_kernel( b = tl.load(b_ptrs, eviction_policy=b_evict) #################################################################################### k_m = k * BLOCK_SIZE_K_S - # NO TMA - scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) - - # # 2D TMA - #scales_b = tl.load_tensor_descriptor(scales_b_desc, [k * BLOCK_SIZE_K_S, pid_n * BLOCK_SIZE_N]).T - - # 5D Scale Loads and Unpacking - # offs_scale_m = pid_m * rep_m - # offs_scale_n = pid_n * rep_n - # offs_scale_k = k * rep_k - - #scale_b = tl.load_tensor_descriptor(scales_b_desc, [0, offs_scale_n, offs_scale_k, 0, 0]) - #scales_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) - #https://github.com/triton-lang/triton/blob/main/python/tutorials/10-block-scaled-matmul.py#L220C1-L221C117 + if USE_5D_SCALES: + # 5D TMA scale loads (preshuffled layout) + scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, k * rep_k, 0, 0]) + scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) + else: + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) if(channel_scale_mode == 4): scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) - #scales_a = tl.load_tensor_descriptor(scales_a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_S]) else: scales_a = scales_a_1s #################################################################################### @@ -854,7 +853,7 @@ def gemm_MX_kernel( PRINTED = False -def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, @@ -877,43 +876,55 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x stride_meta_a_m, stride_meta_a_g = None, None if(is_mx_dtype(input_dtype)): - gemm_kernel = gemm_MX_kernel - load_scales_as_block = True + gemm_MX_kernel[grid]( + x, W_q, output, + scales, zeros, scales_x, scales_5d, + M, N, K, M_CLOSEST, + W_nbits, group_size, unpack_mask, elements_per_sample, + type_id, x.dtype.itemsize, W_q.dtype.itemsize, + x.stride(0), x.stride(1), + W_q.stride(0), W_q.stride(1), + output.stride(0), output.stride(1), + stride_meta_a_m, stride_meta_a_g, + scales.stride(0), scales.stride(1), + load_scales_as_block = True, + input_dtype = DTYPE_TO_TRITON[input_dtype], + output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], + acc_dtype = DTYPE_TO_TRITON[acc_dtype], + meta_dtype = DTYPE_TO_TRITON[meta_dtype], + channel_scale_mode = channel_scale_mode, + W_group_mode = W_group_mode, + zero_is_scalar = zeros.numel() == 1, + data_contiguous = data_contiguous, + ) else: - gemm_kernel = gemm_INT_kernel - load_scales_as_block = False - - gemm_kernel[grid]( - x, W_q, output, - scales, zeros, scales_x, - M, N, K, M_CLOSEST, - ############################################# - W_nbits, group_size, unpack_mask, elements_per_sample, - type_id, x.dtype.itemsize, W_q.dtype.itemsize, - ############################################### - x.stride(0), x.stride(1), - W_q.stride(0), W_q.stride(1), - output.stride(0), output.stride(1), - stride_meta_a_m, stride_meta_a_g, - scales.stride(0), scales.stride(1), - ################################################ - load_scales_as_block = load_scales_as_block, - input_dtype = DTYPE_TO_TRITON[input_dtype], - output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], - acc_dtype = DTYPE_TO_TRITON[acc_dtype], - meta_dtype = DTYPE_TO_TRITON[meta_dtype], - ################################################ - channel_scale_mode = channel_scale_mode, - W_group_mode = W_group_mode, - zero_is_scalar = zeros.numel() == 1, - data_contiguous = data_contiguous, - ) + gemm_INT_kernel[grid]( + x, W_q, output, + scales, zeros, scales_x, + M, N, K, M_CLOSEST, + W_nbits, group_size, unpack_mask, elements_per_sample, + type_id, x.dtype.itemsize, W_q.dtype.itemsize, + x.stride(0), x.stride(1), + W_q.stride(0), W_q.stride(1), + output.stride(0), output.stride(1), + stride_meta_a_m, stride_meta_a_g, + scales.stride(0), scales.stride(1), + load_scales_as_block = False, + input_dtype = DTYPE_TO_TRITON[input_dtype], + output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], + acc_dtype = DTYPE_TO_TRITON[acc_dtype], + meta_dtype = DTYPE_TO_TRITON[meta_dtype], + channel_scale_mode = channel_scale_mode, + W_group_mode = W_group_mode, + zero_is_scalar = zeros.numel() == 1, + data_contiguous = data_contiguous, + ) return output # # Persistent version # NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count -# def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +# def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, # W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, # input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, # channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 43846e4..41ca478 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -696,7 +696,7 @@ def gemm_splitK_MX_kernel( else: tl.store(c_ptrs, acc, mask=mask) -def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, diff --git a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py index 2e9fb54..ae47589 100755 --- a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py @@ -417,7 +417,7 @@ def gemm_splitK_persistent_kernel( tl.store(c_ptrs, acc, mask=mask) -def gemm_splitK_persistent_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +def gemm_splitK_persistent_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index 538b4aa..b1e36f3 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -575,7 +575,7 @@ def gemv_MX_kernel( tl.atomic_add(c_ptrs, acc, sem=atomic_mode) #TODO: gemv not generating correct reuslts with mxfp dtypes use except for A16W4. -def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index a667ff9..0aa7acc 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -399,7 +399,7 @@ def gemv_INT_revsplitK_kernel( KERNEL_CACHE = {} -def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index ff89b3d..650d0f4 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -459,7 +459,7 @@ def gemv_INT_splitK_kernel( tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) -def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, +def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index 0076e61..3c1a8b6 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -109,7 +109,7 @@ def test_serialization(self): ref_args = gemlite_linear.get_tensor_args() loaded_args = gemlite_linear_loaded.get_tensor_args() for i in range(len(ref_args)): - assert (ref_args[i] - loaded_args[i]).float().abs().mean() == 0, "tensor_args mismatch at " + str(i) + if ref_args[i].numel() > 0: assert (ref_args[i] - loaded_args[i]).float().abs().mean() == 0, "tensor_args mismatch at " + str(i) def ref_fn(x): return gemlite_linear.forward_manual(x, matmul_type='GEMM') From 597d99cd27569f2501bbad5a5b61d3314b08e2aa Mon Sep 17 00:00:00 2001 From: mobicham Date: Sun, 8 Mar 2026 10:03:15 -0700 Subject: [PATCH 33/63] remove 5d scales duplicate --- gemlite/core.py | 44 +++------ gemlite/triton_kernels/gemm_kernels.py | 81 ++++++++-------- gemlite/triton_kernels/gemm_splitK_kernels.py | 96 ++++++++++++------- .../gemm_splitK_persistent_kernels.py | 2 +- gemlite/triton_kernels/gemv_kernels.py | 2 +- .../triton_kernels/gemv_revsplitK_kernels.py | 2 +- gemlite/triton_kernels/gemv_splitK_kernels.py | 2 +- tests/test_mxfp.py | 2 +- 8 files changed, 123 insertions(+), 108 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index e8ea178..7deaf26 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -333,23 +333,16 @@ def load_state_dict(self, state_dict, strict=True, assign=False): self.scaled_activations = bool(self.scaled_activations) self.data_contiguous = bool(self.data_contiguous) - # Regenerate scales_5d if not in saved state (backward compat) - if getattr(self, "scales_5d", None) is None: - if is_mx_dtype(self.input_dtype) and self.scales is not None: - # scales is in 2D transposed layout [K//gs, N] as a Parameter - # We need the original [N, K//gs] for preshuffling - s = self.scales.data if isinstance(self.scales, torch.nn.Parameter) else self.scales - # s is transposed view: shape [K//gs, N], original data is [N, K//gs] - s_orig = s.T.contiguous() # [N, K//gs] - N_dim = s_orig.shape[0] - K_S = s_orig.shape[1] + # Backward compat: pop stale scales_5d from old saves + state_dict.pop("scales_5d", None) + # Convert 2D scales to 5D TMA layout for MX dtypes + if is_mx_dtype(self.input_dtype) and self.scales is not None: + s = self.scales.data if isinstance(self.scales, torch.nn.Parameter) else self.scales + if s.ndim == 2: + s_2d = s.T.contiguous() # [K_S, N] contiguous + N_dim, K_S = s_2d.shape[1], s_2d.shape[0] if N_dim % 128 == 0 and K_S % 4 == 0: - self.scales_5d = s_orig.reshape(N_dim // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N_dim // 128, K_S // 4, 2, 256).contiguous() - else: - self.scales_5d = torch.tensor([[]], dtype=torch.int32, device=s.device) - else: - device = self.W_q.device if self.W_q is not None else "cuda" - self.scales_5d = torch.tensor([[]], dtype=torch.int32, device=device) + self.scales = s_2d.reshape(N_dim // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N_dim // 128, K_S // 4, 2, 256).contiguous() #Make sure to feed UINT8 W_q for packing def pack( @@ -525,32 +518,23 @@ def pack( group_size = self.W_q.numel() // self.scales.numel() # Preshuffle weight scales to 5D TMA layout for fast loading - # Original: [N, K//group_size] -> 5D: [1, N//128, K//group_size//4, 2, 256] + # Original: [K_S, N] -> 5D: [1, N//128, K_S//4, 2, 256] K_S = K // group_size if N % 128 == 0 and K_S % 4 == 0: - self.scales_5d = self.scales.reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() - self.use_5d_scales = True + self.scales = self.scales.reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() else: - self.scales_5d = None - self.use_5d_scales = False - - # Keep 2D transposed layout for the kernel (pointer-based fallback) - self.scales = self.scales.T + # Keep 2D transposed layout for pointer-based fallback + self.scales = self.scales.T ################################ if(self.scales is not None): self.meta_dtype = TORCH_TO_DTYPE[self.scales.dtype] - # Default scales_5d for non-FP4 dtypes - if not hasattr(self, 'scales_5d') or self.scales_5d is None: - self.scales_5d = torch.tensor([[]], dtype=torch.int32, device=self.device) - #Register tensors as buffers self.W_q = torch.nn.Parameter(self.W_q, requires_grad=False) self.bias = torch.nn.Parameter(self.bias, requires_grad=False) if self.bias is not None else None self.scales = torch.nn.Parameter(self.scales,requires_grad=False) self.zeros = torch.nn.Parameter(self.zeros, requires_grad=False) - self.scales_5d = torch.nn.Parameter(self.scales_5d, requires_grad=False) #Register metadata self.metadata = torch.nn.Parameter( @@ -567,7 +551,7 @@ def pack( #Return the main arguments def get_tensor_args(self): - return [self.W_q, self.scales, self.zeros, self.scales_5d] + return [self.W_q, self.scales, self.zeros] def get_meta_args(self): return [int(self.scaled_activations), diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 9445696..09ca2c2 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -25,6 +25,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): b_sizeof = nargs['b_sizeof'] #Check cache + load_scales_as_block = kwargs['load_scales_as_block'] if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE): signature = str(tuple([get_closest_m(m), n, k, g, e, t])) if(signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]): @@ -38,16 +39,23 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) config["NUM_STAGES"] = num_stages - + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + # Adjust 5D TMA compatibility for cached configs + if load_scales_as_block and n % 128 == 0 and (k // g) % 4 == 0: + config['BLOCK_SIZE_N'] = max(config['BLOCK_SIZE_N'], 128) + while (config['BLOCK_SIZE_K'] // g) % 4 != 0: + config['BLOCK_SIZE_K'] *= 2 + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps) return - + gpu_shared_memory = get_gpu_shared_memory() - load_scales_as_block = kwargs['load_scales_as_block'] used = set() for config in configs: group_size_m = config.kwargs['GROUP_SIZE_M'] @@ -71,18 +79,16 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_n = next_power_of_2(block_size_n) #Constraints - if(load_scales_as_block): - # FOR TMA - # block_size_k = min(block_size_k, 256) #TODO: tmp MXFP TMA fix - # if block_size_n % 128 > 0: - # block_size_n = 128 - # if block_size_k % 128 > 0: - # block_size_k = 128 + if(load_scales_as_block): if(e > 1): block_size_k = max(block_size_k, 64) #m16n8k64 else: block_size_k = max(block_size_k, 32) #m16n8k32 - #block_size_k = max(block_size_k, 128) #TMA + # 5D TMA scale compatibility: adjust block sizes for 5D TMA descriptor + if n % 128 == 0 and (k // g) % 4 == 0: + block_size_n = max(block_size_n, 128) + while (block_size_k // g) % 4 != 0: + block_size_k *= 2 else: block_size_k = min(block_size_k, g) @@ -635,7 +641,7 @@ def gemm_INT_kernel_persistent_tma( @triton.jit def gemm_MX_kernel( a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, scales_a_ptr, scales_5d_ptr, + scales_ptr, zeros_ptr, scales_a_ptr, #M, N, K, M_CLOSEST, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, M_CLOSEST: tl.constexpr, ######### Quant parms ######### @@ -679,6 +685,7 @@ def gemm_MX_kernel( meta_scale_norm: tl.constexpr = (0.05 ** 2), ################################# use_tma: tl.constexpr = True, + USE_5D_SCALES: tl.constexpr = False, ): pid = tl.program_id(axis=0) @@ -722,8 +729,9 @@ def gemm_MX_kernel( BLOCK_SIZE_K_S: tl.constexpr = BLOCK_SIZE_K // group_size offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] - + if not USE_5D_SCALES: + scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] + if use_tma: a_desc = tl.make_tensor_descriptor( a_ptr, @@ -746,26 +754,8 @@ def gemm_MX_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N] ) - # 2D TMA for activation scales (disabled: last dim < 16 bytes) - # if(channel_scale_mode == 4): - # scales_a_desc = tl.make_tensor_descriptor( - # scales_a_ptr, - # [M, K // group_size], - # [stride_meta_a_m, stride_meta_a_g], - # [BLOCK_SIZE_M, BLOCK_SIZE_K_S], - # ) - - scales_b_desc = tl.make_tensor_descriptor( - scales_ptr, - [K // group_size, N], - [stride_meta_g, stride_meta_n], - [BLOCK_SIZE_K_S, BLOCK_SIZE_N], - ) - # 5D TMA Descriptors for Scales (preshuffled layout) - USE_5D_SCALES: tl.constexpr = use_tma and (N % 128 == 0) and (BLOCK_SIZE_K // group_size % 4 == 0) and (BLOCK_SIZE_M % 128 == 0) and (BLOCK_SIZE_N % 128 == 0) if USE_5D_SCALES: - rep_m: tl.constexpr = BLOCK_SIZE_M // 128 rep_n: tl.constexpr = BLOCK_SIZE_N // 128 rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 scales_b_shape1: tl.constexpr = N // 128 @@ -776,7 +766,7 @@ def gemm_MX_kernel( stride_b1: tl.constexpr = 512 * scales_b_shape2 stride_b0: tl.constexpr = stride_b1 * scales_b_shape1 scales_b_5d_desc = tl.make_tensor_descriptor( - scales_5d_ptr, + scales_ptr, [1, scales_b_shape1, scales_b_shape2, 2, 256], [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], [1, rep_n, rep_k, 2, 256] @@ -853,10 +843,10 @@ def gemm_MX_kernel( PRINTED = False -def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, - W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, - input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, - channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, +def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, + W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, + input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, + channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, ) -> Tensor: @@ -876,9 +866,19 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5 stride_meta_a_m, stride_meta_a_g = None, None if(is_mx_dtype(input_dtype)): + use_5d_scales = (scales.ndim == 5) + + # When autotuner has only 1 config (default mode), pruning is skipped entirely. + # Adjust block sizes directly to satisfy 5D TMA descriptor requirements. + if use_5d_scales and len(gemm_MX_kernel.configs) == 1: + cfg = gemm_MX_kernel.configs[0] + cfg.kwargs['BLOCK_SIZE_N'] = max(cfg.kwargs.get('BLOCK_SIZE_N', 64), 128) + while (cfg.kwargs.get('BLOCK_SIZE_K', 64) // group_size) % 4 != 0: + cfg.kwargs['BLOCK_SIZE_K'] *= 2 + gemm_MX_kernel[grid]( - x, W_q, output, - scales, zeros, scales_x, scales_5d, + x, W_q, output, + scales, zeros, scales_x, M, N, K, M_CLOSEST, W_nbits, group_size, unpack_mask, elements_per_sample, type_id, x.dtype.itemsize, W_q.dtype.itemsize, @@ -886,7 +886,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5 W_q.stride(0), W_q.stride(1), output.stride(0), output.stride(1), stride_meta_a_m, stride_meta_a_g, - scales.stride(0), scales.stride(1), + 0 if use_5d_scales else scales.stride(0), 0 if use_5d_scales else scales.stride(1), load_scales_as_block = True, input_dtype = DTYPE_TO_TRITON[input_dtype], output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], @@ -896,6 +896,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5 W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, + USE_5D_SCALES = use_5d_scales, ) else: gemm_INT_kernel[grid]( diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 41ca478..7a8470f 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -25,6 +25,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): b_sizeof = nargs['b_sizeof'] #Check cache + load_scales_as_block = kwargs['load_scales_as_block'] if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE): signature = str(tuple([get_closest_m(m), n, k, g, e, t])) if(signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]): @@ -38,21 +39,27 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) config["NUM_STAGES"] = num_stages - + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) - + + # Adjust 5D TMA compatibility for cached configs + if load_scales_as_block and n % 128 == 0 and (k // g) % 4 == 0: + config['BLOCK_SIZE_N'] = max(config['BLOCK_SIZE_N'], 128) + while (config['BLOCK_SIZE_K'] // g) % 4 != 0: + config['BLOCK_SIZE_K'] *= 2 + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero("c_ptr") if (config['SPLIT_K'] > 1) else None, ) - return - gpu_shared_memory = get_gpu_shared_memory() - load_scales_as_block = kwargs['load_scales_as_block'] + gpu_shared_memory = get_gpu_shared_memory() used = set() for config in configs: group_size_m = config.kwargs['GROUP_SIZE_M'] @@ -75,17 +82,16 @@ def kernel_config_pruner(configs, nargs, **kwargs): if(m >= 32): split_k = min(split_k, 8) #Constraints - if(load_scales_as_block): - # FOR TMA - # block_size_k = min(block_size_k, 256) #TODO: tmp MXFP TMA fix - # if block_size_n % 128 > 0: - # block_size_n = 128 - # if block_size_k % 128 > 0: - # block_size_k = 128 + if(load_scales_as_block): if(e > 1): block_size_k = max(block_size_k, 64) #m16n8k64 else: block_size_k = max(block_size_k, 32) #m16n8k32 + # 5D TMA scale compatibility: adjust block sizes for 5D TMA descriptor + if n % 128 == 0 and (k // g) % 4 == 0: + block_size_n = max(block_size_n, 128) + while (block_size_k // g) % 4 != 0: + block_size_k *= 2 else: block_size_k = min(block_size_k, g) @@ -324,6 +330,9 @@ def gemm_splitK_INT_kernel( atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', + USE_5D_SCALES: tl.constexpr = False, + SCALES_5D_SHAPE1: tl.constexpr = 0, + SCALES_5D_SHAPE2: tl.constexpr = 0, ): """ Based on https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py @@ -539,6 +548,9 @@ def gemm_splitK_MX_kernel( meta_scale_norm: tl.constexpr = (0.05 ** 2), ################################# use_tma: tl.constexpr = False, + USE_5D_SCALES: tl.constexpr = False, + SCALES_5D_SHAPE1: tl.constexpr = 0, + SCALES_5D_SHAPE2: tl.constexpr = 0, ): pid = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -585,13 +597,13 @@ def gemm_splitK_MX_kernel( offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) #B scales: [BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] - scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g + if not USE_5D_SCALES: + scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #A scales if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g - if use_tma: a_desc = tl.make_tensor_descriptor( a_ptr, @@ -599,7 +611,7 @@ def gemm_splitK_MX_kernel( [stride_am, stride_ak], [BLOCK_SIZE_M, BLOCK_SIZE_K_A_E] ) - + b_desc = tl.make_tensor_descriptor( b_ptr, [N, K // elements_per_sample], @@ -607,26 +619,27 @@ def gemm_splitK_MX_kernel( [BLOCK_SIZE_N, BLOCK_SIZE_K_B_E] ) - c_desc = tl.make_tensor_descriptor( + c_desc = tl.make_tensor_descriptor( c_ptr, [M, N], [stride_cm, stride_cn], [BLOCK_SIZE_M, BLOCK_SIZE_N] ) - - # 2D TMA - transposed - # scales_a_desc = tl.make_tensor_descriptor( - # scales_a_ptr, - # [M, K // group_size], - # [stride_meta_a_m, stride_meta_a_g], - # [BLOCK_SIZE_M, BLOCK_SIZE_K_S], - # ) - - scales_b_desc = tl.make_tensor_descriptor( + + # 5D TMA Descriptors for Scales (preshuffled layout) + if USE_5D_SCALES: + rep_n: tl.constexpr = BLOCK_SIZE_N // 128 + rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 + stride_b4: tl.constexpr = 1 + stride_b3: tl.constexpr = 256 + stride_b2: tl.constexpr = 512 + stride_b1: tl.constexpr = 512 * SCALES_5D_SHAPE2 + stride_b0: tl.constexpr = 512 * SCALES_5D_SHAPE2 * SCALES_5D_SHAPE1 + scales_b_5d_desc = tl.make_tensor_descriptor( scales_ptr, - [K // group_size, N], - [stride_meta_g, stride_meta_n], - [BLOCK_SIZE_K_S, BLOCK_SIZE_N], + [1, SCALES_5D_SHAPE1, SCALES_5D_SHAPE2, 2, 256], + [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], + [1, rep_n, rep_k, 2, 256] ) # Used in channel-wise MXPF8 version @@ -648,7 +661,11 @@ def gemm_splitK_MX_kernel( #k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) k_m = (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_S #OK for BLOCK_SIZE_K >=group_size - scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + if USE_5D_SCALES: + scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, (k * SPLIT_K + pid_k) * rep_k, 0, 0]) + scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) + else: + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) if(channel_scale_mode == 4): scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) @@ -696,7 +713,7 @@ def gemm_splitK_MX_kernel( else: tl.store(c_ptrs, acc, mask=mask) -def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, +def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, @@ -721,12 +738,22 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s if(is_mx_dtype(input_dtype)): gemm_splitK_kernel = gemm_splitK_MX_kernel load_scales_as_block = True + use_5d_scales = (scales.ndim == 5) + + # When autotuner has only 1 config (default mode), pruning is skipped entirely. + # Adjust block sizes directly to satisfy 5D TMA descriptor requirements. + if use_5d_scales and len(gemm_splitK_MX_kernel.configs) == 1: + cfg = gemm_splitK_MX_kernel.configs[0] + cfg.kwargs['BLOCK_SIZE_N'] = max(cfg.kwargs.get('BLOCK_SIZE_N', 64), 128) + while (cfg.kwargs.get('BLOCK_SIZE_K', 64) // group_size) % 4 != 0: + cfg.kwargs['BLOCK_SIZE_K'] *= 2 else: gemm_splitK_kernel = gemm_splitK_INT_kernel load_scales_as_block = False + use_5d_scales = False gemm_splitK_kernel[grid]( - x, W_q, output, + x, W_q, output, scales, zeros, scales_x, M, N, K, M_CLOSEST, ############################################# @@ -737,7 +764,7 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_q.stride(0), W_q.stride(1), output.stride(0), output.stride(1), stride_meta_a_m, stride_meta_a_g, - scales.stride(0), scales.stride(1), + 0 if use_5d_scales else scales.stride(0), 0 if use_5d_scales else scales.stride(1), ################################################ load_scales_as_block = load_scales_as_block, input_dtype = DTYPE_TO_TRITON[input_dtype], @@ -749,6 +776,9 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, + USE_5D_SCALES = use_5d_scales, + SCALES_5D_SHAPE1 = N // 128 if use_5d_scales else 0, + SCALES_5D_SHAPE2 = K // group_size // 4 if use_5d_scales else 0, ) if(not native_atomic): diff --git a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py index ae47589..2e9fb54 100755 --- a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py @@ -417,7 +417,7 @@ def gemm_splitK_persistent_kernel( tl.store(c_ptrs, acc, mask=mask) -def gemm_splitK_persistent_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, +def gemm_splitK_persistent_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index b1e36f3..538b4aa 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -575,7 +575,7 @@ def gemv_MX_kernel( tl.atomic_add(c_ptrs, acc, sem=atomic_mode) #TODO: gemv not generating correct reuslts with mxfp dtypes use except for A16W4. -def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, +def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index 0aa7acc..a667ff9 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -399,7 +399,7 @@ def gemv_INT_revsplitK_kernel( KERNEL_CACHE = {} -def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, +def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index 650d0f4..ff89b3d 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -459,7 +459,7 @@ def gemv_INT_splitK_kernel( tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) -def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, +def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 36bf55f..c55b299 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -89,6 +89,6 @@ def test_A4W4_NVFP_dynamic(self): gemlite_linear = A4W4_NVFP_dynamic(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) self.assertTrue(gemlite_linear.scaled_activations) - self.eval(gemlite_linear, tol = 1e-3) + self.eval(gemlite_linear, tol = 2e-3) From 01227af8af1b26affe452fd289af7128d9177140 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sun, 8 Mar 2026 10:42:51 -0700 Subject: [PATCH 34/63] clean-up --- gemlite/triton_kernels/gemm_kernels.py | 400 ++++++++---------- gemlite/triton_kernels/gemm_splitK_kernels.py | 13 +- gemlite/triton_kernels/utils.py | 6 + 3 files changed, 191 insertions(+), 228 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 09ca2c2..0d78e04 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -90,7 +90,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): while (block_size_k // g) % 4 != 0: block_size_k *= 2 else: - block_size_k = min(block_size_k, g) + block_size_k = max(min(block_size_k, g), 32) #tl.dot minimum K #Hint: skip block_size_n > block_size_k for col-major non-packed data. @@ -202,7 +202,7 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':4}, num_warps=4, num_stages=4),] + return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] ######################################################################################################################################################################## #AMD - Instinct MI300X @@ -323,6 +323,7 @@ def gemm_INT_kernel( meta_evict_policy: tl.constexpr = "evict_last", a_evict: tl.constexpr = "", b_evict: tl.constexpr = "evict_first", + USE_5D_SCALES: tl.constexpr = False, ): """ Based on https://github.com/fpgaminer/GPTQ-triton @@ -441,7 +442,6 @@ def gemm_INT_kernel( if not EVEN_K: a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) - ############################################################################################################# #Channel-wise scaling if(channel_scale_mode == 1): #weight-only @@ -466,171 +466,163 @@ def gemm_INT_kernel( c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -# TMA descriptors require a global memory allocation -from typing import Optional -def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) -triton.set_allocator(alloc_fn) - -@triton.autotune( - configs = get_autotune_config(), - key = KEYS, - prune_configs_by = {'early_config_prune': kernel_config_pruner}, - use_cuda_graph = AUTOTUNE.USE_CUDA_GRAPH, -) -@triton.jit -def gemm_INT_kernel_persistent_tma( - a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, - ######### Quant parms ######### - W_nbits: tl.constexpr, - group_size: tl.constexpr, - unpack_mask: tl.constexpr, - elements_per_sample: tl.constexpr, - ################################# - type_id: tl.constexpr, - a_sizeof: tl.constexpr, - b_sizeof: tl.constexpr, - ######### Strides ######### - stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_bk: tl.constexpr, stride_bn: tl.constexpr, - stride_cm: tl.constexpr, stride_cn: tl.constexpr, - stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, - stride_meta_g: tl.constexpr, stride_meta_n: tl.constexpr, - ######### Dtypes ######### - load_scales_as_block: tl.constexpr, #False - input_dtype: tl.constexpr, - output_dtype: tl.constexpr, - acc_dtype: tl.constexpr, - meta_dtype: tl.constexpr, - ######### Meta-data mode ######### - channel_scale_mode: tl.constexpr, - W_group_mode: tl.constexpr, - zero_is_scalar: tl.constexpr, - ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, - ################################# - EVEN_M: tl.constexpr = False, - EVEN_K: tl.constexpr = False, - EVEN_N: tl.constexpr = False, - ################################# - A_load_order: tl.constexpr = 0, - data_contiguous: tl.constexpr = True, - ################################# - meta_evict_policy: tl.constexpr = '', - a_evict: tl.constexpr = '', - b_evict: tl.constexpr = '', - NUM_SMS: tl.constexpr = 8, -): - """ - Persistent + TMA version. - A: (M, K) fp16/bf16 - B_packed: (K//elements_per_sample, N) int32 - scales/zeros: (num_groups, N) or other depending on W_group_mode - """ - - # --------------------------- - # Persistent tiling setup - # --------------------------- - start_pid = tl.program_id(0).to(tl.int32) +# @triton.autotune( +# configs = get_autotune_config(), +# key = KEYS, +# prune_configs_by = {'early_config_prune': kernel_config_pruner}, +# use_cuda_graph = AUTOTUNE.USE_CUDA_GRAPH, +# ) +# @triton.jit +# def gemm_INT_kernel_persistent_tma( +# a_ptr, b_ptr, c_ptr, +# scales_ptr, zeros_ptr, scales_a_ptr, +# M, N, K, M_CLOSEST, +# ######### Quant parms ######### +# W_nbits: tl.constexpr, +# group_size: tl.constexpr, +# unpack_mask: tl.constexpr, +# elements_per_sample: tl.constexpr, +# ################################# +# type_id: tl.constexpr, +# a_sizeof: tl.constexpr, +# b_sizeof: tl.constexpr, +# ######### Strides ######### +# stride_am: tl.constexpr, stride_ak: tl.constexpr, +# stride_bk: tl.constexpr, stride_bn: tl.constexpr, +# stride_cm: tl.constexpr, stride_cn: tl.constexpr, +# stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, +# stride_meta_g: tl.constexpr, stride_meta_n: tl.constexpr, +# ######### Dtypes ######### +# load_scales_as_block: tl.constexpr, #False +# input_dtype: tl.constexpr, +# output_dtype: tl.constexpr, +# acc_dtype: tl.constexpr, +# meta_dtype: tl.constexpr, +# ######### Meta-data mode ######### +# channel_scale_mode: tl.constexpr, +# W_group_mode: tl.constexpr, +# zero_is_scalar: tl.constexpr, +# ######### tuning params ######### +# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +# GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, +# ################################# +# EVEN_M: tl.constexpr = False, +# EVEN_K: tl.constexpr = False, +# EVEN_N: tl.constexpr = False, +# ################################# +# A_load_order: tl.constexpr = 0, +# data_contiguous: tl.constexpr = True, +# ################################# +# meta_evict_policy: tl.constexpr = '', +# a_evict: tl.constexpr = '', +# b_evict: tl.constexpr = '', +# NUM_SMS: tl.constexpr = 8, +# ): +# """ +# Persistent + TMA version. +# A: (M, K) fp16/bf16 +# B_packed: (K//elements_per_sample, N) int32 +# scales/zeros: (num_groups, N) or other depending on W_group_mode +# """ + +# # --------------------------- +# # Persistent tiling setup +# # --------------------------- +# start_pid = tl.program_id(0).to(tl.int32) - grid_m = tl.cdiv(M, BLOCK_SIZE_M) - grid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_tiles = grid_m * grid_n - width = GROUP_SIZE_M * grid_n # tiles per "group stripe" - - a_desc = tl.make_tensor_descriptor( - a_ptr, - [M, K], - [stride_am, stride_ak], - [BLOCK_SIZE_M, BLOCK_SIZE_K] - ) +# grid_m = tl.cdiv(M, BLOCK_SIZE_M) +# grid_n = tl.cdiv(N, BLOCK_SIZE_N) +# num_tiles = grid_m * grid_n +# width = GROUP_SIZE_M * grid_n # tiles per "group stripe" + +# a_desc = tl.make_tensor_descriptor( +# a_ptr, +# [M, K], +# [stride_am, stride_ak], +# [BLOCK_SIZE_M, BLOCK_SIZE_K] +# ) - # b_desc = tl.make_tensor_descriptor( - # b_ptr, - # [K, N], - # [stride_bk, stride_bn], - # [BLOCK_SIZE_K, BLOCK_SIZE_N] - # ) +# # b_desc = tl.make_tensor_descriptor( +# # b_ptr, +# # [K, N], +# # [stride_bk, stride_bn], +# # [BLOCK_SIZE_K, BLOCK_SIZE_N] +# # ) - #transposed : use self.W_q = self.W_q.contiguous().t() - b_desc = tl.make_tensor_descriptor( - b_ptr, - [N, K], - [stride_bn, stride_bk], - [BLOCK_SIZE_N, BLOCK_SIZE_K] - ) +# #transposed : use self.W_q = self.W_q.contiguous().t() +# b_desc = tl.make_tensor_descriptor( +# b_ptr, +# [N, K], +# [stride_bn, stride_bk], +# [BLOCK_SIZE_N, BLOCK_SIZE_K] +# ) - # # Precompute unpack shifts (vector length = elements_per_sample) - # # shifts = [0, W_nbits, 2*W_nbits, ...] - # shifts = (tl.arange(0, elements_per_sample) * W_nbits).to(tl.int32) - - # # Optional scalar zero - # if zero_is_scalar: - # zero_scalar = tl.load(zeros_ptr, eviction_policy="evict_last") - - ############################################################################################################# - # Main loop - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): - group_id = tile_id // width - first_m = group_id * GROUP_SIZE_M - gs = tl.minimum(grid_m - first_m, GROUP_SIZE_M) +# # # Precompute unpack shifts (vector length = elements_per_sample) +# # # shifts = [0, W_nbits, 2*W_nbits, ...] +# # shifts = (tl.arange(0, elements_per_sample) * W_nbits).to(tl.int32) + +# # # Optional scalar zero +# # if zero_is_scalar: +# # zero_scalar = tl.load(zeros_ptr, eviction_policy="evict_last") + +# ############################################################################################################# +# # Main loop +# for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): +# group_id = tile_id // width +# first_m = group_id * GROUP_SIZE_M +# gs = tl.minimum(grid_m - first_m, GROUP_SIZE_M) - pid_m = first_m + (tile_id % gs) - pid_n = (tile_id % width) // gs - - rm = pid_m * BLOCK_SIZE_M - rn = pid_n * BLOCK_SIZE_N - - # Accumulator - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - - # Column indices for this tile (used for metadata + store) - offs_n = rn + tl.arange(0, BLOCK_SIZE_N) - n_mask = offs_n < N - - # K loop - for k in tl.range(0, K, BLOCK_SIZE_K): - a = tl.load_tensor_descriptor(a_desc, [rm, k]) - - k_packed = k // elements_per_sample - #b = tl.load_tensor_descriptor(b_desc, [k_packed, rn]) - b = tl.load_tensor_descriptor(b_desc, [rn, k_packed]).T #Transposed - - acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) - - ############################################################################################################# - # Channel-wise scaling - offs_m = rm + tl.arange(0, BLOCK_SIZE_M) - m_mask = offs_m < M - if channel_scale_mode == 1: # weight-only - # expects a 1D per-N scale at scales_ptr (same as your original) - scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) - acc = acc.to(meta_dtype) * scales_b[None, :] - - if channel_scale_mode == 2: # activation-only - scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) - acc = acc.to(meta_dtype) * scales_a[:, None] - - if channel_scale_mode == 3: # weight + activation - scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) - scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) - - acc = acc.to(output_dtype) - - ############################################################################################################# - # Store - c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn - mask = (m_mask[:, None] & n_mask[None, :]).to(tl.int1) - if EVEN_M and EVEN_N: - tl.store(c_ptrs, acc) - else: - tl.store(c_ptrs, acc, mask=mask) +# pid_m = first_m + (tile_id % gs) +# pid_n = (tile_id % width) // gs + +# rm = pid_m * BLOCK_SIZE_M +# rn = pid_n * BLOCK_SIZE_N + +# # Accumulator +# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + +# # Column indices for this tile (used for metadata + store) +# offs_n = rn + tl.arange(0, BLOCK_SIZE_N) +# n_mask = offs_n < N + +# # K loop +# for k in tl.range(0, K, BLOCK_SIZE_K): +# a = tl.load_tensor_descriptor(a_desc, [rm, k]) + +# k_packed = k // elements_per_sample +# #b = tl.load_tensor_descriptor(b_desc, [k_packed, rn]) +# b = tl.load_tensor_descriptor(b_desc, [rn, k_packed]).T #Transposed + +# acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) + +# ############################################################################################################# +# # Channel-wise scaling +# offs_m = rm + tl.arange(0, BLOCK_SIZE_M) +# m_mask = offs_m < M +# if channel_scale_mode == 1: # weight-only +# # expects a 1D per-N scale at scales_ptr (same as your original) +# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * scales_b[None, :] + +# if channel_scale_mode == 2: # activation-only +# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * scales_a[:, None] + +# if channel_scale_mode == 3: # weight + activation +# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) +# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + +# acc = acc.to(output_dtype) + +# ############################################################################################################# +# # Store +# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn +# mask = (m_mask[:, None] & n_mask[None, :]).to(tl.int1) +# if EVEN_M and EVEN_N: +# tl.store(c_ptrs, acc) +# else: +# tl.store(c_ptrs, acc, mask=mask) @triton.autotune( configs = get_autotune_config(), @@ -642,7 +634,6 @@ def gemm_INT_kernel_persistent_tma( def gemm_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - #M, N, K, M_CLOSEST, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, M_CLOSEST: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, @@ -852,7 +843,6 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x global PRINTED M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W - #M, K, N = x.shape[0], W_q.shape[1] * elements_per_sample, W_q.shape[0] #W.T M_CLOSEST = get_closest_m(M) #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" @@ -866,60 +856,36 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x stride_meta_a_m, stride_meta_a_g = None, None if(is_mx_dtype(input_dtype)): + gemm_kernel = gemm_MX_kernel + load_scales_as_block = True use_5d_scales = (scales.ndim == 5) - - # When autotuner has only 1 config (default mode), pruning is skipped entirely. - # Adjust block sizes directly to satisfy 5D TMA descriptor requirements. - if use_5d_scales and len(gemm_MX_kernel.configs) == 1: - cfg = gemm_MX_kernel.configs[0] - cfg.kwargs['BLOCK_SIZE_N'] = max(cfg.kwargs.get('BLOCK_SIZE_N', 64), 128) - while (cfg.kwargs.get('BLOCK_SIZE_K', 64) // group_size) % 4 != 0: - cfg.kwargs['BLOCK_SIZE_K'] *= 2 - - gemm_MX_kernel[grid]( - x, W_q, output, - scales, zeros, scales_x, - M, N, K, M_CLOSEST, - W_nbits, group_size, unpack_mask, elements_per_sample, - type_id, x.dtype.itemsize, W_q.dtype.itemsize, - x.stride(0), x.stride(1), - W_q.stride(0), W_q.stride(1), - output.stride(0), output.stride(1), - stride_meta_a_m, stride_meta_a_g, - 0 if use_5d_scales else scales.stride(0), 0 if use_5d_scales else scales.stride(1), - load_scales_as_block = True, - input_dtype = DTYPE_TO_TRITON[input_dtype], - output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], - acc_dtype = DTYPE_TO_TRITON[acc_dtype], - meta_dtype = DTYPE_TO_TRITON[meta_dtype], - channel_scale_mode = channel_scale_mode, - W_group_mode = W_group_mode, - zero_is_scalar = zeros.numel() == 1, - data_contiguous = data_contiguous, - USE_5D_SCALES = use_5d_scales, - ) else: - gemm_INT_kernel[grid]( - x, W_q, output, - scales, zeros, scales_x, - M, N, K, M_CLOSEST, - W_nbits, group_size, unpack_mask, elements_per_sample, - type_id, x.dtype.itemsize, W_q.dtype.itemsize, - x.stride(0), x.stride(1), - W_q.stride(0), W_q.stride(1), - output.stride(0), output.stride(1), - stride_meta_a_m, stride_meta_a_g, - scales.stride(0), scales.stride(1), - load_scales_as_block = False, - input_dtype = DTYPE_TO_TRITON[input_dtype], - output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], - acc_dtype = DTYPE_TO_TRITON[acc_dtype], - meta_dtype = DTYPE_TO_TRITON[meta_dtype], - channel_scale_mode = channel_scale_mode, - W_group_mode = W_group_mode, - zero_is_scalar = zeros.numel() == 1, - data_contiguous = data_contiguous, - ) + gemm_kernel = gemm_INT_kernel + load_scales_as_block = False + use_5d_scales = False + + gemm_kernel[grid]( + x, W_q, output, + scales, zeros, scales_x, + M, N, K, M_CLOSEST, + W_nbits, group_size, unpack_mask, elements_per_sample, + type_id, x.dtype.itemsize, W_q.dtype.itemsize, + x.stride(0), x.stride(1), + W_q.stride(0), W_q.stride(1), + output.stride(0), output.stride(1), + stride_meta_a_m, stride_meta_a_g, + 0 if use_5d_scales else scales.stride(0), 0 if use_5d_scales else scales.stride(1), + load_scales_as_block = load_scales_as_block, + input_dtype = DTYPE_TO_TRITON[input_dtype], + output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], + acc_dtype = DTYPE_TO_TRITON[acc_dtype], + meta_dtype = DTYPE_TO_TRITON[meta_dtype], + channel_scale_mode = channel_scale_mode, + W_group_mode = W_group_mode, + zero_is_scalar = zeros.numel() == 1, + data_contiguous = data_contiguous, + USE_5D_SCALES = use_5d_scales, + ) return output diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 7a8470f..e478d52 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -93,7 +93,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): while (block_size_k // g) % 4 != 0: block_size_k *= 2 else: - block_size_k = min(block_size_k, g) + block_size_k = max(min(block_size_k, g), 32) #tl.dot minimum K block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) @@ -202,7 +202,7 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] + return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] ######################################################################################################################################################################## #AMD - Instinct MI300X @@ -720,7 +720,6 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s ) -> Tensor: M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W - #M, K, N = x.shape[0], W_q.shape[1] * elements_per_sample, W_q.shape[0] #W.T #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" M_CLOSEST = get_closest_m(M) @@ -739,14 +738,6 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s gemm_splitK_kernel = gemm_splitK_MX_kernel load_scales_as_block = True use_5d_scales = (scales.ndim == 5) - - # When autotuner has only 1 config (default mode), pruning is skipped entirely. - # Adjust block sizes directly to satisfy 5D TMA descriptor requirements. - if use_5d_scales and len(gemm_splitK_MX_kernel.configs) == 1: - cfg = gemm_splitK_MX_kernel.configs[0] - cfg.kwargs['BLOCK_SIZE_N'] = max(cfg.kwargs.get('BLOCK_SIZE_N', 64), 128) - while (cfg.kwargs.get('BLOCK_SIZE_K', 64) // group_size) % 4 != 0: - cfg.kwargs['BLOCK_SIZE_K'] *= 2 else: gemm_splitK_kernel = gemm_splitK_INT_kernel load_scales_as_block = False diff --git a/gemlite/triton_kernels/utils.py b/gemlite/triton_kernels/utils.py index 65a98b1..be064e2 100755 --- a/gemlite/triton_kernels/utils.py +++ b/gemlite/triton_kernels/utils.py @@ -6,6 +6,12 @@ from triton.runtime import driver from ..dtypes import * +# TMA descriptors require a global memory allocation +from typing import Optional +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) +triton.set_allocator(alloc_fn) + @triton.jit def swizzle_tile_v1(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): grid_m = tl.cdiv(M, BLOCK_SIZE_M) From 67cf3174c91e84df76f8ceaa2aec3d8ab766ae83 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sun, 8 Mar 2026 14:02:42 -0700 Subject: [PATCH 35/63] fix tests with autotune --- gemlite/quant_utils.py | 22 ++--- gemlite/triton_kernels/gemm_kernels.py | 79 +++++++-------- gemlite/triton_kernels/gemm_splitK_kernels.py | 96 ++++++++++--------- tests/test_gemlitelineartriton.py | 2 +- tests/test_mxfp.py | 2 +- 5 files changed, 99 insertions(+), 102 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 2eacde9..e030034 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -874,7 +874,7 @@ def scale_activations_mxfp8_triton_v2( M_padded = M + pad_m out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) BLOCK_SIZE_M = group_size @@ -982,7 +982,7 @@ def scale_activations_mxfp8_triton_v3( M_padded = M + pad_m out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1105,7 +1105,7 @@ def scale_activations_mxfp8_triton_v4( M_padded = M + pad_m out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) @@ -1322,7 +1322,7 @@ def scale_activations_mxfp4_triton(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1454,7 +1454,7 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1578,7 +1578,7 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) device_index = tensor.device.index @@ -1704,7 +1704,7 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) device_index = tensor.device.index @@ -1820,7 +1820,7 @@ def scale_activations_mxfp4_triton_v3(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1947,7 +1947,7 @@ def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tenso M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -2097,7 +2097,7 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) device_index = tensor.device.index @@ -2236,7 +2236,7 @@ def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tenso M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) device_index = tensor.device.index diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 0d78e04..a9c76c5 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -97,15 +97,17 @@ def kernel_config_pruner(configs, nargs, **kwargs): if not IS_HIP: if e == 1 and num_stages == 1: continue - - # Prune configs that exceed shared memory - estimated_smem = estimate_shared_memory_per_block( - block_size_m, block_size_n, block_size_k, - a_sizeof, b_sizeof, num_stages, e, g, - load_scales_as_block - ) - if estimated_smem > gpu_shared_memory: - continue + + # Reduce num_stages until config fits in shared memory + while num_stages > 1: + estimated_smem = estimate_shared_memory_per_block( + block_size_m, block_size_n, block_size_k, + a_sizeof, b_sizeof, num_stages, e, g, + load_scales_as_block + ) + if estimated_smem <= gpu_shared_memory: + break + num_stages -= 1 key = (block_size_m, block_size_n, block_size_k, group_size_m, A_load_order, num_stages, num_warps) @@ -139,7 +141,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): ######################################################################################################################################################################## #Nvidia def get_max_autotune_config_nvidia(): - stages = [1, 4, 5] if gpu_has_more_shared_memory() else [1, 2, 4] + stages = [1, 3, 4, 5] configs = [] for A in [0, 2]: for w in [4, 8]: @@ -158,47 +160,26 @@ def get_max_autotune_config_nvidia(): def get_fast_autotune_config_nvidia(): configs = [] #BLOCK_SIZE_M is automatically adapted in the config pruning. - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + #Small tiles (packed INT with small group_size, small N problems) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=5)) + #Medium N tiles configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=3)) + #Large N tiles configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) - - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=3)) + #Large M×N tiles (pruner adapts M for large batch sizes) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - - #MXFP/NVFP - configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':256, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - - #MXFP/NVFP 5D TMA optimized (num_stages=3) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=3)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + #Extra coverage + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) return configs def get_default_config_nvidia(): @@ -791,10 +772,18 @@ def gemm_MX_kernel( scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, k * rep_k, 0, 0]) scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) else: - scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + if EVEN_K: + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + else: + _scale_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, mask=_scale_k_mask, other=0.0, eviction_policy=meta_evict_policy) if(channel_scale_mode == 4): - scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + if EVEN_K: + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + else: + _scale_a_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, mask=_scale_a_k_mask, other=0.0, eviction_policy=meta_evict_policy) else: scales_a = scales_a_1s #################################################################################### diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index e478d52..e862194 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -42,7 +42,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) - config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + config['EVEN_K'] = (k % (config['BLOCK_SIZE_K'] * config.get('SPLIT_K', 1)) == 0) # Adjust 5D TMA compatibility for cached configs if load_scales_as_block and n % 128 == 0 and (k // g) % 4 == 0: @@ -50,7 +50,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): while (config['BLOCK_SIZE_K'] // g) % 4 != 0: config['BLOCK_SIZE_K'] *= 2 config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) - config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + config['EVEN_K'] = (k % (config['BLOCK_SIZE_K'] * config.get('SPLIT_K', 1)) == 0) yield triton.Config(config, num_stages=num_stages, @@ -102,21 +102,23 @@ def kernel_config_pruner(configs, nargs, **kwargs): if not IS_HIP: if e == 1 and num_stages == 1: continue - - # Prune configs that exceed shared memory - estimated_smem = estimate_shared_memory_per_block( - block_size_m, block_size_n, block_size_k, - a_sizeof, b_sizeof, num_stages, e, g, - load_scales_as_block - ) - if estimated_smem > gpu_shared_memory: - continue + + # Reduce num_stages until config fits in shared memory + while num_stages > 1: + estimated_smem = estimate_shared_memory_per_block( + block_size_m, block_size_n, block_size_k, + a_sizeof, b_sizeof, num_stages, e, g, + load_scales_as_block + ) + if estimated_smem <= gpu_shared_memory: + break + num_stages -= 1 key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k, A_load_order, num_stages, num_warps) even_m = (m % block_size_m == 0) even_n = (n % block_size_n == 0) - even_k = (k % block_size_k == 0) + even_k = (k % (block_size_k * split_k) == 0) new_config = { "BLOCK_SIZE_M": block_size_m, @@ -151,7 +153,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): #These autotunes are optimized for batch-size 1 to 64 (!) def get_max_autotune_config_nvidia(): - stages = [1, 2, 4, 5] if gpu_has_more_shared_memory() else [1, 2, 4] + stages = [1, 2, 3, 4, 5] configs = [] for A in [0, 2]: for w in [4, 8]: @@ -172,33 +174,26 @@ def get_max_autotune_config_nvidia(): #Faster autotuner def get_fast_autotune_config_nvidia(): configs = [] - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':512, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':8, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':512, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'SPLIT_K':8, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':512, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':32, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=2)) + #Small N tiles + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=5)) + #Medium N tiles (N=128 — workhorse for MX/INT types) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + #Large N tiles + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':32, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + #High split_k with wide N + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'SPLIT_K':8, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + #Extra coverage + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=5)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) return configs def get_default_config_nvidia(): @@ -590,6 +585,7 @@ def gemm_splitK_MX_kernel( offs_bk = pid_k * BLOCK_SIZE_K_B_E + tl.arange(0, BLOCK_SIZE_K_B_E) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn + b_mask = (offs_bk[:, None] < (K // elements_per_sample)) #Scales stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size @@ -657,7 +653,10 @@ def gemm_splitK_MX_kernel( else: a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + if EVEN_K: + b = tl.load(b_ptrs, eviction_policy=b_evict) + else: + b = tl.load(b_ptrs, mask=b_mask, other=0.0, eviction_policy=b_evict) #k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) k_m = (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_S #OK for BLOCK_SIZE_K >=group_size @@ -665,10 +664,18 @@ def gemm_splitK_MX_kernel( scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, (k * SPLIT_K + pid_k) * rep_k, 0, 0]) scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) else: - scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + if EVEN_K: + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + else: + _scale_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, mask=_scale_k_mask, other=0.0, eviction_policy=meta_evict_policy) if(channel_scale_mode == 4): - scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + if EVEN_K: + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + else: + _scale_a_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, mask=_scale_a_k_mask, other=0.0, eviction_policy=meta_evict_policy) else: scales_a = scales_a_1s @@ -679,7 +686,8 @@ def gemm_splitK_MX_kernel( if not use_tma: if not EVEN_K: - a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K_A) < (K // elements_per_sample_a))).to(tl.int1) + b_mask = ((offs_bk[:, None] + (k + 1) * BLOCK_SIZE_K_B) < (K // elements_per_sample)) #NVFP4 meta-scale if(group_size == 16): diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index 3c1a8b6..c9a9cb5 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -19,7 +19,7 @@ def is_fp8_supported(): gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] reset_config() -set_autotune(False) +#set_autotune(False) KERNEL.ENABLE_CACHING = False in_features, out_features = 4032, 2032 diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index c55b299..0912eb6 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -16,7 +16,7 @@ def is_fp8_supported(device_index=0): compute_dtype = torch.bfloat16 #float16, bfloat16 matmul_types = ['GEMM', 'GEMM_SPLITK'] #TODO: improve GEMV mxfp accuracy. reset_config() -set_autotune(False) +#set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) From 8d87c4521e4befd9c42d0f40e5157fd14edfdb81 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sun, 8 Mar 2026 14:11:32 -0700 Subject: [PATCH 36/63] enable tma for splitK --- gemlite/triton_kernels/gemm_splitK_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index e862194..57e630e 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -542,7 +542,7 @@ def gemm_splitK_MX_kernel( b_evict: tl.constexpr = 'evict_first', meta_scale_norm: tl.constexpr = (0.05 ** 2), ################################# - use_tma: tl.constexpr = False, + use_tma: tl.constexpr = True, USE_5D_SCALES: tl.constexpr = False, SCALES_5D_SHAPE1: tl.constexpr = 0, SCALES_5D_SHAPE2: tl.constexpr = 0, From f3a2f3e529aa5a0df2985ccc7c3cfbb4a0279cd9 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sun, 8 Mar 2026 15:20:09 -0700 Subject: [PATCH 37/63] fix mx autotune config test --- gemlite/core.py | 4 ++-- gemlite/triton_kernels/gemm_kernels.py | 4 +++- gemlite/triton_kernels/gemm_splitK_kernels.py | 11 ++++++++--- tests/test_gemlitelineartriton.py | 2 +- tests/test_mxfp.py | 5 +++-- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 7deaf26..23e0766 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -518,10 +518,10 @@ def pack( group_size = self.W_q.numel() // self.scales.numel() # Preshuffle weight scales to 5D TMA layout for fast loading - # Original: [K_S, N] -> 5D: [1, N//128, K_S//4, 2, 256] + # Original: [K_S, N] -> transpose to [N, K_S] -> 5D: [1, N//128, K_S//4, 2, 256] K_S = K // group_size if N % 128 == 0 and K_S % 4 == 0: - self.scales = self.scales.reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() + self.scales = self.scales.T.contiguous().reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() else: # Keep 2D transposed layout for pointer-based fallback self.scales = self.scales.T diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index a9c76c5..b27364b 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -183,7 +183,9 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] + return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2), + ] ######################################################################################################################################################################## #AMD - Instinct MI300X diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 57e630e..7181309 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -197,7 +197,7 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] + return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] ######################################################################################################################################################################## #AMD - Instinct MI300X @@ -370,7 +370,8 @@ def gemm_splitK_INT_kernel( offs_ak = offs_k offs_bk = offs_k - b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) + b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) + b_mask = (offs_bk[:, None] < K).to(tl.int1) q_shift = ((offs_bk % elements_per_sample) * W_nbits).to(tl.int32)[:, None] #Inputs @@ -399,7 +400,10 @@ def gemm_splitK_INT_kernel( else: a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + if EVEN_K: + b = tl.load(b_ptrs, eviction_policy=b_evict) + else: + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): #Early load if EVEN_M and EVEN_K: @@ -448,6 +452,7 @@ def gemm_splitK_INT_kernel( if not EVEN_K: a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K * SPLIT_K) < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] + (k + 1) * BLOCK_SIZE_K_U) < K).to(tl.int1) ############################################################################################################# #Channel-wise scaling diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index c9a9cb5..3c1a8b6 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -19,7 +19,7 @@ def is_fp8_supported(): gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] reset_config() -#set_autotune(False) +set_autotune(False) KERNEL.ENABLE_CACHING = False in_features, out_features = 4032, 2032 diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 0912eb6..ab41e08 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -16,11 +16,12 @@ def is_fp8_supported(device_index=0): compute_dtype = torch.bfloat16 #float16, bfloat16 matmul_types = ['GEMM', 'GEMM_SPLITK'] #TODO: improve GEMV mxfp accuracy. reset_config() -#set_autotune(False) +set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) -in_features, out_features = 4032, 2048 +in_features, out_features = 4224, 2048 # test 5D TMA +#in_features, out_features = 4032, 2048 # test 2D scales fall-back batch_sizes = [1, 3, 16, 30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. From 1f113627ef44f4f3b2c66ed0c54d9c8ad42b3e17 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sun, 8 Mar 2026 15:49:34 -0700 Subject: [PATCH 38/63] prune activation quant configs --- gemlite/quant_utils.py | 70 +----------------------------------------- 1 file changed, 1 insertion(+), 69 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index e030034..6c92e42 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -533,13 +533,10 @@ def scale_activations_per_token_triton_v3( @triton.autotune( configs=[ - triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 1024}, num_warps=4, num_stages=2), triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 2048}, num_warps=4, num_stages=2), triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 1024}, num_warps=4, num_stages=2), triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 2048}, num_warps=4, num_stages=2), triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 1024}, num_warps=4, num_stages=2), triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 2048}, num_warps=8, num_stages=2), triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), ], @@ -906,9 +903,7 @@ def scale_activations_mxfp8_triton_v2( triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), ], @@ -1006,15 +1001,9 @@ def scale_activations_mxfp8_triton_v3( @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 512}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), ], key=['M', 'K'], @@ -1224,11 +1213,9 @@ def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks}, @@ -1347,12 +1334,10 @@ def scale_activations_mxfp4_triton(tensor: Tensor) -> Tuple[Tensor, Tensor]: @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=3), triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), ], key=['M', 'K'], @@ -1480,13 +1465,9 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, #################################################################################################################### @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), ], key=['M', 'K'], @@ -1601,13 +1582,9 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: #################################################################################################################### @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), ], key=['M', 'K'], @@ -1732,10 +1709,7 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 256}, num_warps=8, num_stages=3), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks}, @@ -1848,14 +1822,11 @@ def scale_activations_mxfp4_triton_v3(tensor: Tensor) -> Tuple[Tensor, Tensor]: @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 256}, num_warps=8, num_stages=3), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks}, @@ -1992,28 +1963,11 @@ def prune_large_blocks_2d(configs, named_args, **kwargs): @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64}, num_warps=8, num_stages=2), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks_2d}, @@ -2120,33 +2074,11 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks_2d}, From bcf689f78a67d2da6c252bfed97d2d9a8e05e803 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 9 Mar 2026 03:12:12 -0700 Subject: [PATCH 39/63] update --- examples/eval_flops.py | 15 +++++++++--- gemlite/triton_kernels/gemm_kernels.py | 23 +++++++++++++++---- gemlite/triton_kernels/gemm_splitK_kernels.py | 10 +++++--- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/examples/eval_flops.py b/examples/eval_flops.py index a62e71a..ba6fb27 100644 --- a/examples/eval_flops.py +++ b/examples/eval_flops.py @@ -11,7 +11,7 @@ device, dtype = 'cuda:0', torch.bfloat16 repeat = 32 -#gemlite.reset_cache() +gemlite.reset_config() #gemlite.set_autotune("max") #gemlite.core.enable_activation_scaling(2) @@ -320,14 +320,23 @@ def run_benchmark(proc_name, M, K, N): old_cudagraph = _inductor_config.triton.cudagraph_trees _inductor_config.triton.cudagraph_trees = False + # NOTE: flashinfer's CUTLASS NVFP4 kernel requires M to be a multiple of 128. + # When M < 128, we pad M up to 128 so the kernel doesn't crash. The TFLOP/s + # are computed using the padded M to keep the comparison fair (same actual work). + M_padded = max(M, 128) + M_padded = ((M_padded + 127) // 128) * 128 + model = get_model(K, N, repeat=repeat) patch_model_flashinfer_nvfp4(model) model = torch.compile(model, mode="reduce-overhead", fullgraph=True) - perf_time_ms = eval_model(model, M, K) / repeat + perf_time_ms = eval_model(model, M_padded, K) / repeat tflops = get_flops(M, K, N, perf_time_ms) label = "flashinfer NVFP4 (dynamic)" - print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + if M_padded != M: + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s (M padded to {M_padded} internally)") + else: + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") cleanup(model) _inductor_config.triton.cudagraph_trees = old_cudagraph diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index b27364b..8b265fd 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -174,8 +174,15 @@ def get_fast_autotune_config_nvidia(): #Large M×N tiles (pruner adapts M for large batch sizes) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + # NVFP4-friendly configs: BSK=128 allows stages=3/4 within 99KB smem (NVFP4 g=16 cant fit BSK=256 stages=3) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=8, num_stages=3)) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=8, num_stages=4)) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":2}, num_warps=8, num_stages=3)) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":2}, num_warps=8, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) - configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) #Extra coverage configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) @@ -750,9 +757,13 @@ def gemm_MX_kernel( if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g - # Used in channel-wise MXPF8 version - scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) - scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + # _1s dtype must match actual scale dtype: uint8 for MXFP (E8M0), float8e4nv for NVFP4 (E4M3) + if group_size == 16: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + else: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): @@ -788,6 +799,10 @@ def gemm_MX_kernel( scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, mask=_scale_a_k_mask, other=0.0, eviction_policy=meta_evict_policy) else: scales_a = scales_a_1s + + #scales_b = scales_b_1s + #scales_a = scales_a_1s + #################################################################################### acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 7181309..f689458 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -643,9 +643,13 @@ def gemm_splitK_MX_kernel( [1, rep_n, rep_k, 2, 256] ) - # Used in channel-wise MXPF8 version - scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) - scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + + if group_size == 16: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + else: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k): From 8aaff4ba40073bd25cbed1591897d683a2ce8911 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 9 Mar 2026 04:18:17 -0700 Subject: [PATCH 40/63] fix mxfp8 activation quant spill over --- gemlite/quant_utils.py | 56 +++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 6c92e42..8f64277 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -871,7 +871,7 @@ def scale_activations_mxfp8_triton_v2( M_padded = M + pad_m out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) BLOCK_SIZE_M = group_size @@ -977,7 +977,7 @@ def scale_activations_mxfp8_triton_v3( M_padded = M + pad_m out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1012,7 +1012,7 @@ def scale_activations_mxfp8_triton_v3( @triton.jit def scale_activations_mxfp8_triton_kernel_v4( tensor_ptr, out_ptr, scales_ptr, - M, K, + M, M_padded, K, stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, @@ -1024,7 +1024,7 @@ def scale_activations_mxfp8_triton_kernel_v4( ): pid = tl.program_id(0) num_programs = tl.num_programs(0) - num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + num_m_tiles = tl.cdiv(M_padded, BLOCK_SIZE_M) GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK @@ -1067,12 +1067,14 @@ def scale_activations_mxfp8_triton_kernel_v4( # Store scales: [FLAT_M] → [BLOCK_M, GROUPS_PER_BLOCK] scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + # For padding rows (M <= row < M_padded), store identity scale (127 = 2^0 in E8M0) + scales_2d = tl.where(m_mask[:, None], scales_2d, tl.full(scales_2d.shape, 127, dtype=tl.uint8)) group_idx = k_start // GROUP_SIZE offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) tl.store( scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, - scales_2d, mask=m_mask[:, None] & g_mask[None, :] + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] ) tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) @@ -1094,13 +1096,13 @@ def scale_activations_mxfp8_triton_v4( M_padded = M + pad_m out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) - scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) scale_activations_mxfp8_triton_kernel_v4[grid]( tensor, out, scales, - M, K, + M, M_padded, K, tensor.stride(0), tensor.stride(1), out.stride(0), out.stride(1), scales.stride(0), scales.stride(1), @@ -1309,7 +1311,7 @@ def scale_activations_mxfp4_triton(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1439,7 +1441,7 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1487,7 +1489,7 @@ def scale_activations_mxfp4_triton_kernel_v2( ): pid = tl.program_id(0) num_programs = tl.num_programs(0) - num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + num_m_tiles = tl.cdiv(M_padded, BLOCK_SIZE_M) GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 @@ -1559,7 +1561,7 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) device_index = tensor.device.index @@ -1605,7 +1607,7 @@ def scale_activations_nvfp4_triton_kernel_v2( ): pid = tl.program_id(0) num_programs = tl.num_programs(0) - num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) + num_m_tiles = tl.cdiv(M_padded, BLOCK_SIZE_M) GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 @@ -1681,7 +1683,7 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) device_index = tensor.device.index @@ -1794,7 +1796,7 @@ def scale_activations_mxfp4_triton_v3(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1918,7 +1920,7 @@ def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tenso M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index @@ -1975,7 +1977,7 @@ def prune_large_blocks_2d(configs, named_args, **kwargs): @triton.jit def scale_activations_mxfp4_triton_kernel_v5( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, + M, M_padded, K, stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, @@ -2030,12 +2032,14 @@ def scale_activations_mxfp4_triton_kernel_v5( tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + # For padding rows (M <= row < M_padded), store identity scale (127 = 2^0 in E8M0) + scales_2d = tl.where(offs_m[:, None] < M, scales_2d, tl.full(scales_2d.shape, 127, dtype=tl.uint8)) base_group = pid_k * GROUPS_PER_BLOCK offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) tl.store( scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, - scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] ) @@ -2051,14 +2055,14 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.full((M_padded, K // group_size), fill_value=127, device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + grid = lambda meta: (triton.cdiv(M_padded, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) device_index = tensor.device.index scale_activations_mxfp4_triton_kernel_v5[grid]( tensor, out, scales, thr_pos[device_index], - M, K, + M, M_padded, K, tensor.stride(0), tensor.stride(1), scales.stride(0), scales.stride(1), out.stride(0), out.stride(1), @@ -2086,7 +2090,7 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: @triton.jit def scale_activations_nvfp4_triton_kernel_v5( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, + M, M_padded, K, stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, @@ -2146,12 +2150,14 @@ def scale_activations_nvfp4_triton_kernel_v5( tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + # For padding rows (M <= row < M_padded), store identity scale (1.0 in fp8) + scales_2d = tl.where(offs_m[:, None] < M, scales_2d, tl.full(scales_2d.shape, 1.0, dtype=tl.float32).to(fp8_dtype)) base_group = pid_k * GROUPS_PER_BLOCK offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) tl.store( scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, - scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] ) @@ -2168,14 +2174,14 @@ def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tenso M_padded = M + pad_m out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.zeros((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + grid = lambda meta: (triton.cdiv(M_padded, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) device_index = tensor.device.index scale_activations_nvfp4_triton_kernel_v5[grid]( tensor, out, scales, thr_pos[device_index], - M, K, + M, M_padded, K, tensor.stride(0), tensor.stride(1), scales.stride(0), scales.stride(1), out.stride(0), out.stride(1), From a0d98ac1ff619eac5e63429e7d58c4e079dd5346 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 9 Mar 2026 09:14:34 -0700 Subject: [PATCH 41/63] set tma flag --- gemlite/__init__.py | 1 + gemlite/core.py | 10 ++++++++-- gemlite/triton_kernels/gemm_kernels.py | 3 +++ gemlite/triton_kernels/gemm_splitK_kernels.py | 3 +++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/gemlite/__init__.py b/gemlite/__init__.py index ca6d3b5..ce5f154 100755 --- a/gemlite/__init__.py +++ b/gemlite/__init__.py @@ -12,6 +12,7 @@ set_acc_dtype, set_autotune, set_kernel_caching, + enable_tma, forward_functional, ) diff --git a/gemlite/core.py b/gemlite/core.py index 23e0766..8e699b1 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -66,6 +66,7 @@ GEMLITE_MATMUL_TYPES_MAPPING = {GEMLITE_MATMUL_TYPES[i]: i for i in range(len(GEMLITE_MATMUL_TYPES))} GEMLITE_TRITON_CONFIG_CACHE = {} #Global config cache for all the kernels _GROUP_SIZE_WARNED = False +GEMLITE_USE_TMA = True ################################################################################### #Utils @@ -96,6 +97,11 @@ def set_acc_dtype(dtype): assert dtype in [DType.FP16, DType.FP32], "Invalid dtype (should be DType.FP16 or DType.FP32)." GEMLITE_ACC_DTYPE[DType.FP16] = dtype +#Enable/disable TMA for MX kernel data loading +def enable_tma(enabled: bool = True): + global GEMLITE_USE_TMA + GEMLITE_USE_TMA = enabled + #Return the default gemv kernel to use for M==1 def get_default_gemv(W_nbits: int, mx_dtype: bool = False) -> str: #TODO: adapt mx for IS_HIP = True @@ -341,7 +347,7 @@ def load_state_dict(self, state_dict, strict=True, assign=False): if s.ndim == 2: s_2d = s.T.contiguous() # [K_S, N] contiguous N_dim, K_S = s_2d.shape[1], s_2d.shape[0] - if N_dim % 128 == 0 and K_S % 4 == 0: + if GEMLITE_USE_TMA and N_dim % 128 == 0 and K_S % 4 == 0: self.scales = s_2d.reshape(N_dim // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N_dim // 128, K_S // 4, 2, 256).contiguous() #Make sure to feed UINT8 W_q for packing @@ -520,7 +526,7 @@ def pack( # Preshuffle weight scales to 5D TMA layout for fast loading # Original: [K_S, N] -> transpose to [N, K_S] -> 5D: [1, N//128, K_S//4, 2, 256] K_S = K // group_size - if N % 128 == 0 and K_S % 4 == 0: + if GEMLITE_USE_TMA and N % 128 == 0 and K_S % 4 == 0: self.scales = self.scales.T.contiguous().reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() else: # Keep 2D transposed layout for pointer-based fallback diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 8b265fd..cce98fb 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -314,6 +314,7 @@ def gemm_INT_kernel( a_evict: tl.constexpr = "", b_evict: tl.constexpr = "evict_first", USE_5D_SCALES: tl.constexpr = False, + use_tma: tl.constexpr = True, ): """ Based on https://github.com/fpgaminer/GPTQ-triton @@ -848,6 +849,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x global PRINTED + from ..core import GEMLITE_USE_TMA M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W M_CLOSEST = get_closest_m(M) @@ -891,6 +893,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, USE_5D_SCALES = use_5d_scales, + use_tma = GEMLITE_USE_TMA, ) return output diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index f689458..e5c3287 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -328,6 +328,7 @@ def gemm_splitK_INT_kernel( USE_5D_SCALES: tl.constexpr = False, SCALES_5D_SHAPE1: tl.constexpr = 0, SCALES_5D_SHAPE2: tl.constexpr = 0, + use_tma: tl.constexpr = True, ): """ Based on https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py @@ -736,6 +737,7 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, ) -> Tensor: + from ..core import GEMLITE_USE_TMA M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" @@ -787,6 +789,7 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s USE_5D_SCALES = use_5d_scales, SCALES_5D_SHAPE1 = N // 128 if use_5d_scales else 0, SCALES_5D_SHAPE2 = K // group_size // 4 if use_5d_scales else 0, + use_tma = GEMLITE_USE_TMA, ) if(not native_atomic): From ff7114f2d7c7961d496a366cebdac967050df6b0 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 9 Mar 2026 13:43:47 -0700 Subject: [PATCH 42/63] gemv fixes and tests --- gemlite/core.py | 2 +- gemlite/helper.py | 2 +- gemlite/quant_utils.py | 12 ++++--- gemlite/triton_kernels/gemv_kernels.py | 36 ++++++++++++------- .../triton_kernels/gemv_revsplitK_kernels.py | 31 +++++++++++----- gemlite/triton_kernels/gemv_splitK_kernels.py | 9 +++-- tests/test_gemlitelineartriton.py | 12 +++++-- tests/test_mxfp.py | 10 +++++- 8 files changed, 82 insertions(+), 32 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 8e699b1..5216d31 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -66,7 +66,7 @@ GEMLITE_MATMUL_TYPES_MAPPING = {GEMLITE_MATMUL_TYPES[i]: i for i in range(len(GEMLITE_MATMUL_TYPES))} GEMLITE_TRITON_CONFIG_CACHE = {} #Global config cache for all the kernels _GROUP_SIZE_WARNED = False -GEMLITE_USE_TMA = True +GEMLITE_USE_TMA = True # Set to False for faster MXFP8 on sm_120 ################################################################################### #Utils diff --git a/gemlite/helper.py b/gemlite/helper.py index 35a4054..41f40e0 100755 --- a/gemlite/helper.py +++ b/gemlite/helper.py @@ -836,7 +836,7 @@ def from_weights(self, weight, bias=None, scales=None): assert weight.dtype in [torch.uint8], f"Invalid weight.dtype, should be an MXPF8 valid dtype, got {weight.dtype}." assert scales.dtype in [torch.float8_e8m0fnu, torch.uint8], f"Invalid scales.dtype, should be e8m0 / view(uint8), got {scales.dtype}." assert self.dtype is not None, f"Input dtype should be either torch.float16 or torch.bfloat16, not None." - assert self.group_size == 32, f"Only group_size=16 is supported for MXFP4, got {self.group_size}" + assert self.group_size == 32, f"Only group_size=32 is supported for MXFP4, got {self.group_size}" dtype = self.dtype gemlite_dtype = TORCH_TO_DTYPE[dtype] diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 8f64277..3a7e484 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -214,13 +214,15 @@ def quantize_nvfp4( return W_q, scales - def dequantize(self, W_q, scales, shape = None, dtype = None): + def dequantize(self, W_q, scales, shape = None, dtype = None, meta_scales = None): if(W_q.dtype == torch.uint8): #from indices device_index = W_q.device.index W_q = fp4_values[device_index][W_q.int()] group_size = W_q.numel() // scales.numel() out = (W_q.view([-1, group_size]).float() * scales.float()) + if meta_scales is not None: + out = out * meta_scales if(shape is not None): out = out.view(shape) return out.to(self.compute_dtype if dtype is None else dtype) @@ -1478,7 +1480,7 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, @triton.jit def scale_activations_mxfp4_triton_kernel_v2( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, + M, M_padded, K, stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, @@ -1568,7 +1570,7 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: scale_activations_mxfp4_triton_kernel_v2[grid]( tensor, out, scales, thr_pos[device_index], - M, K, + M, M_padded, K, tensor.stride(0), tensor.stride(1), scales.stride(0), scales.stride(1), out.stride(0), out.stride(1), @@ -1595,7 +1597,7 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: @triton.jit def scale_activations_nvfp4_triton_kernel_v2( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, + M, M_padded, K, stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, @@ -1690,7 +1692,7 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso scale_activations_nvfp4_triton_kernel_v2[grid]( tensor, out, scales, thr_pos[device_index], - M, K, + M, M_padded, K, tensor.stride(0), tensor.stride(1), scales.stride(0), scales.stride(1), out.stride(0), out.stride(1), diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index 538b4aa..a30a7a1 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -51,6 +51,8 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_inc_consumer', None) config['NUM_STAGES'] = num_stages + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, pre_hook=pre_hook) return @@ -79,6 +81,8 @@ def kernel_config_pruner(configs, nargs, **kwargs): num_stages = config.num_stages num_warps = config.num_warps + even_n = (n % block_size_n == 0) + key = (block_size_m, block_size_n, block_size_k, A_load_order, dot_prod_mode, num_stages, num_warps) new_config = { @@ -88,6 +92,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): 'A_load_order': A_load_order, 'dot_prod_mode': dot_prod_mode, 'NUM_STAGES': num_stages, + 'EVEN_N': even_n, } if IS_HIP: @@ -278,6 +283,7 @@ def gemv_INT_kernel( join_version: tl.constexpr = False, ################################# load_scales_as_block: tl.constexpr = False, + EVEN_N: tl.constexpr = False, ): """ GEMV for C = matmul(A, dequantize(B, scales, zeros)). This is optimized for M==1 @@ -322,14 +328,12 @@ def gemv_INT_kernel( ################################################################### #Load if(A_load_order == 0): - a = tl.load(a_ptrs, eviction_policy=a_evict) - #a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) - #b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): - a = tl.load(a_ptrs, eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(W_group_mode > 0): k_m = (pid_k * (BLOCK_SIZE_K / group_size)).to(tl.int32) @@ -348,7 +352,7 @@ def gemv_INT_kernel( zeros = None if(A_load_order == 2): - a = tl.load(a_ptrs, eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #tl.join() version if(join_version): @@ -365,7 +369,7 @@ def gemv_INT_kernel( b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): - a = tl.load(a_ptrs, eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(dump_b_val > 0): b = b.to(tl.float32) * dump_b_val @@ -398,7 +402,11 @@ def gemv_INT_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + if EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) @triton.autotune( @@ -459,6 +467,7 @@ def gemv_MX_kernel( join_version: tl.constexpr = False, ################################# load_scales_as_block: tl.constexpr = False, + EVEN_N: tl.constexpr = False, ): """ GEMV for C = matmul(A, dequantize(B, scales, zeros)). This is optimized for M==1 @@ -507,8 +516,7 @@ def gemv_MX_kernel( if(A_load_order == 0): a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) - #b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) @@ -572,7 +580,11 @@ def gemv_MX_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + if EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) #TODO: gemv not generating correct reuslts with mxfp dtypes use except for A16W4. def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, @@ -654,7 +666,7 @@ def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == DType.FP16.value and W_nbits == 8) else 0, #Warning: Only use with INT8 + dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == tl.float16 and W_nbits == 8) else 0, #Warning: Only use with INT8 ) if(not native_atomic): diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index a667ff9..82cc989 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -39,6 +39,8 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, pre_hook=pre_hook) return @@ -76,6 +78,9 @@ def kernel_config_pruner(configs, nargs, **kwargs): if(block_size_k < e): continue + + even_n = (n % block_size_n == 0) + key = (block_size_m, block_size_n, block_size_k, A_load_order, dot_prod_mode, num_stages, num_warps) new_config = { @@ -84,6 +89,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): 'BLOCK_SIZE_K': block_size_k, 'A_load_order': A_load_order, 'dot_prod_mode': dot_prod_mode, + 'EVEN_N': even_n, } if IS_HIP: @@ -272,6 +278,7 @@ def gemv_INT_revsplitK_kernel( atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', + EVEN_N: tl.constexpr = False, ): """ GEMV for C = matmul(A, dequantize(B, scales, zeros)). This is optimized for M==1 @@ -303,6 +310,8 @@ def gemv_INT_revsplitK_kernel( a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] + a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] < K) & (offs_bn[None, :] < N)).to(tl.int1) #Stage 0: Load scales/zeros #----------------------------------------------------------------------------------------------------------- @@ -328,12 +337,12 @@ def gemv_INT_revsplitK_kernel( #----------------------------------------------------------------------------------------------------------- #Load if(A_load_order == 0): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) - b = tl.load(b_ptrs, eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) @@ -350,16 +359,18 @@ def gemv_INT_revsplitK_kernel( #Advance and load next chunk a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + BLOCK_SIZE_K) < K)).to(tl.int1) + b_mask = (((offs_bk[:, None] + BLOCK_SIZE_K) < K) & (offs_bn[None, :] < N)).to(tl.int1) #Stage 2 #----------------------------------------------------------------------------------------------------------- if(A_load_order == 0): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) - b = tl.load(b_ptrs, eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) @@ -395,7 +406,11 @@ def gemv_INT_revsplitK_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + if EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) KERNEL_CACHE = {} @@ -464,7 +479,7 @@ def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype in [DType.FP16.value] and W_nbits == 8) else 0, #Warning: Only use with INT8 + dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == tl.float16 and W_nbits == 8) else 0, #Warning: Only use with INT8 ) if(not native_atomic): diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index ff89b3d..74e9ab1 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -340,6 +340,7 @@ def gemv_INT_splitK_kernel( a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] < K) & (offs_bn[None, :] < N)).to(tl.int1) #Meta data stuff q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] @@ -369,7 +370,10 @@ def gemv_INT_splitK_kernel( else: a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + if EVEN_K and EVEN_N: + b = tl.load(b_ptrs, eviction_policy=b_evict) + else: + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): #Early load if EVEN_M and EVEN_K: @@ -422,6 +426,7 @@ def gemv_INT_splitK_kernel( #Update mask if not EVEN_K: a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K_U) < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] + (k + 1) * BLOCK_SIZE_K_U < K) & (offs_bn[None, :] < N)).to(tl.int1) if(dot_prod_mode == 0): acc = tl.sum(acc, axis=0, keep_dims=True) @@ -514,7 +519,7 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == DType.FP16.value and W_nbits == 8) else 0, #Warning: Only use with INT8 + dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == tl.float16 and W_nbits == 8) else 0, #Warning: Only use with INT8 ) if(not native_atomic): diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index 3c1a8b6..28d9bf6 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -1,4 +1,9 @@ #python -m unittest test_gemlitelineartriton.py +# Usage: python3 test_file.py [--autotune] +import sys +_autotune = '--autotune' in sys.argv +if _autotune: sys.argv.remove('--autotune') + import unittest import torch @@ -18,8 +23,9 @@ def is_fp8_supported(): fp8_dtype = torch.float8_e4m3fn #float8_e4m3fn / torch.float8_e5m2 (Nvidia) gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] + reset_config() -set_autotune(False) +if _autotune is False: set_autotune(False) KERNEL.ENABLE_CACHING = False in_features, out_features = 4032, 2032 @@ -357,4 +363,6 @@ def input_fn(batch_size): def ref_fn(x): return torch.matmul(x.to(compute_dtype), W.T) - self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 \ No newline at end of file + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index ab41e08..b44a1ed 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -1,4 +1,8 @@ #python -m unittest test_mxfp.py +# Usage: python3 test_file.py [--autotune] +import sys +_autotune = '--autotune' in sys.argv +if _autotune: sys.argv.remove('--autotune') import unittest import torch @@ -15,8 +19,9 @@ def is_fp8_supported(device_index=0): device = 'cuda:0' compute_dtype = torch.bfloat16 #float16, bfloat16 matmul_types = ['GEMM', 'GEMM_SPLITK'] #TODO: improve GEMV mxfp accuracy. + reset_config() -set_autotune(False) +if _autotune is False: set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) @@ -93,3 +98,6 @@ def test_A4W4_NVFP_dynamic(self): self.eval(gemlite_linear, tol = 2e-3) + +if __name__ == '__main__': + unittest.main() From 525e655b2f086ef7883219113cc6291001b0bf29 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 10 Mar 2026 04:03:59 -0700 Subject: [PATCH 43/63] update shared memory estimate --- gemlite/triton_kernels/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gemlite/triton_kernels/utils.py b/gemlite/triton_kernels/utils.py index be064e2..05fc97b 100755 --- a/gemlite/triton_kernels/utils.py +++ b/gemlite/triton_kernels/utils.py @@ -131,10 +131,17 @@ def estimate_shared_memory_per_block(block_size_m, block_size_n, block_size_k, a a_smem = block_size_m * block_size_k * a_sizeof if load_scales_as_block: # MX kernels: dot_scaled handles scaling natively, no dequant buffer + # A tile: packed elements (e.g. NVFP4 packs 2 per byte, so K_A = K // e) + a_smem = block_size_m * (block_size_k // e) * a_sizeof b_smem = (block_size_k // e) * block_size_n * b_sizeof - # scales: (BLOCK_N, BLOCK_K // group_size) × meta_sizeof - s_smem = block_size_n * (block_size_k // g) * 1 # uint8 or e4m3 = 1 byte - estimated_smem = (a_smem + b_smem + s_smem) * max(num_stages - 1, 1) + # scales_b: (BLOCK_N, BLOCK_K // group_size), scales_a: (BLOCK_M, BLOCK_K // group_size) + sb_smem = block_size_n * (block_size_k // g) * 1 + sa_smem = block_size_m * (block_size_k // g) * 1 + # 1.25x margin accounts for Triton alignment padding, barriers, and TMA descriptors + loop_smem = int((a_smem + b_smem + sb_smem + sa_smem) * max(num_stages - 1, 1) * 1.25) + # Triton overlaps output buffer with loop data (reuses same SMEM) + output_smem = block_size_m * block_size_n * 2 # bf16 output via TMA store + estimated_smem = max(loop_smem, output_smem) elif e > 1: # INT packed: need packed B + dequantized B for MMA b_smem = (block_size_k // e) * block_size_n * b_sizeof From 3059416a24d353dde11609afc5119fb85e99f6c1 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 10 Mar 2026 11:10:26 -0700 Subject: [PATCH 44/63] fix bugs --- examples/eval_flops.py | 1 + gemlite/__init__.py | 1 + gemlite/core.py | 7 ++++++- gemlite/triton_kernels/gemm_kernels.py | 2 +- gemlite/triton_kernels/gemm_splitK_kernels.py | 13 +++++++++++-- .../gemm_splitK_persistent_kernels.py | 2 +- gemlite/triton_kernels/gemv_kernels.py | 2 +- gemlite/triton_kernels/gemv_revsplitK_kernels.py | 2 +- gemlite/triton_kernels/gemv_splitK_kernels.py | 2 +- gemlite/triton_kernels/utils.py | 3 +-- 10 files changed, 25 insertions(+), 10 deletions(-) diff --git a/examples/eval_flops.py b/examples/eval_flops.py index ba6fb27..4351bc4 100644 --- a/examples/eval_flops.py +++ b/examples/eval_flops.py @@ -12,6 +12,7 @@ repeat = 32 gemlite.reset_config() +#gemlite.enable_cudagraph_autotune(True) #gemlite.set_autotune("max") #gemlite.core.enable_activation_scaling(2) diff --git a/gemlite/__init__.py b/gemlite/__init__.py index ce5f154..0e5e38b 100755 --- a/gemlite/__init__.py +++ b/gemlite/__init__.py @@ -13,6 +13,7 @@ set_autotune, set_kernel_caching, enable_tma, + enable_cudagraph_autotune, forward_functional, ) diff --git a/gemlite/core.py b/gemlite/core.py index 5216d31..a889331 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -102,6 +102,10 @@ def enable_tma(enabled: bool = True): global GEMLITE_USE_TMA GEMLITE_USE_TMA = enabled +#Enable/disable CUDA graph-based autotuning (more accurate but slower) +def enable_cudagraph_autotune(enabled: bool = True): + set_autotune("fast", use_cuda_graph=enabled) + #Return the default gemv kernel to use for M==1 def get_default_gemv(W_nbits: int, mx_dtype: bool = False) -> str: #TODO: adapt mx for IS_HIP = True @@ -112,7 +116,8 @@ def get_default_gemv(W_nbits: int, mx_dtype: bool = False) -> str: #matmul type selection logic def get_matmul_type(batch_size: int, W_nbits: int, mx_dtype: bool = False): - if batch_size > 64: + gemm_limit = 64 + if batch_size >= gemm_limit: return "GEMM" gemv_limit = 4 if (W_nbits < 8 and not mx_dtype) else 2 # previous 1 if batch_size > gemv_limit: diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index cce98fb..242956e 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -32,7 +32,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index e5c3287..692ba08 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -32,7 +32,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) @@ -181,7 +181,6 @@ def get_fast_autotune_config_nvidia(): #Medium N tiles (N=128 — workhorse for MX/INT types) configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) #Large N tiles @@ -194,6 +193,16 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=5)) configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + #Additional M=16 configs for MX kernel coverage + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) + #M=32 tiles (for M=32..64 batch sizes) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) return configs def get_default_config_nvidia(): diff --git a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py index 2e9fb54..9808b1c 100755 --- a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py @@ -32,7 +32,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index a30a7a1..7468144 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -43,7 +43,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index 82cc989..cfa52c8 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -32,7 +32,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index 74e9ab1..362fcf3 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -30,7 +30,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) diff --git a/gemlite/triton_kernels/utils.py b/gemlite/triton_kernels/utils.py index 05fc97b..c4c84be 100755 --- a/gemlite/triton_kernels/utils.py +++ b/gemlite/triton_kernels/utils.py @@ -137,8 +137,7 @@ def estimate_shared_memory_per_block(block_size_m, block_size_n, block_size_k, a # scales_b: (BLOCK_N, BLOCK_K // group_size), scales_a: (BLOCK_M, BLOCK_K // group_size) sb_smem = block_size_n * (block_size_k // g) * 1 sa_smem = block_size_m * (block_size_k // g) * 1 - # 1.25x margin accounts for Triton alignment padding, barriers, and TMA descriptors - loop_smem = int((a_smem + b_smem + sb_smem + sa_smem) * max(num_stages - 1, 1) * 1.25) + loop_smem = (a_smem + b_smem + sb_smem + sa_smem) * max(num_stages - 1, 1) # Triton overlaps output buffer with loop data (reuses same SMEM) output_smem = block_size_m * block_size_n * 2 # bf16 output via TMA store estimated_smem = max(loop_smem, output_smem) From 50ff07c525829505e1b567430c507c4996d824bf Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 10 Mar 2026 11:11:56 -0700 Subject: [PATCH 45/63] clean-up --- examples/bench_5d_tma.py | 267 ---------------------- examples/bench_act_quant.py | 82 ------- examples/bench_act_quant_final.py | 62 ------ examples/bench_act_quant_v4.py | 343 ----------------------------- examples/bench_act_quant_v5.py | 353 ------------------------------ 5 files changed, 1107 deletions(-) delete mode 100644 examples/bench_5d_tma.py delete mode 100644 examples/bench_act_quant.py delete mode 100644 examples/bench_act_quant_final.py delete mode 100644 examples/bench_act_quant_v4.py delete mode 100644 examples/bench_act_quant_v5.py diff --git a/examples/bench_5d_tma.py b/examples/bench_5d_tma.py deleted file mode 100644 index 8271b92..0000000 --- a/examples/bench_5d_tma.py +++ /dev/null @@ -1,267 +0,0 @@ -""" -Standalone benchmark: compare pointer-based vs 5D TMA scale loading in block-scaled GEMM. -Tests the NVFP4 case (group_size=16, e4m3 scales). -""" -import torch -import triton -import triton.language as tl -from triton.tools.tensor_descriptor import TensorDescriptor - -device = "cuda:0" -dtype = torch.bfloat16 - -# Required for TMA tensor descriptors -from typing import Optional -def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) -triton.set_allocator(alloc_fn) - - -def preshuffle_scales(scales_2d, N, K_S): - """Convert [N, K_S] scales to 5D preshuffled layout for TMA. - - Follows the Triton tutorial layout: [1, N//128, K_S//4, 2, 256] - Preserves dtype (fp8_e4m3fn for NVFP4, uint8 for MXFP4). - """ - return ( - scales_2d - .reshape(N // 128, 4, 32, K_S // 4, 4) - .permute(0, 3, 2, 1, 4) - .reshape(1, N // 128, K_S // 4, 2, 256) - .contiguous() - ) - - -# Kernel with pointer-based scale loading (current gemlite approach) -@triton.jit -def gemm_fp4_pointer_scales( - a_ptr, b_ptr, c_ptr, - scales_b_ptr, scales_a_ptr, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - group_size: tl.constexpr, - stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_bn: tl.constexpr, stride_bk: tl.constexpr, - stride_cm: tl.constexpr, stride_cn: tl.constexpr, - stride_sb_n: tl.constexpr, stride_sb_g: tl.constexpr, - stride_sa_m: tl.constexpr, stride_sa_g: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_STAGES: tl.constexpr, - meta_scale_norm: tl.constexpr = 0.0025, -): - pid = tl.program_id(0) - num_pid_n = tl.cdiv(N, BLOCK_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - BLOCK_K_A: tl.constexpr = BLOCK_K // 2 # packed FP4 - BLOCK_K_B: tl.constexpr = BLOCK_K // 2 - BLOCK_K_S: tl.constexpr = BLOCK_K // group_size - - # TMA for data - a_desc = tl.make_tensor_descriptor(a_ptr, [M, K // 2], [stride_am, stride_ak], [BLOCK_M, BLOCK_K_A]) - b_desc = tl.make_tensor_descriptor(b_ptr, [N, K // 2], [stride_bn, stride_bk], [BLOCK_N, BLOCK_K_B]) - c_desc = tl.make_tensor_descriptor(c_ptr, [M, N], [stride_cm, stride_cn], [BLOCK_M, BLOCK_N]) - - # Pointer-based scales - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_k_s = tl.arange(0, BLOCK_K_S) - scales_b_ptrs = scales_b_ptr + offs_n[:, None] * stride_sb_n + offs_k_s[None, :] * stride_sb_g - - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - scales_a_ptrs = scales_a_ptr + offs_m[:, None] * stride_sa_m + offs_k_s[None, :] * stride_sa_g - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - num_k = tl.cdiv(K, BLOCK_K) - for k in tl.range(num_k, num_stages=NUM_STAGES): - a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_M, k * BLOCK_K_A]) - b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_N, k * BLOCK_K_B]).T - - k_m = k * BLOCK_K_S - scales_b = tl.load(scales_b_ptrs + k_m * stride_sb_g) - scales_a = tl.load(scales_a_ptrs + k_m * stride_sa_g) - - acc = tl.dot_scaled(a, scales_a, "e2m1", b, scales_b, "e2m1", acc) - - if group_size == 16: - acc *= meta_scale_norm - - tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_M, pid_n * BLOCK_N], value=acc) - - -# Kernel with 5D TMA scale loading (tutorial approach) -@triton.jit -def gemm_fp4_5d_tma_scales( - a_ptr, b_ptr, c_ptr, - scales_b_5d_ptr, scales_a_5d_ptr, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - group_size: tl.constexpr, - stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_bn: tl.constexpr, stride_bk: tl.constexpr, - stride_cm: tl.constexpr, stride_cn: tl.constexpr, - sb_s0: tl.constexpr, sb_s1: tl.constexpr, sb_s2: tl.constexpr, sb_s3: tl.constexpr, sb_s4: tl.constexpr, - sa_s0: tl.constexpr, sa_s1: tl.constexpr, sa_s2: tl.constexpr, sa_s3: tl.constexpr, sa_s4: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_STAGES: tl.constexpr, - meta_scale_norm: tl.constexpr = 0.0025, -): - pid = tl.program_id(0) - num_pid_n = tl.cdiv(N, BLOCK_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - BLOCK_K_A: tl.constexpr = BLOCK_K // 2 - BLOCK_K_B: tl.constexpr = BLOCK_K // 2 - BLOCK_K_S: tl.constexpr = BLOCK_K // group_size - VEC_SIZE: tl.constexpr = group_size - - rep_m: tl.constexpr = BLOCK_M // 128 - rep_n: tl.constexpr = BLOCK_N // 128 - rep_k: tl.constexpr = BLOCK_K // VEC_SIZE // 4 - - # TMA for data - a_desc = tl.make_tensor_descriptor(a_ptr, [M, K // 2], [stride_am, stride_ak], [BLOCK_M, BLOCK_K_A]) - b_desc = tl.make_tensor_descriptor(b_ptr, [N, K // 2], [stride_bn, stride_bk], [BLOCK_N, BLOCK_K_B]) - c_desc = tl.make_tensor_descriptor(c_ptr, [M, N], [stride_cm, stride_cn], [BLOCK_M, BLOCK_N]) - - # 5D TMA for scales - scales_b_shape1: tl.constexpr = N // 128 - scales_b_shape2: tl.constexpr = K // VEC_SIZE // 4 - scales_b_desc = tl.make_tensor_descriptor( - scales_b_5d_ptr, - [1, scales_b_shape1, scales_b_shape2, 2, 256], - [sb_s0, sb_s1, sb_s2, sb_s3, sb_s4], - [1, rep_n, rep_k, 2, 256], - ) - - scales_a_shape1: tl.constexpr = M // 128 - scales_a_shape2: tl.constexpr = K // VEC_SIZE // 4 - scales_a_desc = tl.make_tensor_descriptor( - scales_a_5d_ptr, - [1, scales_a_shape1, scales_a_shape2, 2, 256], - [sa_s0, sa_s1, sa_s2, sa_s3, sa_s4], - [1, rep_m, rep_k, 2, 256], - ) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - num_k = tl.cdiv(K, BLOCK_K) - for k in tl.range(num_k, num_stages=NUM_STAGES): - a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_M, k * BLOCK_K_A]) - b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_N, k * BLOCK_K_B]).T - - # 5D TMA scale loads - scale_b_raw = tl.load_tensor_descriptor(scales_b_desc, [0, pid_n * rep_n, k * rep_k, 0, 0]) - scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K_S) - - scale_a_raw = tl.load_tensor_descriptor(scales_a_desc, [0, pid_m * rep_m, k * rep_k, 0, 0]) - scales_a = scale_a_raw.reshape(rep_m, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K_S) - - acc = tl.dot_scaled(a, scales_a, "e2m1", b, scales_b, "e2m1", acc) - - if group_size == 16: - acc *= meta_scale_norm - - tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_M, pid_n * BLOCK_N], value=acc) - - -def bench(M, N, K, group_size=16, BLOCK_M=128, BLOCK_N=128, BLOCK_K=128, NUM_STAGES=4, num_warps=4): - VEC_SIZE = group_size - K_S = K // group_size - - # Create random FP4 data (packed as uint8) - a = torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device) - b = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device) - - # 2D scales for pointer-based kernel - scales_b_2d = torch.randn(N, K_S, device=device).to(torch.float8_e4m3fn) # [N, K_S] - scales_a_2d = torch.randn(M, K_S, device=device).to(torch.float8_e4m3fn) # [M, K_S] - - # Transposed view (matching gemlite's current layout) - scales_b_T = scales_b_2d.T # [K_S, N] with strides (1, K_S) - scales_a_T = scales_a_2d.T # not used directly, pointer from original - - # 5D preshuffled scales (keep fp8_e4m3fn dtype for NVFP4) - scales_b_5d = preshuffle_scales(scales_b_2d, N, K_S) - scales_a_5d = preshuffle_scales(scales_a_2d, M, K_S) - - c_ptr = torch.empty((M, N), dtype=torch.bfloat16, device=device) - c_5d = torch.empty((M, N), dtype=torch.bfloat16, device=device) - - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) - - # Pointer-based kernel - def run_pointer(): - gemm_fp4_pointer_scales[grid]( - a, b, c_ptr, - scales_b_T, scales_a_2d, # scales_b is transposed, scales_a is row-major - M, N, K, group_size, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c_ptr.stride(0), c_ptr.stride(1), - scales_b_T.stride(0), scales_b_T.stride(1), # stride_sb_n=1, stride_sb_g=K_S - scales_a_2d.stride(0), scales_a_2d.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - NUM_STAGES=NUM_STAGES, - num_warps=num_warps, - ) - - # 5D TMA kernel - def run_5d_tma(): - gemm_fp4_5d_tma_scales[grid]( - a, b, c_5d, - scales_b_5d, scales_a_5d, - M, N, K, group_size, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c_5d.stride(0), c_5d.stride(1), - scales_b_5d.stride(0), scales_b_5d.stride(1), scales_b_5d.stride(2), scales_b_5d.stride(3), scales_b_5d.stride(4), - scales_a_5d.stride(0), scales_a_5d.stride(1), scales_a_5d.stride(2), scales_a_5d.stride(3), scales_a_5d.stride(4), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - NUM_STAGES=NUM_STAGES, - num_warps=num_warps, - ) - - ms_ptr = triton.testing.do_bench(run_pointer, warmup=200, rep=500) - ms_5d = triton.testing.do_bench(run_5d_tma, warmup=200, rep=500) - - flops = 2.0 * M * N * K - tflops_ptr = flops / (ms_ptr * 1e-3) / 1e12 - tflops_5d = flops / (ms_5d * 1e-3) / 1e12 - - print(f" Pointer scales: {ms_ptr:.3f} ms, {tflops_ptr:.1f} TFLOP/s") - print(f" 5D TMA scales: {ms_5d:.3f} ms, {tflops_5d:.1f} TFLOP/s") - print(f" Speedup: {ms_ptr / ms_5d:.3f}x") - return ms_ptr, ms_5d - - -if __name__ == "__main__": - M, N, K = 8192, 16384, 16384 - group_size = 16 # NVFP4 - - configs = [ - # (BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, num_warps) - # Best so far: 128x128x128, 3 stages, 4 warps = 1217.5 TFLOP/s - (128, 128, 128, 2, 4), - (128, 128, 128, 3, 4), - (128, 128, 128, 3, 8), - (128, 128, 128, 4, 4), - (128, 128, 128, 5, 4), - (128, 128, 256, 2, 4), - (128, 128, 256, 2, 8), - (128, 256, 128, 2, 4), - (128, 256, 128, 2, 8), - (128, 256, 128, 3, 4), - (128, 256, 256, 2, 4), - (128, 256, 256, 2, 8), - ] - - print(f"M={M}, N={N}, K={K}, group_size={group_size}") - for bm, bn, bk, ns, nw in configs: - rep_k = bk // group_size // 4 - if rep_k < 1: - print(f"\n Skipping BLOCK_M={bm}, BLOCK_N={bn}, BLOCK_K={bk} (rep_k < 1)") - continue - print(f"\n BLOCK_M={bm}, BLOCK_N={bn}, BLOCK_K={bk}, stages={ns}, warps={nw}") - try: - bench(M, N, K, group_size, bm, bn, bk, ns, nw) - except Exception as e: - print(f" FAILED: {e}") diff --git a/examples/bench_act_quant.py b/examples/bench_act_quant.py deleted file mode 100644 index 084c3a6..0000000 --- a/examples/bench_act_quant.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Benchmark activation quantization kernels: - - gemlite MXFP4 (v1, v2, v3) - - gemlite NVFP4 (v1, v2, v3) - - flashinfer nvfp4_quantize -""" -import torch -import triton - -torch.manual_seed(0) -device = "cuda:0" -dtype = torch.bfloat16 - -# ---- gemlite quant kernels ---- -from gemlite.quant_utils import ( - scale_activations_mxfp4_triton as mxfp4_v1, - scale_activations_mxfp4_triton_v2 as mxfp4_v2, - scale_activations_mxfp4_triton_v3 as mxfp4_v3, - scale_activations_nvfp4_triton as nvfp4_v1, - scale_activations_nvfp4_triton_v2 as nvfp4_v2, - scale_activations_nvfp4_triton_v3 as nvfp4_v3, -) - -# ---- flashinfer ---- -from flashinfer import nvfp4_quantize, SfLayout - -def flashinfer_nvfp4_quant(x): - global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) - return nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) - -def flashinfer_nvfp4_quant_no_scale(x): - """Just the quantize kernel, pre-computed global scale.""" - return nvfp4_quantize(x, x._global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) - -# ---- benchmark ---- -KERNELS = { - "gemlite mxfp4 v1": mxfp4_v1, - "gemlite mxfp4 v2": mxfp4_v2, - "gemlite mxfp4 v3": mxfp4_v3, - "gemlite nvfp4 v1": nvfp4_v1, - "gemlite nvfp4 v2": nvfp4_v2, - "gemlite nvfp4 v3": nvfp4_v3, - "flashinfer nvfp4 (with global_sf)": flashinfer_nvfp4_quant, - "flashinfer nvfp4 (kernel only)": None, # special case -} - -shapes = [ - (1024, 4096), - (1024, 16384), - (4096, 4096), - (4096, 16384), - (8192, 4096), - (8192, 16384), - (16384, 16384), -] - -print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") -print("=" * 76) - -for M, K in shapes: - x = torch.randn(M, K, device=device, dtype=dtype) - # Pre-compute for flashinfer kernel-only variant - global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) - x._global_sf = global_sf - - bytes_read = M * K * x.element_size() # input bytes - - for name, fn in KERNELS.items(): - if name == "flashinfer nvfp4 (kernel only)": - fn_bench = lambda: nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) - else: - fn_bench = lambda fn=fn: fn(x) - - try: - ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) - us = ms * 1000 - gbps = bytes_read / (ms * 1e-3) / 1e9 - print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") - except Exception as e: - print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:30]}") - - print() diff --git a/examples/bench_act_quant_final.py b/examples/bench_act_quant_final.py deleted file mode 100644 index 8cf6e57..0000000 --- a/examples/bench_act_quant_final.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Benchmark activation quantization kernels from quant_utils.py (v5 integrated) -vs flashinfer nvfp4_quantize. -""" -import torch -import triton - -torch.manual_seed(0) -device = "cuda:0" -dtype = torch.bfloat16 - -# Import directly from quant_utils (now v5 by default) -from gemlite.quant_utils import ( - scale_activations_mxfp4, # v5 - scale_activations_nvfp4, # v5 - scale_activations_mxfp4_triton_v3 as mxfp4_v3, - scale_activations_nvfp4_triton_v3 as nvfp4_v3, -) - -from flashinfer import nvfp4_quantize, SfLayout - -shapes = [ - (1024, 4096), - (1024, 16384), - (4096, 4096), - (4096, 16384), - (8192, 4096), - (8192, 16384), - (16384, 16384), -] - -KERNELS = { - "gemlite mxfp4 (default=v5)": scale_activations_mxfp4, - "gemlite mxfp4 v3 (old)": mxfp4_v3, - "gemlite nvfp4 (default=v5)": scale_activations_nvfp4, - "gemlite nvfp4 v3 (old)": nvfp4_v3, - "flashinfer nvfp4 (kernel)": None, -} - -print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") -print("=" * 76) - -for M, K in shapes: - x = torch.randn(M, K, device=device, dtype=dtype) - global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) - bytes_read = M * K * x.element_size() - - for name, fn in KERNELS.items(): - if fn is None: - fn_bench = lambda: nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) - else: - fn_bench = lambda fn=fn: fn(x) - - try: - ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) - us = ms * 1000 - gbps = bytes_read / (ms * 1e-3) / 1e9 - print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") - except Exception as e: - print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:40]}") - - print() diff --git a/examples/bench_act_quant_v4.py b/examples/bench_act_quant_v4.py deleted file mode 100644 index f7c0679..0000000 --- a/examples/bench_act_quant_v4.py +++ /dev/null @@ -1,343 +0,0 @@ -""" -Test a v4 NVFP4 activation quant kernel: - - persistent 1D grid with K-loop (like v2) for better SM utilization - - scalar threshold comparisons (like v3) to avoid 3D intermediate - - block_ptr for coalesced loads with multi-stage pipelining -""" -import torch -import triton -import triton.language as tl - -torch.manual_seed(0) -device = "cuda:0" -dtype = torch.bfloat16 - -# Import gemlite quant utils for comparison + thr_pos -from gemlite.quant_utils import ( - scale_activations_nvfp4_triton_v3 as nvfp4_v3, - scale_activations_mxfp4_triton_v3 as mxfp4_v3, - thr_pos, - NVFP4_META_SCALE, - get_num_SMs, -) - -NUM_SMS = get_num_SMs(0) - -# flashinfer for comparison -from flashinfer import nvfp4_quantize, SfLayout - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=3), - ], - key=['M', 'K'], -) -@triton.jit -def scale_activations_nvfp4_kernel_v4( - tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, - stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, - stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, - stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, - eps: tl.constexpr, - GROUP_SIZE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - meta_scales: tl.constexpr = NVFP4_META_SCALE, -): - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) - - GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE - HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 - FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK - fp8_dtype: tl.constexpr = tl.float8e4nv - max_fp8: tl.constexpr = 448. - out_dtype: tl.constexpr = out_ptr.dtype.element_ty - - # Load thresholds as scalars (like v3) - thr0 = tl.load(thr_pos_ptr + 0) - thr1 = tl.load(thr_pos_ptr + 1) - thr2 = tl.load(thr_pos_ptr + 2) - thr3 = tl.load(thr_pos_ptr + 3) - thr4 = tl.load(thr_pos_ptr + 4) - thr5 = tl.load(thr_pos_ptr + 5) - thr6 = tl.load(thr_pos_ptr + 6) - thr7 = tl.load(thr_pos_ptr + 7) - - for tile_m in range(pid, num_m_tiles, num_programs): - offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - m_mask = offs_m < M - - tensor_bp = tl.make_block_ptr( - tensor_ptr, (M, K), (stride_m_t, stride_k_t), - (tile_m * BLOCK_SIZE_M, 0), - (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) - ) - out_bp = tl.make_block_ptr( - out_ptr, (M, K // 2), (stride_m_o, stride_k_o), - (tile_m * BLOCK_SIZE_M, 0), - (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) - ) - - for k_start in range(0, K, BLOCK_SIZE_K): - tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) - - # Reshape to [FLAT_M, GROUP_SIZE] for group-wise reduction - tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) - - # Per-group FP8 scale - abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) - scales_raw = abs_max / (6. * meta_scales) - scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) - scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) - - # Scalar threshold comparisons (v3 approach, no 3D intermediate) - wq = tensor_flat / scales_full - abs_wq = tl.abs(wq) - idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + - (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + - (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + - (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - - # Reshape back and pack - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) - lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) - out = lo | (hi << 4) - - tl.store(out_bp, out, boundary_check=(0, 1)) - - # Store scales - scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) - group_idx = k_start // GROUP_SIZE - offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) - g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) - tl.store( - scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, - scales_2d, mask=m_mask[:, None] & g_mask[None, :] - ) - - tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) - out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) - - -def scale_activations_nvfp4_v4(tensor: torch.Tensor): - group_size: int = 16 - eps: float = 1e-6 - fp8_dtype = torch.float8_e4m3fn - - tensor = tensor.contiguous() - tensor = tensor.view(-1, tensor.shape[-1]) - M, K = tensor.shape - - pad_m = (group_size - M % group_size) % group_size - M_padded = M + pad_m - - out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) - - grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) - device_index = tensor.device.index - - scale_activations_nvfp4_kernel_v4[grid]( - tensor, out, scales, thr_pos[device_index], - M, K, - tensor.stride(0), tensor.stride(1), - scales.stride(0), scales.stride(1), - out.stride(0), out.stride(1), - eps=eps, - GROUP_SIZE=group_size, - ) - return out, scales - - -# Also write MXFP4 v4 with same approach -from gemlite.quant_utils import next_power_of_2_triton - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=3), - ], - key=['M', 'K'], -) -@triton.jit -def scale_activations_mxfp4_kernel_v4( - tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, - stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, - stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, - stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, - eps_exp: tl.constexpr, - GROUP_SIZE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) - - GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE - HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 - FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK - out_dtype: tl.constexpr = out_ptr.dtype.element_ty - - thr0 = tl.load(thr_pos_ptr + 0) - thr1 = tl.load(thr_pos_ptr + 1) - thr2 = tl.load(thr_pos_ptr + 2) - thr3 = tl.load(thr_pos_ptr + 3) - thr4 = tl.load(thr_pos_ptr + 4) - thr5 = tl.load(thr_pos_ptr + 5) - thr6 = tl.load(thr_pos_ptr + 6) - thr7 = tl.load(thr_pos_ptr + 7) - - for tile_m in range(pid, num_m_tiles, num_programs): - offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - m_mask = offs_m < M - - tensor_bp = tl.make_block_ptr( - tensor_ptr, (M, K), (stride_m_t, stride_k_t), - (tile_m * BLOCK_SIZE_M, 0), - (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) - ) - out_bp = tl.make_block_ptr( - out_ptr, (M, K // 2), (stride_m_o, stride_k_o), - (tile_m * BLOCK_SIZE_M, 0), - (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) - ) - - for k_start in range(0, K, BLOCK_SIZE_K): - tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) - tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) - - # MXFP4 scales: next power of 2 - scales, scales_log2 = next_power_of_2_triton( - tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp - ) - - wq = tensor_flat / scales - abs_wq = tl.abs(wq) - idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + - (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + - (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + - (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) - lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) - out = lo | (hi << 4) - - tl.store(out_bp, out, boundary_check=(0, 1)) - - scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) - group_idx = k_start // GROUP_SIZE - offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) - g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) - tl.store( - scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, - scales_2d, mask=m_mask[:, None] & g_mask[None, :] - ) - - tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) - out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) - - -def scale_activations_mxfp4_v4(tensor: torch.Tensor): - group_size: int = 32 - eps_exp: int = -30 - - tensor = tensor.contiguous() - tensor = tensor.view(-1, tensor.shape[-1]) - M, K = tensor.shape - - pad_m = (group_size - M % group_size) % group_size - M_padded = M + pad_m - - out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - - grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) - device_index = tensor.device.index - - scale_activations_mxfp4_kernel_v4[grid]( - tensor, out, scales, thr_pos[device_index], - M, K, - tensor.stride(0), tensor.stride(1), - scales.stride(0), scales.stride(1), - out.stride(0), out.stride(1), - eps_exp=eps_exp, - GROUP_SIZE=group_size, - ) - return out, scales - - -# ---- Benchmark ---- -def flashinfer_nvfp4_kernel_only(x, global_sf): - return nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) - - -shapes = [ - (1024, 4096), - (1024, 16384), - (4096, 4096), - (4096, 16384), - (8192, 4096), - (8192, 16384), - (16384, 16384), -] - -KERNELS = { - "gemlite mxfp4 v3": mxfp4_v3, - "gemlite mxfp4 v4": scale_activations_mxfp4_v4, - "gemlite nvfp4 v3": nvfp4_v3, - "gemlite nvfp4 v4": scale_activations_nvfp4_v4, - "flashinfer nvfp4 (kernel only)": None, -} - -print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") -print("=" * 76) - -for M, K in shapes: - x = torch.randn(M, K, device=device, dtype=dtype) - global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) - bytes_read = M * K * x.element_size() - - for name, fn in KERNELS.items(): - if name == "flashinfer nvfp4 (kernel only)": - fn_bench = lambda: flashinfer_nvfp4_kernel_only(x, global_sf) - else: - fn_bench = lambda fn=fn: fn(x) - - try: - ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) - us = ms * 1000 - gbps = bytes_read / (ms * 1e-3) / 1e9 - print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") - except Exception as e: - print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:40]}") - - print() diff --git a/examples/bench_act_quant_v5.py b/examples/bench_act_quant_v5.py deleted file mode 100644 index 7884f65..0000000 --- a/examples/bench_act_quant_v5.py +++ /dev/null @@ -1,353 +0,0 @@ -""" -v5 NVFP4 activation quant: 2D grid like v3, but with BLOCK_SIZE_K processing -multiple groups per block. Keeps the simplicity of v3 while reducing block count. -""" -import torch -import triton -import triton.language as tl -from typing import Tuple - -torch.manual_seed(0) -device = "cuda:0" -dtype = torch.bfloat16 - -from gemlite.quant_utils import ( - scale_activations_nvfp4_triton_v3 as nvfp4_v3, - scale_activations_mxfp4_triton_v3 as mxfp4_v3, - thr_pos, - NVFP4_META_SCALE, - next_power_of_2_triton, -) -from flashinfer import nvfp4_quantize, SfLayout - - -def prune_large_blocks(configs, nargs, **kwargs): - M = nargs['M'] - K = nargs['K'] - for config in configs: - bm = config.kwargs['BLOCK_SIZE_M'] - bk = config.kwargs['BLOCK_SIZE_K'] - if bm > M or bk > K: - continue - yield config - - -# ---- NVFP4 v5: 2D grid, multi-group per block ---- -@triton.autotune( - configs=[ - # BLOCK_SIZE_K must be multiple of GROUP_SIZE=16 - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), - # Multi-stage - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), - ], - key=['M', 'K'], - prune_configs_by={'early_config_prune': prune_large_blocks}, -) -@triton.jit -def scale_activations_nvfp4_kernel_v5( - tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, - stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, - stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, - stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, - eps: tl.constexpr, - GROUP_SIZE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - meta_scales: tl.constexpr = NVFP4_META_SCALE, -): - pid_m = tl.program_id(axis=0) - pid_k = tl.program_id(axis=1) - - fp8_dtype: tl.constexpr = tl.float8e4nv - max_fp8: tl.constexpr = 448. - HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 - GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE - FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK - out_dtype: tl.constexpr = out_ptr.dtype.element_ty - - thr0 = tl.load(thr_pos_ptr + 0) - thr1 = tl.load(thr_pos_ptr + 1) - thr2 = tl.load(thr_pos_ptr + 2) - thr3 = tl.load(thr_pos_ptr + 3) - thr4 = tl.load(thr_pos_ptr + 4) - thr5 = tl.load(thr_pos_ptr + 5) - thr6 = tl.load(thr_pos_ptr + 6) - thr7 = tl.load(thr_pos_ptr + 7) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - - # Load BLOCK_SIZE_K elements (multiple groups) - offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) - tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) - tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) - - # Reshape to [FLAT_M, GROUP_SIZE] for per-group reduction - tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) - - # FP8 scales per group - abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) - scales_raw = abs_max / (6. * meta_scales) - scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) - scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) - - # Scalar threshold comparisons - wq = tensor_flat / scales_full - abs_wq = tl.abs(wq) - idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + - (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + - (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + - (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - - # Reshape back to [BLOCK_SIZE_M, BLOCK_SIZE_K] and pack pairs - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) - lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) - out = lo | (hi << 4) - - # Store packed output - offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) - out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) - tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) - - # Store scales [BLOCK_SIZE_M, GROUPS_PER_BLOCK] - scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) - base_group = pid_k * GROUPS_PER_BLOCK - offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) - g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) - tl.store( - scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, - scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] - ) - - -def scale_activations_nvfp4_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - group_size: int = 16 - eps: float = 1e-6 - fp8_dtype = torch.float8_e4m3fn - - tensor = tensor.contiguous() - tensor = tensor.view(-1, tensor.shape[-1]) - M, K = tensor.shape - - pad_m = (group_size - M % group_size) % group_size - M_padded = M + pad_m - - out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) - - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) - device_index = tensor.device.index - - scale_activations_nvfp4_kernel_v5[grid]( - tensor, out, scales, thr_pos[device_index], - M, K, - tensor.stride(0), tensor.stride(1), - scales.stride(0), scales.stride(1), - out.stride(0), out.stride(1), - eps=eps, - GROUP_SIZE=group_size, - ) - return out, scales - - -# Same for MXFP4 -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), - # Multi-stage - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=3), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=2), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64}, num_warps=8, num_stages=2), - ], - key=['M', 'K'], - prune_configs_by={'early_config_prune': prune_large_blocks}, -) -@triton.jit -def scale_activations_mxfp4_kernel_v5( - tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - M, K, - stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, - stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, - stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, - eps_exp: tl.constexpr, - GROUP_SIZE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_k = tl.program_id(axis=1) - - HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 - GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE - FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK - out_dtype: tl.constexpr = out_ptr.dtype.element_ty - - thr0 = tl.load(thr_pos_ptr + 0) - thr1 = tl.load(thr_pos_ptr + 1) - thr2 = tl.load(thr_pos_ptr + 2) - thr3 = tl.load(thr_pos_ptr + 3) - thr4 = tl.load(thr_pos_ptr + 4) - thr5 = tl.load(thr_pos_ptr + 5) - thr6 = tl.load(thr_pos_ptr + 6) - thr7 = tl.load(thr_pos_ptr + 7) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) - tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) - tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) - - tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) - - scales, scales_log2 = next_power_of_2_triton( - tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp - ) - - wq = tensor_flat / scales - abs_wq = tl.abs(wq) - idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + - (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + - (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + - (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) - lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) - out = lo | (hi << 4) - - offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) - out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) - tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) - - scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) - base_group = pid_k * GROUPS_PER_BLOCK - offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) - g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) - tl.store( - scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, - scales_2d, mask=(offs_m[:, None] < M) & g_mask[None, :] - ) - - -def scale_activations_mxfp4_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - group_size: int = 32 - eps_exp: int = -30 - - tensor = tensor.contiguous() - tensor = tensor.view(-1, tensor.shape[-1]) - M, K = tensor.shape - - pad_m = (group_size - M % group_size) % group_size - M_padded = M + pad_m - - out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) - scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) - device_index = tensor.device.index - - scale_activations_mxfp4_kernel_v5[grid]( - tensor, out, scales, thr_pos[device_index], - M, K, - tensor.stride(0), tensor.stride(1), - scales.stride(0), scales.stride(1), - out.stride(0), out.stride(1), - eps_exp=eps_exp, - GROUP_SIZE=group_size, - ) - return out, scales - - -# ---- Benchmark ---- -def flashinfer_nvfp4_kernel_only(x, global_sf): - return nvfp4_quantize(x, global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) - - -shapes = [ - (1024, 4096), - (1024, 16384), - (4096, 4096), - (4096, 16384), - (8192, 4096), - (8192, 16384), - (16384, 16384), -] - -KERNELS = { - "gemlite mxfp4 v3": mxfp4_v3, - "gemlite mxfp4 v5": scale_activations_mxfp4_v5, - "gemlite nvfp4 v3": nvfp4_v3, - "gemlite nvfp4 v5": scale_activations_nvfp4_v5, - "flashinfer nvfp4 (kernel only)": None, -} - -print(f"{'Kernel':<40} {'Shape':>14} {'Time (us)':>10} {'GB/s':>8}") -print("=" * 76) - -for M, K in shapes: - x = torch.randn(M, K, device=device, dtype=dtype) - global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) - bytes_read = M * K * x.element_size() - - for name, fn in KERNELS.items(): - if name == "flashinfer nvfp4 (kernel only)": - fn_bench = lambda: flashinfer_nvfp4_kernel_only(x, global_sf) - else: - fn_bench = lambda fn=fn: fn(x) - - try: - ms = triton.testing.do_bench(fn_bench, warmup=200, rep=200) - us = ms * 1000 - gbps = bytes_read / (ms * 1e-3) / 1e9 - print(f" {name:<38} {str((M,K)):>14} {us:>10.1f} {gbps:>8.1f}") - except Exception as e: - print(f" {name:<38} {str((M,K)):>14} {'FAILED':>10} {str(e)[:40]}") - - print() From dd7abeab71b64134e7e86951a74f26b83170cf6b Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 10 Mar 2026 11:13:22 -0700 Subject: [PATCH 46/63] update version --- gemlite/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gemlite/__init__.py b/gemlite/__init__.py index 0e5e38b..fcc31b3 100755 --- a/gemlite/__init__.py +++ b/gemlite/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.1.post1" +__version__ = "0.6.0" __author__ = 'Dr. Hicham Badri' __credits__ = 'Mobius Labs GmbH' diff --git a/setup.py b/setup.py index 817eaf7..78c2d65 100755 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( name='gemlite', - version="0.5.1.post1", + version="0.6.0", url="https://github.com/mobiusml/gemlite/", author="Dr. Hicham Badri", author_email="hicham@mobiuslabs.com", From 909f6a316a20ed9bef994ee674cbe7e58c86516c Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 11 Mar 2026 01:40:11 -0700 Subject: [PATCH 47/63] update --- gemlite/triton_kernels/utils.py | 12 +++++++++--- tests/test_gemlitelineartriton.py | 3 +-- tests/test_mxfp.py | 3 +-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/gemlite/triton_kernels/utils.py b/gemlite/triton_kernels/utils.py index c4c84be..8e2ef73 100755 --- a/gemlite/triton_kernels/utils.py +++ b/gemlite/triton_kernels/utils.py @@ -59,7 +59,6 @@ def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexp pid_n = pid // tl.cdiv(M, BLOCK_SIZE_M) return pid_m, pid_n -################################################################################################################# @triton.jit def dequantize( b, @@ -114,9 +113,16 @@ def is_divisible(dividend, divisor): def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" -def gpu_has_more_shared_memory(ref_gpus = ["a100", "h100", "h200", "h20", "h800", "b100", "b200", "b300", "6000"]): +def gpu_has_more_shared_memory( + ref_gpus=( + "a100", + "h100", "h200", "h20", "h800", + "b100", "b200", "b300", + "6000", + ), +): gpu_name = torch.cuda.get_device_properties(0).name.lower() - return True in [g in gpu_name for g in ref_gpus] + return any(g in gpu_name for g in ref_gpus) def gpu_supports_float16_acc( ref_gpus=["5090", "5080", "5070", "5060", diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py index 28d9bf6..1416fdc 100755 --- a/tests/test_gemlitelineartriton.py +++ b/tests/test_gemlitelineartriton.py @@ -1,5 +1,4 @@ -#python -m unittest test_gemlitelineartriton.py -# Usage: python3 test_file.py [--autotune] +# Usage: python3 test_gemlitelineartriton.py [--autotune] import sys _autotune = '--autotune' in sys.argv if _autotune: sys.argv.remove('--autotune') diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index b44a1ed..2ca3250 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -1,5 +1,4 @@ -#python -m unittest test_mxfp.py -# Usage: python3 test_file.py [--autotune] +# Usage: python3 test_mxfp.py [--autotune] import sys _autotune = '--autotune' in sys.argv if _autotune: sys.argv.remove('--autotune') From d70b9f470fdea9393471b66203b1560782cdea31 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 11 Mar 2026 11:40:05 -0700 Subject: [PATCH 48/63] cleanup --- examples/eval_flops.py | 53 +++---------------- gemlite/core.py | 5 +- gemlite/triton_kernels/gemm_kernels.py | 15 +++--- gemlite/triton_kernels/gemm_splitK_kernels.py | 27 ++++------ 4 files changed, 28 insertions(+), 72 deletions(-) diff --git a/examples/eval_flops.py b/examples/eval_flops.py index 4351bc4..a6d5640 100644 --- a/examples/eval_flops.py +++ b/examples/eval_flops.py @@ -12,7 +12,8 @@ repeat = 32 gemlite.reset_config() -#gemlite.enable_cudagraph_autotune(True) +gemlite.enable_cudagraph_autotune(True) +gemlite.enable_tma(True) #gemlite.set_autotune("max") #gemlite.core.enable_activation_scaling(2) @@ -262,29 +263,6 @@ def patch_model_flashinfer_nvfp4(model): model[i] = FlashinferNVFP4Dynamic(layer) -def bench_flashinfer_nvfp4(M, N, K): - """ - Benchmark flashinfer NVFP4 matmul (CUTLASS backend) - raw single matmul, no activation quant. - """ - from flashinfer import nvfp4_quantize, mm_fp4, SfLayout - - a_bf16 = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b_bf16 = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) - - a_global_sf = (448.0 * 6.0) / a_bf16.float().abs().nan_to_num().max() - b_global_sf = (448.0 * 6.0) / b_bf16.float().abs().nan_to_num().max() - - a_fp4, a_sf = nvfp4_quantize(a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) - b_fp4, b_sf = nvfp4_quantize(b_bf16, b_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=True) - - alpha = 1.0 / (a_global_sf * b_global_sf) - - ms = triton.testing.do_bench( - lambda: mm_fp4(a_fp4, b_fp4.T, a_sf, b_sf.T, alpha, torch.bfloat16), - warmup=500, rep=500, - ) - return ms - ########################################################################################################################### def run_benchmark(proc_name, M, K, N): @@ -294,24 +272,6 @@ def run_benchmark(proc_name, M, K, N): """ has_flashinfer, fi_err = _get_flashinfer() - # ---- flashinfer NVFP4 raw (single matmul, no activation quant, triton.do_bench) ---- - if proc_name == "flashinfer_nvfp4_raw": - if not has_flashinfer: - print(f" Skipping {proc_name}: {fi_err}") - return None - M_a = ((M + 127) // 128) * 128 - N_a = ((N + 127) // 128) * 128 - K_a = ((K + 127) // 128) * 128 - try: - ms = bench_flashinfer_nvfp4(M_a, N_a, K_a) - tflops = get_flops(M_a, K_a, N_a, ms) - label = "flashinfer NVFP4 (raw)" - print(f" {label} | {M_a}, {K_a}, {N_a} | {tflops:.2f} TFLOP/s ({ms:.3f} ms)") - return (label, M_a, K_a, N_a, tflops) - except Exception as e: - print(f" flashinfer NVFP4 raw failed: {e}") - return None - # ---- flashinfer NVFP4 dynamic (torch.compile + activation quant) ---- if proc_name == "flashinfer_nvfp4_dynamic": if not has_flashinfer: @@ -423,7 +383,6 @@ def run_benchmark(proc_name, M, K, N): "native_int8", "native_fp8", "flashinfer_nvfp4_dynamic", - "flashinfer_nvfp4_raw", ] @@ -436,7 +395,7 @@ def main(): python eval_flops.py # Run with specific dimensions: - python eval_flops.py --M 128 --K 4096 --N 4096 + python eval_flops.py --M 8192 --K 8192 --N 8192 # Run only specific processors (comma-separated): python eval_flops.py --processor A4W4_MXFP_dynamic,flashinfer_nvfp4_dynamic,native_fp8 @@ -450,15 +409,15 @@ def main(): # A8W8_MXFP_dynamic_post_scale, A8W8_MXFP_dynamic, # A4W4_MXFP_dynamic, A4W4_NVFP_dynamic # PyTorch: native_int8, native_fp8 - # flashinfer: flashinfer_nvfp4_dynamic, flashinfer_nvfp4_raw + # flashinfer: flashinfer_nvfp4_dynamic # Baseline: none / fp16 (BF16, no quantization) # Use "all" to run every processor. """, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--M", type=int, default=8192, help="Batch/sequence dimension") - parser.add_argument("--K", type=int, default=16384, help="Input feature dimension") - parser.add_argument("--N", type=int, default=16384, help="Output feature dimension") + parser.add_argument("--K", type=int, default=8192, help="Input feature dimension") + parser.add_argument("--N", type=int, default=8192, help="Output feature dimension") parser.add_argument("--processor", type=str, default="all", help='Comma-separated processor names or "all" (default: all)') args = parser.parse_args() diff --git a/gemlite/core.py b/gemlite/core.py index a889331..9579cab 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -352,7 +352,7 @@ def load_state_dict(self, state_dict, strict=True, assign=False): if s.ndim == 2: s_2d = s.T.contiguous() # [K_S, N] contiguous N_dim, K_S = s_2d.shape[1], s_2d.shape[0] - if GEMLITE_USE_TMA and N_dim % 128 == 0 and K_S % 4 == 0: + if GEMLITE_USE_TMA and self.elements_per_sample > 1 and N_dim % 128 == 0 and K_S % 4 == 0: self.scales = s_2d.reshape(N_dim // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N_dim // 128, K_S // 4, 2, 256).contiguous() #Make sure to feed UINT8 W_q for packing @@ -531,7 +531,8 @@ def pack( # Preshuffle weight scales to 5D TMA layout for fast loading # Original: [K_S, N] -> transpose to [N, K_S] -> 5D: [1, N//128, K_S//4, 2, 256] K_S = K // group_size - if GEMLITE_USE_TMA and N % 128 == 0 and K_S % 4 == 0: + if GEMLITE_USE_TMA and self.elements_per_sample > 1 and N % 128 == 0 and K_S % 4 == 0: + # Currently TMA only enabled for MXFP4/NVFP4 NOT for MXFP8 because of poor performance on sm_120 (self.elements_per_sample > 1 check) self.scales = self.scales.T.contiguous().reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() else: # Keep 2D transposed layout for pointer-based fallback diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 242956e..3feccf7 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -313,8 +313,9 @@ def gemm_INT_kernel( meta_evict_policy: tl.constexpr = "evict_last", a_evict: tl.constexpr = "", b_evict: tl.constexpr = "evict_first", - USE_5D_SCALES: tl.constexpr = False, + ################################# use_tma: tl.constexpr = True, + use_5d_scales: tl.constexpr = False, ): """ Based on https://github.com/fpgaminer/GPTQ-triton @@ -667,7 +668,7 @@ def gemm_MX_kernel( meta_scale_norm: tl.constexpr = (0.05 ** 2), ################################# use_tma: tl.constexpr = True, - USE_5D_SCALES: tl.constexpr = False, + use_5d_scales: tl.constexpr = False, ): pid = tl.program_id(axis=0) @@ -711,7 +712,7 @@ def gemm_MX_kernel( BLOCK_SIZE_K_S: tl.constexpr = BLOCK_SIZE_K // group_size offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if not USE_5D_SCALES: + if not use_5d_scales: scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] if use_tma: @@ -737,7 +738,7 @@ def gemm_MX_kernel( ) # 5D TMA Descriptors for Scales (preshuffled layout) - if USE_5D_SCALES: + if use_5d_scales: rep_n: tl.constexpr = BLOCK_SIZE_N // 128 rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 scales_b_shape1: tl.constexpr = N // 128 @@ -781,7 +782,7 @@ def gemm_MX_kernel( b = tl.load(b_ptrs, eviction_policy=b_evict) #################################################################################### k_m = k * BLOCK_SIZE_K_S - if USE_5D_SCALES: + if use_5d_scales: # 5D TMA scale loads (preshuffled layout) scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, k * rep_k, 0, 0]) scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) @@ -892,8 +893,8 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - USE_5D_SCALES = use_5d_scales, - use_tma = GEMLITE_USE_TMA, + use_tma = use_5d_scales, + use_5d_scales = use_5d_scales, ) return output diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 692ba08..0de7112 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -334,10 +334,9 @@ def gemm_splitK_INT_kernel( atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', - USE_5D_SCALES: tl.constexpr = False, - SCALES_5D_SHAPE1: tl.constexpr = 0, - SCALES_5D_SHAPE2: tl.constexpr = 0, + ################################# dmmy use_tma: tl.constexpr = True, + use_5d_scales: tl.constexpr = False, ): """ Based on https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py @@ -558,9 +557,7 @@ def gemm_splitK_MX_kernel( meta_scale_norm: tl.constexpr = (0.05 ** 2), ################################# use_tma: tl.constexpr = True, - USE_5D_SCALES: tl.constexpr = False, - SCALES_5D_SHAPE1: tl.constexpr = 0, - SCALES_5D_SHAPE2: tl.constexpr = 0, + use_5d_scales: tl.constexpr = False, ): pid = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -608,7 +605,7 @@ def gemm_splitK_MX_kernel( offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) #B scales: [BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] - if not USE_5D_SCALES: + if not use_5d_scales: scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #A scales @@ -638,17 +635,17 @@ def gemm_splitK_MX_kernel( ) # 5D TMA Descriptors for Scales (preshuffled layout) - if USE_5D_SCALES: + if use_5d_scales: rep_n: tl.constexpr = BLOCK_SIZE_N // 128 rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 stride_b4: tl.constexpr = 1 stride_b3: tl.constexpr = 256 stride_b2: tl.constexpr = 512 - stride_b1: tl.constexpr = 512 * SCALES_5D_SHAPE2 - stride_b0: tl.constexpr = 512 * SCALES_5D_SHAPE2 * SCALES_5D_SHAPE1 + stride_b1: tl.constexpr = 512 * (K // group_size // 4) + stride_b0: tl.constexpr = stride_b1 * (N // 128) scales_b_5d_desc = tl.make_tensor_descriptor( scales_ptr, - [1, SCALES_5D_SHAPE1, SCALES_5D_SHAPE2, 2, 256], + [1, N // 128, K // group_size // 4, 2, 256], [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], [1, rep_n, rep_k, 2, 256] ) @@ -679,7 +676,7 @@ def gemm_splitK_MX_kernel( #k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) k_m = (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_S #OK for BLOCK_SIZE_K >=group_size - if USE_5D_SCALES: + if use_5d_scales: scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, (k * SPLIT_K + pid_k) * rep_k, 0, 0]) scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) else: @@ -795,10 +792,8 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - USE_5D_SCALES = use_5d_scales, - SCALES_5D_SHAPE1 = N // 128 if use_5d_scales else 0, - SCALES_5D_SHAPE2 = K // group_size // 4 if use_5d_scales else 0, - use_tma = GEMLITE_USE_TMA, + use_tma = use_5d_scales, + use_5d_scales = use_5d_scales, ) if(not native_atomic): From ef1cdd7ab33f3cbcfa8d0d393c0b9dd75a1f6f0e Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 11 Mar 2026 15:00:23 -0700 Subject: [PATCH 49/63] add ptx packing for mxfp4/nvfp4 --- gemlite/quant_utils.py | 133 ++++++++++++++++++++++++++++++----------- 1 file changed, 97 insertions(+), 36 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 3a7e484..9f9f47f 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1987,6 +1987,7 @@ def scale_activations_mxfp4_triton_kernel_v5( GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ptx_pack: tl.constexpr = True, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -1996,15 +1997,6 @@ def scale_activations_mxfp4_triton_kernel_v5( FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK out_dtype: tl.constexpr = out_ptr.dtype.element_ty - thr0 = tl.load(thr_pos_ptr + 0) - thr1 = tl.load(thr_pos_ptr + 1) - thr2 = tl.load(thr_pos_ptr + 2) - thr3 = tl.load(thr_pos_ptr + 3) - thr4 = tl.load(thr_pos_ptr + 4) - thr5 = tl.load(thr_pos_ptr + 5) - thr6 = tl.load(thr_pos_ptr + 6) - thr7 = tl.load(thr_pos_ptr + 7) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) @@ -2018,16 +2010,53 @@ def scale_activations_mxfp4_triton_kernel_v5( ) wq = tensor_flat / scales - abs_wq = tl.abs(wq) - idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + - (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + - (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + - (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) - lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) - out = lo | (hi << 4) + if ptx_pack: + # PTX path: hardware e2m1x2 quantization + nibble packing + wq_2d = tl.reshape(wq, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + wq_pairs = wq_2d.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False) + lo_val, hi_val = tl.split(wq_pairs) + lo_f16 = lo_val.to(tl.float16) + hi_f16 = hi_val.to(tl.float16) + lo_bits = lo_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF + hi_bits = (hi_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF) << 16 + packed_f16x2 = lo_bits | hi_bits + packed_e2m1 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 tmp_out; + .reg .f16x2 tmp_in; + mov.b32 tmp_in, $1; + cvt.rn.satfinite.e2m1x2.f16x2 tmp_out, tmp_in; + cvt.u32.u8 $0, tmp_out; + } + """, + constraints="=r,r", + args=[packed_f16x2], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + out = packed_e2m1.to(tl.uint8) + else: + # Threshold path: 8 comparisons + manual nibble packing + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) @@ -2074,6 +2103,8 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: return out, scales + + #################################################################################################################### # NVFP4 v5: 2D grid with multi-group BLOCK_SIZE_K (fewer blocks, better bandwidth) #################################################################################################################### @@ -2101,6 +2132,7 @@ def scale_activations_nvfp4_triton_kernel_v5( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, meta_scales: tl.constexpr = NVFP4_META_SCALE, + ptx_pack: tl.constexpr = True, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -2112,15 +2144,6 @@ def scale_activations_nvfp4_triton_kernel_v5( FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK out_dtype: tl.constexpr = out_ptr.dtype.element_ty - thr0 = tl.load(thr_pos_ptr + 0) - thr1 = tl.load(thr_pos_ptr + 1) - thr2 = tl.load(thr_pos_ptr + 2) - thr3 = tl.load(thr_pos_ptr + 3) - thr4 = tl.load(thr_pos_ptr + 4) - thr5 = tl.load(thr_pos_ptr + 5) - thr6 = tl.load(thr_pos_ptr + 6) - thr7 = tl.load(thr_pos_ptr + 7) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) @@ -2136,16 +2159,53 @@ def scale_activations_nvfp4_triton_kernel_v5( scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) wq = tensor_flat / scales_full - abs_wq = tl.abs(wq) - idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + - (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + - (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + - (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) - lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) - out = lo | (hi << 4) + if ptx_pack: + # PTX path: hardware e2m1x2 quantization + nibble packing + wq_2d = tl.reshape(wq, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + wq_pairs = wq_2d.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False) + lo_val, hi_val = tl.split(wq_pairs) + lo_f16 = lo_val.to(tl.float16) + hi_f16 = hi_val.to(tl.float16) + lo_bits = lo_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF + hi_bits = (hi_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF) << 16 + packed_f16x2 = lo_bits | hi_bits + packed_e2m1 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 tmp_out; + .reg .f16x2 tmp_in; + mov.b32 tmp_in, $1; + cvt.rn.satfinite.e2m1x2.f16x2 tmp_out, tmp_in; + cvt.u32.u8 $0, tmp_out; + } + """, + constraints="=r,r", + args=[packed_f16x2], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + out = packed_e2m1.to(tl.uint8) + else: + # Threshold path: 8 comparisons + manual nibble packing + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) @@ -2194,6 +2254,7 @@ def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tenso + #################################################################################################################### scale_activations_per_token = scale_activations_per_token_triton_v3 scale_activations_mxfp8 = scale_activations_mxfp8_triton_v4 From 405e2016b65274f3447bdc252f148dbd858418d4 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 11 Mar 2026 15:19:20 -0700 Subject: [PATCH 50/63] use default ptx false --- gemlite/quant_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 9f9f47f..3dea80e 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1987,7 +1987,10 @@ def scale_activations_mxfp4_triton_kernel_v5( GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ptx_pack: tl.constexpr = True, + # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, replace + # the bundled ptxas-blackwell with the system one: cp /usr/local/cuda/bin/ptxas + # /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas-blackwell + ptx_pack: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -2132,7 +2135,10 @@ def scale_activations_nvfp4_triton_kernel_v5( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, meta_scales: tl.constexpr = NVFP4_META_SCALE, - ptx_pack: tl.constexpr = True, + # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, replace + # the bundled ptxas-blackwell with the system one: cp /usr/local/cuda/bin/ptxas + # /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas-blackwell + ptx_pack: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) From ebc0b1b172d7d8e21b8a142bbca023a8551fb884 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 11 Mar 2026 15:24:00 -0700 Subject: [PATCH 51/63] add todo --- gemlite/quant_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 3dea80e..5fc658b 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1990,6 +1990,8 @@ def scale_activations_mxfp4_triton_kernel_v5( # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, replace # the bundled ptxas-blackwell with the system one: cp /usr/local/cuda/bin/ptxas # /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas-blackwell + # TODO: once Triton ships CUDA 13.0+ ptxas, set default to True and add ptx_pack + # to the autotuner configs so it can pick the best path per shape. ptx_pack: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) @@ -2138,6 +2140,8 @@ def scale_activations_nvfp4_triton_kernel_v5( # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, replace # the bundled ptxas-blackwell with the system one: cp /usr/local/cuda/bin/ptxas # /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas-blackwell + # TODO: once Triton ships CUDA 13.0+ ptxas, set default to True and add ptx_pack + # to the autotuner configs so it can pick the best path per shape. ptx_pack: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) From bf3b514b89d769e0d75000d622b920a74b34ab37 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 19 Mar 2026 08:02:05 -0700 Subject: [PATCH 52/63] improve M=64 perf --- gemlite/triton_kernels/gemm_kernels.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 3feccf7..f987943 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -187,6 +187,9 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) + #Small M tiles (for M=32..64 where more tiles improve SM utilization) + configs.append(triton.Config({"BLOCK_SIZE_M":32, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({"BLOCK_SIZE_M":32, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":256, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=4, num_stages=3)) return configs def get_default_config_nvidia(): From 846d0ef7a797fa42c66acda9533e731884aff86c Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 19 Mar 2026 11:04:25 -0700 Subject: [PATCH 53/63] add hints to mx non-tma path --- gemlite/triton_kernels/gemm_kernels.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index f987943..7981e21 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -701,6 +701,8 @@ def gemm_MX_kernel( BLOCK_SIZE_K_A: tl.constexpr = BLOCK_SIZE_K // elements_per_sample_a offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_ak = tl.arange(0, BLOCK_SIZE_K_A) + if not use_tma: + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K // elements_per_sample_a)).to(tl.int1) @@ -708,7 +710,14 @@ def gemm_MX_kernel( BLOCK_SIZE_K_B: tl.constexpr = BLOCK_SIZE_K // elements_per_sample offs_bk = tl.arange(0, BLOCK_SIZE_K_B) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn + if not use_tma: + if data_contiguous: + offs_bn_load = offs_bn + else: + offs_bn_load = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + else: + offs_bn_load = offs_bn + b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn_load[None, :] * stride_bn #Scales stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size @@ -836,6 +845,7 @@ def gemm_MX_kernel( else: offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) if EVEN_M and EVEN_N: From 4631e2f89e5796e31a36107423729dd5b78d4e94 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 19 Mar 2026 12:44:08 -0700 Subject: [PATCH 54/63] add ptx_pack=True quant configs for MXFP4/NVFP4 (hardware e2m1x2 conversion) Requires CUDA 13.0+ ptxas. Gives ~10% speedup on activation quantization kernels, translating to 3-5% end-to-end improvement at M=1024. Co-Authored-By: Claude Opus 4.6 (1M context) --- gemlite/quant_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 5fc658b..b236e31 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1972,6 +1972,11 @@ def prune_large_blocks_2d(configs, named_args, **kwargs): triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + #ptx_pack=True configs (hardware e2m1x2 conversion, requires CUDA 13.0+ ptxas) + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64, 'ptx_pack': True}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128, 'ptx_pack': True}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=8, num_stages=1), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks_2d}, @@ -2121,6 +2126,11 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + #ptx_pack=True configs (hardware e2m1x2 conversion, requires CUDA 13.0+ ptxas) + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64, 'ptx_pack': True}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128, 'ptx_pack': True}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=8, num_stages=1), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks_2d}, From 9ea1036cb42df361ea88a365d9cf5b44804b2a7d Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 19 Mar 2026 12:51:41 -0700 Subject: [PATCH 55/63] Revert "add ptx_pack=True quant configs for MXFP4/NVFP4 (hardware e2m1x2 conversion)" This reverts commit 4631e2f89e5796e31a36107423729dd5b78d4e94. --- gemlite/quant_utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index b236e31..5fc658b 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1972,11 +1972,6 @@ def prune_large_blocks_2d(configs, named_args, **kwargs): triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), - #ptx_pack=True configs (hardware e2m1x2 conversion, requires CUDA 13.0+ ptxas) - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64, 'ptx_pack': True}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128, 'ptx_pack': True}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=8, num_stages=1), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks_2d}, @@ -2126,11 +2121,6 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), - #ptx_pack=True configs (hardware e2m1x2 conversion, requires CUDA 13.0+ ptxas) - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64, 'ptx_pack': True}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128, 'ptx_pack': True}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=4, num_stages=1), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256, 'ptx_pack': True}, num_warps=8, num_stages=1), ], key=['M', 'K'], prune_configs_by={'early_config_prune': prune_large_blocks_2d}, From 7111bad5fefa1327afba371457e5bf5c1385f769 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 19 Mar 2026 13:03:25 -0700 Subject: [PATCH 56/63] add gemlite.set_ptx_pack() for hardware FP4 packing Adds GEMLITE_ENABLE_PTX_PACK global (default False) and gemlite.set_ptx_pack(True/False) API. When enabled, MXFP4/NVFP4 activation quantization kernels use hardware cvt.rn.satfinite.e2m1x2 PTX instruction instead of threshold comparisons. Requires CUDA 13.0+ ptxas to be installed in the Triton backends directory. Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/eval_flops.py | 1 + gemlite/__init__.py | 1 + gemlite/core.py | 8 ++++++++ gemlite/quant_utils.py | 17 ++++++++++------- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/examples/eval_flops.py b/examples/eval_flops.py index a6d5640..78a52bc 100644 --- a/examples/eval_flops.py +++ b/examples/eval_flops.py @@ -14,6 +14,7 @@ gemlite.reset_config() gemlite.enable_cudagraph_autotune(True) gemlite.enable_tma(True) +#gemlite.set_ptx_fp4_pack(True) #gemlite.set_autotune("max") #gemlite.core.enable_activation_scaling(2) diff --git a/gemlite/__init__.py b/gemlite/__init__.py index fcc31b3..3c2e654 100755 --- a/gemlite/__init__.py +++ b/gemlite/__init__.py @@ -13,6 +13,7 @@ set_autotune, set_kernel_caching, enable_tma, + set_ptx_fp4_pack, enable_cudagraph_autotune, forward_functional, ) diff --git a/gemlite/core.py b/gemlite/core.py index 9579cab..70f367a 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -67,6 +67,7 @@ GEMLITE_TRITON_CONFIG_CACHE = {} #Global config cache for all the kernels _GROUP_SIZE_WARNED = False GEMLITE_USE_TMA = True # Set to False for faster MXFP8 on sm_120 +GEMLITE_ENABLE_PTX_FP4_PACK = False # Set to True for hardware e2m1x2 FP4 packing (requires CUDA 13.0+ ptxas) ################################################################################### #Utils @@ -102,6 +103,13 @@ def enable_tma(enabled: bool = True): global GEMLITE_USE_TMA GEMLITE_USE_TMA = enabled +#Enable/disable hardware PTX FP4 packing in activation quantization (requires CUDA 13.0+ ptxas) +def set_ptx_fp4_pack(enabled: bool = True): + global GEMLITE_ENABLE_PTX_FP4_PACK + GEMLITE_ENABLE_PTX_FP4_PACK = enabled + from .quant_utils import set_ptx_fp4_pack_flag + set_ptx_fp4_pack_flag(enabled) + #Enable/disable CUDA graph-based autotuning (more accurate but slower) def enable_cudagraph_autotune(enabled: bool = True): set_autotune("fast", use_cuda_graph=enabled) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 5fc658b..ba56ea4 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -10,6 +10,11 @@ from .triton_kernels.utils import IS_HIP, get_num_SMs, next_power_of_2 from .dtypes import * +GEMLITE_ENABLE_PTX_FP4_PACK = False # Enable with CUDA13+ ptxas +def set_ptx_fp4_pack_flag(enabled: bool): + global GEMLITE_ENABLE_PTX_FP4_PACK + GEMLITE_ENABLE_PTX_FP4_PACK = enabled + #Get dtype min/max range based on compute dtype def get_dtype_range(compute_dtype: torch.dtype) -> float: if(compute_dtype.is_floating_point): @@ -1992,7 +1997,7 @@ def scale_activations_mxfp4_triton_kernel_v5( # /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas-blackwell # TODO: once Triton ships CUDA 13.0+ ptxas, set default to True and add ptx_pack # to the autotuner configs so it can pick the best path per shape. - ptx_pack: tl.constexpr = False, + ptx_pack: tl.constexpr = GEMLITE_ENABLE_PTX_FP4_PACK, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -2137,12 +2142,10 @@ def scale_activations_nvfp4_triton_kernel_v5( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, meta_scales: tl.constexpr = NVFP4_META_SCALE, - # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, replace - # the bundled ptxas-blackwell with the system one: cp /usr/local/cuda/bin/ptxas - # /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas-blackwell - # TODO: once Triton ships CUDA 13.0+ ptxas, set default to True and add ptx_pack - # to the autotuner configs so it can pick the best path per shape. - ptx_pack: tl.constexpr = False, + # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, set + # the environment variable TRITON_CUDA_ARCH_LIST to include CUDA 13.0+ ptxas, + # and override the bundled ptxas-blackwell. + ptx_pack: tl.constexpr = GEMLITE_ENABLE_PTX_FP4_PACK, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) From a29c4c470d0b32fb9324daeda6ed8b4a1e16d77a Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 20 Mar 2026 10:40:06 -0700 Subject: [PATCH 57/63] update tl.constexpr --- gemlite/triton_kernels/gemm_kernels.py | 4 ++-- gemlite/triton_kernels/gemm_splitK_kernels.py | 4 ++-- gemlite/triton_kernels/gemv_kernels.py | 4 ++-- gemlite/triton_kernels/gemv_revsplitK_kernels.py | 2 +- gemlite/triton_kernels/gemv_splitK_kernels.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 7981e21..a807eaa 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -274,7 +274,7 @@ def get_default_config_amd(): def gemm_INT_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -629,7 +629,7 @@ def gemm_INT_kernel( def gemm_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, M_CLOSEST: tl.constexpr, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 0de7112..e251257 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -289,7 +289,7 @@ def get_default_config_amd(): def gemm_splitK_INT_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -507,7 +507,7 @@ def gemm_splitK_INT_kernel( def gemm_splitK_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index 7468144..e8ae9e1 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -238,7 +238,7 @@ def gemv_INT_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, mapping_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -422,7 +422,7 @@ def gemv_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, mapping_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index cfa52c8..a51a948 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -239,7 +239,7 @@ def get_default_config_amd(): def gemv_INT_revsplitK_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index 362fcf3..cede2a4 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -251,7 +251,7 @@ def get_default_config_amd(): def gemv_INT_splitK_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, From 7356faf16dc1e15b243a574a09c2de37366f25bc Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 20 Mar 2026 11:01:53 -0700 Subject: [PATCH 58/63] add save/load guard --- gemlite/core.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gemlite/core.py b/gemlite/core.py index 70f367a..0d4dba4 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -318,6 +318,16 @@ def __init__( #Default forward self.forward = self.forward_auto_no_warmup + def _save_to_state_dict(self, destination, prefix, keep_vars): + # Rebuild metadata from live attributes to ensure consistency + # (helpers may override channel_scale_mode/W_group_mode after pack()) + if hasattr(self, 'metadata') and self.metadata is not None and hasattr(self, 'W_nbits'): + self.metadata = torch.nn.Parameter( + torch.tensor(self.get_meta_args(), device=self.metadata.device, dtype=torch.int32), + requires_grad=False, + ) + super()._save_to_state_dict(destination, prefix, keep_vars) + def load_state_dict(self, state_dict, strict=True, assign=False): self.W_q = state_dict.pop("W_q", None) self.bias = state_dict.pop("bias", None) From cd6b8d60e5b03724b673ada083ee8a6526050f42 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 21 Mar 2026 02:38:32 -0700 Subject: [PATCH 59/63] refactor tests --- ...est_gemlitelineartriton.py => test_int.py} | 0 tests/test_serialization.py | 134 ++++++++++++++++++ 2 files changed, 134 insertions(+) rename tests/{test_gemlitelineartriton.py => test_int.py} (100%) create mode 100644 tests/test_serialization.py diff --git a/tests/test_gemlitelineartriton.py b/tests/test_int.py similarity index 100% rename from tests/test_gemlitelineartriton.py rename to tests/test_int.py diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..ebcf4f5 --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,134 @@ +# Usage: python3 test_serialization.py [--autotune] +import sys +_autotune = '--autotune' in sys.argv +if _autotune: sys.argv.remove('--autotune') + +import unittest +import torch +from gemlite import reset_config, set_autotune +from gemlite.core import GemLiteLinearTriton, DType, TORCH_TO_DTYPE +from gemlite.triton_kernels.config import KERNEL +from gemlite.helper import A4W4_MXFP_dynamic, A4W4_NVFP_dynamic, patch_model + +def is_fp8_supported(): + if not torch.cuda.is_available(): + return False + capability = torch.cuda.get_device_capability(0) + return capability >= (8, 9) + +device = 'cuda:0' +compute_dtype = torch.bfloat16 +gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] + +reset_config() +if _autotune is False: set_autotune(False) +KERNEL.ENABLE_CACHING = False + +def _check_serialization(test_case, gemlite_linear, matmul_type='GEMM', batch_size=32, tol=1e-7): + """Shared serialization round-trip check.""" + in_features = gemlite_linear.in_features + + torch.save(gemlite_linear.state_dict(), '/tmp/_test_serial.pt') + + loaded = GemLiteLinearTriton() + loaded.load_state_dict(torch.load('/tmp/_test_serial.pt')) + + # Check meta_args match + ref_meta = gemlite_linear.get_meta_args() + loaded_meta = loaded.get_meta_args() + for i in range(len(ref_meta)): + test_case.assertEqual(ref_meta[i], loaded_meta[i], f"meta_args mismatch at {i}: {ref_meta[i]} != {loaded_meta[i]}") + + # Check tensor_args match + ref_tensors = gemlite_linear.get_tensor_args() + loaded_tensors = loaded.get_tensor_args() + for i in range(len(ref_tensors)): + if ref_tensors[i].numel() > 0: + diff = (ref_tensors[i].float() - loaded_tensors[i].float()).abs().mean().item() + test_case.assertEqual(diff, 0, f"tensor_args mismatch at {i}: mean diff = {diff}") + + # Check inference matches + x = torch.randn(batch_size, in_features, dtype=compute_dtype, device=device) / 10. + y_ref = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + y_loaded = loaded.forward_manual(x, matmul_type=matmul_type) + diff = (y_ref - y_loaded).abs().mean().item() + test_case.assertTrue(diff < tol, f"Inference mismatch: mean diff = {diff}, expected < {tol}") + + +class TestSerializationINT(unittest.TestCase): + """Serialization tests for INT quantized layers.""" + + def test_A16W4(self): + in_features, out_features = 4096, 2048 + W_nbits, group_size = 4, 128 + + W_q = torch.randint(0, 2**W_nbits - 1, (out_features, in_features), device=device).to(torch.uint8) + gs = W_q.numel() // group_size + scales = torch.ones((gs, 1), device=device, dtype=compute_dtype) * 0.001 + zeros = torch.zeros((gs, 1), device=device, dtype=compute_dtype) * ((2**W_nbits - 1)//2) + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype) + gemlite_linear.pack(W_q, scales, zeros, None) + + _check_serialization(self, gemlite_linear) + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_A8W8(self): + in_features, out_features = 4096, 2048 + fp8_dtype = torch.float8_e4m3fn + + W = torch.randn((out_features, in_features), dtype=compute_dtype, device=device) / 10. + _scales = torch.randn((1, out_features), dtype=compute_dtype, device=device) * 1e-4 + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=True) + gemlite_linear.pack(W.to(fp8_dtype), scales=_scales, zeros=None, bias=None) + + _check_serialization(self, gemlite_linear) + + +class TestSerializationMX(unittest.TestCase): + """Serialization tests for MXFP/NVFP quantized layers.""" + + def setUp(self): + self.in_features, self.out_features = 4224, 2048 + torch.manual_seed(42) + self.linear_layer = torch.nn.Linear( + self.in_features, self.out_features, dtype=compute_dtype, device=device, bias=False + ) + self.linear_layer.weight.data /= 10. + self.linear_layer.weight.requires_grad = False + + def _quantize(self, processor_fn): + model = torch.nn.Sequential( + torch.nn.Linear(self.in_features, self.out_features, dtype=compute_dtype, device=device, bias=False) + ) + model.requires_grad_(False) + model[0].weight.data = self.linear_layer.weight.data.clone() + processor = processor_fn(dtype=compute_dtype) + patch_model(model, device=device, processor=processor) + return model[0] + + def test_A4W4_MXFP(self): + gemlite_linear = self._quantize(A4W4_MXFP_dynamic) + _check_serialization(self, gemlite_linear, matmul_type='GEMM') + _check_serialization(self, gemlite_linear, matmul_type='GEMM_SPLITK', batch_size=2) + + def test_A4W4_NVFP(self): + gemlite_linear = self._quantize(A4W4_NVFP_dynamic) + _check_serialization(self, gemlite_linear, matmul_type='GEMM') + _check_serialization(self, gemlite_linear, matmul_type='GEMM_SPLITK', batch_size=2) + + +if __name__ == '__main__': + unittest.main() From 3d98df7ece7fa00259dce4779e69f3846e208eff Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 21 Mar 2026 11:00:30 -0700 Subject: [PATCH 60/63] improve nvfp4 --- gemlite/core.py | 28 +- gemlite/helper.py | 11 +- gemlite/quant_utils.py | 53 +-- gemlite/triton_kernels/gemm_kernels.py | 352 +++++++++--------- gemlite/triton_kernels/gemm_splitK_kernels.py | 26 +- gemlite/triton_kernels/gemv_kernels.py | 1 + .../triton_kernels/gemv_revsplitK_kernels.py | 1 + gemlite/triton_kernels/gemv_splitK_kernels.py | 1 + 8 files changed, 259 insertions(+), 214 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 0d4dba4..6b7f605 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -175,6 +175,7 @@ def forward_functional( scaled_activations = bool(meta_args[0]) and enable_activation_scaling(batch_size) #Dynamic activation quantization scales_x = None + meta_scale = 0.0 if(scaled_activations): input_dtype = DType(meta_args[5]) channel_scale_mode = meta_args[9] @@ -192,7 +193,9 @@ def forward_functional( x, scales_x = scale_activations_mxfp4(x) elif(input_dtype in [DType.NVFP4] and channel_scale_mode == 4): #NVPF4: TODO - x, scales_x = scale_activations_nvfp4(x) + meta_scale = tensor_args[3] + x, scales_x, meta_scale_a = scale_activations_nvfp4(x) + meta_scale = meta_scale * meta_scale_a # combine weight and activation meta_scales x = x.view(-1, x.shape[-1]) @@ -204,7 +207,7 @@ def forward_functional( out = ( GEMLITE_TRITON_MAPPING[matmul_type_str] .forward( - x, *tensor_args, scales_x, *meta_args[1:-1], data_contiguous, type_id + x, *tensor_args[:3], scales_x, *meta_args[1:-1], data_contiguous, type_id, meta_scale=meta_scale ) .view(out_shape) ) @@ -318,6 +321,9 @@ def __init__( #Default forward self.forward = self.forward_auto_no_warmup + #Meta-scale for NVFP4 (0.0 = not used) + self.meta_scale = 0.0 + def _save_to_state_dict(self, destination, prefix, keep_vars): # Rebuild metadata from live attributes to ensure consistency # (helpers may override channel_scale_mode/W_group_mode after pack()) @@ -335,6 +341,7 @@ def load_state_dict(self, state_dict, strict=True, assign=False): self.zeros = state_dict.pop("zeros", None) self.metadata = state_dict.pop("metadata", None) self.orig_shape = state_dict.pop("orig_shape", None) + _meta_scale = state_dict.pop("meta_scale", None) self.metadata = [v.item() for v in self.metadata] self.orig_shape = (v.item() for v in self.orig_shape) @@ -357,6 +364,12 @@ def load_state_dict(self, state_dict, strict=True, assign=False): self.acc_dtype = DType(self.acc_dtype) self.meta_dtype = DType(self.meta_dtype) + # Restore meta_scale with backward compat for old checkpoints + if _meta_scale is not None: + self.meta_scale = _meta_scale.float() + else: + self.meta_scale = 0.05 if self.input_dtype == DType.NVFP4 else 0.0 # backward compat default for old checkpoints + self.out_features, self.in_features = self.orig_shape self.compute_dtype = DTYPE_TO_TORCH[self.input_dtype.value] self.scaled_activations = bool(self.scaled_activations) @@ -577,11 +590,20 @@ def pack( requires_grad=False, ) + + self.meta_scale = torch.nn.Parameter( + torch.tensor(self.meta_scale, device=self.device, dtype=torch.float32), + requires_grad=False, + ) + return self #Return the main arguments def get_tensor_args(self): - return [self.W_q, self.scales, self.zeros] + meta_scale = self.meta_scale + if not isinstance(meta_scale, torch.Tensor): + meta_scale = torch.tensor(meta_scale, dtype=torch.float32, device=self.W_q.device) + return [self.W_q, self.scales, self.zeros, meta_scale] def get_meta_args(self): return [int(self.scaled_activations), diff --git a/gemlite/helper.py b/gemlite/helper.py index 41f40e0..213ab18 100755 --- a/gemlite/helper.py +++ b/gemlite/helper.py @@ -888,7 +888,7 @@ def __init__(self, device='cuda:0', dtype=None): self.group_size = 16 self.input_dtype = DType.NVFP4 - def from_weights(self, weight, bias=None, scales=None): + def from_weights(self, weight, bias=None, scales=None, meta_scale=None): if(isinstance(weight, torch.nn.Parameter)): weight = weight.data if(isinstance(bias, torch.nn.Parameter)): @@ -923,6 +923,11 @@ def from_weights(self, weight, bias=None, scales=None): gemlite_linear.pack(W_q, scales, zeros=None, bias=bias) gemlite_linear.W_group_mode = 0 gemlite_linear.channel_scale_mode = 4 + if meta_scale is not None: + gemlite_linear.meta_scale = torch.nn.Parameter( + meta_scale.to(dtype=torch.float32, device=gemlite_linear.W_q.device).reshape(()), + requires_grad=False, + ) return gemlite_linear @@ -933,11 +938,11 @@ def from_linear(self, linear_layer, del_orig=True): W = linear_layer.weight.data bias = linear_layer.bias.clone() if (linear_layer.bias is not None) else None N, K = W.shape - W_q, scales = self.quantizer_mx.quantize_nvfp4(W, index=True) + W_q, scales, _meta_scale = self.quantizer_mx.quantize_nvfp4(W, index=True) W_q, scales = W_q.view([N, K]), scales.view(N, K // self.group_size) cleanup_linear(linear_layer, del_orig) - out_layer = self.from_weights(weight=W_q, scales=scales, bias=bias) + out_layer = self.from_weights(weight=W_q, scales=scales, bias=bias, meta_scale=_meta_scale) #Clean-uo del W_q diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index ba56ea4..6d7579e 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -24,7 +24,6 @@ def get_dtype_range(compute_dtype: torch.dtype) -> float: return dtype_info.min, dtype_info.max NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count -NVFP4_META_SCALE = 0.05 #Temporary NVFP logic #################################################################################################################### #MXFP4 / NVFP4 weight quantizer #################################################################################################################### @@ -170,8 +169,8 @@ def quantize_mxfp4( @torch.compile(fullgraph=True) def quantize_nvfp4( - self, W: torch.Tensor, window_size: int = 0, index: bool = False - ) -> (torch.Tensor, torch.Tensor): + self, W: torch.Tensor, window_size: int = 0, index: bool = False, + ) -> (torch.Tensor, torch.Tensor, torch.Tensor): group_size: int = 16 eps: float = 1e-6 @@ -183,7 +182,7 @@ def quantize_nvfp4( W_flat = W.view(-1, group_size).float() ideal_scale = W_flat.abs().amax(dim=1, keepdim=True) ideal_scale /= max_val - meta_scales = NVFP4_META_SCALE #ideal_scale.max().clamp_(min=eps) - TODO: use max() + meta_scales = ideal_scale.max().clamp_(min=eps).float() ideal_scale /= meta_scales ideal_scale = ideal_scale.clamp_(max=max_fp8).to(fp8_dtype) @@ -217,7 +216,7 @@ def quantize_nvfp4( if(index): W_q = self.to_index(W_q) - return W_q, scales + return W_q, scales, meta_scales def dequantize(self, W_q, scales, shape = None, dtype = None, meta_scales = None): if(W_q.dtype == torch.uint8): #from indices @@ -1188,7 +1187,7 @@ def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: W_flat = tensor.view(-1, group_size).float() scales = W_flat.abs().amax(dim=1, keepdim=True) scales /= max_val - meta_scales = NVFP4_META_SCALE #scales.max().clamp_(min=eps) - TODO: use max() + meta_scales = scales.max().clamp_(min=eps) scales /= meta_scales scales = scales.clamp(max=max_fp8).to(fp8_dtype).to(W_flat.dtype) @@ -1215,7 +1214,7 @@ def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: .to(fp8_dtype) .view(post_pad_shape[0], post_pad_shape[1] // group_size) ) - return W_q, scales + return W_q, scales, meta_scales.float() @triton.autotune( configs=[ @@ -1370,7 +1369,7 @@ def scale_activations_nvfp4_triton_kernel( eps: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - meta_scales: tl.constexpr = NVFP4_META_SCALE, + meta_scales_ptr, use_tma: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) @@ -1410,6 +1409,7 @@ def scale_activations_nvfp4_triton_kernel( tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) #FP8 scales + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') scales = tl.max(tl.abs(tensor), axis=1, keep_dims=True) / (6. * meta_scales) scales = tl.minimum(scales, max_fp8).to(fp8_dtype) @@ -1435,10 +1435,11 @@ def scale_activations_nvfp4_triton_kernel( tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) -def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn #Nvidia only + meta_scale = (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -1465,9 +1466,10 @@ def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, ######################### eps=eps, GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, ) - return out, scales + return out, scales, meta_scale #################################################################################################################### # MXFP4 v2: persistent 1D grid, processes multiple K-groups per iteration @@ -1610,7 +1612,7 @@ def scale_activations_nvfp4_triton_kernel_v2( GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - meta_scales: tl.constexpr = NVFP4_META_SCALE, + meta_scales_ptr, ): pid = tl.program_id(0) num_programs = tl.num_programs(0) @@ -1624,6 +1626,7 @@ def scale_activations_nvfp4_triton_kernel_v2( out_dtype: tl.constexpr = out_ptr.dtype.element_ty thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') for tile_m in range(pid, num_m_tiles, num_programs): offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) m_mask = offs_m < M @@ -1677,10 +1680,11 @@ def scale_activations_nvfp4_triton_kernel_v2( out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) -def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn + meta_scale = (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -1703,9 +1707,10 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso out.stride(0), out.stride(1), eps=eps, GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, ) - return out, scales + return out, scales, meta_scale #################################################################################################################### @@ -1858,7 +1863,7 @@ def scale_activations_nvfp4_triton_kernel_v3( eps: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - meta_scales: tl.constexpr = NVFP4_META_SCALE, + meta_scales_ptr, use_tma: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) @@ -1888,6 +1893,7 @@ def scale_activations_nvfp4_triton_kernel_v3( tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) #FP8 scales + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') scales = tl.max(tl.abs(tensor), axis=1, keep_dims=True) / (6. * meta_scales) scales = tl.minimum(scales, max_fp8).to(fp8_dtype) @@ -1914,10 +1920,11 @@ def scale_activations_nvfp4_triton_kernel_v3( tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) -def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn + meta_scale = (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -1944,9 +1951,10 @@ def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tenso ######################### eps=eps, GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, ) - return out, scales + return out, scales, meta_scale @@ -2113,8 +2121,6 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: return out, scales - - #################################################################################################################### # NVFP4 v5: 2D grid with multi-group BLOCK_SIZE_K (fewer blocks, better bandwidth) #################################################################################################################### @@ -2141,7 +2147,7 @@ def scale_activations_nvfp4_triton_kernel_v5( GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - meta_scales: tl.constexpr = NVFP4_META_SCALE, + meta_scales_ptr, # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, set # the environment variable TRITON_CUDA_ARCH_LIST to include CUDA 13.0+ ptxas, # and override the bundled ptxas-blackwell. @@ -2163,9 +2169,9 @@ def scale_activations_nvfp4_triton_kernel_v5( mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) - abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) scales_raw = abs_max / (6. * meta_scales) scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) @@ -2236,10 +2242,12 @@ def scale_activations_nvfp4_triton_kernel_v5( ) -def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn + # Compute per-tensor meta_scale from activation data + meta_scale = (tensor.view(-1, group_size).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -2262,8 +2270,9 @@ def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tenso out.stride(0), out.stride(1), eps=eps, GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, ) - return out, scales + return out, scales, meta_scale diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index a807eaa..3abdefe 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -316,6 +316,7 @@ def gemm_INT_kernel( meta_evict_policy: tl.constexpr = "evict_last", a_evict: tl.constexpr = "", b_evict: tl.constexpr = "evict_first", + meta_scale_norm_ptr = None, ################################# use_tma: tl.constexpr = True, use_5d_scales: tl.constexpr = False, @@ -451,174 +452,16 @@ def gemm_INT_kernel( scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) - - acc = acc.to(output_dtype) ############################################################################################################# + #Output + acc = acc.to(output_dtype) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) -# @triton.autotune( -# configs = get_autotune_config(), -# key = KEYS, -# prune_configs_by = {'early_config_prune': kernel_config_pruner}, -# use_cuda_graph = AUTOTUNE.USE_CUDA_GRAPH, -# ) -# @triton.jit -# def gemm_INT_kernel_persistent_tma( -# a_ptr, b_ptr, c_ptr, -# scales_ptr, zeros_ptr, scales_a_ptr, -# M, N, K, M_CLOSEST, -# ######### Quant parms ######### -# W_nbits: tl.constexpr, -# group_size: tl.constexpr, -# unpack_mask: tl.constexpr, -# elements_per_sample: tl.constexpr, -# ################################# -# type_id: tl.constexpr, -# a_sizeof: tl.constexpr, -# b_sizeof: tl.constexpr, -# ######### Strides ######### -# stride_am: tl.constexpr, stride_ak: tl.constexpr, -# stride_bk: tl.constexpr, stride_bn: tl.constexpr, -# stride_cm: tl.constexpr, stride_cn: tl.constexpr, -# stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, -# stride_meta_g: tl.constexpr, stride_meta_n: tl.constexpr, -# ######### Dtypes ######### -# load_scales_as_block: tl.constexpr, #False -# input_dtype: tl.constexpr, -# output_dtype: tl.constexpr, -# acc_dtype: tl.constexpr, -# meta_dtype: tl.constexpr, -# ######### Meta-data mode ######### -# channel_scale_mode: tl.constexpr, -# W_group_mode: tl.constexpr, -# zero_is_scalar: tl.constexpr, -# ######### tuning params ######### -# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -# GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, -# ################################# -# EVEN_M: tl.constexpr = False, -# EVEN_K: tl.constexpr = False, -# EVEN_N: tl.constexpr = False, -# ################################# -# A_load_order: tl.constexpr = 0, -# data_contiguous: tl.constexpr = True, -# ################################# -# meta_evict_policy: tl.constexpr = '', -# a_evict: tl.constexpr = '', -# b_evict: tl.constexpr = '', -# NUM_SMS: tl.constexpr = 8, -# ): -# """ -# Persistent + TMA version. -# A: (M, K) fp16/bf16 -# B_packed: (K//elements_per_sample, N) int32 -# scales/zeros: (num_groups, N) or other depending on W_group_mode -# """ - -# # --------------------------- -# # Persistent tiling setup -# # --------------------------- -# start_pid = tl.program_id(0).to(tl.int32) - -# grid_m = tl.cdiv(M, BLOCK_SIZE_M) -# grid_n = tl.cdiv(N, BLOCK_SIZE_N) -# num_tiles = grid_m * grid_n -# width = GROUP_SIZE_M * grid_n # tiles per "group stripe" - -# a_desc = tl.make_tensor_descriptor( -# a_ptr, -# [M, K], -# [stride_am, stride_ak], -# [BLOCK_SIZE_M, BLOCK_SIZE_K] -# ) - -# # b_desc = tl.make_tensor_descriptor( -# # b_ptr, -# # [K, N], -# # [stride_bk, stride_bn], -# # [BLOCK_SIZE_K, BLOCK_SIZE_N] -# # ) - -# #transposed : use self.W_q = self.W_q.contiguous().t() -# b_desc = tl.make_tensor_descriptor( -# b_ptr, -# [N, K], -# [stride_bn, stride_bk], -# [BLOCK_SIZE_N, BLOCK_SIZE_K] -# ) - -# # # Precompute unpack shifts (vector length = elements_per_sample) -# # # shifts = [0, W_nbits, 2*W_nbits, ...] -# # shifts = (tl.arange(0, elements_per_sample) * W_nbits).to(tl.int32) - -# # # Optional scalar zero -# # if zero_is_scalar: -# # zero_scalar = tl.load(zeros_ptr, eviction_policy="evict_last") - -# ############################################################################################################# -# # Main loop -# for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): -# group_id = tile_id // width -# first_m = group_id * GROUP_SIZE_M -# gs = tl.minimum(grid_m - first_m, GROUP_SIZE_M) - -# pid_m = first_m + (tile_id % gs) -# pid_n = (tile_id % width) // gs - -# rm = pid_m * BLOCK_SIZE_M -# rn = pid_n * BLOCK_SIZE_N - -# # Accumulator -# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - -# # Column indices for this tile (used for metadata + store) -# offs_n = rn + tl.arange(0, BLOCK_SIZE_N) -# n_mask = offs_n < N - -# # K loop -# for k in tl.range(0, K, BLOCK_SIZE_K): -# a = tl.load_tensor_descriptor(a_desc, [rm, k]) - -# k_packed = k // elements_per_sample -# #b = tl.load_tensor_descriptor(b_desc, [k_packed, rn]) -# b = tl.load_tensor_descriptor(b_desc, [rn, k_packed]).T #Transposed - -# acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) - -# ############################################################################################################# -# # Channel-wise scaling -# offs_m = rm + tl.arange(0, BLOCK_SIZE_M) -# m_mask = offs_m < M -# if channel_scale_mode == 1: # weight-only -# # expects a 1D per-N scale at scales_ptr (same as your original) -# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) -# acc = acc.to(meta_dtype) * scales_b[None, :] - -# if channel_scale_mode == 2: # activation-only -# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) -# acc = acc.to(meta_dtype) * scales_a[:, None] - -# if channel_scale_mode == 3: # weight + activation -# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) -# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) -# acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) - -# acc = acc.to(output_dtype) - -# ############################################################################################################# -# # Store -# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn -# mask = (m_mask[:, None] & n_mask[None, :]).to(tl.int1) -# if EVEN_M and EVEN_N: -# tl.store(c_ptrs, acc) -# else: -# tl.store(c_ptrs, acc, mask=mask) - @triton.autotune( configs = get_autotune_config(), key = KEYS, @@ -668,7 +511,7 @@ def gemm_MX_kernel( meta_evict_policy: tl.constexpr = "evict_last", a_evict: tl.constexpr = "", b_evict: tl.constexpr = "", - meta_scale_norm: tl.constexpr = (0.05 ** 2), + meta_scale_norm_ptr = None, ################################# use_tma: tl.constexpr = True, use_5d_scales: tl.constexpr = False, @@ -774,10 +617,10 @@ def gemm_MX_kernel( # _1s dtype must match actual scale dtype: uint8 for MXFP (E8M0), float8e4nv for NVFP4 (E4M3) if group_size == 16: scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) - scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + #scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) else: scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) - scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + #scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): @@ -814,9 +657,6 @@ def gemm_MX_kernel( else: scales_a = scales_a_1s - #scales_b = scales_b_1s - #scales_a = scales_a_1s - #################################################################################### acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) @@ -829,17 +669,17 @@ def gemm_MX_kernel( #NVFP4 meta-scale if(group_size == 16): - acc *= meta_scale_norm + acc = acc.to(tl.float32) * tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') ############################################################################################################# #Channel-wise scaling if channel_scale_mode == 2: # activation-only - dtype: tl.constexpr = c_ptr.dtype.element_ty scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1.0, eviction_policy=meta_evict_policy) - acc = acc.to(dtype) * scales_a[:, None] + acc = acc * scales_a[:, None] ############################################################################################################# #Output + acc = acc.to(output_dtype) if use_tma: tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) else: @@ -859,6 +699,7 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, + meta_scale: Tensor = None, ) -> Tensor: @@ -902,16 +743,177 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], acc_dtype = DTYPE_TO_TRITON[acc_dtype], meta_dtype = DTYPE_TO_TRITON[meta_dtype], - channel_scale_mode = channel_scale_mode, - W_group_mode = W_group_mode, - zero_is_scalar = zeros.numel() == 1, - data_contiguous = data_contiguous, - use_tma = use_5d_scales, - use_5d_scales = use_5d_scales, + channel_scale_mode = channel_scale_mode, + W_group_mode = W_group_mode, + zero_is_scalar = zeros.numel() == 1, + data_contiguous = data_contiguous, + use_tma = use_5d_scales, + use_5d_scales = use_5d_scales, + meta_scale_norm_ptr = meta_scale, ) return output + + +# @triton.autotune( +# configs = get_autotune_config(), +# key = KEYS, +# prune_configs_by = {'early_config_prune': kernel_config_pruner}, +# use_cuda_graph = AUTOTUNE.USE_CUDA_GRAPH, +# ) +# @triton.jit +# def gemm_INT_kernel_persistent_tma( +# a_ptr, b_ptr, c_ptr, +# scales_ptr, zeros_ptr, scales_a_ptr, +# M, N, K, M_CLOSEST, +# ######### Quant parms ######### +# W_nbits: tl.constexpr, +# group_size: tl.constexpr, +# unpack_mask: tl.constexpr, +# elements_per_sample: tl.constexpr, +# ################################# +# type_id: tl.constexpr, +# a_sizeof: tl.constexpr, +# b_sizeof: tl.constexpr, +# ######### Strides ######### +# stride_am: tl.constexpr, stride_ak: tl.constexpr, +# stride_bk: tl.constexpr, stride_bn: tl.constexpr, +# stride_cm: tl.constexpr, stride_cn: tl.constexpr, +# stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, +# stride_meta_g: tl.constexpr, stride_meta_n: tl.constexpr, +# ######### Dtypes ######### +# load_scales_as_block: tl.constexpr, #False +# input_dtype: tl.constexpr, +# output_dtype: tl.constexpr, +# acc_dtype: tl.constexpr, +# meta_dtype: tl.constexpr, +# ######### Meta-data mode ######### +# channel_scale_mode: tl.constexpr, +# W_group_mode: tl.constexpr, +# zero_is_scalar: tl.constexpr, +# ######### tuning params ######### +# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +# GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, +# ################################# +# EVEN_M: tl.constexpr = False, +# EVEN_K: tl.constexpr = False, +# EVEN_N: tl.constexpr = False, +# ################################# +# A_load_order: tl.constexpr = 0, +# data_contiguous: tl.constexpr = True, +# ################################# +# meta_evict_policy: tl.constexpr = '', +# a_evict: tl.constexpr = '', +# b_evict: tl.constexpr = '', +# NUM_SMS: tl.constexpr = 8, +# ): +# """ +# Persistent + TMA version. +# A: (M, K) fp16/bf16 +# B_packed: (K//elements_per_sample, N) int32 +# scales/zeros: (num_groups, N) or other depending on W_group_mode +# """ + +# # --------------------------- +# # Persistent tiling setup +# # --------------------------- +# start_pid = tl.program_id(0).to(tl.int32) + +# grid_m = tl.cdiv(M, BLOCK_SIZE_M) +# grid_n = tl.cdiv(N, BLOCK_SIZE_N) +# num_tiles = grid_m * grid_n +# width = GROUP_SIZE_M * grid_n # tiles per "group stripe" + +# a_desc = tl.make_tensor_descriptor( +# a_ptr, +# [M, K], +# [stride_am, stride_ak], +# [BLOCK_SIZE_M, BLOCK_SIZE_K] +# ) + +# # b_desc = tl.make_tensor_descriptor( +# # b_ptr, +# # [K, N], +# # [stride_bk, stride_bn], +# # [BLOCK_SIZE_K, BLOCK_SIZE_N] +# # ) + +# #transposed : use self.W_q = self.W_q.contiguous().t() +# b_desc = tl.make_tensor_descriptor( +# b_ptr, +# [N, K], +# [stride_bn, stride_bk], +# [BLOCK_SIZE_N, BLOCK_SIZE_K] +# ) + +# # # Precompute unpack shifts (vector length = elements_per_sample) +# # # shifts = [0, W_nbits, 2*W_nbits, ...] +# # shifts = (tl.arange(0, elements_per_sample) * W_nbits).to(tl.int32) + +# # # Optional scalar zero +# # if zero_is_scalar: +# # zero_scalar = tl.load(zeros_ptr, eviction_policy="evict_last") + +# ############################################################################################################# +# # Main loop +# for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): +# group_id = tile_id // width +# first_m = group_id * GROUP_SIZE_M +# gs = tl.minimum(grid_m - first_m, GROUP_SIZE_M) + +# pid_m = first_m + (tile_id % gs) +# pid_n = (tile_id % width) // gs + +# rm = pid_m * BLOCK_SIZE_M +# rn = pid_n * BLOCK_SIZE_N + +# # Accumulator +# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + +# # Column indices for this tile (used for metadata + store) +# offs_n = rn + tl.arange(0, BLOCK_SIZE_N) +# n_mask = offs_n < N + +# # K loop +# for k in tl.range(0, K, BLOCK_SIZE_K): +# a = tl.load_tensor_descriptor(a_desc, [rm, k]) + +# k_packed = k // elements_per_sample +# #b = tl.load_tensor_descriptor(b_desc, [k_packed, rn]) +# b = tl.load_tensor_descriptor(b_desc, [rn, k_packed]).T #Transposed + +# acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) + +# ############################################################################################################# +# # Channel-wise scaling +# offs_m = rm + tl.arange(0, BLOCK_SIZE_M) +# m_mask = offs_m < M +# if channel_scale_mode == 1: # weight-only +# # expects a 1D per-N scale at scales_ptr (same as your original) +# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * scales_b[None, :] + +# if channel_scale_mode == 2: # activation-only +# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * scales_a[:, None] + +# if channel_scale_mode == 3: # weight + activation +# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) +# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + +# acc = acc.to(output_dtype) + +# ############################################################################################################# +# # Store +# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn +# mask = (m_mask[:, None] & n_mask[None, :]).to(tl.int1) +# if EVEN_M and EVEN_N: +# tl.store(c_ptrs, acc) +# else: +# tl.store(c_ptrs, acc, mask=mask) + # # Persistent version # NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count # def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index e251257..93cc4cc 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -334,6 +334,7 @@ def gemm_splitK_INT_kernel( atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', + meta_scale_norm_ptr = None, ################################# dmmy use_tma: tl.constexpr = True, use_5d_scales: tl.constexpr = False, @@ -480,6 +481,7 @@ def gemm_splitK_INT_kernel( ############################################################################################################# #Output + acc = acc.to(output_dtype) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) @@ -554,7 +556,7 @@ def gemm_splitK_MX_kernel( atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', - meta_scale_norm: tl.constexpr = (0.05 ** 2), + meta_scale_norm_ptr = None, ################################# use_tma: tl.constexpr = True, use_5d_scales: tl.constexpr = False, @@ -707,17 +709,17 @@ def gemm_splitK_MX_kernel( #NVFP4 meta-scale if(group_size == 16): - acc *= meta_scale_norm + acc = acc.to(tl.float32) * tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') ############################################################################################################# #Channel-wise scaling if channel_scale_mode == 2: # activation-only - dtype: tl.constexpr = c_ptr.dtype.element_ty scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1.0, eviction_policy=meta_evict_policy) - acc = acc.to(dtype) * scales_a[:, None] - + acc = acc * scales_a[:, None] + ############################################################################################################# #Output + acc = acc.to(output_dtype) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) @@ -741,6 +743,7 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, + meta_scale: Tensor = None, ) -> Tensor: from ..core import GEMLITE_USE_TMA @@ -788,12 +791,13 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s acc_dtype = DTYPE_TO_TRITON[acc_dtype], meta_dtype = DTYPE_TO_TRITON[meta_dtype], ################################################ - channel_scale_mode = channel_scale_mode, - W_group_mode = W_group_mode, - zero_is_scalar = zeros.numel() == 1, - data_contiguous = data_contiguous, - use_tma = use_5d_scales, - use_5d_scales = use_5d_scales, + channel_scale_mode = channel_scale_mode, + W_group_mode = W_group_mode, + zero_is_scalar = zeros.numel() == 1, + data_contiguous = data_contiguous, + use_tma = use_5d_scales, + use_5d_scales = use_5d_scales, + meta_scale_norm_ptr = meta_scale, ) if(not native_atomic): diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index e8ae9e1..89cd7a9 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -591,6 +591,7 @@ def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, + meta_scale: float = 0.0, ) -> Tensor: global KERNEL_CACHE diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index a51a948..628f35e 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -418,6 +418,7 @@ def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, + meta_scale: float = 0.0, ) -> Tensor: global KERNEL_CACHE diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index cede2a4..bec539c 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -468,6 +468,7 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, + meta_scale: float = 0.0, ) -> Tensor: M, K, N = x.shape[0], x.shape[1], W_q.shape[1] From 9cab0dcd1d4c1d990ef439224e6d5931d8b6b0b5 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 21 Mar 2026 12:04:11 -0700 Subject: [PATCH 61/63] cleanup nvfp4 --- gemlite/core.py | 4 ++-- gemlite/quant_utils.py | 20 +++++++++---------- gemlite/triton_kernels/gemm_kernels.py | 3 ++- gemlite/triton_kernels/gemm_splitK_kernels.py | 3 ++- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/gemlite/core.py b/gemlite/core.py index 6b7f605..def0740 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -193,9 +193,9 @@ def forward_functional( x, scales_x = scale_activations_mxfp4(x) elif(input_dtype in [DType.NVFP4] and channel_scale_mode == 4): #NVPF4: TODO - meta_scale = tensor_args[3] + meta_scale_w = tensor_args[3] x, scales_x, meta_scale_a = scale_activations_nvfp4(x) - meta_scale = meta_scale * meta_scale_a # combine weight and activation meta_scales + meta_scale = meta_scale_w * meta_scale_a # combine weight and activation meta_scales x = x.view(-1, x.shape[-1]) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 6d7579e..50e01ce 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -1168,7 +1168,7 @@ def scale_activations_mxfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: return W_q, scales @torch.compile(fullgraph=True) -def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: +def scale_activations_nvfp4_torch(tensor: Tensor, meta_scale=None) -> Tuple[Tensor, Tensor]: group_size: int = 16 eps: float = 1e-6 max_val: float = 6 @@ -1187,7 +1187,7 @@ def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: W_flat = tensor.view(-1, group_size).float() scales = W_flat.abs().amax(dim=1, keepdim=True) scales /= max_val - meta_scales = scales.max().clamp_(min=eps) + meta_scales = meta_scale if meta_scale is not None else scales.max().clamp_(min=eps) scales /= meta_scales scales = scales.clamp(max=max_fp8).to(fp8_dtype).to(W_flat.dtype) @@ -1435,11 +1435,11 @@ def scale_activations_nvfp4_triton_kernel( tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) -def scale_activations_nvfp4_triton(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn #Nvidia only - meta_scale = (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) + meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -1680,11 +1680,11 @@ def scale_activations_nvfp4_triton_kernel_v2( out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) -def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn - meta_scale = (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) + meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -1920,11 +1920,11 @@ def scale_activations_nvfp4_triton_kernel_v3( tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) -def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn - meta_scale = (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) + meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -2242,12 +2242,12 @@ def scale_activations_nvfp4_triton_kernel_v5( ) -def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn # Compute per-tensor meta_scale from activation data - meta_scale = (tensor.view(-1, group_size).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) + meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, group_size).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 3abdefe..d0e4ab1 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -622,6 +622,7 @@ def gemm_MX_kernel( scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) #scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + _meta_scale_norm = tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') if group_size == 16 else 1.0 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): # Load A and B tiles @@ -669,7 +670,7 @@ def gemm_MX_kernel( #NVFP4 meta-scale if(group_size == 16): - acc = acc.to(tl.float32) * tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') + acc = acc.to(tl.float32) * _meta_scale_norm ############################################################################################################# #Channel-wise scaling diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 93cc4cc..125fa57 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -660,6 +660,7 @@ def gemm_splitK_MX_kernel( scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + _meta_scale_norm = tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') if group_size == 16 else 1.0 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k): if use_tma: @@ -709,7 +710,7 @@ def gemm_splitK_MX_kernel( #NVFP4 meta-scale if(group_size == 16): - acc = acc.to(tl.float32) * tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') + acc = acc.to(tl.float32) * _meta_scale_norm ############################################################################################################# #Channel-wise scaling From 49872c3fa2c417230cd0782fcd140d40f9aa24c0 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 21 Mar 2026 12:42:14 -0700 Subject: [PATCH 62/63] add fast nvfp4 mode function --- gemlite/__init__.py | 1 + gemlite/core.py | 22 ++++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/gemlite/__init__.py b/gemlite/__init__.py index 3c2e654..046a3a4 100755 --- a/gemlite/__init__.py +++ b/gemlite/__init__.py @@ -15,6 +15,7 @@ enable_tma, set_ptx_fp4_pack, enable_cudagraph_autotune, + set_fast_nvfp4, forward_functional, ) diff --git a/gemlite/core.py b/gemlite/core.py index def0740..d777a29 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -67,7 +67,9 @@ GEMLITE_TRITON_CONFIG_CACHE = {} #Global config cache for all the kernels _GROUP_SIZE_WARNED = False GEMLITE_USE_TMA = True # Set to False for faster MXFP8 on sm_120 -GEMLITE_ENABLE_PTX_FP4_PACK = False # Set to True for hardware e2m1x2 FP4 packing (requires CUDA 13.0+ ptxas) +GEMLITE_ENABLE_PTX_FP4_PACK = False # Set to True for hardware e2m1x2 FP4 packing (requires CUDA 13.0+ ptxas) +GEMLITE_FAST_NVFP4 = False +GEMLITE_NVFP4_META_SCALES = [] # Pre-allocated per-GPU meta_scale tensors ################################################################################### #Utils @@ -111,6 +113,17 @@ def set_ptx_fp4_pack(enabled: bool = True): set_ptx_fp4_pack_flag(enabled) #Enable/disable CUDA graph-based autotuning (more accurate but slower) +#Enable/disable fast NVFP4 mode (pre-allocated static meta_scale, skips dynamic computation) +def set_fast_nvfp4(enabled: bool = True, default_value: float = 0.05): + global GEMLITE_FAST_NVFP4, GEMLITE_NVFP4_META_SCALES + GEMLITE_FAST_NVFP4 = enabled + if enabled and len(GEMLITE_NVFP4_META_SCALES) == 0: + num_gpus = torch.cuda.device_count() + GEMLITE_NVFP4_META_SCALES = [ + torch.full((1,), fill_value=default_value, device=f"cuda:{i}", dtype=torch.float32) + for i in range(num_gpus) + ] + def enable_cudagraph_autotune(enabled: bool = True): set_autotune("fast", use_cuda_graph=enabled) @@ -193,9 +206,10 @@ def forward_functional( x, scales_x = scale_activations_mxfp4(x) elif(input_dtype in [DType.NVFP4] and channel_scale_mode == 4): #NVPF4: TODO - meta_scale_w = tensor_args[3] - x, scales_x, meta_scale_a = scale_activations_nvfp4(x) - meta_scale = meta_scale_w * meta_scale_a # combine weight and activation meta_scales + meta_scale = tensor_args[3] + _static_meta = GEMLITE_NVFP4_META_SCALES[x.device.index] if GEMLITE_FAST_NVFP4 else None + x, scales_x, meta_scale_a = scale_activations_nvfp4(x, meta_scale=_static_meta) + meta_scale = meta_scale * meta_scale_a # combine weight and activation meta_scales x = x.view(-1, x.shape[-1]) From 2e597ae8b706ff8a797f267e9b23d1257cd33dc5 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sat, 21 Mar 2026 13:36:11 -0700 Subject: [PATCH 63/63] nvfp4 fused kernel --- gemlite/quant_utils.py | 232 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 218 insertions(+), 14 deletions(-) diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 50e01ce..a2b230a 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -2121,6 +2121,183 @@ def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: return out, scales +#################################################################################################################### +# Pre-allocated per-device buffers for dynamic NVFP4 meta_scale computation +_nvfp4_meta_scale_bufs = [] # meta_scale output (float32 scalar) +_nvfp4_amax_bufs = [] # atomic max scratch (float32 scalar) +_nvfp4_counter_bufs = [] # grid sync counter (int32 scalar) + +def _get_nvfp4_bufs(device_index): + """Get or create pre-allocated buffers for the given device.""" + global _nvfp4_meta_scale_bufs, _nvfp4_amax_bufs, _nvfp4_counter_bufs + for buf_list in [_nvfp4_meta_scale_bufs, _nvfp4_amax_bufs, _nvfp4_counter_bufs]: + while len(buf_list) <= device_index: + buf_list.append(None) + if _nvfp4_meta_scale_bufs[device_index] is None: + dev = f"cuda:{device_index}" + _nvfp4_meta_scale_bufs[device_index] = torch.zeros(1, device=dev, dtype=torch.float32) + _nvfp4_amax_bufs[device_index] = torch.zeros(1, device=dev, dtype=torch.float32) + _nvfp4_counter_bufs[device_index] = torch.zeros(1, device=dev, dtype=torch.int32) + return _nvfp4_meta_scale_bufs[device_index], _nvfp4_amax_bufs[device_index], _nvfp4_counter_bufs[device_index] + +#################################################################################################################### +# Fused persistent NVFP4 v6: Single-kernel amax + quantize +# Phase 1: all blocks compute tile amax, atomicMax to global, grid barrier +# Phase 2: all blocks quantize tiles using computed meta_scale +# Grid limited to num_SMs so all blocks run concurrently (spin-wait safe) +#################################################################################################################### +@triton.jit +def scale_activations_nvfp4_fused_kernel_v6( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, M_padded, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + meta_scales_ptr, # output: computed meta_scale + amax_ptr, # scratch: atomic max accumulator + counter_ptr, # scratch: grid sync counter + num_tiles_m, num_tiles_k, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + ptx_pack: tl.constexpr = False, +): + pid = tl.program_id(0) + num_pids = tl.num_programs(0) + total_tiles = num_tiles_m * num_tiles_k + + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + # Load thresholds once + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + # ---- Phase 1: Compute amax across all tiles ---- + local_amax = tl.full((1,), value=0.0, dtype=tl.float32) + for tile_idx in range(pid, total_tiles, num_pids): + tile_m = tile_idx // num_tiles_k + tile_k = tile_idx % num_tiles_k + + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tile_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tile_max = tl.max(tl.abs(tensor)) + local_amax = tl.maximum(local_amax, tile_max) + + # Atomic max to global (release: ensures atomicMax is visible before counter increment) + tl.atomic_max(amax_ptr, tl.max(local_amax, axis=0), sem='relaxed') + + # Grid barrier: last block computes meta_scale and signals + # acq_rel: acquires all prior releases (sees all other blocks' atomicMax) + old_count = tl.atomic_add(counter_ptr, 1, sem='relaxed') + if old_count == num_pids - 1: + final_amax = tl.load(amax_ptr) + tl.store(meta_scales_ptr, tl.maximum(final_amax / 6.0, eps)) + # Reset scratch for next call + tl.store(amax_ptr, 0.0) + # Signal ready by setting counter to -num_pids (distinguishable from 0..num_pids-1) + tl.store(counter_ptr, -1) + + # Spin-wait for ready signal (safe: grid <= num_SMs, all blocks run concurrently) + while tl.atomic_add(counter_ptr, 0, sem='relaxed') >= 0: + pass + + # ---- Phase 2: Quantize using computed meta_scale ---- + meta_scales = tl.load(meta_scales_ptr) + + for tile_idx in range(pid, total_tiles, num_pids): + tile_m = tile_idx // num_tiles_k + tile_k = tile_idx % num_tiles_k + + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tile_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + # Reload tile (L2 cached from Phase 1) + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + wq = tensor_flat / scales_full + + if ptx_pack: + wq_2d = tl.reshape(wq, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + wq_pairs = wq_2d.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False) + lo_val, hi_val = tl.split(wq_pairs) + lo_f16 = lo_val.to(tl.float16) + hi_f16 = hi_val.to(tl.float16) + lo_bits = lo_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF + hi_bits = (hi_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF) << 16 + packed_f16x2 = lo_bits | hi_bits + packed_e2m1 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 tmp_out; + .reg .f16x2 tmp_in; + mov.b32 tmp_in, $1; + cvt.rn.satfinite.e2m1x2.f16x2 tmp_out, tmp_in; + cvt.u32.u8 $0, tmp_out; + } + """, + constraints="=r,r", + args=[packed_f16x2], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + out = packed_e2m1.to(tl.uint8) + else: + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + # Store quantized output + offs_k_out = tile_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + # Store scales + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + scales_2d = tl.where(offs_m[:, None] < M, scales_2d, tl.full(scales_2d.shape, 1.0, dtype=tl.float32).to(fp8_dtype)) + base_group = tile_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] + ) + + # Last block resets counter for next call + if old_count == num_pids - 1: + tl.store(counter_ptr, 0) + #################################################################################################################### # NVFP4 v5: 2D grid with multi-group BLOCK_SIZE_K (fewer blocks, better bandwidth) #################################################################################################################### @@ -2246,8 +2423,6 @@ def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor, meta_scale=None) -> group_size: int = 16 eps: float = 1e-6 fp8_dtype = torch.float8_e4m3fn - # Compute per-tensor meta_scale from activation data - meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, group_size).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -2258,20 +2433,49 @@ def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor, meta_scale=None) -> out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) - - grid = lambda meta: (triton.cdiv(M_padded, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) device_index = tensor.device.index - scale_activations_nvfp4_triton_kernel_v5[grid]( - tensor, out, scales, thr_pos[device_index], - M, M_padded, K, - tensor.stride(0), tensor.stride(1), - scales.stride(0), scales.stride(1), - out.stride(0), out.stride(1), - eps=eps, - GROUP_SIZE=group_size, - meta_scales_ptr=meta_scale, - ) + if meta_scale is None: + # Fused path: single kernel computes amax + quantizes + meta_scale, amax_buf, counter_buf = _get_nvfp4_bufs(device_index) + BLOCK_M = 16 + BLOCK_K = 256 + num_tiles_m = triton.cdiv(M_padded, BLOCK_M) + num_tiles_k = triton.cdiv(K, BLOCK_K) + total_tiles = num_tiles_m * num_tiles_k + num_SMs = torch.cuda.get_device_properties(device_index).multi_processor_count + num_blocks = min(total_tiles, num_SMs) + + scale_activations_nvfp4_fused_kernel_v6[(num_blocks,)]( + tensor, out, scales, thr_pos[device_index], + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, + amax_ptr=amax_buf, + counter_ptr=counter_buf, + num_tiles_m=num_tiles_m, + num_tiles_k=num_tiles_k, + BLOCK_SIZE_M=BLOCK_M, + BLOCK_SIZE_K=BLOCK_K, + ) + else: + # Static path: meta_scale already provided, use v5 kernel directly + grid = lambda meta: (triton.cdiv(M_padded, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + scale_activations_nvfp4_triton_kernel_v5[grid]( + tensor, out, scales, thr_pos[device_index], + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, + ) + return out, scales, meta_scale