Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 108 additions & 82 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]:
"mxfp8_block_scaling",
):
return dtype_tols(tex.DType.kFloat8E4M3)
if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"):
if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"):
return dtype_tols(tex.DType.kFloat4E2M1)
raise ValueError(f"Unsupported quantization scheme ({name})")

Expand All @@ -145,10 +145,10 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]:
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs)
if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"):
if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"):
use_4over6 = name == "nvfp4_4over6"
kwargs = {
"disable_rht": True,
"disable_rht": name != "nvfp4_rht",
"disable_stochastic_rounding": True,
"disable_2d_quantization": not use_4over6,
"row_scaled_activation": name == "nvfp4_row_scaled",
Expand All @@ -163,12 +163,16 @@ def recipe_id(recipe: Optional[Recipe]) -> str:
"""Readable pytest id for a quantization recipe."""
if not isinstance(recipe, Recipe):
return "None"
if recipe.nvfp4() and recipe.row_scaled_activation and recipe.nvfp4_4over6 != "none":
return "NVFP4RowScaled4Over6BlockScaling"
if recipe.nvfp4() and recipe.nvfp4_4over6 != "none":
return "NVFP44Over6BlockScaling"
if recipe.nvfp4() and recipe.row_scaled_activation:
return "NVFP4RowScaledBlockScaling"
if recipe.nvfp4():
nvfp4_features = []
if recipe.row_scaled_activation:
nvfp4_features.append("RowScaled")
if recipe.nvfp4_4over6 != "none":
nvfp4_features.append("4Over6")
if not recipe.disable_rht:
nvfp4_features.append("RHT")
if nvfp4_features:
return f"NVFP4{''.join(nvfp4_features)}BlockScaling"
return type(recipe).__name__


Expand Down
149 changes: 148 additions & 1 deletion transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Helper functions used in fusible operations."""

from __future__ import annotations
from collections.abc import Iterable
import functools
import math
from importlib.metadata import PackageNotFoundError, version as get_pkg_version
Expand All @@ -13,10 +14,13 @@
import torch
from packaging.version import Version as PkgVersion

import transformer_engine_torch as tex
from transformer_engine_torch import FP8TensorMeta
from ..torch_version import torch_version
from ..quantization import FP8GlobalStateManager
from ..tensor import NVFP4Quantizer, NVFP4Tensor, NVFP4TensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.grouped_tensor import GroupedTensor
from ..quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype

Expand Down Expand Up @@ -57,6 +61,146 @@ def _nvidia_cudnn_frontend_supports_wgrad() -> bool:
return _cudnn_frontend_version_supported()


def _group_quantize_for_grouped_mlp(
tensor: torch.Tensor,
quantizer: Quantizer,
num_groups: int,
split_sizes: Optional[torch.Tensor],
*,
tensor_offsets: Optional[torch.Tensor] = None,
) -> GroupedTensor:
"""Quantize into grouped storage."""

# Typical case: group-quantize
if num_groups != 1 or not isinstance(quantizer, NVFP4Quantizer):
return tex.group_quantize(tensor, quantizer, num_groups, split_sizes)

# --------------------------------------------------
# Special case: single-tensor NVFP4 quantize
# --------------------------------------------------

quantized = tex.quantize(tensor, quantizer)
Comment thread
sraman-rgb marked this conversation as resolved.
with_gemm_swizzled_scales = quantized._with_gemm_swizzled_scales
if quantizer.optimize_for_gemm:
tex.swizzle_scales_for_gemm_(quantized)
with_gemm_swizzled_scales = True

rowwise_data = quantized._rowwise_data
rowwise_scale = quantized._rowwise_scale_inv
columnwise_data = quantized._columnwise_data
columnwise_scale = quantized._columnwise_scale_inv
amax = quantized._amax_rowwise
columnwise_amax = quantized._amax_columnwise

if split_sizes is None:
split_sizes = torch.full((1,), tensor.shape[0], dtype=torch.int64, device=tensor.device)
else:
split_sizes = split_sizes.to(dtype=torch.int64, device=tensor.device)

m_dim = tensor.shape[0]
if rowwise_data is not None:
k_dim = rowwise_data.shape[-1] * 2
elif columnwise_data is not None:
k_dim = columnwise_data.shape[0]
else:
k_dim = tensor.shape[-1]

if tensor_offsets is None:
tensor_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.int64, device=tensor.device),
torch.cumsum(split_sizes * k_dim, dim=0),
],
)

return GroupedTensor(
shape=(m_dim, k_dim),
dtype=tensor.dtype,
quantizer=quantizer,
num_tensors=1,
data=rowwise_data.reshape(-1) if rowwise_data is not None else None,
columnwise_data=columnwise_data.reshape(-1) if columnwise_data is not None else None,
scale_inv=rowwise_scale.reshape(-1) if rowwise_scale is not None else None,
columnwise_scale_inv=columnwise_scale.reshape(-1) if columnwise_scale is not None else None,
amax=amax,
columnwise_amax=columnwise_amax,
first_dims=split_sizes,
tensor_offsets=tensor_offsets,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
)


def _nvfp4_amax(
tensors: GroupedTensor | Iterable[NVFP4TensorStorage],
*,
columnwise: bool,
) -> torch.Tensor:
"""Get one NVFP4 amax value per group."""
grouped_attr = "columnwise_amax" if columnwise else "amax"
tensor_attr = "_amax_columnwise" if columnwise else "_amax_rowwise"

if hasattr(tensors, grouped_attr):
amax = getattr(tensors, grouped_attr)
if amax is None:
raise RuntimeError(f"NVFP4 GroupedTensor is missing {grouped_attr}.")
return amax.view(-1)

amaxes = [getattr(tensor, tensor_attr) for tensor in tensors]
if any(amax is None for amax in amaxes):
raise RuntimeError(f"NVFP4 tensor list is missing {tensor_attr}.")
return torch.cat([amax.view(-1) for amax in amaxes], dim=0)


def _nvfp4_single_tensor_from_grouped(
grouped: GroupedTensor,
quantizer: Optional[NVFP4Quantizer] = None,
*,
fp4_dtype: Optional[torch.dtype] = None,
) -> NVFP4Tensor:
"""Build a single NVFP4Tensor view over a one-member grouped storage."""
if quantizer is None:
quantizer = grouped.quantizer
if not isinstance(quantizer, NVFP4Quantizer):
raise TypeError("Expected an NVFP4 GroupedTensor.")

shape = tuple(grouped.logical_shape)
rowwise_data = None
if grouped.rowwise_data is not None:
rowwise_data = grouped.rowwise_data.view(quantizer.convert_shape_for_fp4(shape))

rowwise_scale_inv = None
if grouped.scale_inv is not None:
rowwise_scale_inv = grouped.scale_inv.view(quantizer.get_scale_shape(shape, False))

columnwise_data = None
if grouped.columnwise_data is not None:
columnwise_shape = quantizer.get_columnwise_shape(shape)
columnwise_data = grouped.columnwise_data.view(
quantizer.convert_shape_for_fp4(columnwise_shape)
)

columnwise_scale_inv = None
if grouped.columnwise_scale_inv is not None:
columnwise_scale_inv = grouped.columnwise_scale_inv.view(
quantizer.get_scale_shape(shape, True)
)

return NVFP4Tensor(
shape=shape,
dtype=grouped.get_dtype(),
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=grouped.amax,
amax_columnwise=grouped.columnwise_amax,
fp4_dtype=fp4_dtype or quantizer.dtype,
quantizer=quantizer,
requires_grad=False,
with_gemm_swizzled_scales=grouped._with_gemm_swizzled_scales,
)


def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool:
"""Check if tensor is a quantized tensor"""
return isinstance(tensor, QuantizedTensorStorage)
Expand Down Expand Up @@ -285,7 +429,10 @@ def fuse_grouped_mlp_ops(

if not fused_op_cls.is_supported():
return ops
if recipe is None or not recipe.mxfp8():
if recipe is None or not (recipe.mxfp8() or recipe.nvfp4()):
Comment thread
timmoon10 marked this conversation as resolved.
return ops
# NVFP4 fused grouped MLP uses graph-safe grouped quantize, which currently requires RHT.
if recipe.nvfp4() and recipe.disable_rht:
return ops
if activation_op_types is None:
activation_op_types = (ScaledSwiGLU, ScaledClampedQGeGLU)
Expand Down
21 changes: 20 additions & 1 deletion transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...quantization import FP8GlobalStateManager, Recipe
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
from ...utils import (
Expand Down Expand Up @@ -291,6 +291,25 @@ def num_quantizers(self, mode: str) -> int:
return self.num_groups
return 0

def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]:
name = getattr(self, "name", "") or ""
if mode == "forward":
roles = []
for _ in range(self.num_groups):
roles.extend(
[
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
]
)
return roles
if mode == "backward":
return [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name)
for _ in range(self.num_groups)
]
return None

@property
def has_bias(self) -> bool:
"""Whether an additive bias is being applied"""
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/ops/fused/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
# Import experimental fusions
# Note: Registration logic is non-trivial, so submodule handles it internally.
from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position
ForwardGroupedMLP_CuTeGEMMGLU_MXFP8,
ForwardGroupedMLP_CuTeGEMMUnary_MXFP8,
ForwardGroupedMLP_CuTeGEMMGLU,
ForwardGroupedMLP_CuTeGEMMUnary,
)
from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position
BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8,
BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8,
BackwardGroupedMLP_CuTeGEMMDGLU,
BackwardGroupedMLP_CuTeGEMMDUnary,
)
Loading
Loading