Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Training-oriented kernels and schemes include:

- **[Blockwise FP8](alto/kernels/blockwise_fp8)** — linear, grouped GEMM, and FlashAttention.
- **[MXFP4](alto/kernels/fp4/mxfp4)** — linear, grouped GEMM, and FlashAttention.
- **[MXFP8](alto/kernels/mxfp8)** — linear and grouped GEMM (block-scaled E4M3, with E5M2 reserved for gradients).

Techniques used to narrow the gap versus BF16 include:

Expand Down
32 changes: 29 additions & 3 deletions alto/kernels/dispatch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_quantize_then_nvfp4_scaled_grouped_mm,
)
from alto.kernels.mxfp8.mxfp8_linear import _to_mxfp8_then_scaled_mm
from alto.kernels.mxfp8.mxfp8_grouped_gemm import _quantize_then_mxfp8_scaled_grouped_mm
from .config import TrainingOpConfig

aten = torch.ops.aten
Expand Down Expand Up @@ -403,9 +404,34 @@ class MXFP8TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
if func.__name__ == "_grouped_mm":
raise NotImplementedError(
"MXFP8 _grouped_mm is not supported by this dispatch path; "
"restrict MXFP8 schemes to Linear targets."
# Routed-expert MoE path: 2d activations x 3d weights with offsets.
A, B = args[0], args[1]
bias = kwargs.get("bias", None)
offs = kwargs.get("offs", None)

assert not isinstance(A, cls), f"A should not be a {cls.__name__}"
assert isinstance(B, cls), f"B should be a {cls.__name__}"
assert A.ndim == 2 and B.ndim == 3 and offs is not None, (
"Only 2d x 3d with offsets is supported for MXFP8 grouped_mm"
)
assert bias is None, "Bias is not supported for grouped_mm"

config = B.config
assert config.precision == "mxfp8_e4m3", (
"MXFP8 grouped_mm V1 supports only mxfp8_e4m3; "
f"got {config.precision} (e5m2 grouped path is not yet validated)"
)
assert not config.use_hadamard and not config.use_dge, (
"MXFP8 grouped_mm V1 does not support Hadamard or DGE options."
)

return _quantize_then_mxfp8_scaled_grouped_mm(
A,
B,
offs=offs,
use_2dblock_x=config.use_2dblock_x,
use_2dblock_w=config.use_2dblock_w,
use_sr_grad=config.use_sr_grad,
)

if func.__name__ in gemm_ops:
Expand Down
413 changes: 413 additions & 0 deletions alto/kernels/mxfp8/MXFP8_GROUPED_GEMM_PLAN.md

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions alto/kernels/mxfp8/mxfp8_grouped_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2026 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

from alto.kernels.mxfp8.mxfp8_grouped_gemm.functional import (
mxfp8_grouped_gemm,
_quantize_then_mxfp8_scaled_grouped_mm,
)

__all__ = ["mxfp8_grouped_gemm", "_quantize_then_mxfp8_scaled_grouped_mm"]
51 changes: 51 additions & 0 deletions alto/kernels/mxfp8/mxfp8_grouped_gemm/autotune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2026 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT
"""Autotune configs for mxfp8 grouped GEMM.

v1 keeps a single conservative config:
- BLOCK_SIZE_K == QUANT_BLOCK_SIZE (=32) so each tl.dot_scaled call covers
exactly one mx scale group; this matches the numerical contract validated
by alto/kernels/mxfp8/mxfp8_linear.py.
- BSM=BSN=128 matches mxfp4 grouped GEMM's default tile.
Wider autotune is deferred to v2.
"""

import triton

ALIGN_SIZE_M = 128 # token routing alignment; tokens routed to the same expert must form contiguous blocks of this size

STANDARD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
},
num_stages=2,
num_warps=4,
),
]

DGRAD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32, # dgrad reduces over N; keep one MX scale group per dot_scaled
"BLOCK_SIZE_K": 32,
},
num_stages=2,
num_warps=4,
),
]

WGRAD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 32, # wgrad reduces over M; keep one MX scale group per dot_scaled
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
},
num_stages=2,
num_warps=4,
),
]
Loading