Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
e588542
update
mobicham Feb 23, 2026
2b9cd59
update
mobicham Feb 23, 2026
baeda23
update
mobicham Feb 23, 2026
0dc7d0e
update
mobicham Feb 23, 2026
da98055
update
mobicham Feb 24, 2026
fc18161
update
mobicham Feb 24, 2026
0d02f97
update
mobicham Feb 24, 2026
1a66408
update
mobicham Feb 24, 2026
590ee0a
update
mobicham Feb 24, 2026
cf124c6
update
mobicham Feb 24, 2026
0715390
update
mobicham Mar 2, 2026
5a6ce1d
update
mobicham Mar 2, 2026
3095cdf
Fix tests
mobicham Mar 2, 2026
f625b98
fix more test
mobicham Mar 3, 2026
23d88d9
add eval flops script
mobicham Mar 3, 2026
31c8c40
update store masking
mobicham Mar 3, 2026
488e494
update configs
mobicham Mar 6, 2026
afe4105
update configs
mobicham Mar 6, 2026
7aee361
update configs
mobicham Mar 6, 2026
1869027
fix
mobicham Mar 6, 2026
48d03e8
fix
mobicham Mar 6, 2026
dcd6c79
fix
mobicham Mar 6, 2026
16a5078
fix
mobicham Mar 6, 2026
ffbb01b
use tma for a,b,c mx
mobicham Mar 6, 2026
a4a753a
use tma for a,b,c mx
mobicham Mar 6, 2026
2b6ce21
update scales
mobicham Mar 6, 2026
edb6fc3
tma stable
mobicham Mar 6, 2026
6dfcd1b
add mxfp8 v4 activation quant
mobicham Mar 6, 2026
22b074b
add mxfp4/nvfp4 v3 activation quant
mobicham Mar 6, 2026
f686e4e
add flashinfer nvfp4 benchmark
mobicham Mar 7, 2026
dd3c4f7
update mxfp/nvfp activation quant kernels
mobicham Mar 7, 2026
a0bdcdc
add 5d tma attempt
mobicham Mar 7, 2026
597d99c
remove 5d scales duplicate
mobicham Mar 8, 2026
01227af
clean-up
mobicham Mar 8, 2026
67cf317
fix tests with autotune
mobicham Mar 8, 2026
8d87c45
enable tma for splitK
mobicham Mar 8, 2026
f3a2f3e
fix mx autotune config test
mobicham Mar 8, 2026
1f11362
prune activation quant configs
mobicham Mar 8, 2026
bcf689f
update
mobicham Mar 9, 2026
8aaff4b
fix mxfp8 activation quant spill over
mobicham Mar 9, 2026
a0d98ac
set tma flag
mobicham Mar 9, 2026
ff7114f
gemv fixes and tests
mobicham Mar 9, 2026
525e655
update shared memory estimate
mobicham Mar 10, 2026
3059416
fix bugs
mobicham Mar 10, 2026
50ff07c
clean-up
mobicham Mar 10, 2026
dd7abea
update version
mobicham Mar 10, 2026
909f6a3
update
mobicham Mar 11, 2026
d70b9f4
cleanup
mobicham Mar 11, 2026
ef1cdd7
add ptx packing for mxfp4/nvfp4
mobicham Mar 11, 2026
405e201
use default ptx false
mobicham Mar 11, 2026
ebc0b1b
add todo
mobicham Mar 11, 2026
bf3b514
improve M=64 perf
mobicham Mar 19, 2026
846d0ef
add hints to mx non-tma path
mobicham Mar 19, 2026
4631e2f
add ptx_pack=True quant configs for MXFP4/NVFP4 (hardware e2m1x2 conv…
mobicham Mar 19, 2026
9ea1036
Revert "add ptx_pack=True quant configs for MXFP4/NVFP4 (hardware e2m…
mobicham Mar 19, 2026
7111bad
add gemlite.set_ptx_pack() for hardware FP4 packing
mobicham Mar 19, 2026
a29c4c4
update tl.constexpr
mobicham Mar 20, 2026
7356faf
add save/load guard
mobicham Mar 20, 2026
cd6b8d6
refactor tests
mobicham Mar 21, 2026
3d98df7
improve nvfp4
mobicham Mar 21, 2026
9cab0dc
cleanup nvfp4
mobicham Mar 21, 2026
49872c3
add fast nvfp4 mode function
mobicham Mar 21, 2026
2e597ae
nvfp4 fused kernel
mobicham Mar 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
451 changes: 451 additions & 0 deletions examples/eval_flops.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion gemlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.1.post1"
__version__ = "0.6.0"
__author__ = 'Dr. Hicham Badri'
__credits__ = 'Mobius Labs GmbH'

Expand All @@ -12,6 +12,10 @@
set_acc_dtype,
set_autotune,
set_kernel_caching,
enable_tma,
set_ptx_fp4_pack,
enable_cudagraph_autotune,
set_fast_nvfp4,
forward_functional,
)

Expand Down
117 changes: 107 additions & 10 deletions gemlite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
GEMLITE_MATMUL_TYPES_MAPPING = {GEMLITE_MATMUL_TYPES[i]: i for i in range(len(GEMLITE_MATMUL_TYPES))}
GEMLITE_TRITON_CONFIG_CACHE = {} #Global config cache for all the kernels
_GROUP_SIZE_WARNED = False
GEMLITE_USE_TMA = True # Set to False for faster MXFP8 on sm_120
GEMLITE_ENABLE_PTX_FP4_PACK = False # Set to True for hardware e2m1x2 FP4 packing (requires CUDA 13.0+ ptxas)
GEMLITE_FAST_NVFP4 = False
GEMLITE_NVFP4_META_SCALES = [] # Pre-allocated per-GPU meta_scale tensors

###################################################################################
#Utils
Expand Down Expand Up @@ -96,19 +100,48 @@ def set_acc_dtype(dtype):
assert dtype in [DType.FP16, DType.FP32], "Invalid dtype (should be DType.FP16 or DType.FP32)."
GEMLITE_ACC_DTYPE[DType.FP16] = dtype

#Enable/disable TMA for MX kernel data loading
def enable_tma(enabled: bool = True):
global GEMLITE_USE_TMA
GEMLITE_USE_TMA = enabled

#Enable/disable hardware PTX FP4 packing in activation quantization (requires CUDA 13.0+ ptxas)
def set_ptx_fp4_pack(enabled: bool = True):
global GEMLITE_ENABLE_PTX_FP4_PACK
GEMLITE_ENABLE_PTX_FP4_PACK = enabled
from .quant_utils import set_ptx_fp4_pack_flag
set_ptx_fp4_pack_flag(enabled)

#Enable/disable CUDA graph-based autotuning (more accurate but slower)
#Enable/disable fast NVFP4 mode (pre-allocated static meta_scale, skips dynamic computation)
def set_fast_nvfp4(enabled: bool = True, default_value: float = 0.05):
global GEMLITE_FAST_NVFP4, GEMLITE_NVFP4_META_SCALES
GEMLITE_FAST_NVFP4 = enabled
if enabled and len(GEMLITE_NVFP4_META_SCALES) == 0:
num_gpus = torch.cuda.device_count()
GEMLITE_NVFP4_META_SCALES = [
torch.full((1,), fill_value=default_value, device=f"cuda:{i}", dtype=torch.float32)
for i in range(num_gpus)
]

def enable_cudagraph_autotune(enabled: bool = True):
set_autotune("fast", use_cuda_graph=enabled)

#Return the default gemv kernel to use for M==1
def get_default_gemv(W_nbits: int, mx_dtype: bool = False) -> str:
#TODO: adapt mx for IS_HIP = True
if mx_dtype:
return 'GEMM_SPLITK' #TODO: fix mxf bugs in GEMV outputs garbage.
return 'GEMM_SPLITK' #TODO:'GEMV' if (W_nbits < 8) else 'GEMM_SPLITK' -> Revisit NVFP4 failing test.
else:
return 'GEMV_REVSPLITK' if (W_nbits < 8) else 'GEMV_SPLITK'

#matmul type selection logic
def get_matmul_type(batch_size: int, W_nbits: int, mx_dtype: bool = False):
if batch_size > 64:
gemm_limit = 64
if batch_size >= gemm_limit:
return "GEMM"
if batch_size > 1:
gemv_limit = 4 if (W_nbits < 8 and not mx_dtype) else 2 # previous 1
if batch_size > gemv_limit:
return "GEMM_SPLITK"
else:
return get_default_gemv(W_nbits, mx_dtype)
Expand All @@ -121,7 +154,7 @@ def enable_activation_scaling(batch_size):
Only works with the MXFP format - use with A8W4_MXFP/A4W4_MXFP.
"""
return True
#return batch_size >= 32
#return batch_size >= 2 #TODO: Needs Triton fix https://github.com/triton-lang/triton/pull/9577


#Main functional forward call
Expand Down Expand Up @@ -155,6 +188,7 @@ def forward_functional(
scaled_activations = bool(meta_args[0]) and enable_activation_scaling(batch_size)
#Dynamic activation quantization
scales_x = None
meta_scale = 0.0
if(scaled_activations):
input_dtype = DType(meta_args[5])
channel_scale_mode = meta_args[9]
Expand All @@ -172,7 +206,10 @@ def forward_functional(
x, scales_x = scale_activations_mxfp4(x)

elif(input_dtype in [DType.NVFP4] and channel_scale_mode == 4): #NVPF4: TODO
x, scales_x = scale_activations_nvfp4(x)
meta_scale = tensor_args[3]
_static_meta = GEMLITE_NVFP4_META_SCALES[x.device.index] if GEMLITE_FAST_NVFP4 else None
x, scales_x, meta_scale_a = scale_activations_nvfp4(x, meta_scale=_static_meta)
meta_scale = meta_scale * meta_scale_a # combine weight and activation meta_scales

x = x.view(-1, x.shape[-1])

Expand All @@ -184,7 +221,7 @@ def forward_functional(
out = (
GEMLITE_TRITON_MAPPING[matmul_type_str]
.forward(
x, *tensor_args, scales_x, *meta_args[1:-1], data_contiguous, type_id
x, *tensor_args[:3], scales_x, *meta_args[1:-1], data_contiguous, type_id, meta_scale=meta_scale
)
.view(out_shape)
)
Expand Down Expand Up @@ -252,7 +289,7 @@ def __init__(

if in_features is not None and out_features is not None:
if (in_features % GemLiteLinearTriton.MIN_SIZE != 0) or (
in_features % group_size != 0 if (group_size is not None) else False
(in_features % group_size != 0) if (group_size is not None and W_nbits < 16) else False
):
raise NotImplementedError(
"Invalid input shapes: "
Expand Down Expand Up @@ -298,13 +335,27 @@ def __init__(
#Default forward
self.forward = self.forward_auto_no_warmup

#Meta-scale for NVFP4 (0.0 = not used)
self.meta_scale = 0.0

def _save_to_state_dict(self, destination, prefix, keep_vars):
# Rebuild metadata from live attributes to ensure consistency
# (helpers may override channel_scale_mode/W_group_mode after pack())
if hasattr(self, 'metadata') and self.metadata is not None and hasattr(self, 'W_nbits'):
self.metadata = torch.nn.Parameter(
torch.tensor(self.get_meta_args(), device=self.metadata.device, dtype=torch.int32),
requires_grad=False,
)
super()._save_to_state_dict(destination, prefix, keep_vars)

def load_state_dict(self, state_dict, strict=True, assign=False):
self.W_q = state_dict.pop("W_q", None)
self.bias = state_dict.pop("bias", None)
self.scales = state_dict.pop("scales", None)
self.zeros = state_dict.pop("zeros", None)
self.metadata = state_dict.pop("metadata", None)
self.orig_shape = state_dict.pop("orig_shape", None)
_meta_scale = state_dict.pop("meta_scale", None)

self.metadata = [v.item() for v in self.metadata]
self.orig_shape = (v.item() for v in self.orig_shape)
Expand All @@ -327,10 +378,27 @@ def load_state_dict(self, state_dict, strict=True, assign=False):
self.acc_dtype = DType(self.acc_dtype)
self.meta_dtype = DType(self.meta_dtype)

# Restore meta_scale with backward compat for old checkpoints
if _meta_scale is not None:
self.meta_scale = _meta_scale.float()
else:
self.meta_scale = 0.05 if self.input_dtype == DType.NVFP4 else 0.0 # backward compat default for old checkpoints

self.out_features, self.in_features = self.orig_shape
self.compute_dtype = DTYPE_TO_TORCH[self.input_dtype.value]
self.scaled_activations = bool(self.scaled_activations)
self.data_contiguous = bool(self.data_contiguous)

# Backward compat: pop stale scales_5d from old saves
state_dict.pop("scales_5d", None)
# Convert 2D scales to 5D TMA layout for MX dtypes
if is_mx_dtype(self.input_dtype) and self.scales is not None:
s = self.scales.data if isinstance(self.scales, torch.nn.Parameter) else self.scales
if s.ndim == 2:
s_2d = s.T.contiguous() # [K_S, N] contiguous
N_dim, K_S = s_2d.shape[1], s_2d.shape[0]
if GEMLITE_USE_TMA and self.elements_per_sample > 1 and N_dim % 128 == 0 and K_S % 4 == 0:
self.scales = s_2d.reshape(N_dim // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N_dim // 128, K_S // 4, 2, 256).contiguous()

#Make sure to feed UINT8 W_q for packing
def pack(
Expand Down Expand Up @@ -399,7 +467,7 @@ def pack(

if(self.W_q is None):
raise Exception('Weights were not packed, please check your W_q.dtype')

#Bias / device
self.device = self.W_q.device
self.bias = None if (bias is None) else bias.to(device=self.device)
Expand Down Expand Up @@ -492,9 +560,29 @@ def pack(
if(self.input_dtype in [DType.NVFP4]):
self.scales = self.scales.to(torch.float8_e4m3fn)
if(is_mx_dtype(self.input_dtype)):
self.scales = self.scales.T
self.W_group_mode = 2
self.channel_scale_mode = 0

################################
# TMA
K, N = self.W_q.shape

if(self.input_dtype in [DType.MXFP4, DType.NVFP4]):
K *= 2
group_size = 2 * self.W_q.numel() // self.scales.numel()
else:
group_size = self.W_q.numel() // self.scales.numel()

# Preshuffle weight scales to 5D TMA layout for fast loading
# Original: [K_S, N] -> transpose to [N, K_S] -> 5D: [1, N//128, K_S//4, 2, 256]
K_S = K // group_size
if GEMLITE_USE_TMA and self.elements_per_sample > 1 and N % 128 == 0 and K_S % 4 == 0:
# Currently TMA only enabled for MXFP4/NVFP4 NOT for MXFP8 because of poor performance on sm_120 (self.elements_per_sample > 1 check)
self.scales = self.scales.T.contiguous().reshape(N // 128, 4, 32, K_S // 4, 4).permute(0, 3, 2, 1, 4).reshape(1, N // 128, K_S // 4, 2, 256).contiguous()
else:
# Keep 2D transposed layout for pointer-based fallback
self.scales = self.scales.T
################################

if(self.scales is not None):
self.meta_dtype = TORCH_TO_DTYPE[self.scales.dtype]
Expand All @@ -516,11 +604,20 @@ def pack(
requires_grad=False,
)


self.meta_scale = torch.nn.Parameter(
torch.tensor(self.meta_scale, device=self.device, dtype=torch.float32),
requires_grad=False,
)

return self

#Return the main arguments
def get_tensor_args(self):
return [self.W_q, self.scales, self.zeros]
meta_scale = self.meta_scale
if not isinstance(meta_scale, torch.Tensor):
meta_scale = torch.tensor(meta_scale, dtype=torch.float32, device=self.W_q.device)
return [self.W_q, self.scales, self.zeros, meta_scale]

def get_meta_args(self):
return [int(self.scaled_activations),
Expand Down
13 changes: 9 additions & 4 deletions gemlite/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def from_weights(self, weight, bias=None, scales=None):
assert weight.dtype in [torch.uint8], f"Invalid weight.dtype, should be an MXPF8 valid dtype, got {weight.dtype}."
assert scales.dtype in [torch.float8_e8m0fnu, torch.uint8], f"Invalid scales.dtype, should be e8m0 / view(uint8), got {scales.dtype}."
assert self.dtype is not None, f"Input dtype should be either torch.float16 or torch.bfloat16, not None."
assert self.group_size == 32, f"Only group_size=16 is supported for MXFP4, got {self.group_size}"
assert self.group_size == 32, f"Only group_size=32 is supported for MXFP4, got {self.group_size}"

dtype = self.dtype
gemlite_dtype = TORCH_TO_DTYPE[dtype]
Expand Down Expand Up @@ -888,7 +888,7 @@ def __init__(self, device='cuda:0', dtype=None):
self.group_size = 16
self.input_dtype = DType.NVFP4

def from_weights(self, weight, bias=None, scales=None):
def from_weights(self, weight, bias=None, scales=None, meta_scale=None):
if(isinstance(weight, torch.nn.Parameter)):
weight = weight.data
if(isinstance(bias, torch.nn.Parameter)):
Expand Down Expand Up @@ -923,6 +923,11 @@ def from_weights(self, weight, bias=None, scales=None):
gemlite_linear.pack(W_q, scales, zeros=None, bias=bias)
gemlite_linear.W_group_mode = 0
gemlite_linear.channel_scale_mode = 4
if meta_scale is not None:
gemlite_linear.meta_scale = torch.nn.Parameter(
meta_scale.to(dtype=torch.float32, device=gemlite_linear.W_q.device).reshape(()),
requires_grad=False,
)
return gemlite_linear


Expand All @@ -933,11 +938,11 @@ def from_linear(self, linear_layer, del_orig=True):
W = linear_layer.weight.data
bias = linear_layer.bias.clone() if (linear_layer.bias is not None) else None
N, K = W.shape
W_q, scales = self.quantizer_mx.quantize_nvfp4(W, index=True)
W_q, scales, _meta_scale = self.quantizer_mx.quantize_nvfp4(W, index=True)
W_q, scales = W_q.view([N, K]), scales.view(N, K // self.group_size)
cleanup_linear(linear_layer, del_orig)

out_layer = self.from_weights(weight=W_q, scales=scales, bias=bias)
out_layer = self.from_weights(weight=W_q, scales=scales, bias=bias, meta_scale=_meta_scale)

#Clean-uo
del W_q
Expand Down
Loading