diff --git a/examples/eval_flops.py b/examples/eval_flops.py new file mode 100644 index 0000000..78a52bc --- /dev/null +++ b/examples/eval_flops.py @@ -0,0 +1,451 @@ +import torch +import time, gc +import gemlite +from gemlite.helper import * +import argparse +import torch._dynamo +torch._dynamo.config.recompile_limit = 256 +import torch._inductor.config as _inductor_config +import triton + +device, dtype = 'cuda:0', torch.bfloat16 +repeat = 32 + +gemlite.reset_config() +gemlite.enable_cudagraph_autotune(True) +gemlite.enable_tma(True) +#gemlite.set_ptx_fp4_pack(True) +#gemlite.set_autotune("max") +#gemlite.core.enable_activation_scaling(2) + +def get_model(K, N, repeat=repeat): + torch.manual_seed(0) + model = torch.nn.Sequential(*[ + torch.nn.Linear(N, K, dtype=dtype, device=device, bias=False) + for _ in range(repeat) + ]) + model.requires_grad_(False) + return model + + +@torch.no_grad() +def eval_model(model, M, K, iters=50, verbose=False): + torch.manual_seed(0) + t = [] + for i in range(iters): + x = torch.randn(M, K, dtype=dtype, device=device) + torch.cuda.synchronize() + t1 = time.perf_counter() + out = model(x) + torch.cuda.synchronize() + t2 = time.perf_counter() + _time = (t2 - t1) * 1000 + t.append(_time) + if verbose: + print(f"Took: {_time} ms") + t = t[-(iters // 2):] + time_torch = (sum(t) / len(t)) + return time_torch + + +def get_flops(M, K, N, perf_time_ms): + flops_per_linear = 2 * M * N * K + tflops = flops_per_linear / (perf_time_ms * 1e-3) / 1e12 + return tflops + + +def cleanup(model): + del model + torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() + + +########################################################################################################################### +# Pytorch INT8 dynamic reference +########################################################################################################################### +class NativePyTorchINT8Dynamic(torch.nn.Module): + def __init__(self, linear_layer): + super().__init__() + w_fp16 = linear_layer.weight.data + self.w_scales = w_fp16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-5) / 127.0 + w_int8 = torch.round(w_fp16 / self.w_scales).to(torch.int8) + self.w_int8 = w_int8.contiguous() + self.w_scales = self.w_scales.view(1, -1) + + def forward(self, x): + x_scales = x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5) / 127.0 + x_int8 = torch.round(x / x_scales).to(torch.int8) + out_int32 = torch._int_mm(x_int8, self.w_int8.t()) + return out_int32.to(x.dtype) * (x_scales * self.w_scales) + + +def patch_model_native_int8(model): + for i, layer in enumerate(model): + if isinstance(layer, torch.nn.Linear): + model[i] = NativePyTorchINT8Dynamic(layer) + + +########################################################################################################################### +# Pytorch FP8 dynamic reference +########################################################################################################################### +def _to_fp8_and_inv_scale( + x: torch.Tensor, + fp8_dtype: torch.dtype, + dim: int | tuple[int, ...] | None, + keepdim: bool, + clamp_min: float = 1e-12, +): + finfo = torch.finfo(fp8_dtype) + x_fp32 = x.float() + if dim is None: + amax = x_fp32.abs().amax().clamp(min=clamp_min) + else: + amax = x_fp32.abs().amax(dim=dim, keepdim=keepdim).clamp(min=clamp_min) + + scale_gain = (finfo.max / amax) + x_scaled_sat = (x_fp32 * scale_gain).clamp(min=finfo.min, max=finfo.max) + x_fp8 = x_scaled_sat.to(fp8_dtype) + inv_scale = scale_gain.reciprocal().to(torch.float32) + return x_fp8, inv_scale + + +class NativePyTorchFP8Dynamic(torch.nn.Module): + def __init__( + self, + linear_layer: torch.nn.Linear, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + use_fast_accum: bool = False, + ): + super().__init__() + self.fp8_dtype = fp8_dtype + self.use_fast_accum = use_fast_accum + + w_hp = linear_layer.weight.data + w_fp8, w_inv_scale_row = _to_fp8_and_inv_scale(w_hp, fp8_dtype=fp8_dtype, dim=1, keepdim=True) + self.register_buffer("w_fp8", w_fp8.contiguous().t()) + self.register_buffer("w_inv_scale", w_inv_scale_row.view(1, -1).contiguous()) + + if linear_layer.bias is not None: + self.register_buffer("bias", linear_layer.bias.data.contiguous()) + else: + self.bias = None + + def forward(self, x: torch.Tensor): + x_fp8, x_inv_scale = _to_fp8_and_inv_scale(x, fp8_dtype=self.fp8_dtype, dim=-1, keepdim=True) + out = torch._scaled_mm( + x_fp8, + self.w_fp8, + scale_a=x_inv_scale, + scale_b=self.w_inv_scale, + bias=self.bias, + out_dtype=x.dtype, + use_fast_accum=self.use_fast_accum, + ) + if isinstance(out, tuple): + out = out[0] + return out + + +def patch_model_native_fp8(model, fp8_dtype=torch.float8_e4m3fn, use_fast_accum=False): + for i, layer in enumerate(model): + if isinstance(layer, torch.nn.Linear): + model[i] = NativePyTorchFP8Dynamic( + layer, fp8_dtype=fp8_dtype, use_fast_accum=use_fast_accum, + ) + + +########################################################################################################################### +# flashinfer NVFP4 reference (CUTLASS-based, supports sm_120) +########################################################################################################################### +def _get_flashinfer(): + """Check if flashinfer with NVFP4 support is available.""" + try: + from flashinfer import nvfp4_quantize, mm_fp4, SfLayout + return True, None + except ImportError: + return False, "flashinfer not installed (pip install flashinfer)" + + +# ---- custom_op wrappers for torch.compile compatibility ---- +@torch.library.custom_op("flashinfer_bench::nvfp4_quantize", mutates_args=()) +def _nvfp4_quantize_op( + a: torch.Tensor, a_global_sf: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + from flashinfer import nvfp4_quantize, SfLayout + a_fp4, a_sf = nvfp4_quantize(a, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + return a_fp4, a_sf + + +@torch.library.register_fake("flashinfer_bench::nvfp4_quantize") +def _nvfp4_quantize_fake( + a: torch.Tensor, a_global_sf: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + M, K = a.shape + a_fp4 = torch.empty((M, K // 2), dtype=torch.uint8, device=a.device) + a_sf = torch.empty((M, K // 16), dtype=torch.uint8, device=a.device) + return a_fp4, a_sf + + +@torch.library.custom_op("flashinfer_bench::mm_fp4", mutates_args=()) +def _mm_fp4_op( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: torch.Tensor, + out_N: int, +) -> torch.Tensor: + from flashinfer import mm_fp4 + return mm_fp4(a, b, a_descale, b_descale, alpha, torch.bfloat16, backend="cutlass") + + +@torch.library.register_fake("flashinfer_bench::mm_fp4") +def _mm_fp4_fake( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: torch.Tensor, + out_N: int, +) -> torch.Tensor: + M = a.shape[0] + return torch.empty((M, out_N), dtype=torch.bfloat16, device=a.device) + + +class FlashinferNVFP4Dynamic(torch.nn.Module): + """ + NVFP4 dynamic quantization using flashinfer CUTLASS backend. + Weights quantized offline in __init__; activations quantized on-the-fly in forward. + Compatible with torch.compile via custom_op wrappers. + """ + + def __init__(self, linear_layer: torch.nn.Linear): + super().__init__() + from flashinfer import nvfp4_quantize, SfLayout + + w_bf16 = linear_layer.weight.data # [N, K] + N, K = w_bf16.shape + + # Quantize weights offline + w_global_sf = (448.0 * 6.0) / w_bf16.float().abs().nan_to_num().amax().clamp(min=1e-12) + w_fp4, w_sf = nvfp4_quantize( + w_bf16, w_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False + ) + + # Store pre-transposed for mm_fp4: b=[K//2, N], b_descale=[K//16, N] + self.register_buffer("w_fp4_t", w_fp4.T.contiguous()) + self.register_buffer("w_sf_t", w_sf.T.contiguous()) + self.register_buffer( + "w_global_sf_inv", + (1.0 / w_global_sf).to(torch.float32).contiguous(), + ) + self.N = N + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Activation quantization: compute global scale with pytorch ops + x_global_sf = (448.0 * 6.0) / x.float().abs().nan_to_num().amax().clamp(min=1e-12) + + # Quantize activation via custom_op (flashinfer CUDA kernel) + x_fp4, x_sf = torch.ops.flashinfer_bench.nvfp4_quantize(x, x_global_sf) + + # alpha = 1 / (x_global_sf * w_global_sf) + alpha = self.w_global_sf_inv / x_global_sf + + # CUTLASS FP4 matmul via custom_op + return torch.ops.flashinfer_bench.mm_fp4( + x_fp4, self.w_fp4_t, x_sf, self.w_sf_t, alpha, self.N + ) + + +def patch_model_flashinfer_nvfp4(model): + for i, layer in enumerate(model): + if isinstance(layer, torch.nn.Linear): + model[i] = FlashinferNVFP4Dynamic(layer) + + + +########################################################################################################################### +def run_benchmark(proc_name, M, K, N): + """ + Unified benchmark runner. Returns (label, M, K, N, tflops) or None on skip. + Handles gemlite processors, native PyTorch INT8/FP8, and flashinfer NVFP4. + """ + has_flashinfer, fi_err = _get_flashinfer() + + # ---- flashinfer NVFP4 dynamic (torch.compile + activation quant) ---- + if proc_name == "flashinfer_nvfp4_dynamic": + if not has_flashinfer: + print(f" Skipping {proc_name}: {fi_err}") + return None + # Disable cudagraph trees: flashinfer CUTLASS does internal workspace allocs + old_cudagraph = _inductor_config.triton.cudagraph_trees + _inductor_config.triton.cudagraph_trees = False + + # NOTE: flashinfer's CUTLASS NVFP4 kernel requires M to be a multiple of 128. + # When M < 128, we pad M up to 128 so the kernel doesn't crash. The TFLOP/s + # are computed using the padded M to keep the comparison fair (same actual work). + M_padded = max(M, 128) + M_padded = ((M_padded + 127) // 128) * 128 + + model = get_model(K, N, repeat=repeat) + patch_model_flashinfer_nvfp4(model) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M_padded, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + label = "flashinfer NVFP4 (dynamic)" + if M_padded != M: + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s (M padded to {M_padded} internally)") + else: + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + _inductor_config.triton.cudagraph_trees = old_cudagraph + return (label, M, K, N, tflops) + + # ---- Native PyTorch INT8 dynamic ---- + if proc_name == "native_int8": + if M <= 16: + print(f" Skipping native_int8 for M={M} (requires M > 16)") + return None + model = get_model(K, N, repeat=repeat) + patch_model_native_int8(model) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + label = "PyTorch Native INT8" + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + return (label, M, K, N, tflops) + + # ---- Native PyTorch FP8 dynamic ---- + if proc_name == "native_fp8": + model = get_model(K, N, repeat=repeat) + patch_model_native_fp8(model, fp8_dtype=torch.float8_e4m3fn, use_fast_accum=False) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + tflops = get_flops(M, K, N, perf_time_ms) + label = "PyTorch Native FP8" + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + return (label, M, K, N, tflops) + + # ---- GemLite processors + BF16 baseline ---- + GEMLITE_MAP = { + "A16W8_INT8": lambda: A16W8_INT8(), + "A16W8_FP8": lambda: A16W8_FP8(), + "A16W4_HQQ_INT": lambda: A16W4_HQQ_INT(), + "A8W8_INT8_dynamic": lambda: A8W8_INT8_dynamic(), + "A8W8_FP8_dynamic": lambda: A8W8_FP8_dynamic(), + "A8W8_MXFP_dynamic_post_scale": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=True), + "A8W8_MXFP_dynamic": lambda: A8W8_MXFP_dynamic(dtype=dtype, post_scale=False), + "A4W4_MXFP_dynamic": lambda: A4W4_MXFP_dynamic(dtype=dtype), + "A4W4_NVFP_dynamic": lambda: A4W4_NVFP_dynamic(dtype=dtype), + "none": lambda: None, + "fp16": lambda: None, + } + + if proc_name not in GEMLITE_MAP: + print(f" Unknown processor: {proc_name}, skipping.") + return None + + procesor = GEMLITE_MAP[proc_name]() + + model = get_model(K, N, repeat=repeat) + if procesor is not None: + patch_model(model, device=device, processor=procesor) + model = torch.compile(model, mode="reduce-overhead", fullgraph=True) + + perf_time_ms = eval_model(model, M, K) / repeat + label = proc_name if procesor is not None else "BF16 (no processor)" + tflops = get_flops(M, K, N, perf_time_ms) + print(f" {label} | {M}, {K}, {N} | {tflops:.2f} TFLOP/s") + + cleanup(model) + return (label, M, K, N, tflops) + + +ALL_PROCESSORS = [ + "none", + "A16W8_INT8", + "A16W8_FP8", + "A16W4_HQQ_INT", + "A8W8_INT8_dynamic", + "A8W8_FP8_dynamic", + "A8W8_MXFP_dynamic_post_scale", + "A8W8_MXFP_dynamic", + "A4W4_MXFP_dynamic", + "A4W4_NVFP_dynamic", + "native_int8", + "native_fp8", + "flashinfer_nvfp4_dynamic", +] + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate TFLOP/s for various quantized matmul processors.", + epilog=""" +Examples: + # Run with default parameters (all processors) + python eval_flops.py + + # Run with specific dimensions: + python eval_flops.py --M 8192 --K 8192 --N 8192 + + # Run only specific processors (comma-separated): + python eval_flops.py --processor A4W4_MXFP_dynamic,flashinfer_nvfp4_dynamic,native_fp8 + + # Run only BF16 baseline (no quantization): + python eval_flops.py --processor none + + # Available processors: + # GemLite: A16W8_INT8, A16W8_FP8, A16W4_HQQ_INT, + # A8W8_INT8_dynamic, A8W8_FP8_dynamic, + # A8W8_MXFP_dynamic_post_scale, A8W8_MXFP_dynamic, + # A4W4_MXFP_dynamic, A4W4_NVFP_dynamic + # PyTorch: native_int8, native_fp8 + # flashinfer: flashinfer_nvfp4_dynamic + # Baseline: none / fp16 (BF16, no quantization) + # Use "all" to run every processor. + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--M", type=int, default=8192, help="Batch/sequence dimension") + parser.add_argument("--K", type=int, default=8192, help="Input feature dimension") + parser.add_argument("--N", type=int, default=8192, help="Output feature dimension") + parser.add_argument("--processor", type=str, default="all", + help='Comma-separated processor names or "all" (default: all)') + args = parser.parse_args() + + M, K, N = args.M, args.K, args.N + + if args.processor == "all": + processor_names = list(ALL_PROCESSORS) + else: + processor_names = [p.strip() for p in args.processor.split(",")] + + results = [] + for proc_name in processor_names: + result = run_benchmark(proc_name, M, K, N) + if result is not None: + results.append(result) + + # ---- Summary ---- + print("\n" + "=" * 70) + gpu_name = torch.cuda.get_device_name(device) + print(f"SUMMARY (GPU: {gpu_name})") + print("=" * 70) + max_label_len = max(len(r[0]) for r in results) if results else 0 + for label, m, k, n, tflops in results: + print(f" {label:<{max_label_len}} | {m}, {k}, {n} | {tflops:.2f} TFLOP/s") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/gemlite/__init__.py b/gemlite/__init__.py index ca6d3b5..046a3a4 100755 --- a/gemlite/__init__.py +++ b/gemlite/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.1.post1" +__version__ = "0.6.0" __author__ = 'Dr. Hicham Badri' __credits__ = 'Mobius Labs GmbH' @@ -12,6 +12,10 @@ set_acc_dtype, set_autotune, set_kernel_caching, + enable_tma, + set_ptx_fp4_pack, + enable_cudagraph_autotune, + set_fast_nvfp4, forward_functional, ) diff --git a/gemlite/core.py b/gemlite/core.py index 30fa424..d777a29 100755 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -66,6 +66,10 @@ GEMLITE_MATMUL_TYPES_MAPPING = {GEMLITE_MATMUL_TYPES[i]: i for i in range(len(GEMLITE_MATMUL_TYPES))} GEMLITE_TRITON_CONFIG_CACHE = {} #Global config cache for all the kernels _GROUP_SIZE_WARNED = False +GEMLITE_USE_TMA = True # Set to False for faster MXFP8 on sm_120 +GEMLITE_ENABLE_PTX_FP4_PACK = False # Set to True for hardware e2m1x2 FP4 packing (requires CUDA 13.0+ ptxas) +GEMLITE_FAST_NVFP4 = False +GEMLITE_NVFP4_META_SCALES = [] # Pre-allocated per-GPU meta_scale tensors ################################################################################### #Utils @@ -96,19 +100,48 @@ def set_acc_dtype(dtype): assert dtype in [DType.FP16, DType.FP32], "Invalid dtype (should be DType.FP16 or DType.FP32)." GEMLITE_ACC_DTYPE[DType.FP16] = dtype +#Enable/disable TMA for MX kernel data loading +def enable_tma(enabled: bool = True): + global GEMLITE_USE_TMA + GEMLITE_USE_TMA = enabled + +#Enable/disable hardware PTX FP4 packing in activation quantization (requires CUDA 13.0+ ptxas) +def set_ptx_fp4_pack(enabled: bool = True): + global GEMLITE_ENABLE_PTX_FP4_PACK + GEMLITE_ENABLE_PTX_FP4_PACK = enabled + from .quant_utils import set_ptx_fp4_pack_flag + set_ptx_fp4_pack_flag(enabled) + +#Enable/disable CUDA graph-based autotuning (more accurate but slower) +#Enable/disable fast NVFP4 mode (pre-allocated static meta_scale, skips dynamic computation) +def set_fast_nvfp4(enabled: bool = True, default_value: float = 0.05): + global GEMLITE_FAST_NVFP4, GEMLITE_NVFP4_META_SCALES + GEMLITE_FAST_NVFP4 = enabled + if enabled and len(GEMLITE_NVFP4_META_SCALES) == 0: + num_gpus = torch.cuda.device_count() + GEMLITE_NVFP4_META_SCALES = [ + torch.full((1,), fill_value=default_value, device=f"cuda:{i}", dtype=torch.float32) + for i in range(num_gpus) + ] + +def enable_cudagraph_autotune(enabled: bool = True): + set_autotune("fast", use_cuda_graph=enabled) + #Return the default gemv kernel to use for M==1 def get_default_gemv(W_nbits: int, mx_dtype: bool = False) -> str: #TODO: adapt mx for IS_HIP = True if mx_dtype: - return 'GEMM_SPLITK' #TODO: fix mxf bugs in GEMV outputs garbage. + return 'GEMM_SPLITK' #TODO:'GEMV' if (W_nbits < 8) else 'GEMM_SPLITK' -> Revisit NVFP4 failing test. else: return 'GEMV_REVSPLITK' if (W_nbits < 8) else 'GEMV_SPLITK' #matmul type selection logic def get_matmul_type(batch_size: int, W_nbits: int, mx_dtype: bool = False): - if batch_size > 64: + gemm_limit = 64 + if batch_size >= gemm_limit: return "GEMM" - if batch_size > 1: + gemv_limit = 4 if (W_nbits < 8 and not mx_dtype) else 2 # previous 1 + if batch_size > gemv_limit: return "GEMM_SPLITK" else: return get_default_gemv(W_nbits, mx_dtype) @@ -121,7 +154,7 @@ def enable_activation_scaling(batch_size): Only works with the MXFP format - use with A8W4_MXFP/A4W4_MXFP. """ return True - #return batch_size >= 32 + #return batch_size >= 2 #TODO: Needs Triton fix https://github.com/triton-lang/triton/pull/9577 #Main functional forward call @@ -155,6 +188,7 @@ def forward_functional( scaled_activations = bool(meta_args[0]) and enable_activation_scaling(batch_size) #Dynamic activation quantization scales_x = None + meta_scale = 0.0 if(scaled_activations): input_dtype = DType(meta_args[5]) channel_scale_mode = meta_args[9] @@ -172,7 +206,10 @@ def forward_functional( x, scales_x = scale_activations_mxfp4(x) elif(input_dtype in [DType.NVFP4] and channel_scale_mode == 4): #NVPF4: TODO - x, scales_x = scale_activations_nvfp4(x) + meta_scale = tensor_args[3] + _static_meta = GEMLITE_NVFP4_META_SCALES[x.device.index] if GEMLITE_FAST_NVFP4 else None + x, scales_x, meta_scale_a = scale_activations_nvfp4(x, meta_scale=_static_meta) + meta_scale = meta_scale * meta_scale_a # combine weight and activation meta_scales x = x.view(-1, x.shape[-1]) @@ -184,7 +221,7 @@ def forward_functional( out = ( GEMLITE_TRITON_MAPPING[matmul_type_str] .forward( - x, *tensor_args, scales_x, *meta_args[1:-1], data_contiguous, type_id + x, *tensor_args[:3], scales_x, *meta_args[1:-1], data_contiguous, type_id, meta_scale=meta_scale ) .view(out_shape) ) @@ -252,7 +289,7 @@ def __init__( if in_features is not None and out_features is not None: if (in_features % GemLiteLinearTriton.MIN_SIZE != 0) or ( - in_features % group_size != 0 if (group_size is not None) else False + (in_features % group_size != 0) if (group_size is not None and W_nbits < 16) else False ): raise NotImplementedError( "Invalid input shapes: " @@ -298,6 +335,19 @@ def __init__( #Default forward self.forward = self.forward_auto_no_warmup + #Meta-scale for NVFP4 (0.0 = not used) + self.meta_scale = 0.0 + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # Rebuild metadata from live attributes to ensure consistency + # (helpers may override channel_scale_mode/W_group_mode after pack()) + if hasattr(self, 'metadata') and self.metadata is not None and hasattr(self, 'W_nbits'): + self.metadata = torch.nn.Parameter( + torch.tensor(self.get_meta_args(), device=self.metadata.device, dtype=torch.int32), + requires_grad=False, + ) + super()._save_to_state_dict(destination, prefix, keep_vars) + def load_state_dict(self, state_dict, strict=True, assign=False): self.W_q = state_dict.pop("W_q", None) self.bias = state_dict.pop("bias", None) @@ -305,6 +355,7 @@ def load_state_dict(self, state_dict, strict=True, assign=False): self.zeros = state_dict.pop("zeros", None) self.metadata = state_dict.pop("metadata", None) self.orig_shape = state_dict.pop("orig_shape", None) + _meta_scale = state_dict.pop("meta_scale", None) self.metadata = [v.item() for v in self.metadata] self.orig_shape = (v.item() for v in self.orig_shape) @@ -327,10 +378,27 @@ def load_state_dict(self, state_dict, strict=True, assign=False): self.acc_dtype = DType(self.acc_dtype) self.meta_dtype = DType(self.meta_dtype) + # Restore meta_scale with backward compat for old checkpoints + if _meta_scale is not None: + self.meta_scale = _meta_scale.float() + else: + self.meta_scale = 0.05 if self.input_dtype == DType.NVFP4 else 0.0 # backward compat default for old checkpoints + self.out_features, self.in_features = self.orig_shape self.compute_dtype = DTYPE_TO_TORCH[self.input_dtype.value] self.scaled_activations = bool(self.scaled_activations) self.data_contiguous = bool(self.data_contiguous) + + # Backward compat: pop stale scales_5d from old saves + state_dict.pop("scales_5d", None) + # Convert 2D scales to 5D TMA layout for MX dtypes + if is_mx_dtype(self.input_dtype) and self.scales is not None: + s = self.scales.data if isinstance(self.scales, torch.nn.Parameter) else self.scales + if s.ndim == 2: + s_2d = s.T.contiguous() # [K_S, N] contiguous + N_dim, K_S = s_2d.shape[1], s_2d.shape[0] + if GEMLITE_USE_TMA and self.elements_per_sample > 1 and N_dim % 128 == 0 and K_S % 4 == 0: + self.scales = s_2d.reshape(N_dim // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N_dim // 128, K_S // 4, 2, 256).contiguous() #Make sure to feed UINT8 W_q for packing def pack( @@ -399,7 +467,7 @@ def pack( if(self.W_q is None): raise Exception('Weights were not packed, please check your W_q.dtype') - + #Bias / device self.device = self.W_q.device self.bias = None if (bias is None) else bias.to(device=self.device) @@ -492,9 +560,29 @@ def pack( if(self.input_dtype in [DType.NVFP4]): self.scales = self.scales.to(torch.float8_e4m3fn) if(is_mx_dtype(self.input_dtype)): - self.scales = self.scales.T self.W_group_mode = 2 self.channel_scale_mode = 0 + + ################################ + # TMA + K, N = self.W_q.shape + + if(self.input_dtype in [DType.MXFP4, DType.NVFP4]): + K *= 2 + group_size = 2 * self.W_q.numel() // self.scales.numel() + else: + group_size = self.W_q.numel() // self.scales.numel() + + # Preshuffle weight scales to 5D TMA layout for fast loading + # Original: [K_S, N] -> transpose to [N, K_S] -> 5D: [1, N//128, K_S//4, 2, 256] + K_S = K // group_size + if GEMLITE_USE_TMA and self.elements_per_sample > 1 and N % 128 == 0 and K_S % 4 == 0: + # Currently TMA only enabled for MXFP4/NVFP4 NOT for MXFP8 because of poor performance on sm_120 (self.elements_per_sample > 1 check) + self.scales = self.scales.T.contiguous().reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous() + else: + # Keep 2D transposed layout for pointer-based fallback + self.scales = self.scales.T + ################################ if(self.scales is not None): self.meta_dtype = TORCH_TO_DTYPE[self.scales.dtype] @@ -516,11 +604,20 @@ def pack( requires_grad=False, ) + + self.meta_scale = torch.nn.Parameter( + torch.tensor(self.meta_scale, device=self.device, dtype=torch.float32), + requires_grad=False, + ) + return self #Return the main arguments def get_tensor_args(self): - return [self.W_q, self.scales, self.zeros] + meta_scale = self.meta_scale + if not isinstance(meta_scale, torch.Tensor): + meta_scale = torch.tensor(meta_scale, dtype=torch.float32, device=self.W_q.device) + return [self.W_q, self.scales, self.zeros, meta_scale] def get_meta_args(self): return [int(self.scaled_activations), diff --git a/gemlite/helper.py b/gemlite/helper.py index 35a4054..213ab18 100755 --- a/gemlite/helper.py +++ b/gemlite/helper.py @@ -836,7 +836,7 @@ def from_weights(self, weight, bias=None, scales=None): assert weight.dtype in [torch.uint8], f"Invalid weight.dtype, should be an MXPF8 valid dtype, got {weight.dtype}." assert scales.dtype in [torch.float8_e8m0fnu, torch.uint8], f"Invalid scales.dtype, should be e8m0 / view(uint8), got {scales.dtype}." assert self.dtype is not None, f"Input dtype should be either torch.float16 or torch.bfloat16, not None." - assert self.group_size == 32, f"Only group_size=16 is supported for MXFP4, got {self.group_size}" + assert self.group_size == 32, f"Only group_size=32 is supported for MXFP4, got {self.group_size}" dtype = self.dtype gemlite_dtype = TORCH_TO_DTYPE[dtype] @@ -888,7 +888,7 @@ def __init__(self, device='cuda:0', dtype=None): self.group_size = 16 self.input_dtype = DType.NVFP4 - def from_weights(self, weight, bias=None, scales=None): + def from_weights(self, weight, bias=None, scales=None, meta_scale=None): if(isinstance(weight, torch.nn.Parameter)): weight = weight.data if(isinstance(bias, torch.nn.Parameter)): @@ -923,6 +923,11 @@ def from_weights(self, weight, bias=None, scales=None): gemlite_linear.pack(W_q, scales, zeros=None, bias=bias) gemlite_linear.W_group_mode = 0 gemlite_linear.channel_scale_mode = 4 + if meta_scale is not None: + gemlite_linear.meta_scale = torch.nn.Parameter( + meta_scale.to(dtype=torch.float32, device=gemlite_linear.W_q.device).reshape(()), + requires_grad=False, + ) return gemlite_linear @@ -933,11 +938,11 @@ def from_linear(self, linear_layer, del_orig=True): W = linear_layer.weight.data bias = linear_layer.bias.clone() if (linear_layer.bias is not None) else None N, K = W.shape - W_q, scales = self.quantizer_mx.quantize_nvfp4(W, index=True) + W_q, scales, _meta_scale = self.quantizer_mx.quantize_nvfp4(W, index=True) W_q, scales = W_q.view([N, K]), scales.view(N, K // self.group_size) cleanup_linear(linear_layer, del_orig) - out_layer = self.from_weights(weight=W_q, scales=scales, bias=bias) + out_layer = self.from_weights(weight=W_q, scales=scales, bias=bias, meta_scale=_meta_scale) #Clean-uo del W_q diff --git a/gemlite/quant_utils.py b/gemlite/quant_utils.py index 96ec3ad..a2b230a 100644 --- a/gemlite/quant_utils.py +++ b/gemlite/quant_utils.py @@ -10,6 +10,11 @@ from .triton_kernels.utils import IS_HIP, get_num_SMs, next_power_of_2 from .dtypes import * +GEMLITE_ENABLE_PTX_FP4_PACK = False # Enable with CUDA13+ ptxas +def set_ptx_fp4_pack_flag(enabled: bool): + global GEMLITE_ENABLE_PTX_FP4_PACK + GEMLITE_ENABLE_PTX_FP4_PACK = enabled + #Get dtype min/max range based on compute dtype def get_dtype_range(compute_dtype: torch.dtype) -> float: if(compute_dtype.is_floating_point): @@ -18,7 +23,7 @@ def get_dtype_range(compute_dtype: torch.dtype) -> float: dtype_info = torch.iinfo(compute_dtype) return dtype_info.min, dtype_info.max -NVFP4_META_SCALE = 0.05 #Temporary NVFP logic +NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count #################################################################################################################### #MXFP4 / NVFP4 weight quantizer #################################################################################################################### @@ -164,8 +169,8 @@ def quantize_mxfp4( @torch.compile(fullgraph=True) def quantize_nvfp4( - self, W: torch.Tensor, window_size: int = 0, index: bool = False - ) -> (torch.Tensor, torch.Tensor): + self, W: torch.Tensor, window_size: int = 0, index: bool = False, + ) -> (torch.Tensor, torch.Tensor, torch.Tensor): group_size: int = 16 eps: float = 1e-6 @@ -177,7 +182,7 @@ def quantize_nvfp4( W_flat = W.view(-1, group_size).float() ideal_scale = W_flat.abs().amax(dim=1, keepdim=True) ideal_scale /= max_val - meta_scales = NVFP4_META_SCALE #ideal_scale.max().clamp_(min=eps) - TODO: use max() + meta_scales = ideal_scale.max().clamp_(min=eps).float() ideal_scale /= meta_scales ideal_scale = ideal_scale.clamp_(max=max_fp8).to(fp8_dtype) @@ -211,15 +216,17 @@ def quantize_nvfp4( if(index): W_q = self.to_index(W_q) - return W_q, scales + return W_q, scales, meta_scales - def dequantize(self, W_q, scales, shape = None, dtype = None): + def dequantize(self, W_q, scales, shape = None, dtype = None, meta_scales = None): if(W_q.dtype == torch.uint8): #from indices device_index = W_q.device.index W_q = fp4_values[device_index][W_q.int()] group_size = W_q.numel() // scales.numel() out = (W_q.view([-1, group_size]).float() * scales.float()) + if meta_scales is not None: + out = out * meta_scales if(shape is not None): out = out.view(shape) return out.to(self.compute_dtype if dtype is None else dtype) @@ -227,6 +234,29 @@ def dequantize(self, W_q, scales, shape = None, dtype = None): #################################################################################################################### #INT8 / FP8 activations #################################################################################################################### +def prune_large_blocks(configs, named_args, **kwargs): + M = named_args['M'] + + pruned = [] + for config in configs: + if config.kwargs['BLOCK_SIZE_M'] <= M: + pruned.append(config) + + if not pruned: + for config in configs: + new_kwargs = config.kwargs.copy() + new_kwargs['BLOCK_SIZE_M'] = 16 + + pruned.append( + triton.Config( + new_kwargs, + num_warps=config.num_warps, + num_stages=config.num_stages + ) + ) + + return pruned + # Main activation scaling functions @torch.compile(fullgraph=True) def scale_activations_per_token_torch( @@ -249,7 +279,7 @@ def scale_activations_per_token_torch( if not w_dtype.is_floating_point: out.round_() - out = out.to(dtype=w_dtype) + out = out.to(dtype=w_dtype) return out.view(out_shape), scales @triton.jit @@ -266,23 +296,25 @@ def round_triton_amd(tensor): round_triton = round_triton_nvidia @triton.jit -def scale_activations_per_token_kernel( +def scale_activations_per_token_triton_v1_kernel( tensor_ptr, scale_ptr, y_ptr, M, K, - stride_m, stride_k, stride_sm, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, ROUND: tl.constexpr, UNROLL: tl.constexpr, min_val: tl.constexpr, max_val: tl.constexpr, fp32_scale: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(0) * UNROLL pid_k = tl.program_id(1) - offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) for m in range(UNROLL): mask = ((offs_m < M)[:, None] & (offs_k < K)[None, :]).to(tl.int1) @@ -302,10 +334,9 @@ def scale_activations_per_token_kernel( tl.store(scale_ptr + offs_m[:, None] * stride_sm, scales_x) tl.store(y_ptr + in_ptrs, tensor, mask=mask) - offs_m += BLOCK_M + offs_m += BLOCK_SIZE_M - -def scale_activations_per_token_triton( +def scale_activations_per_token_triton_v1( tensor: Tensor, w_dtype: torch.dtype, fp32_scale: bool = True ) -> Tuple[Tensor, Tensor]: min_val, max_val = get_dtype_range(w_dtype) @@ -318,13 +349,13 @@ def scale_activations_per_token_triton( y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) UNROLL = 1 # max(1, M // 128) - BLOCK_M = 1 - BLOCK_K = triton.next_power_of_2(K) - grid = (triton.cdiv(M, BLOCK_M * UNROLL), triton.cdiv(K, BLOCK_K)) + BLOCK_SIZE_M = 1 + BLOCK_SIZE_K = triton.next_power_of_2(K) + grid = (triton.cdiv(M, BLOCK_SIZE_M * UNROLL), triton.cdiv(K, BLOCK_SIZE_K)) ROUND = not w_dtype.is_floating_point - scale_activations_per_token_kernel[grid]( + scale_activations_per_token_triton_v1_kernel[grid]( tensor, scales, y, @@ -338,44 +369,325 @@ def scale_activations_per_token_triton( fp32_scale=fp32_scale, ROUND=ROUND, UNROLL=UNROLL, - BLOCK_M=BLOCK_M, - BLOCK_K=BLOCK_K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, num_stages=1, num_warps=4, ) return y.view(x_shape), scales +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 1}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 2}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 4}, num_warps=8, num_stages=1), + ], + key=['M', 'K'] +) +@triton.jit +def scale_activations_per_token_triton_v2_kernel( + tensor_ptr, scale_ptr, y_ptr, + M, K, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, + min_val: tl.constexpr, + max_val: tl.constexpr, + fp32_scale: tl.constexpr, + ROUND: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + m_mask = offs_m < M + k_mask = offs_k < K + mask = m_mask[:, None] & k_mask[None, :] + + offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + + tensor = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) + + if fp32_scale: + tensor = tensor.to(tl.float32) + + scales_x = tl.max(tl.abs(tensor), axis=1) / max_val + scales_x = tl.maximum(scales_x, 1e-6) + tensor = tensor / scales_x[:, None] + tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) + + if ROUND: + tensor = round_triton(tensor) + + tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) + tl.store(y_ptr + offsets, tensor, mask=mask) + +def scale_activations_per_token_triton_v2( + tensor: torch.Tensor, w_dtype: torch.dtype, fp32_scale: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + + min_val, max_val = get_dtype_range(w_dtype) + + x_shape = tensor.shape + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) + y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), ) + + BLOCK_SIZE_K = triton.next_power_of_2(K) + ROUND = not w_dtype.is_floating_point + + scale_activations_per_token_triton_v2_kernel[grid]( + tensor, scales, y, + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), + min_val=min_val, max_val=max_val, + fp32_scale=fp32_scale, ROUND=ROUND, + BLOCK_SIZE_K=BLOCK_SIZE_K + ) + + return y.view(x_shape), scales + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 1}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 2}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 4}, num_warps=8, num_stages=1), + ], + key=['M', 'K'] +) +@triton.jit +def scale_activations_per_token_triton_v3_kernel( + tensor_ptr, scale_ptr, y_ptr, + M, K, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, + min_val: tl.constexpr, + max_val: tl.constexpr, + fp32_scale: tl.constexpr, + ROUND: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + start_pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + k_mask = offs_k < K + + for pid_m in range(start_pid, num_tiles, num_programs): + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + mask = m_mask[:, None] & k_mask[None, :] + + offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + tensor = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) + + if fp32_scale: + tensor = tensor.to(tl.float32) + + scales_x = tl.max(tl.abs(tensor), axis=1) / max_val + scales_x = tl.maximum(scales_x, 1e-6) + tensor = tensor / scales_x[:, None] + tensor = tl.minimum(tl.maximum(tensor, min_val), max_val) + + if ROUND: + tensor = round_triton(tensor) + + tl.store(y_ptr + offsets, tensor, mask=mask) + tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) + +def scale_activations_per_token_triton_v3( + tensor: torch.Tensor, w_dtype: torch.dtype, fp32_scale: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + + min_val, max_val = get_dtype_range(w_dtype) + + x_shape = tensor.shape + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + scales = torch.empty((M, 1), dtype=torch.float32 if fp32_scale else tensor.dtype, device=tensor.device) + y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META['BLOCK_SIZE_M'])), ) + + BLOCK_SIZE_K = triton.next_power_of_2(K) + ROUND = not w_dtype.is_floating_point + + scale_activations_per_token_triton_v3_kernel[grid]( + tensor, scales, y, + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), + min_val=min_val, max_val=max_val, + fp32_scale=fp32_scale, ROUND=ROUND, + BLOCK_SIZE_K=BLOCK_SIZE_K + ) + + return y.view(x_shape), scales + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 1, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 2048}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 2048}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_SIZE_M": 4, "BLOCK_SIZE_K": 4096}, num_warps=8, num_stages=2), + ], + key=["M", "K"], +) +@triton.jit +def scale_activations_per_token_triton_v4_kernel( + tensor_ptr, + scale_ptr, + y_ptr, + M, + K, + stride_m: tl.constexpr, + stride_k: tl.constexpr, + stride_sm: tl.constexpr, + min_val: tl.constexpr, + max_val: tl.constexpr, + fp32_scale: tl.constexpr, + ROUND: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + start_pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) + + for pid_m in range(start_pid, num_tiles, num_programs): + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + # Pass 1: streaming amax over K chunks + row_max = tl.zeros([BLOCK_SIZE_M], dtype=tl.float32) + for k_start in range(0, K, BLOCK_SIZE_K): + offs_k = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = offs_k < K + mask = m_mask[:, None] & k_mask[None, :] + chunk = tl.load( + tensor_ptr + offs_m[:, None] * stride_m + offs_k[None, :] * stride_k, + mask=mask, + other=0.0, + ) + if fp32_scale: + chunk = chunk.to(tl.float32) + row_max = tl.maximum(row_max, tl.max(tl.abs(chunk), axis=1)) + + scales_x = row_max / max_val + scales_x = tl.maximum(scales_x, 1e-6) + tl.store(scale_ptr + offs_m * stride_sm, scales_x, mask=m_mask) + + # Pass 2: scale, clamp, store + inv_scales = 1.0 / scales_x + for k_start in range(0, K, BLOCK_SIZE_K): + offs_k = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = offs_k < K + mask = m_mask[:, None] & k_mask[None, :] + offsets = offs_m[:, None] * stride_m + offs_k[None, :] * stride_k + chunk = tl.load(tensor_ptr + offsets, mask=mask, other=0.0) + if fp32_scale: + chunk = chunk.to(tl.float32) + chunk = chunk * inv_scales[:, None] + chunk = tl.minimum(tl.maximum(chunk, min_val), max_val) + if ROUND: + chunk = round_triton(chunk) + tl.store(y_ptr + offsets, chunk, mask=mask) + + +def scale_activations_per_token_triton_v4( + tensor: torch.Tensor, + w_dtype: torch.dtype, + fp32_scale: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + min_val, max_val = get_dtype_range(w_dtype) + + x_shape = tensor.shape + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + scales = torch.empty( + (M, 1), + dtype=torch.float32 if fp32_scale else tensor.dtype, + device=tensor.device, + ) + y = torch.empty((M, K), dtype=w_dtype, device=tensor.device) + + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"])),) + + ROUND = not w_dtype.is_floating_point + + scale_activations_per_token_triton_v4_kernel[grid]( + tensor, + scales, + y, + M, + K, + tensor.stride(0), + tensor.stride(1), + scales.stride(0), + min_val=min_val, + max_val=max_val, + fp32_scale=fp32_scale, + ROUND=ROUND, + ) + + return y.view(x_shape), scales #################################################################################################################### #MXFP8 #################################################################################################################### @triton.jit -def next_power_of_2_log_triton(val, eps: tl.constexpr): +def next_power_of_2_log_triton(val, eps_exp: tl.constexpr): exp = tl.ceil(tl.log2(val)).to(tl.int32) + exp = exp + 127 exp = tl.maximum(tl.minimum(exp, 254), 127 + eps_exp) - scales = tl.where(exp >= 0, 1 << scales_log2, 1.0 / (1 << (-exp))) + scales = tl.cast(exp << 23, tl.float32, bitcast=True) return scales, exp @triton.jit -def next_power_of_2_logapprox_triton(val, eps_exp: tl.constexpr): - exp = tl.inline_asm_elementwise( - """ - { - lg2.approx.f32 $1, $1; - cvt.rpi.f32.f32 $1, $1; - cvt.rzi.s32.f32 $0, $1; - } +def next_power_of_2_ptx_triton(val, eps_exp: tl.constexpr): + scales, biased_exp = tl.inline_asm_elementwise( + f""" + {{ + .reg .f32 f_log; + .reg .f32 f_ceil; + .reg .s32 r_exp; + .reg .f32 f_clamped; + + lg2.approx.f32 f_log, $2; + cvt.rpi.f32.f32 f_ceil, f_log; + cvt.rzi.s32.f32 r_exp, f_ceil; + + max.s32 r_exp, r_exp, {eps_exp}; + min.s32 r_exp, r_exp, 127; + + add.s32 $1, r_exp, 127; + cvt.rn.f32.s32 f_clamped, r_exp; + ex2.approx.f32 $0, f_clamped; + }} """, - "=r,r", + "=f,=r,f", [val], - dtype=tl.int32, + dtype=(tl.float32, tl.int32), is_pure=True, pack=1 ) - - exp = tl.maximum(tl.minimum(exp, 254), 127 + eps_exp) - scales = tl.where(exp >= 0, 1 << exp, 1.0 / (1 << (-exp))) - return scales, exp + + return scales, biased_exp @triton.jit def next_power_of_2_bitwise_triton(val, eps_exp: tl.constexpr): @@ -384,12 +696,14 @@ def next_power_of_2_bitwise_triton(val, eps_exp: tl.constexpr): mant = xi & 0x7FFFFF exp += tl.where(mant != 0, 1, 0) exp = tl.maximum(tl.minimum(exp, 254), 127 + eps_exp) - yi = exp << 23 - scales = tl.cast(yi, tl.float32, bitcast=True) + scales = tl.cast(exp << 23, tl.float32, bitcast=True) return scales, exp next_power_of_2_triton = next_power_of_2_bitwise_triton +#################################################################################################################### +#MXFP8 +#################################################################################################################### @torch.compile(fullgraph=True) def scale_activations_mxfp8_torch( tensor: Tensor, w_dtype: torch.dtype = torch.float8_e4m3fn @@ -505,9 +819,13 @@ def scale_activations_mxfp8_triton_kernel_v2( out_ptr, scales_ptr, M, K, - stride_m_t, stride_k_t, - stride_m_s, stride_k_s, - stride_m_o, stride_k_o, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, ######################### min_val: tl.constexpr, max_val: tl.constexpr, @@ -548,7 +866,7 @@ def scale_activations_mxfp8_triton_v2( ) -> Tuple[torch.Tensor, torch.Tensor]: group_size: int = 32 eps_exp: int = -30 - eps: float = 2 ** -30 + eps: float = 2 ** eps_exp min_val, max_val = get_dtype_range(w_dtype) tensor = tensor.contiguous() @@ -560,7 +878,7 @@ def scale_activations_mxfp8_triton_v2( out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - + #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) BLOCK_SIZE_M = group_size grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) @@ -580,13 +898,226 @@ def scale_activations_mxfp8_triton_v2( eps_exp=eps_exp, GROUP_SIZE=group_size, BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=2, + num_stages=1, num_warps=4, ) return out, scales +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp8_triton_kernel_v3( + tensor_ptr, + out_ptr, + scales_ptr, + M, K, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, + ######################### + min_val: tl.constexpr, + max_val: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + + tensor_block_ptr = tl.make_block_ptr( + base=tensor_ptr, shape=(M, K), strides=(stride_m_t, stride_k_t), + offsets=(pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE), + block_shape=(BLOCK_SIZE_M, GROUP_SIZE), order=(1, 0) + ) + + out_block_ptr = tl.make_block_ptr( + base=out_ptr, shape=(M, K), strides=(stride_m_o, stride_k_o), + offsets=(pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE), + block_shape=(BLOCK_SIZE_M, GROUP_SIZE), order=(1, 0) + ) + + tensor = tl.load(tensor_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + abs_max = tl.max(tl.abs(tensor), axis=1, keep_dims=True) + scales, scales_log2 = next_power_of_2_triton(abs_max / max_val, eps_exp) + + out = tensor * (1.0 / scales) + out = tl.clamp(out, min=min_val, max=max_val) + out = out.to(out_ptr.dtype.element_ty) + + tl.store(out_block_ptr, out, boundary_check=(0, 1)) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask_m = offs_m < M + scales_ptrs = scales_ptr + (offs_m * stride_m_s + pid_k * stride_k_s) + tl.store(scales_ptrs, tl.reshape(scales_log2, (BLOCK_SIZE_M, )), mask=mask_m) + +def scale_activations_mxfp8_triton_v3( + tensor: torch.Tensor, w_dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 32 + eps_exp: int = -30 + eps: float = 2 ** eps_exp + min_val, max_val = get_dtype_range(w_dtype) + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) + device_index = tensor.device.index + + scale_activations_mxfp8_triton_kernel_v3[grid]( + tensor, + out, + scales, + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + min_val=min_val, + max_val=max_val, + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + + return out, scales + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 512}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp8_triton_kernel_v4( + tensor_ptr, out_ptr, scales_ptr, + M, M_padded, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + min_val: tl.constexpr, max_val: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M_padded, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + # Load [BLOCK_M, BLOCK_K] + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + # Reshape to [BLOCK_M * GROUPS_PER_BLOCK, GROUP_SIZE] for group-wise reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # Per-group abs_max → power-of-2 scale + abs_max = tl.max(tl.abs(tensor_flat), axis=1) + scales, scales_log2 = next_power_of_2_bitwise_triton(abs_max / max_val, eps_exp) + + # Quantize: multiply by reciprocal, clamp, cast + out = tensor_flat * (1.0 / scales[:, None]) + out = tl.clamp(out, min=min_val, max=max_val) + out = out.to(out_dtype) + + # Reshape back to [BLOCK_M, BLOCK_K] and store + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + tl.store(out_bp, out, boundary_check=(0, 1)) + + # Store scales: [FLAT_M] → [BLOCK_M, GROUPS_PER_BLOCK] + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + # For padding rows (M <= row < M_padded), store identity scale (127 = 2^0 in E8M0) + scales_2d = tl.where(m_mask[:, None], scales_2d, tl.full(scales_2d.shape, 127, dtype=tl.uint8)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, BLOCK_SIZE_K)) + +# ersistent 1D grid, processes multiple K-groups per iteration via reshape +def scale_activations_mxfp8_triton_v4( + tensor: torch.Tensor, w_dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + group_size: int = 32 + eps_exp: int = -30 + min_val, max_val = get_dtype_range(w_dtype) + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K), device=tensor.device, dtype=w_dtype) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) + scale_activations_mxfp8_triton_kernel_v4[grid]( + tensor, out, scales, + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + out.stride(0), out.stride(1), + scales.stride(0), scales.stride(1), + min_val=min_val, max_val=max_val, + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + + return out, scales #################################################################################################################### #MXPF4 / NVFP4 #################################################################################################################### @@ -637,7 +1168,7 @@ def scale_activations_mxfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: return W_q, scales @torch.compile(fullgraph=True) -def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: +def scale_activations_nvfp4_torch(tensor: Tensor, meta_scale=None) -> Tuple[Tensor, Tensor]: group_size: int = 16 eps: float = 1e-6 max_val: float = 6 @@ -656,7 +1187,7 @@ def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: W_flat = tensor.view(-1, group_size).float() scales = W_flat.abs().amax(dim=1, keepdim=True) scales /= max_val - meta_scales = NVFP4_META_SCALE #scales.max().clamp_(min=eps) - TODO: use max() + meta_scales = meta_scale if meta_scale is not None else scales.max().clamp_(min=eps) scales /= meta_scales scales = scales.clamp(max=max_fp8).to(fp8_dtype).to(W_flat.dtype) @@ -683,110 +1214,555 @@ def scale_activations_nvfp4_torch(tensor: Tensor) -> Tuple[Tensor, Tensor]: .to(fp8_dtype) .view(post_pad_shape[0], post_pad_shape[1] // group_size) ) - return W_q, scales - + return W_q, scales, meta_scales.float() + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) @triton.jit -def scale_activations_mxfp4_triton_kernel_v1( +def scale_activations_mxfp4_triton_kernel( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, - E, + M, K, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, + ######################### eps_exp: tl.constexpr, - UNROLL: tl.constexpr, GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + use_tma: tl.constexpr = False, ): - pid = tl.program_id(axis=0) * UNROLL + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + if use_tma: + tensor_desc = tl.make_tensor_descriptor( + tensor_ptr, + [M, K], + [stride_m_t, stride_k_t], + [BLOCK_SIZE_M, GROUP_SIZE] + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + [M, K // 2], + [stride_m_o, stride_k_o], + [BLOCK_SIZE_M, HALF_GROUP_SIZE] + ) HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 out_dtype: tl.constexpr = out_ptr.dtype.element_ty thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] - for m in range(UNROLL): - #Load - offs = pid * GROUP_SIZE + tl.arange(0, GROUP_SIZE) - mask = (offs < E).to(tl.int1) - tensor = tl.load(tensor_ptr + offs, mask=mask, other=0.0).to(tl.float32) - - scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor)) / 6., eps_exp) - - #Map to index - wq = tensor / scales - idx_abs = tl.sum(tl.abs(wq[:, None]) > thr_pos, axis=1) - out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) - - #Pack - lo, hi = tl.split(out.reshape((HALF_GROUP_SIZE, 2), can_reorder=False)) - out = lo | (hi << 4) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) - #Store - offs_out = pid * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) - tl.store(out_ptr + offs_out, out) - tl.store(scales_ptr + pid, scales_log2) + #Load + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + + if use_tma: + tensor = tl.load_tensor_descriptor(tensor_desc, [pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE]).to(tl.float32) + else: + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + #next power of 2 via log + scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor), axis=1, keep_dims=True) / 6., eps_exp) - pid += 1 + #Map to index + wq = tensor / scales + idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + #Pack + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_GROUP_SIZE, 2), can_reorder=False)) + out = lo | (hi << 4) + + #Store + offs_k = pid_k * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) + out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) + if use_tma: + tl.store_tensor_descriptor(out_desc, [pid_m * BLOCK_SIZE_M, pid_k * HALF_GROUP_SIZE], out) + else: + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) + + offs_k = pid_k * 1 + tl.arange(0, 1) + tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales_log2) -def scale_activations_mxfp4_triton_v1(tensor: Tensor) -> Tuple[Tensor, Tensor]: +def scale_activations_mxfp4_triton(tensor: Tensor) -> Tuple[Tensor, Tensor]: group_size: int = 32 eps_exp: int = -30 - eps = 2 ** eps_exp + eps: float = 2 ** eps_exp + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) + device_index = tensor.device.index + + scale_activations_mxfp4_triton_kernel[grid]( + tensor, + out, + scales, + thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) + + return out, scales + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=3), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_nvfp4_triton_kernel( + tensor_ptr, + out_ptr, + scales_ptr, + thr_pos_ptr, + M, K, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, + ######################### + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + meta_scales_ptr, + use_tma: tl.constexpr = False, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) - orig_shape = tensor.shape + if use_tma: + tensor_desc = tl.make_tensor_descriptor( + tensor_ptr, + [M, K], + [stride_m_t, stride_k_t], + [BLOCK_SIZE_M, GROUP_SIZE] + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + [M, K // 2], + [stride_m_o, stride_k_o], + [BLOCK_SIZE_M, HALF_GROUP_SIZE] + ) + + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] + #thr_pos += 1e-6 + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + + #Load + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + + if use_tma: + tensor = tl.load_tensor_descriptor(tensor_desc, [pid_m * BLOCK_SIZE_M, pid_k * GROUP_SIZE]).to(tl.float32) + else: + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + #FP8 scales + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') + scales = tl.max(tl.abs(tensor), axis=1, keep_dims=True) / (6. * meta_scales) + scales = tl.minimum(scales, max_fp8).to(fp8_dtype) + + #Map to index + scales_full = tl.maximum(scales.to(tl.float32) * meta_scales, eps) + wq = tensor / scales_full + idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + #Pack + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_GROUP_SIZE, 2), can_reorder=False)) + out = lo | (hi << 4) + + #Store + offs_k = pid_k * HALF_GROUP_SIZE + tl.arange(0, HALF_GROUP_SIZE) + out_mask = ((offs_m[:, None] < M) & (offs_k[None, :] < (K // 2))).to(tl.int1) + if use_tma: + tl.store_tensor_descriptor(out_desc, [pid_m * BLOCK_SIZE_M, pid_k * HALF_GROUP_SIZE], out) + else: + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k[None, :] * stride_k_o), out, mask=out_mask) + + offs_k = pid_k + tl.arange(0, 1) + tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) + + +def scale_activations_nvfp4_triton(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn #Nvidia only + meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) + + tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) - inter_shape = (tensor.shape[0], tensor.shape[1] // 2) - pad_rows = (group_size - inter_shape[0] % group_size) % group_size - post_pad_shape = (inter_shape[0] + pad_rows, inter_shape[1]) - E = tensor.numel() + M, K = tensor.shape - UNROLL = min(triton.cdiv(triton.cdiv(E, group_size), get_num_SMs(tensor.device)), 1) + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m - out = torch.empty(inter_shape, device=tensor.device, dtype=torch.uint8) - scales = torch.empty( - (post_pad_shape[0], post_pad_shape[1] * 2 // group_size), - device=tensor.device, - dtype=torch.uint8, + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) + device_index = tensor.device.index + + scale_activations_nvfp4_triton_kernel[grid]( + tensor, + out, + scales, + thr_pos[device_index], + M, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + ######################### + eps=eps, + GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, ) + + return out, scales, meta_scale + +#################################################################################################################### +# MXFP4 v2: persistent 1D grid, processes multiple K-groups per iteration +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp4_triton_kernel_v2( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, M_padded, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M_padded, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] + + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K // 2), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + # Reshape to [FLAT_M, GROUP_SIZE] for group-wise reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # Per-group power-of-2 scale + scales, scales_log2 = next_power_of_2_bitwise_triton( + tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp + ) + + # Map to FP4 index via threshold comparison + wq = tensor_flat / scales + idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + # Reshape to [BLOCK_M, BLOCK_K] then pack pairs + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + tl.store(out_bp, out, boundary_check=(0, 1)) + + # Store scales: [FLAT_M, 1] → [BLOCK_M, GROUPS_PER_BLOCK] + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=m_mask[:, None] & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) + + +def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: + group_size: int = 32 + eps_exp: int = -30 + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) device_index = tensor.device.index - - grid = lambda meta: (triton.cdiv(E // UNROLL, group_size), ) - scale_activations_mxfp4_triton_kernel_v1[grid]( - tensor, - out, - scales, - thr_pos[device_index], - E, - eps_exp=eps_exp, - UNROLL=UNROLL, - GROUP_SIZE=group_size, - num_stages=1, - num_warps=4, - ) + + scale_activations_mxfp4_triton_kernel_v2[grid]( + tensor, out, scales, thr_pos[device_index], + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) return out, scales +#################################################################################################################### +# NVFP4 v2: persistent 1D grid, processes multiple K-groups per iteration +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 4, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 128}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) @triton.jit -def scale_activations_mxfp4_triton_kernel_v2( +def scale_activations_nvfp4_triton_kernel_v2( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, M_padded, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + meta_scales_ptr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + num_m_tiles = tl.cdiv(M_padded, BLOCK_SIZE_M) + + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] + + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') + for tile_m in range(pid, num_m_tiles, num_programs): + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < M + + tensor_bp = tl.make_block_ptr( + tensor_ptr, (M, K), (stride_m_t, stride_k_t), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0) + ) + out_bp = tl.make_block_ptr( + out_ptr, (M, K // 2), (stride_m_o, stride_k_o), + (tile_m * BLOCK_SIZE_M, 0), + (BLOCK_SIZE_M, HALF_BLOCK_K), order=(1, 0) + ) + + for k_start in range(0, K, BLOCK_SIZE_K): + tensor = tl.load(tensor_bp, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + # Reshape to [FLAT_M, GROUP_SIZE] for group-wise reduction + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + # Per-group FP8 scale + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + # Map to FP4 index via threshold comparison + wq = tensor_flat / scales_full + idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + + # Reshape to [BLOCK_M, BLOCK_K] then pack pairs + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + tl.store(out_bp, out, boundary_check=(0, 1)) + + # Store scales: [FLAT_M, 1] → [BLOCK_M, GROUPS_PER_BLOCK] + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + group_idx = k_start // GROUP_SIZE + offs_g = group_idx + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=m_mask[:, None] & g_mask[None, :] + ) + + tensor_bp = tl.advance(tensor_bp, (0, BLOCK_SIZE_K)) + out_bp = tl.advance(out_bp, (0, HALF_BLOCK_K)) + + +def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn + meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta['BLOCK_SIZE_M'])),) + device_index = tensor.device.index + + scale_activations_nvfp4_triton_kernel_v2[grid]( + tensor, out, scales, thr_pos[device_index], + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, + ) + + return out, scales, meta_scale + + +#################################################################################################################### +# MXFP4 v3: 2D grid like v1, but scalar threshold loop to avoid 3D tensor +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) +@triton.jit +def scale_activations_mxfp4_triton_kernel_v3( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, M, K, - stride_m_t, stride_k_t, - stride_m_s, stride_k_s, - stride_m_o, stride_k_o, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, ######################### eps_exp: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + use_tma: tl.constexpr = False, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 out_dtype: tl.constexpr = out_ptr.dtype.element_ty - thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] + + # Load 8 thresholds as individual scalars + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) @@ -795,13 +1771,17 @@ def scale_activations_mxfp4_triton_kernel_v2( mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) - + #next power of 2 via log scales, scales_log2 = next_power_of_2_triton(tl.max(tl.abs(tensor), axis=1, keep_dims=True) / 6., eps_exp) - #Map to index + #Map to index via scalar threshold comparisons (avoids 3D intermediate) wq = tensor / scales - idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) #Pack @@ -816,10 +1796,9 @@ def scale_activations_mxfp4_triton_kernel_v2( offs_k = pid_k * 1 + tl.arange(0, 1) tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales_log2) -def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: +def scale_activations_mxfp4_triton_v3(tensor: Tensor) -> Tuple[Tensor, Tensor]: group_size: int = 32 eps_exp: int = -30 - eps: float = 2 ** eps_exp tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -831,48 +1810,62 @@ def scale_activations_mxfp4_triton_v2(tensor: Tensor) -> Tuple[Tensor, Tensor]: out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) - #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) - BLOCK_SIZE_M = group_size - - grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index - scale_activations_mxfp4_triton_kernel_v2[grid]( + scale_activations_mxfp4_triton_kernel_v3[grid]( tensor, out, scales, thr_pos[device_index], - M, K, + M, K, tensor.stride(0), tensor.stride(1), scales.stride(0), scales.stride(1), out.stride(0), out.stride(1), ######################### eps_exp=eps_exp, GROUP_SIZE=group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=2, - num_warps=4, ) return out, scales + +#################################################################################################################### +# NVFP4 v3: 2D grid like v1, but scalar threshold loop to avoid 3D tensor +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=3), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks}, +) @triton.jit -def scale_activations_nvfp4_triton_kernel_v2( +def scale_activations_nvfp4_triton_kernel_v3( tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, M, K, - stride_m_t, stride_k_t, - stride_m_s, stride_k_s, - stride_m_o, stride_k_o, + ######################### + stride_m_t: tl.constexpr, + stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, + stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, + stride_k_o: tl.constexpr, ######################### eps: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - meta_scales: tl.constexpr = NVFP4_META_SCALE, + meta_scales_ptr, + use_tma: tl.constexpr = False, ): - pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -880,8 +1873,16 @@ def scale_activations_nvfp4_triton_kernel_v2( max_fp8: tl.constexpr = 448. HALF_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 out_dtype: tl.constexpr = out_ptr.dtype.element_ty - thr_pos = tl.load(thr_pos_ptr + tl.arange(0, 8), eviction_policy='evict_last')[None, :] - #thr_pos += 1e-6 + + # Load 8 thresholds as individual scalars + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = pid_k * GROUP_SIZE + tl.arange(0, GROUP_SIZE) @@ -890,15 +1891,20 @@ def scale_activations_nvfp4_triton_kernel_v2( mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) - + #FP8 scales + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') scales = tl.max(tl.abs(tensor), axis=1, keep_dims=True) / (6. * meta_scales) scales = tl.minimum(scales, max_fp8).to(fp8_dtype) - #Map to index + #Map to index via scalar threshold comparisons (avoids 3D intermediate) scales_full = tl.maximum(scales.to(tl.float32) * meta_scales, eps) wq = tensor / scales_full - idx_abs = tl.sum(tl.abs(wq[:, :, None]) > thr_pos[None, :, :], axis=2) + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) #Pack @@ -914,10 +1920,11 @@ def scale_activations_nvfp4_triton_kernel_v2( tl.store(scales_ptr + (offs_m[:, None] * stride_m_s + offs_k[None, :] * stride_k_s), scales) -def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def scale_activations_nvfp4_triton_v3(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_size: int = 16 eps: float = 1e-6 - fp8_dtype = torch.float8_e4m3fn #Nvidia only + fp8_dtype = torch.float8_e4m3fn + meta_scale = meta_scale if meta_scale is not None else (tensor.view(-1, 16).abs().amax(dim=1) / 6.0).max().float().clamp_(min=eps) tensor = tensor.contiguous() tensor = tensor.view(-1, tensor.shape[-1]) @@ -929,32 +1936,553 @@ def scale_activations_nvfp4_triton_v2(tensor: torch.Tensor) -> Tuple[torch.Tenso out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) - #BLOCK_SIZE_M = min(max(next_power_of_2(M), group_size), 128) - BLOCK_SIZE_M = group_size - grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(K, group_size)) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, group_size)) device_index = tensor.device.index - scale_activations_nvfp4_triton_kernel_v2[grid]( + scale_activations_nvfp4_triton_kernel_v3[grid]( tensor, out, scales, thr_pos[device_index], - M, K, + M, K, tensor.stride(0), tensor.stride(1), scales.stride(0), scales.stride(1), out.stride(0), out.stride(1), ######################### eps=eps, GROUP_SIZE=group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - num_stages=2, - num_warps=4, + meta_scales_ptr=meta_scale, ) + return out, scales, meta_scale + + + +#################################################################################################################### +# MXFP4 v5: 2D grid with multi-group BLOCK_SIZE_K (fewer blocks, better bandwidth) +#################################################################################################################### +def prune_large_blocks_2d(configs, named_args, **kwargs): + M = named_args['M'] + K = named_args['K'] + + pruned = [] + for config in configs: + bm = config.kwargs['BLOCK_SIZE_M'] + bk = config.kwargs['BLOCK_SIZE_K'] + if bm <= M and bk <= K: + pruned.append(config) + + if not pruned: + pruned.append(configs[0]) + + return pruned + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks_2d}, +) +@triton.jit +def scale_activations_mxfp4_triton_kernel_v5( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, M_padded, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps_exp: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, replace + # the bundled ptxas-blackwell with the system one: cp /usr/local/cuda/bin/ptxas + # /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas-blackwell + # TODO: once Triton ships CUDA 13.0+ ptxas, set default to True and add ptx_pack + # to the autotuner configs so it can pick the best path per shape. + ptx_pack: tl.constexpr = GEMLITE_ENABLE_PTX_FP4_PACK, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + + scales, scales_log2 = next_power_of_2_triton( + tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) / 6., eps_exp + ) + + wq = tensor_flat / scales + + if ptx_pack: + # PTX path: hardware e2m1x2 quantization + nibble packing + wq_2d = tl.reshape(wq, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + wq_pairs = wq_2d.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False) + lo_val, hi_val = tl.split(wq_pairs) + lo_f16 = lo_val.to(tl.float16) + hi_f16 = hi_val.to(tl.float16) + lo_bits = lo_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF + hi_bits = (hi_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF) << 16 + packed_f16x2 = lo_bits | hi_bits + packed_e2m1 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 tmp_out; + .reg .f16x2 tmp_in; + mov.b32 tmp_in, $1; + cvt.rn.satfinite.e2m1x2.f16x2 tmp_out, tmp_in; + cvt.u32.u8 $0, tmp_out; + } + """, + constraints="=r,r", + args=[packed_f16x2], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + out = packed_e2m1.to(tl.uint8) + else: + # Threshold path: 8 comparisons + manual nibble packing + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + scales_2d = tl.reshape(scales_log2, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + # For padding rows (M <= row < M_padded), store identity scale (127 = 2^0 in E8M0) + scales_2d = tl.where(offs_m[:, None] < M, scales_2d, tl.full(scales_2d.shape, 127, dtype=tl.uint8)) + base_group = pid_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] + ) + + +def scale_activations_mxfp4_triton_v5(tensor: Tensor) -> Tuple[Tensor, Tensor]: + group_size: int = 32 + eps_exp: int = -30 + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=torch.uint8) + + grid = lambda meta: (triton.cdiv(M_padded, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + device_index = tensor.device.index + + scale_activations_mxfp4_triton_kernel_v5[grid]( + tensor, out, scales, thr_pos[device_index], + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps_exp=eps_exp, + GROUP_SIZE=group_size, + ) return out, scales + +#################################################################################################################### +# Pre-allocated per-device buffers for dynamic NVFP4 meta_scale computation +_nvfp4_meta_scale_bufs = [] # meta_scale output (float32 scalar) +_nvfp4_amax_bufs = [] # atomic max scratch (float32 scalar) +_nvfp4_counter_bufs = [] # grid sync counter (int32 scalar) + +def _get_nvfp4_bufs(device_index): + """Get or create pre-allocated buffers for the given device.""" + global _nvfp4_meta_scale_bufs, _nvfp4_amax_bufs, _nvfp4_counter_bufs + for buf_list in [_nvfp4_meta_scale_bufs, _nvfp4_amax_bufs, _nvfp4_counter_bufs]: + while len(buf_list) <= device_index: + buf_list.append(None) + if _nvfp4_meta_scale_bufs[device_index] is None: + dev = f"cuda:{device_index}" + _nvfp4_meta_scale_bufs[device_index] = torch.zeros(1, device=dev, dtype=torch.float32) + _nvfp4_amax_bufs[device_index] = torch.zeros(1, device=dev, dtype=torch.float32) + _nvfp4_counter_bufs[device_index] = torch.zeros(1, device=dev, dtype=torch.int32) + return _nvfp4_meta_scale_bufs[device_index], _nvfp4_amax_bufs[device_index], _nvfp4_counter_bufs[device_index] + +#################################################################################################################### +# Fused persistent NVFP4 v6: Single-kernel amax + quantize +# Phase 1: all blocks compute tile amax, atomicMax to global, grid barrier +# Phase 2: all blocks quantize tiles using computed meta_scale +# Grid limited to num_SMs so all blocks run concurrently (spin-wait safe) +#################################################################################################################### +@triton.jit +def scale_activations_nvfp4_fused_kernel_v6( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, M_padded, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + meta_scales_ptr, # output: computed meta_scale + amax_ptr, # scratch: atomic max accumulator + counter_ptr, # scratch: grid sync counter + num_tiles_m, num_tiles_k, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + ptx_pack: tl.constexpr = False, +): + pid = tl.program_id(0) + num_pids = tl.num_programs(0) + total_tiles = num_tiles_m * num_tiles_k + + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + # Load thresholds once + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + + # ---- Phase 1: Compute amax across all tiles ---- + local_amax = tl.full((1,), value=0.0, dtype=tl.float32) + for tile_idx in range(pid, total_tiles, num_pids): + tile_m = tile_idx // num_tiles_k + tile_k = tile_idx % num_tiles_k + + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tile_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tile_max = tl.max(tl.abs(tensor)) + local_amax = tl.maximum(local_amax, tile_max) + + # Atomic max to global (release: ensures atomicMax is visible before counter increment) + tl.atomic_max(amax_ptr, tl.max(local_amax, axis=0), sem='relaxed') + + # Grid barrier: last block computes meta_scale and signals + # acq_rel: acquires all prior releases (sees all other blocks' atomicMax) + old_count = tl.atomic_add(counter_ptr, 1, sem='relaxed') + if old_count == num_pids - 1: + final_amax = tl.load(amax_ptr) + tl.store(meta_scales_ptr, tl.maximum(final_amax / 6.0, eps)) + # Reset scratch for next call + tl.store(amax_ptr, 0.0) + # Signal ready by setting counter to -num_pids (distinguishable from 0..num_pids-1) + tl.store(counter_ptr, -1) + + # Spin-wait for ready signal (safe: grid <= num_SMs, all blocks run concurrently) + while tl.atomic_add(counter_ptr, 0, sem='relaxed') >= 0: + pass + + # ---- Phase 2: Quantize using computed meta_scale ---- + meta_scales = tl.load(meta_scales_ptr) + + for tile_idx in range(pid, total_tiles, num_pids): + tile_m = tile_idx // num_tiles_k + tile_k = tile_idx % num_tiles_k + + offs_m = tile_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tile_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + # Reload tile (L2 cached from Phase 1) + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + wq = tensor_flat / scales_full + + if ptx_pack: + wq_2d = tl.reshape(wq, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + wq_pairs = wq_2d.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False) + lo_val, hi_val = tl.split(wq_pairs) + lo_f16 = lo_val.to(tl.float16) + hi_f16 = hi_val.to(tl.float16) + lo_bits = lo_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF + hi_bits = (hi_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF) << 16 + packed_f16x2 = lo_bits | hi_bits + packed_e2m1 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 tmp_out; + .reg .f16x2 tmp_in; + mov.b32 tmp_in, $1; + cvt.rn.satfinite.e2m1x2.f16x2 tmp_out, tmp_in; + cvt.u32.u8 $0, tmp_out; + } + """, + constraints="=r,r", + args=[packed_f16x2], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + out = packed_e2m1.to(tl.uint8) + else: + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + # Store quantized output + offs_k_out = tile_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + # Store scales + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + scales_2d = tl.where(offs_m[:, None] < M, scales_2d, tl.full(scales_2d.shape, 1.0, dtype=tl.float32).to(fp8_dtype)) + base_group = tile_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] + ) + + # Last block resets counter for next call + if old_count == num_pids - 1: + tl.store(counter_ptr, 0) + +#################################################################################################################### +# NVFP4 v5: 2D grid with multi-group BLOCK_SIZE_K (fewer blocks, better bandwidth) +#################################################################################################################### +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 16}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 128}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 256}, num_warps=4, num_stages=1), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 256}, num_warps=8, num_stages=1), + ], + key=['M', 'K'], + prune_configs_by={'early_config_prune': prune_large_blocks_2d}, +) +@triton.jit +def scale_activations_nvfp4_triton_kernel_v5( + tensor_ptr, out_ptr, scales_ptr, thr_pos_ptr, + M, M_padded, K, + stride_m_t: tl.constexpr, stride_k_t: tl.constexpr, + stride_m_s: tl.constexpr, stride_k_s: tl.constexpr, + stride_m_o: tl.constexpr, stride_k_o: tl.constexpr, + eps: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + meta_scales_ptr, + # Requires CUDA 13.0+ ptxas (Triton bundles 12.9 as of v3.3). To enable, set + # the environment variable TRITON_CUDA_ARCH_LIST to include CUDA 13.0+ ptxas, + # and override the bundled ptxas-blackwell. + ptx_pack: tl.constexpr = GEMLITE_ENABLE_PTX_FP4_PACK, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + fp8_dtype: tl.constexpr = tl.float8e4nv + max_fp8: tl.constexpr = 448. + HALF_BLOCK_K: tl.constexpr = BLOCK_SIZE_K // 2 + GROUPS_PER_BLOCK: tl.constexpr = BLOCK_SIZE_K // GROUP_SIZE + FLAT_M: tl.constexpr = BLOCK_SIZE_M * GROUPS_PER_BLOCK + out_dtype: tl.constexpr = out_ptr.dtype.element_ty + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + mask = ((offs_m[:, None] < M) & (offs_k[None, :] < K)).to(tl.int1) + tensor_ptrs = tensor_ptr + (offs_m[:, None] * stride_m_t + offs_k[None, :] * stride_k_t) + tensor = tl.load(tensor_ptrs, mask=mask, other=0.0).to(tl.float32) + meta_scales = tl.load(meta_scales_ptr, eviction_policy='evict_last') + + tensor_flat = tl.reshape(tensor, (FLAT_M, GROUP_SIZE)) + abs_max = tl.max(tl.abs(tensor_flat), axis=1, keep_dims=True) + scales_raw = abs_max / (6. * meta_scales) + scales_fp8 = tl.minimum(scales_raw, max_fp8).to(fp8_dtype) + scales_full = tl.maximum(scales_fp8.to(tl.float32) * meta_scales, eps) + + wq = tensor_flat / scales_full + + if ptx_pack: + # PTX path: hardware e2m1x2 quantization + nibble packing + wq_2d = tl.reshape(wq, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + wq_pairs = wq_2d.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False) + lo_val, hi_val = tl.split(wq_pairs) + lo_f16 = lo_val.to(tl.float16) + hi_f16 = hi_val.to(tl.float16) + lo_bits = lo_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF + hi_bits = (hi_f16.to(tl.int16, bitcast=True).to(tl.int32) & 0xFFFF) << 16 + packed_f16x2 = lo_bits | hi_bits + packed_e2m1 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 tmp_out; + .reg .f16x2 tmp_in; + mov.b32 tmp_in, $1; + cvt.rn.satfinite.e2m1x2.f16x2 tmp_out, tmp_in; + cvt.u32.u8 $0, tmp_out; + } + """, + constraints="=r,r", + args=[packed_f16x2], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + out = packed_e2m1.to(tl.uint8) + else: + # Threshold path: 8 comparisons + manual nibble packing + thr0 = tl.load(thr_pos_ptr + 0) + thr1 = tl.load(thr_pos_ptr + 1) + thr2 = tl.load(thr_pos_ptr + 2) + thr3 = tl.load(thr_pos_ptr + 3) + thr4 = tl.load(thr_pos_ptr + 4) + thr5 = tl.load(thr_pos_ptr + 5) + thr6 = tl.load(thr_pos_ptr + 6) + thr7 = tl.load(thr_pos_ptr + 7) + abs_wq = tl.abs(wq) + idx_abs = ((abs_wq > thr0).to(tl.int32) + (abs_wq > thr1).to(tl.int32) + + (abs_wq > thr2).to(tl.int32) + (abs_wq > thr3).to(tl.int32) + + (abs_wq > thr4).to(tl.int32) + (abs_wq > thr5).to(tl.int32) + + (abs_wq > thr6).to(tl.int32) + (abs_wq > thr7).to(tl.int32)) + out = tl.where(wq >= 0, idx_abs, idx_abs + 8).to(out_dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_K)) + lo, hi = tl.split(out.reshape((BLOCK_SIZE_M, HALF_BLOCK_K, 2), can_reorder=False)) + out = lo | (hi << 4) + + offs_k_out = pid_k * HALF_BLOCK_K + tl.arange(0, HALF_BLOCK_K) + out_mask = ((offs_m[:, None] < M) & (offs_k_out[None, :] < (K // 2))).to(tl.int1) + tl.store(out_ptr + (offs_m[:, None] * stride_m_o + offs_k_out[None, :] * stride_k_o), out, mask=out_mask) + + scales_2d = tl.reshape(scales_fp8, (BLOCK_SIZE_M, GROUPS_PER_BLOCK)) + # For padding rows (M <= row < M_padded), store identity scale (1.0 in fp8) + scales_2d = tl.where(offs_m[:, None] < M, scales_2d, tl.full(scales_2d.shape, 1.0, dtype=tl.float32).to(fp8_dtype)) + base_group = pid_k * GROUPS_PER_BLOCK + offs_g = base_group + tl.arange(0, GROUPS_PER_BLOCK) + g_mask = offs_g < tl.cdiv(K, GROUP_SIZE) + tl.store( + scales_ptr + offs_m[:, None] * stride_m_s + offs_g[None, :] * stride_k_s, + scales_2d, mask=(offs_m[:, None] < M_padded) & g_mask[None, :] + ) + + +def scale_activations_nvfp4_triton_v5(tensor: torch.Tensor, meta_scale=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + group_size: int = 16 + eps: float = 1e-6 + fp8_dtype = torch.float8_e4m3fn + + tensor = tensor.contiguous() + tensor = tensor.view(-1, tensor.shape[-1]) + M, K = tensor.shape + + pad_m = (group_size - M % group_size) % group_size + M_padded = M + pad_m + + out = torch.empty((M, K // 2), device=tensor.device, dtype=torch.uint8) + scales = torch.empty((M_padded, K // group_size), device=tensor.device, dtype=fp8_dtype) + device_index = tensor.device.index + + if meta_scale is None: + # Fused path: single kernel computes amax + quantizes + meta_scale, amax_buf, counter_buf = _get_nvfp4_bufs(device_index) + BLOCK_M = 16 + BLOCK_K = 256 + num_tiles_m = triton.cdiv(M_padded, BLOCK_M) + num_tiles_k = triton.cdiv(K, BLOCK_K) + total_tiles = num_tiles_m * num_tiles_k + num_SMs = torch.cuda.get_device_properties(device_index).multi_processor_count + num_blocks = min(total_tiles, num_SMs) + + scale_activations_nvfp4_fused_kernel_v6[(num_blocks,)]( + tensor, out, scales, thr_pos[device_index], + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, + amax_ptr=amax_buf, + counter_ptr=counter_buf, + num_tiles_m=num_tiles_m, + num_tiles_k=num_tiles_k, + BLOCK_SIZE_M=BLOCK_M, + BLOCK_SIZE_K=BLOCK_K, + ) + else: + # Static path: meta_scale already provided, use v5 kernel directly + grid = lambda meta: (triton.cdiv(M_padded, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) + scale_activations_nvfp4_triton_kernel_v5[grid]( + tensor, out, scales, thr_pos[device_index], + M, M_padded, K, + tensor.stride(0), tensor.stride(1), + scales.stride(0), scales.stride(1), + out.stride(0), out.stride(1), + eps=eps, + GROUP_SIZE=group_size, + meta_scales_ptr=meta_scale, + ) + + return out, scales, meta_scale + + + + #################################################################################################################### -scale_activations_per_token = scale_activations_per_token_triton -scale_activations_mxfp8 = scale_activations_mxfp8_triton_v2 -scale_activations_mxfp4 = scale_activations_mxfp4_triton_v2 -scale_activations_nvfp4 = scale_activations_nvfp4_triton_v2 +scale_activations_per_token = scale_activations_per_token_triton_v3 +scale_activations_mxfp8 = scale_activations_mxfp8_triton_v4 +scale_activations_mxfp4 = scale_activations_mxfp4_triton_v5 +scale_activations_nvfp4 = scale_activations_nvfp4_triton_v5 diff --git a/gemlite/triton_kernels/gemm_kernels.py b/gemlite/triton_kernels/gemm_kernels.py index 2afafa5..d0e4ab1 100755 --- a/gemlite/triton_kernels/gemm_kernels.py +++ b/gemlite/triton_kernels/gemm_kernels.py @@ -23,15 +23,16 @@ def kernel_config_pruner(configs, nargs, **kwargs): t = nargs['type_id'] a_sizeof = nargs['a_sizeof'] b_sizeof = nargs['b_sizeof'] - + #Check cache + load_scales_as_block = kwargs['load_scales_as_block'] if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE): signature = str(tuple([get_closest_m(m), n, k, g, e, t])) if(signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) @@ -39,11 +40,22 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_inc_consumer', None) config["NUM_STAGES"] = num_stages + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + + # Adjust 5D TMA compatibility for cached configs + if load_scales_as_block and n % 128 == 0 and (k // g) % 4 == 0: + config['BLOCK_SIZE_N'] = max(config['BLOCK_SIZE_N'], 128) + while (config['BLOCK_SIZE_K'] // g) % 4 != 0: + config['BLOCK_SIZE_K'] *= 2 + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps) return - + gpu_shared_memory = get_gpu_shared_memory() - load_scales_as_block = kwargs['load_scales_as_block'] used = set() for config in configs: group_size_m = config.kwargs['GROUP_SIZE_M'] @@ -62,55 +74,55 @@ def kernel_config_pruner(configs, nargs, **kwargs): elif m <= 128: block_size_m = min(max(block_size_m, 64), 128) #m: [64...128] elif m <= 256: block_size_m = min(max(block_size_m, 64), 256) #m: [128...256] elif m > 256: block_size_m = min(max(block_size_m, 64), 256) #m > 256 - - #Constraint: BLOCK_SIZE_K >= group_size, only for load_as_block = False + + block_size_k = next_power_of_2(block_size_k) + block_size_n = next_power_of_2(block_size_n) + + #Constraints if(load_scales_as_block): - num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads if(e > 1): block_size_k = max(block_size_k, 64) #m16n8k64 else: block_size_k = max(block_size_k, 32) #m16n8k32 + # 5D TMA scale compatibility: adjust block sizes for 5D TMA descriptor + if n % 128 == 0 and (k // g) % 4 == 0: + block_size_n = max(block_size_n, 128) + while (block_size_k // g) % 4 != 0: + block_size_k *= 2 else: - block_size_k = min(block_size_k, g) - - block_size_k = next_power_of_2(block_size_k) - block_size_n = next_power_of_2(block_size_n) + block_size_k = max(min(block_size_k, g), 32) #tl.dot minimum K #Hint: skip block_size_n > block_size_k for col-major non-packed data. - #Nvidia if not IS_HIP: - if e > 1 and not load_scales_as_block: - #Limit num stages when data is packed - num_stages = min(num_stages, 4) - if(e == 1 and num_stages == 1): - #skip num_stages=1 for non-packed weights + if e == 1 and num_stages == 1: continue - #Avoid OOM - while num_stages > 0 and not load_scales_as_block: #TODO: revisit MXFP case - shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) - if(e > 1): - shared_mem += block_size_k * block_size_n * a_sizeof - shared_mem *= num_stages - if int(shared_mem) <= gpu_shared_memory: + # Reduce num_stages until config fits in shared memory + while num_stages > 1: + estimated_smem = estimate_shared_memory_per_block( + block_size_m, block_size_n, block_size_k, + a_sizeof, b_sizeof, num_stages, e, g, + load_scales_as_block + ) + if estimated_smem <= gpu_shared_memory: break num_stages -= 1 - if(num_stages == 0): continue #config too large - - ########################################### - if(load_scales_as_block):#tmp MXFP fix - block_size_k = min(block_size_k, 256) - ########################################### - key = (block_size_m, block_size_n, block_size_k, group_size_m, A_load_order, num_stages, num_warps) + + even_m = (m % block_size_m == 0) + even_n = (n % block_size_n == 0) + even_k = (k % block_size_k == 0) new_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, + "EVEN_M": even_m, + "EVEN_N": even_n, + "EVEN_K": even_k, "A_load_order": A_load_order, "NUM_STAGES": num_stages, } @@ -129,7 +141,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): ######################################################################################################################################################################## #Nvidia def get_max_autotune_config_nvidia(): - stages = [1, 4, 5] if gpu_has_more_shared_memory() else [1, 2, 4] + stages = [1, 3, 4, 5] configs = [] for A in [0, 2]: for w in [4, 8]: @@ -148,29 +160,42 @@ def get_max_autotune_config_nvidia(): def get_fast_autotune_config_nvidia(): configs = [] #BLOCK_SIZE_M is automatically adapted in the config pruning. - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + #Small tiles (packed INT with small group_size, small N problems) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=5)) + #Medium N tiles configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=3)) + #Large N tiles configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=3)) + #Large M×N tiles (pruner adapts M for large batch sizes) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=3)) + # NVFP4-friendly configs: BSK=128 allows stages=3/4 within 99KB smem (NVFP4 g=16 cant fit BSK=256 stages=3) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=8, num_stages=3)) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=8, num_stages=4)) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":2}, num_warps=8, num_stages=3)) + configs.append(triton.Config({"BLOCK_SIZE_M":128, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":2}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + #Extra coverage + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) + configs.append(triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':128, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) + #Small M tiles (for M=32..64 where more tiles improve SM utilization) + configs.append(triton.Config({"BLOCK_SIZE_M":32, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":128, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({"BLOCK_SIZE_M":32, "BLOCK_SIZE_N":128, "BLOCK_SIZE_K":256, "GROUP_SIZE_M":8, "A_load_order":0}, num_warps=4, num_stages=3)) return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':4}, num_warps=4, num_stages=4),] + return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2), + ] ######################################################################################################################################################################## #AMD - Instinct MI300X @@ -219,7 +244,7 @@ def get_fast_autotune_config_amd(): return configs def get_default_config_amd(): - return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] + return [triton.Config({'BLOCK_SIZE_M':64, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] ######################################################################################################################################################################## if IS_HIP: @@ -249,7 +274,7 @@ def get_default_config_amd(): def gemm_INT_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -276,14 +301,25 @@ def gemm_INT_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_STAGES: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# - meta_evict_policy: tl.constexpr = '', - a_evict: tl.constexpr = '', - b_evict: tl.constexpr = '', + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# + meta_evict_policy: tl.constexpr = "evict_last", + a_evict: tl.constexpr = "", + b_evict: tl.constexpr = "evict_first", + meta_scale_norm_ptr = None, + ################################# + use_tma: tl.constexpr = True, + use_5d_scales: tl.constexpr = False, ): """ Based on https://github.com/fpgaminer/GPTQ-triton @@ -313,7 +349,7 @@ def gemm_INT_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - + #Offsets ############################################################################################################# if data_contiguous: @@ -347,12 +383,18 @@ def gemm_INT_kernel( for k in range(num_pid_k): if(A_load_order == 0): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) b = tl.load(b_ptrs, eviction_policy=b_evict) if(A_load_order == 1): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Meta-data loading policy if(W_group_mode > 0): @@ -372,13 +414,19 @@ def gemm_INT_kernel( zeros = None if(A_load_order == 2): #Mid load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): #Late load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Dot acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) @@ -386,6 +434,9 @@ def gemm_INT_kernel( #Advance a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K_P * stride_bk + + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) ############################################################################################################# #Channel-wise scaling @@ -395,24 +446,22 @@ def gemm_INT_kernel( if(channel_scale_mode == 2): #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * scales_a[:, None] if(channel_scale_mode == 3): #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) - - acc = acc.to(output_dtype) ############################################################################################################# + #Output + acc = acc.to(output_dtype) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) tl.store(c_ptrs, acc, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - @triton.autotune( configs = get_autotune_config(), key = KEYS, @@ -423,7 +472,7 @@ def gemm_INT_kernel( def gemm_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -434,9 +483,9 @@ def gemm_MX_kernel( a_sizeof: tl.constexpr, b_sizeof: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bk: tl.constexpr, stride_bn: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, stride_meta_n: tl.constexpr, stride_meta_g: tl.constexpr, ######### Dtypes ######### @@ -455,10 +504,17 @@ def gemm_MX_kernel( A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# - meta_evict_policy: tl.constexpr = '', - a_evict: tl.constexpr = '', - b_evict: tl.constexpr = '', - meta_scale_norm: tl.constexpr = (0.05 ** 2), + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# + meta_evict_policy: tl.constexpr = "evict_last", + a_evict: tl.constexpr = "", + b_evict: tl.constexpr = "", + meta_scale_norm_ptr = None, + ################################# + use_tma: tl.constexpr = True, + use_5d_scales: tl.constexpr = False, ): pid = tl.program_id(axis=0) @@ -488,6 +544,8 @@ def gemm_MX_kernel( BLOCK_SIZE_K_A: tl.constexpr = BLOCK_SIZE_K // elements_per_sample_a offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_ak = tl.arange(0, BLOCK_SIZE_K_A) + if not use_tma: + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K // elements_per_sample_a)).to(tl.int1) @@ -495,66 +553,162 @@ def gemm_MX_kernel( BLOCK_SIZE_K_B: tl.constexpr = BLOCK_SIZE_K // elements_per_sample offs_bk = tl.arange(0, BLOCK_SIZE_K_B) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn + if not use_tma: + if data_contiguous: + offs_bn_load = offs_bn + else: + offs_bn_load = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + else: + offs_bn_load = offs_bn + b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn_load[None, :] * stride_bn #Scales stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size BLOCK_SIZE_K_S: tl.constexpr = BLOCK_SIZE_K // group_size offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] - + if not use_5d_scales: + scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] + + if use_tma: + a_desc = tl.make_tensor_descriptor( + a_ptr, + [M, K // elements_per_sample_a], + [stride_am, stride_ak], + [BLOCK_SIZE_M, BLOCK_SIZE_K_A] + ) + + b_desc = tl.make_tensor_descriptor( + b_ptr, + [N, K // elements_per_sample], + [stride_bn, stride_bk], + [BLOCK_SIZE_N, BLOCK_SIZE_K_B] + ) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + [M, N], + [stride_cm, stride_cn], + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # 5D TMA Descriptors for Scales (preshuffled layout) + if use_5d_scales: + rep_n: tl.constexpr = BLOCK_SIZE_N // 128 + rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 + scales_b_shape1: tl.constexpr = N // 128 + scales_b_shape2: tl.constexpr = K // group_size // 4 + stride_b4: tl.constexpr = 1 + stride_b3: tl.constexpr = 256 + stride_b2: tl.constexpr = 512 + stride_b1: tl.constexpr = 512 * scales_b_shape2 + stride_b0: tl.constexpr = stride_b1 * scales_b_shape1 + scales_b_5d_desc = tl.make_tensor_descriptor( + scales_ptr, + [1, scales_b_shape1, scales_b_shape2, 2, 256], + [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], + [1, rep_n, rep_k, 2, 256] + ) + #B-scales if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g + + # _1s dtype must match actual scale dtype: uint8 for MXFP (E8M0), float8e4nv for NVFP4 (E4M3) + if group_size == 16: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + #scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + else: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + #scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + _meta_scale_norm = tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') if group_size == 16 else 1.0 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in tl.range(num_pid_k, num_stages=NUM_STAGES): - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + # Load A and B tiles + if use_tma: + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K_A]) + b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_SIZE_N, k * BLOCK_SIZE_K_B]).T + else: + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + b = tl.load(b_ptrs, eviction_policy=b_evict) + #################################################################################### k_m = k * BLOCK_SIZE_K_S - scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) - + if use_5d_scales: + # 5D TMA scale loads (preshuffled layout) + scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, k * rep_k, 0, 0]) + scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) + else: + if EVEN_K: + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + else: + _scale_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, mask=_scale_k_mask, other=0.0, eviction_policy=meta_evict_policy) + if(channel_scale_mode == 4): - scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + if EVEN_K: + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + else: + _scale_a_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, mask=_scale_a_k_mask, other=0.0, eviction_policy=meta_evict_policy) else: - scales_a = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_a = scales_a_1s + + #################################################################################### acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) - a_ptrs += BLOCK_SIZE_K_A * stride_ak - b_ptrs += BLOCK_SIZE_K_B * stride_bk + if not use_tma: + a_ptrs += BLOCK_SIZE_K_A * stride_ak + b_ptrs += BLOCK_SIZE_K_B * stride_bk + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K) < K)).to(tl.int1) #NVFP4 meta-scale if(group_size == 16): - acc *= meta_scale_norm + acc = acc.to(tl.float32) * _meta_scale_norm ############################################################################################################# - #Channel-wise scaling - if(channel_scale_mode == 2): #activation-only - dtype: tl.constexpr = c_ptr.dtype.element_ty - scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=dtype) - acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) - + #Channel-wise scaling + if channel_scale_mode == 2: # activation-only + scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1.0, eviction_policy=meta_evict_policy) + acc = acc * scales_a[:, None] + ############################################################################################################# #Output - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) - tl.store(c_ptrs, acc, mask=mask) + acc = acc.to(output_dtype) + if use_tma: + tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) + c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) + mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) + +PRINTED = False def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, - W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, - input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, - channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, + W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, + input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, + channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, + meta_scale: Tensor = None, ) -> Tensor: - M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] + + global PRINTED + from ..core import GEMLITE_USE_TMA + M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W M_CLOSEST = get_closest_m(M) - + #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) @@ -568,38 +722,252 @@ def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x if(is_mx_dtype(input_dtype)): gemm_kernel = gemm_MX_kernel load_scales_as_block = True + use_5d_scales = (scales.ndim == 5) else: gemm_kernel = gemm_INT_kernel load_scales_as_block = False + use_5d_scales = False gemm_kernel[grid]( - x, W_q, output, + x, W_q, output, scales, zeros, scales_x, M, N, K, M_CLOSEST, - ############################################# W_nbits, group_size, unpack_mask, elements_per_sample, type_id, x.dtype.itemsize, W_q.dtype.itemsize, - ############################################### x.stride(0), x.stride(1), W_q.stride(0), W_q.stride(1), output.stride(0), output.stride(1), stride_meta_a_m, stride_meta_a_g, - scales.stride(0), scales.stride(1), - ################################################ + 0 if use_5d_scales else scales.stride(0), 0 if use_5d_scales else scales.stride(1), load_scales_as_block = load_scales_as_block, input_dtype = DTYPE_TO_TRITON[input_dtype], output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], acc_dtype = DTYPE_TO_TRITON[acc_dtype], meta_dtype = DTYPE_TO_TRITON[meta_dtype], - ################################################ - channel_scale_mode = channel_scale_mode, - W_group_mode = W_group_mode, - zero_is_scalar = zeros.numel() == 1, - data_contiguous = data_contiguous, + channel_scale_mode = channel_scale_mode, + W_group_mode = W_group_mode, + zero_is_scalar = zeros.numel() == 1, + data_contiguous = data_contiguous, + use_tma = use_5d_scales, + use_5d_scales = use_5d_scales, + meta_scale_norm_ptr = meta_scale, ) - + return output + + +# @triton.autotune( +# configs = get_autotune_config(), +# key = KEYS, +# prune_configs_by = {'early_config_prune': kernel_config_pruner}, +# use_cuda_graph = AUTOTUNE.USE_CUDA_GRAPH, +# ) +# @triton.jit +# def gemm_INT_kernel_persistent_tma( +# a_ptr, b_ptr, c_ptr, +# scales_ptr, zeros_ptr, scales_a_ptr, +# M, N, K, M_CLOSEST, +# ######### Quant parms ######### +# W_nbits: tl.constexpr, +# group_size: tl.constexpr, +# unpack_mask: tl.constexpr, +# elements_per_sample: tl.constexpr, +# ################################# +# type_id: tl.constexpr, +# a_sizeof: tl.constexpr, +# b_sizeof: tl.constexpr, +# ######### Strides ######### +# stride_am: tl.constexpr, stride_ak: tl.constexpr, +# stride_bk: tl.constexpr, stride_bn: tl.constexpr, +# stride_cm: tl.constexpr, stride_cn: tl.constexpr, +# stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, +# stride_meta_g: tl.constexpr, stride_meta_n: tl.constexpr, +# ######### Dtypes ######### +# load_scales_as_block: tl.constexpr, #False +# input_dtype: tl.constexpr, +# output_dtype: tl.constexpr, +# acc_dtype: tl.constexpr, +# meta_dtype: tl.constexpr, +# ######### Meta-data mode ######### +# channel_scale_mode: tl.constexpr, +# W_group_mode: tl.constexpr, +# zero_is_scalar: tl.constexpr, +# ######### tuning params ######### +# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +# GROUP_SIZE_M: tl.constexpr, NUM_STAGES: tl.constexpr, +# ################################# +# EVEN_M: tl.constexpr = False, +# EVEN_K: tl.constexpr = False, +# EVEN_N: tl.constexpr = False, +# ################################# +# A_load_order: tl.constexpr = 0, +# data_contiguous: tl.constexpr = True, +# ################################# +# meta_evict_policy: tl.constexpr = '', +# a_evict: tl.constexpr = '', +# b_evict: tl.constexpr = '', +# NUM_SMS: tl.constexpr = 8, +# ): +# """ +# Persistent + TMA version. +# A: (M, K) fp16/bf16 +# B_packed: (K//elements_per_sample, N) int32 +# scales/zeros: (num_groups, N) or other depending on W_group_mode +# """ + +# # --------------------------- +# # Persistent tiling setup +# # --------------------------- +# start_pid = tl.program_id(0).to(tl.int32) + +# grid_m = tl.cdiv(M, BLOCK_SIZE_M) +# grid_n = tl.cdiv(N, BLOCK_SIZE_N) +# num_tiles = grid_m * grid_n +# width = GROUP_SIZE_M * grid_n # tiles per "group stripe" + +# a_desc = tl.make_tensor_descriptor( +# a_ptr, +# [M, K], +# [stride_am, stride_ak], +# [BLOCK_SIZE_M, BLOCK_SIZE_K] +# ) + +# # b_desc = tl.make_tensor_descriptor( +# # b_ptr, +# # [K, N], +# # [stride_bk, stride_bn], +# # [BLOCK_SIZE_K, BLOCK_SIZE_N] +# # ) + +# #transposed : use self.W_q = self.W_q.contiguous().t() +# b_desc = tl.make_tensor_descriptor( +# b_ptr, +# [N, K], +# [stride_bn, stride_bk], +# [BLOCK_SIZE_N, BLOCK_SIZE_K] +# ) + +# # # Precompute unpack shifts (vector length = elements_per_sample) +# # # shifts = [0, W_nbits, 2*W_nbits, ...] +# # shifts = (tl.arange(0, elements_per_sample) * W_nbits).to(tl.int32) + +# # # Optional scalar zero +# # if zero_is_scalar: +# # zero_scalar = tl.load(zeros_ptr, eviction_policy="evict_last") + +# ############################################################################################################# +# # Main loop +# for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): +# group_id = tile_id // width +# first_m = group_id * GROUP_SIZE_M +# gs = tl.minimum(grid_m - first_m, GROUP_SIZE_M) + +# pid_m = first_m + (tile_id % gs) +# pid_n = (tile_id % width) // gs + +# rm = pid_m * BLOCK_SIZE_M +# rn = pid_n * BLOCK_SIZE_N + +# # Accumulator +# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + +# # Column indices for this tile (used for metadata + store) +# offs_n = rn + tl.arange(0, BLOCK_SIZE_N) +# n_mask = offs_n < N + +# # K loop +# for k in tl.range(0, K, BLOCK_SIZE_K): +# a = tl.load_tensor_descriptor(a_desc, [rm, k]) + +# k_packed = k // elements_per_sample +# #b = tl.load_tensor_descriptor(b_desc, [k_packed, rn]) +# b = tl.load_tensor_descriptor(b_desc, [rn, k_packed]).T #Transposed + +# acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) + +# ############################################################################################################# +# # Channel-wise scaling +# offs_m = rm + tl.arange(0, BLOCK_SIZE_M) +# m_mask = offs_m < M +# if channel_scale_mode == 1: # weight-only +# # expects a 1D per-N scale at scales_ptr (same as your original) +# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * scales_b[None, :] + +# if channel_scale_mode == 2: # activation-only +# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * scales_a[:, None] + +# if channel_scale_mode == 3: # weight + activation +# scales_a = tl.load(scales_a_ptr + offs_m, mask=m_mask, other=1.0, eviction_policy=meta_evict_policy) +# scales_b = tl.load(scales_ptr + offs_n, mask=n_mask, other=1.0, eviction_policy=meta_evict_policy) +# acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + +# acc = acc.to(output_dtype) + +# ############################################################################################################# +# # Store +# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn +# mask = (m_mask[:, None] & n_mask[None, :]).to(tl.int1) +# if EVEN_M and EVEN_N: +# tl.store(c_ptrs, acc) +# else: +# tl.store(c_ptrs, acc, mask=mask) + +# # Persistent version +# NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count +# def gemm_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_5d: Tensor, scales_x: Tensor, +# W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, +# input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, +# channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, +# ) -> Tensor: + +# M = x.shape[0] +# K = W_q.shape[0] * elements_per_sample +# N = W_q.shape[1] +# M_CLOSEST = get_closest_m(M) +# load_scales_as_block = False + +# output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) + +# if scales_x is not None: +# stride_meta_a_m, stride_meta_a_g = scales_x.stride(0), scales_x.stride(1) +# else: +# stride_meta_a_m, stride_meta_a_g = 0, 0 + +# grid = (NUM_SMS,) + +# gemm_INT_kernel_persistent_tma[grid]( +# x, W_q, output, +# scales, zeros, scales_x, +# M, N, K, M_CLOSEST, +# ############################################# +# W_nbits, group_size, unpack_mask, elements_per_sample, +# type_id, x.dtype.itemsize, W_q.dtype.itemsize, +# ############################################### +# x.stride(0), x.stride(1), +# W_q.stride(0), W_q.stride(1), +# output.stride(0), output.stride(1), +# stride_meta_a_m, stride_meta_a_g, +# scales.stride(0), scales.stride(1), +# ################################################ +# load_scales_as_block = load_scales_as_block, +# input_dtype = DTYPE_TO_TRITON[input_dtype], +# output_dtype = TORCH_DTYPE_TO_TRITON[output.dtype], +# acc_dtype = DTYPE_TO_TRITON[acc_dtype], +# meta_dtype = DTYPE_TO_TRITON[meta_dtype], +# ################################################ +# channel_scale_mode = channel_scale_mode, +# W_group_mode = W_group_mode, +# zero_is_scalar = zeros.numel() == 1, +# data_contiguous = data_contiguous, +# NUM_SMS = NUM_SMS, +# ) + + +# return output + class gemm: kernel = [gemm_INT_kernel, gemm_MX_kernel] diff --git a/gemlite/triton_kernels/gemm_splitK_kernels.py b/gemlite/triton_kernels/gemm_splitK_kernels.py index 1576b92..125fa57 100755 --- a/gemlite/triton_kernels/gemm_splitK_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_kernels.py @@ -25,13 +25,14 @@ def kernel_config_pruner(configs, nargs, **kwargs): b_sizeof = nargs['b_sizeof'] #Check cache + load_scales_as_block = kwargs['load_scales_as_block'] if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE): signature = str(tuple([get_closest_m(m), n, k, g, e, t])) if(signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) @@ -39,16 +40,26 @@ def kernel_config_pruner(configs, nargs, **kwargs): config.pop('reg_inc_consumer', None) config["NUM_STAGES"] = num_stages + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % (config['BLOCK_SIZE_K'] * config.get('SPLIT_K', 1)) == 0) + + # Adjust 5D TMA compatibility for cached configs + if load_scales_as_block and n % 128 == 0 and (k // g) % 4 == 0: + config['BLOCK_SIZE_N'] = max(config['BLOCK_SIZE_N'], 128) + while (config['BLOCK_SIZE_K'] // g) % 4 != 0: + config['BLOCK_SIZE_K'] *= 2 + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % (config['BLOCK_SIZE_K'] * config.get('SPLIT_K', 1)) == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero("c_ptr") if (config['SPLIT_K'] > 1) else None, ) - return - gpu_shared_memory = get_gpu_shared_memory() - load_scales_as_block = kwargs['load_scales_as_block'] + gpu_shared_memory = get_gpu_shared_memory() used = set() for config in configs: group_size_m = config.kwargs['GROUP_SIZE_M'] @@ -70,58 +81,54 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Only use higher split_k values for smaller m if(m >= 32): split_k = min(split_k, 8) - #Constraint: BLOCK_SIZE_K >= group_size, only for load_as_block = False + #Constraints if(load_scales_as_block): - num_stages = max(num_stages, 2) #for dot_scaled kernels with pipelined loads if(e > 1): block_size_k = max(block_size_k, 64) #m16n8k64 else: block_size_k = max(block_size_k, 32) #m16n8k32 + # 5D TMA scale compatibility: adjust block sizes for 5D TMA descriptor + if n % 128 == 0 and (k // g) % 4 == 0: + block_size_n = max(block_size_n, 128) + while (block_size_k // g) % 4 != 0: + block_size_k *= 2 else: - block_size_k = min(block_size_k, g) + block_size_k = max(min(block_size_k, g), 32) #tl.dot minimum K block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - - #Constraint: K needs to be divisible by BLOCK_SIZE_K * SPLIT_K - while split_k > 1 and not is_divisible(k, block_size_k * split_k): - #while split_k > 1 and k > block_size_k * split_k: - split_k //= 2 - - #Nvidia + split_k = max(split_k, 1) + if not IS_HIP: - if e > 1 and not load_scales_as_block: - #Limit num stages when data is packed - num_stages = min(num_stages, 4) - if(e == 1 and num_stages == 1): - #skip num_stages=1 for non-packed weights + if e == 1 and num_stages == 1: continue - #Avoid OOM - while num_stages > 0: #TODO: revisit MXFP case - shared_mem = (block_size_m * block_size_k * a_sizeof + block_size_k * block_size_n * b_sizeof) - if(e > 1 and not load_scales_as_block): - shared_mem += block_size_k * block_size_n * a_sizeof - shared_mem *= num_stages - if int(shared_mem) <= gpu_shared_memory: + # Reduce num_stages until config fits in shared memory + while num_stages > 1: + estimated_smem = estimate_shared_memory_per_block( + block_size_m, block_size_n, block_size_k, + a_sizeof, b_sizeof, num_stages, e, g, + load_scales_as_block + ) + if estimated_smem <= gpu_shared_memory: break num_stages -= 1 - if(num_stages == 0): continue #config too large - - ########################################### - if(load_scales_as_block):#tmp MXFP fix - block_size_k = min(block_size_k, 256) - ########################################### - key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k, A_load_order, num_stages, num_warps) + even_m = (m % block_size_m == 0) + even_n = (n % block_size_n == 0) + even_k = (k % (block_size_k * split_k) == 0) + new_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, "SPLIT_K": split_k, + "EVEN_M": even_m, + "EVEN_N": even_n, + "EVEN_K": even_k, "A_load_order": A_load_order, "NUM_STAGES": num_stages, } @@ -146,7 +153,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): #These autotunes are optimized for batch-size 1 to 64 (!) def get_max_autotune_config_nvidia(): - stages = [1, 2, 4, 5] if gpu_has_more_shared_memory() else [1, 2, 4] + stages = [1, 2, 3, 4, 5] configs = [] for A in [0, 2]: for w in [4, 8]: @@ -167,35 +174,39 @@ def get_max_autotune_config_nvidia(): #Faster autotuner def get_fast_autotune_config_nvidia(): configs = [] - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':512, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':8, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':512, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'SPLIT_K':8, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':512, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) - - configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':32, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + #Small N tiles + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=5)) + #Medium N tiles (N=128 — workhorse for MX/INT types) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + #Large N tiles + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=2)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':256, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=8, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':32, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + #High split_k with wide N + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':32, 'SPLIT_K':8, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + #Extra coverage + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=5)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=4)) + #Additional M=16 configs for MX kernel coverage + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=3)) + configs.append(triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) + #M=32 tiles (for M=32..64 batch sizes) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=5)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':256, 'SPLIT_K':4, 'GROUP_SIZE_M':8, 'A_load_order':0}, num_warps=4, num_stages=4)) + configs.append(triton.Config({'BLOCK_SIZE_M':32, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':128, 'SPLIT_K':2, 'GROUP_SIZE_M':8, 'A_load_order':2}, num_warps=8, num_stages=3)) return configs def get_default_config_nvidia(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] + return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2), triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':128, 'BLOCK_SIZE_K':128, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2),] ######################################################################################################################################################################## #AMD - Instinct MI300X @@ -247,7 +258,7 @@ def get_fast_autotune_config_amd(): return configs def get_default_config_amd(): - return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] + return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':64, 'SPLIT_K':1, 'GROUP_SIZE_M':8, 'A_load_order':0, 'NUM_STAGES':2}, num_warps=4, num_stages=2)] ######################################################################################################################################################################## if IS_HIP: @@ -278,7 +289,7 @@ def get_default_config_amd(): def gemm_splitK_INT_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -305,15 +316,28 @@ def gemm_splitK_INT_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + ################################# + NUM_STAGES: tl.constexpr, A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = '', atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', + meta_scale_norm_ptr = None, + ################################# dmmy + use_tma: tl.constexpr = True, + use_5d_scales: tl.constexpr = False, ): """ Based on https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py @@ -356,7 +380,8 @@ def gemm_splitK_INT_kernel( offs_ak = offs_k offs_bk = offs_k - b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) + b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) + b_mask = (offs_bk[:, None] < K).to(tl.int1) q_shift = ((offs_bk % elements_per_sample) * W_nbits).to(tl.int32)[:, None] #Inputs @@ -379,13 +404,22 @@ def gemm_splitK_INT_kernel( for k in range(num_pid_k): - if(A_load_order == 0): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if(A_load_order == 0): #Early load + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + if EVEN_K: + b = tl.load(b_ptrs, eviction_policy=b_evict) + else: + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Meta-data loading policy if(W_group_mode > 0): @@ -405,13 +439,19 @@ def gemm_splitK_INT_kernel( zeros = None if(A_load_order == 2): #Mid load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): #Late load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #Dot acc = tl.dot(a, b.to(input_dtype), acc=acc, out_dtype=acc_dtype) @@ -419,35 +459,45 @@ def gemm_splitK_INT_kernel( #Advance a_ptrs += BLOCK_SIZE_K_U * stride_ak b_ptrs += BLOCK_SIZE_K_P * stride_bk + + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K * SPLIT_K) < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] + (k + 1) * BLOCK_SIZE_K_U) < K).to(tl.int1) ############################################################################################################# #Channel-wise scaling - if(channel_scale_mode == 1): #weight-only + if channel_scale_mode == 1: #weight-only scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) - acc = acc.to(meta_dtype) * scales_b[None, :] + acc = acc.to(meta_dtype) * scales_b[None, :] - if(channel_scale_mode == 2): #activation-only + if channel_scale_mode == 2: #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * scales_a[:, None] - if(channel_scale_mode == 3): #weight + activation + if channel_scale_mode == 3: #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) ############################################################################################################# #Output + acc = acc.to(output_dtype) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) if(SPLIT_K > 1): - tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) + if EVEN_M and EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) else: - tl.store(c_ptrs, acc, mask=mask) + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) @triton.autotune( configs=get_autotune_config(), @@ -459,7 +509,7 @@ def gemm_splitK_INT_kernel( def gemm_splitK_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, M_CLOSEST, + M, N: tl.constexpr, K: tl.constexpr, M_CLOSEST, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -473,8 +523,10 @@ def gemm_splitK_MX_kernel( stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - stride_meta_a_m: tl.constexpr, stride_meta_a_g: tl.constexpr, - stride_meta_n: tl.constexpr, stride_meta_g: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, + stride_meta_n: tl.constexpr, + stride_meta_g: tl.constexpr, ######### Dtypes ######### load_scales_as_block, #True | IF FALSE, RESTRICT BLOCK_SIZE_K <= 32 input_dtype: tl.constexpr, @@ -486,17 +538,28 @@ def gemm_splitK_MX_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + NUM_STAGES: tl.constexpr, + ################################# A_load_order: tl.constexpr, data_contiguous: tl.constexpr, ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = 'evict_first', atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', - meta_scale_norm: tl.constexpr = (0.05 ** 2), + meta_scale_norm_ptr = None, ################################# + use_tma: tl.constexpr = True, + use_5d_scales: tl.constexpr = False, ): pid = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -536,69 +599,156 @@ def gemm_splitK_MX_kernel( offs_bk = pid_k * BLOCK_SIZE_K_B_E + tl.arange(0, BLOCK_SIZE_K_B_E) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn + b_mask = (offs_bk[:, None] < (K // elements_per_sample)) #Scales stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size BLOCK_SIZE_K_S: tl.constexpr = BLOCK_SIZE_K // group_size offs_k_scales = tl.arange(0, BLOCK_SIZE_K_S) offs_n_b_scales = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - #B scales - scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #[BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] + #B scales: [BLOCK_SIZE_N, BLOCK_SIZE_K // group_size] + if not use_5d_scales: + scales_b_ptrs = scales_ptr + offs_n_b_scales[:, None] * stride_meta_n + offs_k_scales[None, :] * stride_meta_g #A scales if(channel_scale_mode == 4): scales_a_ptrs = scales_a_ptr + offs_am[:, None] * stride_meta_a_m + offs_k_scales[None, :] * stride_meta_a_g + if use_tma: + a_desc = tl.make_tensor_descriptor( + a_ptr, + [M, K // elements_per_sample_a], + [stride_am, stride_ak], + [BLOCK_SIZE_M, BLOCK_SIZE_K_A_E] + ) + + b_desc = tl.make_tensor_descriptor( + b_ptr, + [N, K // elements_per_sample], + [stride_bn, stride_bk], + [BLOCK_SIZE_N, BLOCK_SIZE_K_B_E] + ) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + [M, N], + [stride_cm, stride_cn], + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # 5D TMA Descriptors for Scales (preshuffled layout) + if use_5d_scales: + rep_n: tl.constexpr = BLOCK_SIZE_N // 128 + rep_k: tl.constexpr = BLOCK_SIZE_K // group_size // 4 + stride_b4: tl.constexpr = 1 + stride_b3: tl.constexpr = 256 + stride_b2: tl.constexpr = 512 + stride_b1: tl.constexpr = 512 * (K // group_size // 4) + stride_b0: tl.constexpr = stride_b1 * (N // 128) + scales_b_5d_desc = tl.make_tensor_descriptor( + scales_ptr, + [1, N // 128, K // group_size // 4, 2, 256], + [stride_b0, stride_b1, stride_b2, stride_b3, stride_b4], + [1, rep_n, rep_k, 2, 256] + ) + + + if group_size == 16: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=1, dtype=tl.float32).to(tl.float8e4nv) + else: + scales_a_1s = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_b_1s = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + + _meta_scale_norm = tl.load(meta_scale_norm_ptr, eviction_policy='evict_last') if group_size == 16 else 1.0 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for k in tl.range(num_pid_k, num_stages=NUM_STAGES): - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + for k in tl.range(num_pid_k): + if use_tma: + a = tl.load_tensor_descriptor(a_desc, [pid_m * BLOCK_SIZE_M, (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_A_E]) + b = tl.load_tensor_descriptor(b_desc, [pid_n * BLOCK_SIZE_N, (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_B_E]).T + else: + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + + if EVEN_K: + b = tl.load(b_ptrs, eviction_policy=b_evict) + else: + b = tl.load(b_ptrs, mask=b_mask, other=0.0, eviction_policy=b_evict) #k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) k_m = (k * SPLIT_K + pid_k) * BLOCK_SIZE_K_S #OK for BLOCK_SIZE_K >=group_size - scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + if use_5d_scales: + scale_b_raw = tl.load_tensor_descriptor(scales_b_5d_desc, [0, pid_n * rep_n, (k * SPLIT_K + pid_k) * rep_k, 0, 0]) + scales_b = scale_b_raw.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K_S) + else: + if EVEN_K: + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + else: + _scale_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_b = tl.load(scales_b_ptrs + k_m * stride_meta_g, mask=_scale_k_mask, other=0.0, eviction_policy=meta_evict_policy) if(channel_scale_mode == 4): - scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + if EVEN_K: + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, eviction_policy=meta_evict_policy) + else: + _scale_a_k_mask = ((offs_k_scales[None, :] + k_m) < (K // group_size)) + scales_a = tl.load(scales_a_ptrs + k_m * stride_meta_a_g, mask=_scale_a_k_mask, other=0.0, eviction_policy=meta_evict_policy) else: - scales_a = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K_S), value=127, dtype=tl.uint8) + scales_a = scales_a_1s acc = tl.dot_scaled(a, scales_a, a_dtype, b, scales_b, b_dtype, acc) a_ptrs += BLOCK_SIZE_K_A * stride_ak b_ptrs += BLOCK_SIZE_K_B * stride_bk + + if not use_tma: + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K_A) < (K // elements_per_sample_a))).to(tl.int1) + b_mask = ((offs_bk[:, None] + (k + 1) * BLOCK_SIZE_K_B) < (K // elements_per_sample)) #NVFP4 meta-scale if(group_size == 16): - acc *= meta_scale_norm + acc = acc.to(tl.float32) * _meta_scale_norm ############################################################################################################# - #Channel-wise scaling - if(channel_scale_mode == 2): #activation-only - dtype: tl.constexpr = c_ptr.dtype.element_ty - scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=dtype) - acc = acc.to(dtype) * (scales_a[:, None] * scales_b[None, :]) - + #Channel-wise scaling + if channel_scale_mode == 2: # activation-only + scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1.0, eviction_policy=meta_evict_policy) + acc = acc * scales_a[:, None] + ############################################################################################################# #Output + acc = acc.to(output_dtype) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) mask = ((offs_cm[:, None] < M) & (offs_cn[None, :] < N)).to(tl.int1) if(SPLIT_K > 1): - tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) + if EVEN_M and EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) else: - tl.store(c_ptrs, acc, mask=mask) + if use_tma: + tl.store_tensor_descriptor(c_desc, [pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], value=acc) + else: + if EVEN_M and EVEN_N: + tl.store(c_ptrs, acc) + else: + tl.store(c_ptrs, acc, mask=mask) def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id:int, + meta_scale: Tensor = None, ) -> Tensor: - M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] + from ..core import GEMLITE_USE_TMA + M, K, N = x.shape[0], W_q.shape[0] * elements_per_sample, W_q.shape[1] # W #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" M_CLOSEST = get_closest_m(M) @@ -616,12 +766,14 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s if(is_mx_dtype(input_dtype)): gemm_splitK_kernel = gemm_splitK_MX_kernel load_scales_as_block = True + use_5d_scales = (scales.ndim == 5) else: gemm_splitK_kernel = gemm_splitK_INT_kernel load_scales_as_block = False + use_5d_scales = False gemm_splitK_kernel[grid]( - x, W_q, output, + x, W_q, output, scales, zeros, scales_x, M, N, K, M_CLOSEST, ############################################# @@ -632,7 +784,7 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_q.stride(0), W_q.stride(1), output.stride(0), output.stride(1), stride_meta_a_m, stride_meta_a_g, - scales.stride(0), scales.stride(1), + 0 if use_5d_scales else scales.stride(0), 0 if use_5d_scales else scales.stride(1), ################################################ load_scales_as_block = load_scales_as_block, input_dtype = DTYPE_TO_TRITON[input_dtype], @@ -640,10 +792,13 @@ def gemm_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s acc_dtype = DTYPE_TO_TRITON[acc_dtype], meta_dtype = DTYPE_TO_TRITON[meta_dtype], ################################################ - channel_scale_mode = channel_scale_mode, - W_group_mode = W_group_mode, - zero_is_scalar = zeros.numel() == 1, - data_contiguous = data_contiguous, + channel_scale_mode = channel_scale_mode, + W_group_mode = W_group_mode, + zero_is_scalar = zeros.numel() == 1, + data_contiguous = data_contiguous, + use_tma = use_5d_scales, + use_5d_scales = use_5d_scales, + meta_scale_norm_ptr = meta_scale, ) if(not native_atomic): diff --git a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py index 2e9fb54..9808b1c 100755 --- a/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py +++ b/gemlite/triton_kernels/gemm_splitK_persistent_kernels.py @@ -32,7 +32,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) diff --git a/gemlite/triton_kernels/gemv_kernels.py b/gemlite/triton_kernels/gemv_kernels.py index ccbf306..89cd7a9 100755 --- a/gemlite/triton_kernels/gemv_kernels.py +++ b/gemlite/triton_kernels/gemv_kernels.py @@ -43,13 +43,15 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) - configs['NUM_STAGES'] = num_stages + config['NUM_STAGES'] = num_stages + + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, pre_hook=pre_hook) return @@ -62,7 +64,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Constraints: BLOCK_SIZE_K <= group_size -> load_scales_as_block is always False for gemvs block_size_k = min(g, block_size_k) #Makes BLOCK_SIZE_K compatible with the group_size - block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) @@ -80,6 +81,8 @@ def kernel_config_pruner(configs, nargs, **kwargs): num_stages = config.num_stages num_warps = config.num_warps + even_n = (n % block_size_n == 0) + key = (block_size_m, block_size_n, block_size_k, A_load_order, dot_prod_mode, num_stages, num_warps) new_config = { @@ -89,6 +92,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): 'A_load_order': A_load_order, 'dot_prod_mode': dot_prod_mode, 'NUM_STAGES': num_stages, + 'EVEN_N': even_n, } if IS_HIP: @@ -141,6 +145,8 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':64, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=2)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':64, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=2, num_stages=1)) + + configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':1024,'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=1)) return configs @@ -232,7 +238,7 @@ def gemv_INT_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, mapping_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -241,11 +247,16 @@ def gemv_INT_kernel( type_id: tl.constexpr, use_prehook: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_a_m, stride_meta_a_g, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -256,11 +267,14 @@ def gemv_INT_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - A_load_order: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + A_load_order: tl.constexpr, + NUM_STAGES: tl.constexpr, dot_prod_mode:tl.constexpr, data_contiguous: tl.constexpr, - dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post looop scaling + dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post loop scaling ##################################### meta_evict_policy: tl.constexpr = '', atomic_mode: tl.constexpr = 'relaxed', @@ -269,6 +283,7 @@ def gemv_INT_kernel( join_version: tl.constexpr = False, ################################# load_scales_as_block: tl.constexpr = False, + EVEN_N: tl.constexpr = False, ): """ GEMV for C = matmul(A, dequantize(B, scales, zeros)). This is optimized for M==1 @@ -307,15 +322,18 @@ def gemv_INT_kernel( #orig version b_ptrs = b_ptr + (offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn + a_mask = (offs_am[:, None] < M) & (offs_ak[None, :] < K).to(tl.int1) + b_mask = (offs_bk[:, None] < K) & (offs_bn[None, :] < N).to(tl.int1) + #TODO: add EVEN_K / EVEN_N check ################################################################### #Load if(A_load_order == 0): - a = tl.load(a_ptrs, eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): - a = tl.load(a_ptrs, eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(W_group_mode > 0): k_m = (pid_k * (BLOCK_SIZE_K / group_size)).to(tl.int32) @@ -334,7 +352,7 @@ def gemv_INT_kernel( zeros = None if(A_load_order == 2): - a = tl.load(a_ptrs, eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) #tl.join() version if(join_version): @@ -351,13 +369,13 @@ def gemv_INT_kernel( b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): - a = tl.load(a_ptrs, eviction_policy=a_evict) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(dump_b_val > 0): b = b.to(tl.float32) * dump_b_val - + #Dot product if(dot_prod_mode == 0): - acc = tl.sum(a.reshape((BLOCK_SIZE_K, 1), can_reorder=False).to(acc_dtype) * b.to(acc_dtype), axis=0, keep_dims=True) + acc = tl.sum((a.reshape((BLOCK_SIZE_K, 1), can_reorder=False).to(acc_dtype)) * b.to(acc_dtype), axis=0, keep_dims=True) if(dot_prod_mode == 1): acc = tl.sum(a.reshape((BLOCK_SIZE_K, 1), can_reorder=False) * b.to(input_dtype), axis=0, keep_dims=True) @@ -365,27 +383,32 @@ def gemv_INT_kernel( ################################################################## #Channel-wise scaling - if(channel_scale_mode == 1): #weight-only + if channel_scale_mode == 1: #weight-only scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_b[None, :] - if(channel_scale_mode == 2): #activation-only + if channel_scale_mode == 2: #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) + acc = acc.to(meta_dtype) * scales_a[:, None] - if(channel_scale_mode == 3): #weight + activation + if channel_scale_mode == 3: #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) #################################################################### - #Output: tl.atomic_add only supports 1D fp16 arrays, bfp16 would crash + #Output offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + if EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) + + @triton.autotune( configs=get_autotune_config(), key = KEYS, @@ -399,7 +422,7 @@ def gemv_MX_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, mapping_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -408,11 +431,16 @@ def gemv_MX_kernel( type_id: tl.constexpr, use_prehook: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_a_m, stride_meta_a_g, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -423,8 +451,11 @@ def gemv_MX_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - A_load_order: tl.constexpr, NUM_STAGES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + A_load_order: tl.constexpr, + NUM_STAGES: tl.constexpr, dot_prod_mode:tl.constexpr, data_contiguous: tl.constexpr, dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post looop scaling @@ -436,6 +467,7 @@ def gemv_MX_kernel( join_version: tl.constexpr = False, ################################# load_scales_as_block: tl.constexpr = False, + EVEN_N: tl.constexpr = False, ): """ GEMV for C = matmul(A, dequantize(B, scales, zeros)). This is optimized for M==1 @@ -474,6 +506,7 @@ def gemv_MX_kernel( a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak #[1, BLOCK_SIZE_K] b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn #[BLOCK_SIZE_K, BLOCK_SIZE_N] a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < (K // elements_per_sample_a))).to(tl.int1) + b_mask = ((offs_bk[:, None] < (K // elements_per_sample)) & (offs_bn[None, :] < N)).to(tl.int1) if(W_nbits == 4): #mxpf4 mapping mapping = tl.load(mapping_ptr + tl.arange(0, 16), eviction_policy='evict_last')[None, :].broadcast_to((BLOCK_SIZE_K, 16)) @@ -483,7 +516,7 @@ def gemv_MX_kernel( if(A_load_order == 0): a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) @@ -547,13 +580,18 @@ def gemv_MX_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + if EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) #TODO: gemv not generating correct reuslts with mxfp dtypes use except for A16W4. def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x: Tensor, W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, + meta_scale: float = 0.0, ) -> Tensor: global KERNEL_CACHE @@ -591,8 +629,8 @@ def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x stride_meta_a_m, stride_meta_a_g = scales_x.stride(0), scales_x.stride(1) else: stride_meta_a_m, stride_meta_a_g = None, None - channel_scale_mode = 0 - + #channel_scale_mode = 0 + dtype = DTYPE_TO_TRITON[input_dtype] if(dtype in [tl.float16, tl.bfloat16, tl.float32]): acc_dtype = dtype @@ -629,7 +667,7 @@ def gemv_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, scales_x W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == DType.FP16.value and W_nbits == 8) else 0, #Warning: Only use with INT8 + dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == tl.float16 and W_nbits == 8) else 0, #Warning: Only use with INT8 ) if(not native_atomic): diff --git a/gemlite/triton_kernels/gemv_revsplitK_kernels.py b/gemlite/triton_kernels/gemv_revsplitK_kernels.py index bad0d28..628f35e 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_kernels.py +++ b/gemlite/triton_kernels/gemv_revsplitK_kernels.py @@ -32,19 +32,21 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, pre_hook=pre_hook) return used = set() for config in configs: - block_size_m = 1 #Only 1 allowed here + block_size_m = 1 #next_power_of_2(m) #Only 1 allowed here block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) split_k = 2 @@ -59,7 +61,6 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - #tmp fix autotune getting stuck on the MI300X if IS_HIP: if block_size_n * block_size_k >= 65536: @@ -68,14 +69,18 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Since we load the scales / zeros once per split_k pass, we need this while block_size_k >= 8 and (block_size_k * split_k > g): block_size_k //= 2 + block_size_k = max(block_size_k, 8) - if(not (block_size_k * split_k <= g)): - continue + # if(not (block_size_k * split_k <= g)): + # continue #Block size should be compatible with minimum-packing if(block_size_k < e): continue + + even_n = (n % block_size_n == 0) + key = (block_size_m, block_size_n, block_size_k, A_load_order, dot_prod_mode, num_stages, num_warps) new_config = { @@ -84,6 +89,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): 'BLOCK_SIZE_K': block_size_k, 'A_load_order': A_load_order, 'dot_prod_mode': dot_prod_mode, + 'EVEN_N': even_n, } if IS_HIP: @@ -92,7 +98,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): if key in used: continue - + used.add(key) yield triton.Config(new_config, num_stages=num_stages, num_warps=num_warps, pre_hook=pre_hook) @@ -120,6 +126,10 @@ def get_max_autotune_config_nvidia(): #~20 sec/shape def get_fast_autotune_config_nvidia(): configs = [] + #Default + configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=1, num_stages=1)) + + #Extra configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=1, num_stages=1)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=2, num_stages=2)) @@ -138,6 +148,8 @@ def get_fast_autotune_config_nvidia(): configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=2)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=1)) configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':64, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=2)) + + configs.append(triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':1024, 'BLOCK_SIZE_K':32, 'A_load_order':0, 'dot_prod_mode':0}, num_warps=4, num_stages=1)) return configs @@ -227,7 +239,7 @@ def get_default_config_amd(): def gemv_INT_revsplitK_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -236,10 +248,14 @@ def gemv_INT_revsplitK_kernel( type_id: tl.constexpr, use_prehook: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -250,7 +266,9 @@ def gemv_INT_revsplitK_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, A_load_order: tl.constexpr, dot_prod_mode: tl.constexpr, data_contiguous: tl.constexpr, @@ -260,6 +278,7 @@ def gemv_INT_revsplitK_kernel( atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', b_evict: tl.constexpr = 'evict_first', + EVEN_N: tl.constexpr = False, ): """ GEMV for C = matmul(A, dequantize(B, scales, zeros)). This is optimized for M==1 @@ -291,6 +310,8 @@ def gemv_INT_revsplitK_kernel( a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] + a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] < K) & (offs_bn[None, :] < N)).to(tl.int1) #Stage 0: Load scales/zeros #----------------------------------------------------------------------------------------------------------- @@ -316,12 +337,12 @@ def gemv_INT_revsplitK_kernel( #----------------------------------------------------------------------------------------------------------- #Load if(A_load_order == 0): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) - b = tl.load(b_ptrs, eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) @@ -338,16 +359,18 @@ def gemv_INT_revsplitK_kernel( #Advance and load next chunk a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + BLOCK_SIZE_K) < K)).to(tl.int1) + b_mask = (((offs_bk[:, None] + BLOCK_SIZE_K) < K) & (offs_bn[None, :] < N)).to(tl.int1) #Stage 2 #----------------------------------------------------------------------------------------------------------- if(A_load_order == 0): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) - b = tl.load(b_ptrs, eviction_policy=b_evict) + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): - a = tl.load(a_ptrs, eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict).reshape((BLOCK_SIZE_K, 1), can_reorder=False) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) @@ -364,16 +387,15 @@ def gemv_INT_revsplitK_kernel( if(dump_b_val > 0): acc /= dump_b_val ############################################################################################################ #Channel-wise scaling - if(channel_scale_mode == 1): #weight-only + if channel_scale_mode == 1: #weight-only scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * scales_b[None, :] - if(channel_scale_mode == 2): #activation-only + if channel_scale_mode == 2: #activation-only scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) - scales_b = tl.full((BLOCK_SIZE_N,), value=1, dtype=meta_dtype) - acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) - - if(channel_scale_mode == 3): #weight + activation + acc = acc.to(meta_dtype) * scales_a[:, None] + + if channel_scale_mode == 3: #weight + activation scales_a = tl.load(scales_a_ptr + offs_am, mask=offs_am < M, other=1, eviction_policy=meta_evict_policy) scales_b = tl.load(scales_ptr + offs_bn, mask=offs_bn < N, other=1, eviction_policy=meta_evict_policy) acc = acc.to(meta_dtype) * (scales_a[:, None] * scales_b[None, :]) @@ -384,7 +406,11 @@ def gemv_INT_revsplitK_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + if EVEN_N: + tl.atomic_add(c_ptrs, acc, sem=atomic_mode) + else: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode) KERNEL_CACHE = {} @@ -392,6 +418,7 @@ def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, + meta_scale: float = 0.0, ) -> Tensor: global KERNEL_CACHE @@ -453,7 +480,7 @@ def gemv_revsplitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype in [DType.FP16.value] and W_nbits == 8) else 0, #Warning: Only use with INT8 + dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == tl.float16 and W_nbits == 8) else 0, #Warning: Only use with INT8 ) if(not native_atomic): diff --git a/gemlite/triton_kernels/gemv_splitK_kernels.py b/gemlite/triton_kernels/gemv_splitK_kernels.py index 9e5ed40..bec539c 100755 --- a/gemlite/triton_kernels/gemv_splitK_kernels.py +++ b/gemlite/triton_kernels/gemv_splitK_kernels.py @@ -30,13 +30,17 @@ def kernel_config_pruner(configs, nargs, **kwargs): config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][signature]) num_stages = config.pop('num_stages') num_warps = config.pop('num_warps') - num_ctas = config.pop('num_ctas') + num_ctas = config.pop('num_ctas', 1) config.pop('num_buffers_warp_spec', None) config.pop('num_consumer_groups', None) config.pop('reg_dec_producer', None) config.pop('reg_inc_consumer', None) + config['EVEN_M'] = (m % config['BLOCK_SIZE_M'] == 0) + config['EVEN_N'] = (n % config['BLOCK_SIZE_N'] == 0) + config['EVEN_K'] = (k % config['BLOCK_SIZE_K'] == 0) + yield triton.Config(config, num_stages=num_stages, num_warps=num_warps, @@ -63,19 +67,24 @@ def kernel_config_pruner(configs, nargs, **kwargs): block_size_k = next_power_of_2(block_size_k) block_size_n = next_power_of_2(block_size_n) - #K needs to be divisible by BLOCK_SIZE_K * SPLIT_K: TODO: without this, cuda-graphs breaks. - while block_size_k > 16 and not is_divisible(k, block_size_k * split_k): - block_size_k //=2 + # #K needs to be divisible by BLOCK_SIZE_K * SPLIT_K: TODO: without this, cuda-graphs breaks. + # while block_size_k > 16 and not is_divisible(k, block_size_k * split_k): + # block_size_k //=2 + # block_size_k = min(block_size_k, 16) - #Skip blocks that are either too large or too small - block_area = (block_size_k // split_k) * block_size_n - if(block_area < 1024 or block_area > 4096 * 8): #128 * 8 * num_warps - continue + # #Skip blocks that are either too large or too small + # block_area = (block_size_k // split_k) * block_size_n + # if(block_area < 1024 or block_area > 4096 * 8): #128 * 8 * num_warps + # continue #Block size should be compatible with minimum-packing if(block_size_k < e): continue + even_m = (m % block_size_m == 0) + even_n = (n % block_size_n == 0) + even_k = (k % block_size_k == 0) + key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k, A_load_order, dot_prod_mode, num_stages, num_warps) new_config = { @@ -86,6 +95,9 @@ def kernel_config_pruner(configs, nargs, **kwargs): 'SPLIT_K' : split_k, 'A_load_order' : A_load_order, 'dot_prod_mode' : dot_prod_mode, + 'EVEN_M': even_m, + 'EVEN_N': even_n, + 'EVEN_K': even_k, } if IS_HIP: @@ -149,9 +161,7 @@ def get_fast_autotune_config_nvidia(): return configs def get_default_config_nvidia(): - config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':2, 'BLOCK_SIZE_K':2048, 'GROUP_SIZE_M':8, 'SPLIT_K': 1, - 'A_load_order':1, 'dot_prod_mode':0}, num_warps=4, num_stages=2) - + config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':2, 'BLOCK_SIZE_K':2048, 'GROUP_SIZE_M':8, 'SPLIT_K': 1, 'A_load_order':1, 'dot_prod_mode':0}, num_warps=4, num_stages=2) return [config] ######################################################################################################################################################################## @@ -241,7 +251,7 @@ def get_default_config_amd(): def gemv_INT_splitK_kernel( a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, scales_a_ptr, - M, N, K, + M, N: tl.constexpr, K: tl.constexpr, ######### Quant parms ######### W_nbits: tl.constexpr, group_size: tl.constexpr, @@ -249,10 +259,16 @@ def gemv_INT_splitK_kernel( elements_per_sample: tl.constexpr, type_id: tl.constexpr, ######### Strides ######### - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_meta_g, stride_meta_n, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_meta_a_m: tl.constexpr, + stride_meta_a_g: tl.constexpr, + stride_meta_g: tl.constexpr, + stride_meta_n: tl.constexpr, ######### Dtypes ######### input_dtype: tl.constexpr, output_dtype: tl.constexpr, @@ -263,13 +279,20 @@ def gemv_INT_splitK_kernel( W_group_mode: tl.constexpr, zero_is_scalar: tl.constexpr, ######### tuning params ######### - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, A_load_order: tl.constexpr, dot_prod_mode: tl.constexpr, data_contiguous: tl.constexpr, dump_b_val: tl.constexpr = 0, #Improve accuracy mainly for A16W8 with post looop scaling ################################# + EVEN_M: tl.constexpr = False, + EVEN_K: tl.constexpr = False, + EVEN_N: tl.constexpr = False, + ################################# meta_evict_policy: tl.constexpr = '', atomic_mode: tl.constexpr = 'relaxed', a_evict: tl.constexpr = 'evict_last', @@ -315,9 +338,10 @@ def gemv_INT_splitK_kernel( #Inputs a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) - a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K)).to(tl.int1) b_ptrs = b_ptr + ((offs_bk[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) - + a_mask = ((offs_am[:, None] < M) & (offs_ak[None, :] < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] < K) & (offs_bn[None, :] < N)).to(tl.int1) + #Meta data stuff q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] @@ -341,12 +365,21 @@ def gemv_INT_splitK_kernel( for k in range(num_pid_k): if(A_load_order == 0): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) - b = tl.load(b_ptrs, eviction_policy=b_evict) + if EVEN_K and EVEN_N: + b = tl.load(b_ptrs, eviction_policy=b_evict) + else: + b = tl.load(b_ptrs, mask=b_mask, other=0., eviction_policy=b_evict) if(A_load_order == 1): #Early load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(W_group_mode > 0): k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) @@ -365,13 +398,19 @@ def gemv_INT_splitK_kernel( zeros = None if(A_load_order == 2): #Mid load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) # Unpack and dequantize b = dequantize(b, scales, zeros, q_shift, meta_dtype, unpack_mask, elements_per_sample, W_group_mode, zero_is_scalar) if(A_load_order == 3): #Late load - a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) + if EVEN_M and EVEN_K: + a = tl.load(a_ptrs, eviction_policy=a_evict) + else: + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy=a_evict) if(dump_b_val > 0): b = b.to(tl.float32) * dump_b_val @@ -383,6 +422,11 @@ def gemv_INT_splitK_kernel( #Advance a_ptrs += BLOCK_SIZE_K_U * stride_ak b_ptrs += BLOCK_SIZE_K_P * stride_bk + + #Update mask + if not EVEN_K: + a_mask = ((offs_am[:, None] < M) & ((offs_ak[None, :] + (k + 1) * BLOCK_SIZE_K_U) < K)).to(tl.int1) + b_mask = ((offs_bk[:, None] + (k + 1) * BLOCK_SIZE_K_U < K) & (offs_bn[None, :] < N)).to(tl.int1) if(dot_prod_mode == 0): acc = tl.sum(acc, axis=0, keep_dims=True) @@ -424,6 +468,7 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, input_dtype: int, output_dtype: int, acc_dtype: int, meta_dtype:int, channel_scale_mode: int, W_group_mode: int, data_contiguous: bool, type_id: int, + meta_scale: float = 0.0, ) -> Tensor: M, K, N = x.shape[0], x.shape[1], W_q.shape[1] @@ -437,6 +482,13 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K']) + device_index = W_q.device.index + + if(scales_x is not None): + stride_meta_a_m, stride_meta_a_g = scales_x.stride(0), scales_x.stride(1) + else: + stride_meta_a_m, stride_meta_a_g = None, None + dtype = DTYPE_TO_TRITON[input_dtype] if(dtype in [tl.float16, tl.bfloat16, tl.float32]): acc_dtype = dtype @@ -456,6 +508,7 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s x.stride(0), x.stride(1), W_q.stride(0), W_q.stride(1), output.stride(0), output.stride(1), + stride_meta_a_m, stride_meta_a_g, scales.stride(0), scales.stride(1), ################################################ input_dtype = DTYPE_TO_TRITON[input_dtype], @@ -467,7 +520,7 @@ def gemv_splitK_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, s W_group_mode = W_group_mode, zero_is_scalar = zeros.numel() == 1, data_contiguous = data_contiguous, - dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == DType.FP16.value and W_nbits == 8) else 0, #Warning: Only use with INT8 + dump_b_val = 0.001 if(W_group_mode in [0, 1] and acc_dtype == tl.float16 and W_nbits == 8) else 0, #Warning: Only use with INT8 ) if(not native_atomic): @@ -481,4 +534,3 @@ class gemv_splitK: matmul_type = MATMUL_TYPE __all__ = ["gemv_splitK"] - diff --git a/gemlite/triton_kernels/utils.py b/gemlite/triton_kernels/utils.py index 16b51ab..8e2ef73 100755 --- a/gemlite/triton_kernels/utils.py +++ b/gemlite/triton_kernels/utils.py @@ -6,6 +6,12 @@ from triton.runtime import driver from ..dtypes import * +# TMA descriptors require a global memory allocation +from typing import Optional +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) +triton.set_allocator(alloc_fn) + @triton.jit def swizzle_tile_v1(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): grid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -53,7 +59,6 @@ def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexp pid_n = pid // tl.cdiv(M, BLOCK_SIZE_M) return pid_m, pid_n -################################################################################################################# @triton.jit def dequantize( b, @@ -108,9 +113,16 @@ def is_divisible(dividend, divisor): def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" -def gpu_has_more_shared_memory(ref_gpus = ["a100", "h100", "h200", "h20", "h800", "b100", "b200"]): +def gpu_has_more_shared_memory( + ref_gpus=( + "a100", + "h100", "h200", "h20", "h800", + "b100", "b200", "b300", + "6000", + ), +): gpu_name = torch.cuda.get_device_properties(0).name.lower() - return True in [g in gpu_name for g in ref_gpus] + return any(g in gpu_name for g in ref_gpus) def gpu_supports_float16_acc( ref_gpus=["5090", "5080", "5070", "5060", @@ -121,11 +133,34 @@ def gpu_supports_float16_acc( gpu_name = torch.cuda.get_device_properties(0).name.lower() return True in [g in gpu_name for g in ref_gpus] +def estimate_shared_memory_per_block(block_size_m, block_size_n, block_size_k, a_sizeof, b_sizeof, num_stages, e, g, load_scales_as_block): + a_smem = block_size_m * block_size_k * a_sizeof + if load_scales_as_block: + # MX kernels: dot_scaled handles scaling natively, no dequant buffer + # A tile: packed elements (e.g. NVFP4 packs 2 per byte, so K_A = K // e) + a_smem = block_size_m * (block_size_k // e) * a_sizeof + b_smem = (block_size_k // e) * block_size_n * b_sizeof + # scales_b: (BLOCK_N, BLOCK_K // group_size), scales_a: (BLOCK_M, BLOCK_K // group_size) + sb_smem = block_size_n * (block_size_k // g) * 1 + sa_smem = block_size_m * (block_size_k // g) * 1 + loop_smem = (a_smem + b_smem + sb_smem + sa_smem) * max(num_stages - 1, 1) + # Triton overlaps output buffer with loop data (reuses same SMEM) + output_smem = block_size_m * block_size_n * 2 # bf16 output via TMA store + estimated_smem = max(loop_smem, output_smem) + elif e > 1: + # INT packed: need packed B + dequantized B for MMA + b_smem = (block_size_k // e) * block_size_n * b_sizeof + b_smem += block_size_k * block_size_n * a_sizeof + estimated_smem = int((a_smem + b_smem) * num_stages * 1.20) + else: + # INT unpacked (8-bit): exact formula + b_smem = block_size_k * block_size_n * b_sizeof + estimated_smem = (a_smem + b_smem) * max(num_stages - 1, 1) + return estimated_smem def gpu_supports_bfloat16_atomicadd(): - #Triton tl.atomic_add doens't support bfloat16 even for Hopper and above. - #return torch.cuda.get_device_capability()[0] >= 9 #Hopper and above - return False + #Triton tl.atomic_add doens't support bfloat16 on older GPUs. + return torch.cuda.get_device_capability()[0] >= 9 #Hopper and above NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count def get_num_SMs(device): diff --git a/setup.py b/setup.py index 817eaf7..78c2d65 100755 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( name='gemlite', - version="0.5.1.post1", + version="0.6.0", url="https://github.com/mobiusml/gemlite/", author="Dr. Hicham Badri", author_email="hicham@mobiuslabs.com", diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py deleted file mode 100755 index bc0d942..0000000 --- a/tests/test_gemlitelineartriton.py +++ /dev/null @@ -1,383 +0,0 @@ -#python -m unittest test_gemlitelineartriton.py - -import unittest -import torch -from gemlite import reset_config, set_autotune -from gemlite.core import GemLiteLinearTriton, DType, TORCH_TO_DTYPE, forward_functional -from gemlite.triton_kernels.config import KERNEL -from gemlite.quant_utils import scale_activations_per_token_torch as scale_activations - -def is_fp8_supported(): - if not torch.cuda.is_available(): - return False - capability = torch.cuda.get_device_capability(0) - return capability >= (8, 9) - -device = 'cuda:0' -compute_dtype = torch.float16 #float16, bfloat16 -fp8_dtype = torch.float8_e4m3fn #float8_e4m3fn / torch.float8_e5m2 (Nvidia) -gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] -matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] -reset_config() -set_autotune(False) -KERNEL.ENABLE_CACHING = False - - -def gen_data(in_features, out_features, W_nbits, group_size, dtype=compute_dtype): - - W_q = torch.randint(0, 2**W_nbits - 1, (out_features, in_features), device=device).to(torch.uint8) - - shape = (out_features, in_features) - gs = W_q.numel() // group_size - scales = torch.ones((gs, 1), device=device, dtype=dtype) * 0.001 - zeros = torch.zeros((gs, 1), device=device, dtype=dtype) * ((2**W_nbits - 1)//2) - W = ((W_q.reshape([-1, group_size]) - zeros) * scales).to(fp8_dtype).to(dtype) - - zeros = torch.mean(W_q.reshape([-1, group_size]).float() - (W / scales).float(), axis=1, keepdim=True).to(dtype) - W = ((W_q.reshape([-1, group_size]).to(dtype) - zeros) * scales) - W = W.reshape(shape) - - return W, W_q, scales, zeros - - -in_features, out_features = 4096, 1024 -batch_sizes = [1, 4] -W_nbits, group_size = 4, 128 #128 / in_features -W, W_q, scales, zeros = gen_data(in_features, out_features, W_nbits=W_nbits, group_size=group_size) - -class TestGemLiteLinearTriton(unittest.TestCase): - - def test_serialization(self): - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=gemlite_dtype, - output_dtype=gemlite_dtype) - - - gemlite_linear.pack(W_q, scales, zeros, None) - - torch.save(gemlite_linear.state_dict(), 'tmp.pt') - - gemlite_linear_loaded = GemLiteLinearTriton() - gemlite_linear_loaded.load_state_dict(torch.load('tmp.pt')) - - ref_args = gemlite_linear.get_meta_args() - loaded_args = gemlite_linear_loaded.get_meta_args() - for i in range(len(ref_args)): - assert ref_args[i] == loaded_args[i], "meta_args mismatch at " + str(i) - - ref_args = gemlite_linear.get_tensor_args() - loaded_args = gemlite_linear_loaded.get_tensor_args() - for i in range(len(ref_args)): - assert (ref_args[i] - loaded_args[i]).float().abs().mean() == 0, "tensor_args mismatch at " + str(i) - - tol = 1e-7 - for batch_size in batch_sizes: - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - for matmul_type in ['GEMM']: - if(batch_size>1 and 'GEMV' in matmul_type): continue - - y_ref = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - y_gem = gemlite_linear_loaded.forward_manual(x, matmul_type=matmul_type) - - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - - def test_fp16xfp16(self): - gemlite_linear = GemLiteLinearTriton(W_nbits=16, - group_size=None, - in_features=in_features, - out_features=out_features, - input_dtype=gemlite_dtype, - output_dtype=gemlite_dtype, - scaled_activations=False) - - gemlite_linear.pack(W, None, None, None); - - #No weight unpacking / dequant - self.assertTrue(gemlite_linear.W_group_mode == 0 and gemlite_linear.channel_scale_mode == 0) - #Use non-contiguous when data is not packed - self.assertTrue(gemlite_linear.data_contiguous == False) - - tol = 1e-3 - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.) - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - - - def test_fp16xWn_asymmetric(self): - #FP16 x Wn / asymmetric - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=gemlite_dtype, - output_dtype=gemlite_dtype) - - - gemlite_linear.pack(W_q, scales, zeros, None); - - if(group_size == in_features): - #Weights are unpacked() then shift only if group_size == in_features (1) otherwise (3) - self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or - (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) - else: - self.assertTrue(gemlite_linear.W_group_mode in [3, 4] and gemlite_linear.channel_scale_mode == 0) - - #Use-contiguous when data is packed - self.assertTrue(gemlite_linear.data_contiguous == True) - - tol = 1e-3 - - for batch_size in batch_sizes: - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - - - def test_int8xWn_symmetric_no_activation_scaling(self): - #INT8 x Wn - symmetric / no scaling activation scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, #only channelwise is supported - out_features=out_features, - input_dtype=DType.INT8, - output_dtype=DType.FP32, - scaled_activations=False) - - - _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 - gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None); - - #Weights are unpacked() then shifted by 7 - self.assertTrue(gemlite_linear.W_group_mode == 1) - #Since the scales are channel-wise, we perform scaling post K-sum - self.assertTrue(gemlite_linear.channel_scale_mode == 1) - - tol = 1e-3 - - for batch_size in batch_sizes: - x = (torch.randint(-10, 10, (batch_size, in_features), device=device)).to(torch.int8) - y_ref = torch.matmul(x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - - - def test_int8xWn_scaled_activations(self): - #INT8 x Wn - activation scaling only - - gemlite_linear = GemLiteLinearTriton(W_nbits=W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=DType.INT8, - output_dtype=DType.FP32, - scaled_activations=True) - - - gemlite_linear.pack(W_q, scales=None, zeros=7, bias=None) - gemlite_linear.meta_dtype = DType.FP32 - - #Weights are unpacked() then shifted by 7 - self.assertTrue(gemlite_linear.W_group_mode == 1) - #Activations only are scaled - self.assertTrue(gemlite_linear.channel_scale_mode == 2) - - tol = 5e-3 - - for batch_size in batch_sizes: - x = torch.randn((batch_size, in_features), dtype=torch.float16, device=device) / 20. - _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) - y_ref = torch.matmul(_x.to(torch.float16), (W_q.to(torch.float16) - 7).T) * _x_scaled - - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type) - - def test_int8Wn_scaled_weights_scaled_activations(self): - #INT8 x Wn - activation scaling only - - gemlite_linear = GemLiteLinearTriton(W_nbits=8, - group_size=in_features, #only channel-wise supported - in_features=in_features, - out_features=out_features, - input_dtype=DType.INT8, - output_dtype=DType.FP32, - scaled_activations=True) - - _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 - gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None); - - #Weights are unpacked() then shifted by 7 if group_size == in_features (1), otherwise (3) - self.assertTrue(gemlite_linear.W_group_mode == 1) - #Activations only are scaled if group_size != in_features (2) otherwise bot are scales merged (3) - self.assertTrue(gemlite_linear.channel_scale_mode == 3) - - tol = 1e-3 - - for batch_size in batch_sizes: - shape = W_q.shape - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) - y_ref = torch.matmul(_x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) * _x_scaled - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type) - - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xfp8(self): - #FP8 x FP8 - no scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits=8, - group_size=None, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=False) - - - gemlite_linear.pack(W.to(fp8_dtype), None, None, None) - - #No weight unpacking / dequant - self.assertTrue(gemlite_linear.W_group_mode == 0) - #No channel-wise scaling - self.assertTrue(gemlite_linear.channel_scale_mode == 0) - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xfp8_scaled_weights_scaled_activations(self): - #FP8 x FP8 - both activations and weights are scaled - - gemlite_linear = GemLiteLinearTriton(W_nbits=8, - group_size=in_features, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=True) - - - _scales = torch.randn((1, out_features), dtype=compute_dtype, device=device) * 1e-4 - gemlite_linear.pack(W.to(fp8_dtype), scales=_scales, zeros=None, bias=None); - - #No weight unpacking / dequant - self.assertTrue(gemlite_linear.W_group_mode == 0) - #Both activations and weights are scales - self.assertTrue(gemlite_linear.channel_scale_mode == 3) - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - shape = W.shape - x = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. - _x, scales_x = scale_activations(x, w_dtype=fp8_dtype) - - y_ref = torch.matmul(_x.to(compute_dtype), W.T) * (_scales * scales_x) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xWn_scaled_activations(self): - #FP8 x Wn - asymmetric, with activation scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=True) - - - gemlite_linear.pack(W_q, scales, zeros, None); - - if(group_size == in_features): - #weight unpacking and shift if group_size == in_features else (3) - self.assertTrue((gemlite_linear.W_group_mode == 1) and (gemlite_linear.channel_scale_mode == 3) or - (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 2)) - else: - #activations and weights are scaled psot accumulation if group_size==in_features else (2) - self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) - self.assertTrue(gemlite_linear.channel_scale_mode == 2) - - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype).to(compute_dtype) - _x, _scaled_x = scale_activations(x, w_dtype=fp8_dtype) - y_ref = torch.matmul(_x.to(compute_dtype), W.T) * _scaled_x - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) - - @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_fp8xWn_no_activation_scaling(self): - #FP8 x Wn - asymmetric, no activation scaling - - gemlite_linear = GemLiteLinearTriton(W_nbits, - group_size=group_size, - in_features=in_features, - out_features=out_features, - input_dtype=TORCH_TO_DTYPE[fp8_dtype], - output_dtype=gemlite_dtype, - scaled_activations=False) - - gemlite_linear.pack(W_q, scales, zeros, None) - - if(group_size == in_features): - #Weight shift only if group_size==in_features else (3) - self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or - (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) - else: - #weight scaling only - post accumulator if group_size==in_features else (0) - self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) - self.assertTrue(gemlite_linear.channel_scale_mode == 0) - - tol = 5e-3 #needs higher tolerance with fp8 - - for batch_size in batch_sizes: - x = (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) - y_ref = torch.matmul(x.to(compute_dtype), W.T) - for matmul_type in matmul_types: - if(batch_size>1 and 'GEMV' in matmul_type): continue - y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) - err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) diff --git a/tests/test_int.py b/tests/test_int.py new file mode 100755 index 0000000..1416fdc --- /dev/null +++ b/tests/test_int.py @@ -0,0 +1,367 @@ +# Usage: python3 test_gemlitelineartriton.py [--autotune] +import sys +_autotune = '--autotune' in sys.argv +if _autotune: sys.argv.remove('--autotune') + + +import unittest +import torch +from gemlite import reset_config, set_autotune +from gemlite.core import GemLiteLinearTriton, DType, TORCH_TO_DTYPE, forward_functional +from gemlite.triton_kernels.config import KERNEL +from gemlite.quant_utils import scale_activations_per_token_torch as scale_activations + +def is_fp8_supported(): + if not torch.cuda.is_available(): + return False + capability = torch.cuda.get_device_capability(0) + return capability >= (8, 9) + +device = 'cuda:0' +compute_dtype = torch.bfloat16 #float16, bfloat16 +fp8_dtype = torch.float8_e4m3fn #float8_e4m3fn / torch.float8_e5m2 (Nvidia) +gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] +matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMV_SPLITK', 'GEMM_SPLITK', 'GEMM'] + +reset_config() +if _autotune is False: set_autotune(False) +KERNEL.ENABLE_CACHING = False + +in_features, out_features = 4032, 2032 +batch_sizes = [1, 3, 5, 16, 30, 65, 100, 250] +W_nbits, group_size = 4, 128 #128 / in_features + +if group_size is None: + group_size = in_features +if group_size < in_features: + in_features = (in_features // group_size) * group_size #ensure divisibility for current implementation + +def gen_data(in_features, out_features, W_nbits, group_size, dtype=compute_dtype): + + W_q = torch.randint(0, 2**W_nbits - 1, (out_features, in_features), device=device).to(torch.uint8) + + shape = (out_features, in_features) + gs = W_q.numel() // group_size + scales = torch.ones((gs, 1), device=device, dtype=dtype) * 0.001 + zeros = torch.zeros((gs, 1), device=device, dtype=dtype) * ((2**W_nbits - 1)//2) + W = ((W_q.reshape([-1, group_size]) - zeros) * scales).to(fp8_dtype).to(dtype) + + zeros = torch.mean(W_q.reshape([-1, group_size]).float() - (W / scales).float(), axis=1, keepdim=True).to(dtype) + W = ((W_q.reshape([-1, group_size]).to(dtype) - zeros) * scales) + W = W.reshape(shape) + + return W, W_q, scales, zeros + +W, W_q, scales, zeros = gen_data(in_features, out_features, W_nbits=W_nbits, group_size=group_size) + +#Pre-cache data for faster processing +input_data = {} +for batch_size in batch_sizes: + torch.random.manual_seed(0) + input_data[batch_size] = torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10. + +class TestGemLiteLinearTriton(unittest.TestCase): + + def eval(self, gemlite_linear, ref_fn, tol: float = 1e-3, input_fn=None, _matmul_types=None): + """ + Shared evaluation method. + Args: + gemlite_linear: the quantized linear layer to test + ref_fn: callable(x) -> y_ref, computes the reference output + tol: error tolerance + input_fn: optional callable(batch_size) -> x, custom input generator. + If None, uses pre-cached input_data. + _matmul_types: optional list of matmul types to test. If None, uses global matmul_types. + """ + if _matmul_types is None: + _matmul_types = matmul_types + + for batch_size in batch_sizes: + if input_fn is not None: + x = input_fn(batch_size) + else: + x = input_data[batch_size] + + y_ref = ref_fn(x) + + for matmul_type in _matmul_types: + if batch_size > 1 and 'GEMV' in matmul_type: + continue + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) + + def test_serialization(self): + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype) + + gemlite_linear.pack(W_q, scales, zeros, None) + + torch.save(gemlite_linear.state_dict(), 'tmp.pt') + + gemlite_linear_loaded = GemLiteLinearTriton() + gemlite_linear_loaded.load_state_dict(torch.load('tmp.pt')) + + ref_args = gemlite_linear.get_meta_args() + loaded_args = gemlite_linear_loaded.get_meta_args() + for i in range(len(ref_args)): + assert ref_args[i] == loaded_args[i], "meta_args mismatch at " + str(i) + + ref_args = gemlite_linear.get_tensor_args() + loaded_args = gemlite_linear_loaded.get_tensor_args() + for i in range(len(ref_args)): + if ref_args[i].numel() > 0: assert (ref_args[i] - loaded_args[i]).float().abs().mean() == 0, "tensor_args mismatch at " + str(i) + + def ref_fn(x): + return gemlite_linear.forward_manual(x, matmul_type='GEMM') + + self.eval(gemlite_linear_loaded, ref_fn, tol=1e-7, _matmul_types=['GEMM']) + + def test_fp16xfp16(self): + gemlite_linear = GemLiteLinearTriton(W_nbits=16, + group_size=None, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype, + scaled_activations=False) + + gemlite_linear.pack(W, None, None, None) + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0 and gemlite_linear.channel_scale_mode == 0) + #Use non-contiguous when data is not packed + self.assertTrue(gemlite_linear.data_contiguous == False) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=2.5e-3) #higher tol for gemv kernels, otherwise 1e-3 is fine + + def test_fp16xWn_asymmetric(self): + #FP16 x Wn / asymmetric + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype) + + gemlite_linear.pack(W_q, scales, zeros, None) + + if(group_size == in_features): + #Weights are unpacked() then shift only if group_size == in_features (1) otherwise (3) + self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or + (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) + else: + self.assertTrue(gemlite_linear.W_group_mode in [3, 4] and gemlite_linear.channel_scale_mode == 0) + + #Use-contiguous when data is packed + self.assertTrue(gemlite_linear.data_contiguous == True) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=1e-3) + + def test_int8xWn_symmetric_no_activation_scaling(self): + #INT8 x Wn - symmetric / no scaling activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, #only channelwise is supported + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=False) + + _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 + gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None) + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Since the scales are channel-wise, we perform scaling post K-sum + self.assertTrue(gemlite_linear.channel_scale_mode == 1) + + def input_fn(batch_size): + return (torch.randint(-10, 10, (batch_size, in_features), device=device)).to(torch.int8) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) + + self.eval(gemlite_linear, ref_fn, tol=1e-3, input_fn=input_fn) + + def test_int8xWn_scaled_activations(self): + #INT8 x Wn - activation scaling only + + gemlite_linear = GemLiteLinearTriton(W_nbits=W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=True) + + gemlite_linear.pack(W_q, scales=None, zeros=7, bias=None) + gemlite_linear.meta_dtype = DType.FP32 + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Activations only are scaled + self.assertTrue(gemlite_linear.channel_scale_mode == 2) + + def input_fn(batch_size): + return torch.randn((batch_size, in_features), dtype=torch.float16, device=device) / 20. + + def ref_fn(x): + _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) + return torch.matmul(_x.to(torch.float16), (W_q.to(torch.float16) - 7).T) * _x_scaled + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) + + def test_int8Wn_scaled_weights_scaled_activations(self): + #INT8 x Wn - activation scaling only + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, #only channel-wise supported + in_features=in_features, + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=True) + + _scales = torch.randn((out_features, 1), dtype=compute_dtype, device=device) * 1e-4 + gemlite_linear.pack(W_q, scales=_scales, zeros=7, bias=None) + + #Weights are unpacked() then shifted by 7 if group_size == in_features (1), otherwise (3) + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Activations only are scaled if group_size != in_features (2) otherwise bot are scales merged (3) + self.assertTrue(gemlite_linear.channel_scale_mode == 3) + + def ref_fn(x): + _x, _x_scaled = scale_activations(x, w_dtype=torch.int8) + return torch.matmul(_x.to(compute_dtype), ((W_q.to(compute_dtype) - 7) * _scales).T) * _x_scaled + + self.eval(gemlite_linear, ref_fn, tol=1e-3) + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xfp8(self): + #FP8 x FP8 - no scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=None, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=False) + + gemlite_linear.pack(W.to(fp8_dtype), None, None, None) + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0) + #No channel-wise scaling + self.assertTrue(gemlite_linear.channel_scale_mode == 0) + + def input_fn(batch_size): + return (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xfp8_scaled_weights_scaled_activations(self): + #FP8 x FP8 - both activations and weights are scaled + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=True) + + _scales = torch.randn((1, out_features), dtype=compute_dtype, device=device) * 1e-4 + gemlite_linear.pack(W.to(fp8_dtype), scales=_scales, zeros=None, bias=None) + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0) + #Both activations and weights are scales + self.assertTrue(gemlite_linear.channel_scale_mode == 3) + + def ref_fn(x): + _x, scales_x = scale_activations(x, w_dtype=fp8_dtype) + return torch.matmul(_x.to(compute_dtype), W.T) * (_scales * scales_x) + + self.eval(gemlite_linear, ref_fn, tol=5e-3) #needs higher tolerance with fp8 + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xWn_scaled_activations(self): + #FP8 x Wn - asymmetric, with activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=True) + + gemlite_linear.pack(W_q, scales, zeros, None) + + if(group_size == in_features): + #weight unpacking and shift if group_size == in_features else (3) + self.assertTrue((gemlite_linear.W_group_mode == 1) and (gemlite_linear.channel_scale_mode == 3) or + (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 2)) + else: + #activations and weights are scaled psot accumulation if group_size==in_features else (2) + self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) + self.assertTrue(gemlite_linear.channel_scale_mode == 2) + + def input_fn(batch_size): + return (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype).to(compute_dtype) + + def ref_fn(x): + _x, _scaled_x = scale_activations(x, w_dtype=fp8_dtype) + return torch.matmul(_x.to(compute_dtype), W.T) * _scaled_x + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_fp8xWn_no_activation_scaling(self): + #FP8 x Wn - asymmetric, no activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=False) + + gemlite_linear.pack(W_q, scales, zeros, None) + + if(group_size == in_features): + #Weight shift only if group_size==in_features else (3) + self.assertTrue((gemlite_linear.W_group_mode == 1 and gemlite_linear.channel_scale_mode == 1) or + (gemlite_linear.W_group_mode == 3 and gemlite_linear.channel_scale_mode == 0)) + else: + #weight scaling only - post accumulator if group_size==in_features else (0) + self.assertTrue(gemlite_linear.W_group_mode in [3, 4]) + self.assertTrue(gemlite_linear.channel_scale_mode == 0) + + def input_fn(batch_size): + return (torch.randn((batch_size, in_features), dtype=compute_dtype, device=device) / 10.).to(fp8_dtype) + + def ref_fn(x): + return torch.matmul(x.to(compute_dtype), W.T) + + self.eval(gemlite_linear, ref_fn, tol=5e-3, input_fn=input_fn) #needs higher tolerance with fp8 +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py index 950009d..2ca3250 100644 --- a/tests/test_mxfp.py +++ b/tests/test_mxfp.py @@ -1,4 +1,7 @@ -#python -m unittest test_mxfp.py +# Usage: python3 test_mxfp.py [--autotune] +import sys +_autotune = '--autotune' in sys.argv +if _autotune: sys.argv.remove('--autotune') import unittest import torch @@ -14,18 +17,22 @@ def is_fp8_supported(device_index=0): device = 'cuda:0' compute_dtype = torch.bfloat16 #float16, bfloat16 -matmul_types = ['GEMM_SPLITK', 'GEMM'] #TODO: add GEMV use-cases +matmul_types = ['GEMM', 'GEMM_SPLITK'] #TODO: improve GEMV mxfp accuracy. + reset_config() -set_autotune(False) +if _autotune is False: set_autotune(False) KERNEL.ENABLE_CACHING = False torch.random.manual_seed(0) -in_features, out_features = 4096, 2048 -batch_sizes = [1, 4, 16] +in_features, out_features = 4224, 2048 # test 5D TMA +#in_features, out_features = 4032, 2048 # test 2D scales fall-back +batch_sizes = [1, 3, 16, 30, 32, 60, 100, 128] linear_layer = torch.nn.Linear(in_features=in_features, out_features=out_features, device=device, dtype=compute_dtype, bias=False) linear_layer.weight.data /= 10. linear_layer.weight.requires_grad = False +assert in_features % 32 == 0, "in_features must be divisible by 32 for the current implementation" + #Pre-cache data for faster processing input_data = {} for batch_size in batch_sizes: @@ -38,17 +45,25 @@ def eval(self, gemlite_linear, tol: float = 1e-3): x = input_data[batch_size] y_ref = linear_layer(x) for matmul_type in matmul_types: + if(batch_size>1 and 'GEMV' in matmul_type): continue y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) err = (y_ref - y_gem).abs().mean().item() - self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol) + ' | ' + matmul_type + ' | batch_size: ' + str(batch_size)) + + # @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + # def test_A16W8_MXFP(self): + # gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) + # self.assertTrue(not gemlite_linear.scaled_activations) + # self.eval(gemlite_linear, tol = 2e-4) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") - def test_A16W8_MXFP(self): - gemlite_linear = A16W8_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + def test_A8W8_MXFP_post_scale_dynamic(self): + gemlite_linear = A8W8_MXFP_dynamic(device=device, dtype=compute_dtype, post_scale=True).from_linear(linear_layer, del_orig=False) self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features)) - self.assertTrue(not gemlite_linear.scaled_activations) + self.assertTrue(gemlite_linear.scaled_activations) self.eval(gemlite_linear, tol = 2e-4) - + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W8_MXFP_dynamic(self): gemlite_linear = A8W8_MXFP_dynamic(device=device, dtype=compute_dtype, post_scale=False).from_linear(linear_layer, del_orig=False) @@ -56,11 +71,11 @@ def test_A8W8_MXFP_dynamic(self): self.assertTrue(gemlite_linear.scaled_activations) self.eval(gemlite_linear, tol = 2e-4) - def test_A16W4_MXFP(self): - gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) - self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) - self.assertTrue(not gemlite_linear.scaled_activations) - self.eval(gemlite_linear, tol = 7e-4) + # def test_A16W4_MXFP(self): + # gemlite_linear = A16W4_MXFP(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + # self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) + # self.assertTrue(not gemlite_linear.scaled_activations) + # self.eval(gemlite_linear, tol = 7e-4) @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") def test_A8W4_MXFP_dynamic(self): @@ -76,9 +91,12 @@ def test_A4W4_MXFP_dynamic(self): self.eval(gemlite_linear, tol = 1e-3) def test_A4W4_NVFP_dynamic(self): - gemlite_linear = A4W4_MXFP_dynamic(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) + gemlite_linear = A4W4_NVFP_dynamic(device=device, dtype=compute_dtype).from_linear(linear_layer, del_orig=False) self.assertTrue(gemlite_linear.W_q.numel() * gemlite_linear.W_q.itemsize == (in_features * out_features // 2)) self.assertTrue(gemlite_linear.scaled_activations) - self.eval(gemlite_linear, tol = 1e-3) + self.eval(gemlite_linear, tol = 2e-3) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..ebcf4f5 --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,134 @@ +# Usage: python3 test_serialization.py [--autotune] +import sys +_autotune = '--autotune' in sys.argv +if _autotune: sys.argv.remove('--autotune') + +import unittest +import torch +from gemlite import reset_config, set_autotune +from gemlite.core import GemLiteLinearTriton, DType, TORCH_TO_DTYPE +from gemlite.triton_kernels.config import KERNEL +from gemlite.helper import A4W4_MXFP_dynamic, A4W4_NVFP_dynamic, patch_model + +def is_fp8_supported(): + if not torch.cuda.is_available(): + return False + capability = torch.cuda.get_device_capability(0) + return capability >= (8, 9) + +device = 'cuda:0' +compute_dtype = torch.bfloat16 +gemlite_dtype = TORCH_TO_DTYPE[compute_dtype] + +reset_config() +if _autotune is False: set_autotune(False) +KERNEL.ENABLE_CACHING = False + +def _check_serialization(test_case, gemlite_linear, matmul_type='GEMM', batch_size=32, tol=1e-7): + """Shared serialization round-trip check.""" + in_features = gemlite_linear.in_features + + torch.save(gemlite_linear.state_dict(), '/tmp/_test_serial.pt') + + loaded = GemLiteLinearTriton() + loaded.load_state_dict(torch.load('/tmp/_test_serial.pt')) + + # Check meta_args match + ref_meta = gemlite_linear.get_meta_args() + loaded_meta = loaded.get_meta_args() + for i in range(len(ref_meta)): + test_case.assertEqual(ref_meta[i], loaded_meta[i], f"meta_args mismatch at {i}: {ref_meta[i]} != {loaded_meta[i]}") + + # Check tensor_args match + ref_tensors = gemlite_linear.get_tensor_args() + loaded_tensors = loaded.get_tensor_args() + for i in range(len(ref_tensors)): + if ref_tensors[i].numel() > 0: + diff = (ref_tensors[i].float() - loaded_tensors[i].float()).abs().mean().item() + test_case.assertEqual(diff, 0, f"tensor_args mismatch at {i}: mean diff = {diff}") + + # Check inference matches + x = torch.randn(batch_size, in_features, dtype=compute_dtype, device=device) / 10. + y_ref = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + y_loaded = loaded.forward_manual(x, matmul_type=matmul_type) + diff = (y_ref - y_loaded).abs().mean().item() + test_case.assertTrue(diff < tol, f"Inference mismatch: mean diff = {diff}, expected < {tol}") + + +class TestSerializationINT(unittest.TestCase): + """Serialization tests for INT quantized layers.""" + + def test_A16W4(self): + in_features, out_features = 4096, 2048 + W_nbits, group_size = 4, 128 + + W_q = torch.randint(0, 2**W_nbits - 1, (out_features, in_features), device=device).to(torch.uint8) + gs = W_q.numel() // group_size + scales = torch.ones((gs, 1), device=device, dtype=compute_dtype) * 0.001 + zeros = torch.zeros((gs, 1), device=device, dtype=compute_dtype) * ((2**W_nbits - 1)//2) + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=gemlite_dtype, + output_dtype=gemlite_dtype) + gemlite_linear.pack(W_q, scales, zeros, None) + + _check_serialization(self, gemlite_linear) + + @unittest.skipIf(not is_fp8_supported(), "Skipping test: GPU does not support FP8") + def test_A8W8(self): + in_features, out_features = 4096, 2048 + fp8_dtype = torch.float8_e4m3fn + + W = torch.randn((out_features, in_features), dtype=compute_dtype, device=device) / 10. + _scales = torch.randn((1, out_features), dtype=compute_dtype, device=device) * 1e-4 + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, + in_features=in_features, + out_features=out_features, + input_dtype=TORCH_TO_DTYPE[fp8_dtype], + output_dtype=gemlite_dtype, + scaled_activations=True) + gemlite_linear.pack(W.to(fp8_dtype), scales=_scales, zeros=None, bias=None) + + _check_serialization(self, gemlite_linear) + + +class TestSerializationMX(unittest.TestCase): + """Serialization tests for MXFP/NVFP quantized layers.""" + + def setUp(self): + self.in_features, self.out_features = 4224, 2048 + torch.manual_seed(42) + self.linear_layer = torch.nn.Linear( + self.in_features, self.out_features, dtype=compute_dtype, device=device, bias=False + ) + self.linear_layer.weight.data /= 10. + self.linear_layer.weight.requires_grad = False + + def _quantize(self, processor_fn): + model = torch.nn.Sequential( + torch.nn.Linear(self.in_features, self.out_features, dtype=compute_dtype, device=device, bias=False) + ) + model.requires_grad_(False) + model[0].weight.data = self.linear_layer.weight.data.clone() + processor = processor_fn(dtype=compute_dtype) + patch_model(model, device=device, processor=processor) + return model[0] + + def test_A4W4_MXFP(self): + gemlite_linear = self._quantize(A4W4_MXFP_dynamic) + _check_serialization(self, gemlite_linear, matmul_type='GEMM') + _check_serialization(self, gemlite_linear, matmul_type='GEMM_SPLITK', batch_size=2) + + def test_A4W4_NVFP(self): + gemlite_linear = self._quantize(A4W4_NVFP_dynamic) + _check_serialization(self, gemlite_linear, matmul_type='GEMM') + _check_serialization(self, gemlite_linear, matmul_type='GEMM_SPLITK', batch_size=2) + + +if __name__ == '__main__': + unittest.main()