Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ save_quantized_params_path: ""
# accepted values are "inference"
model_call_mode: ""
use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is True.
# quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80
weight_quantization_calibration_method: "absmax"
act_quantization_calibration_method: "absmax"
Expand Down
12 changes: 6 additions & 6 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ class Quantization(BaseModel):
kv_quant_dtype: Literal["int8", "int4"] = Field("int8", description="Data type for KV cache quantization.")
quantization_local_shard_count: int = Field(-1, description="Shards the range finding operation for quantization.")
use_qwix_quantization: bool = Field(False, description="Whether to use qwix for quantization.")
use_manual_quantization: bool = Field(
False,
description="Whether to use manual quantization for batch split. Only used if use_batch_split_schedule is True.",
)
weight_quantization_calibration_method: str = Field(
"absmax",
description="Quantization calibration method used for weights.",
Expand Down Expand Up @@ -2727,8 +2731,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. "
f"Supported options are: {list(supported_decoders)}."
)
if self.quantization:
raise ValueError("Quantization is not supported with 'explicit' sharding.")
if self.context_sharding not in ("context", "expert"):
raise ValueError(f"Assigned context_sharding f{self.context_sharding} is not supported.")
if (
Expand Down Expand Up @@ -2835,10 +2837,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.use_grpo = False

if self.use_batch_split_schedule:
if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"):
raise ValueError(
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
)
if self.quantization and not self.quantization == "fp8_full":
raise ValueError("Batch split quantization only supports `quantization=fp8_full`")

if self.opt_type == "muon" and self.decoder_block not in [
DecoderBlockType.DEEPSEEK,
Expand Down
39 changes: 29 additions & 10 deletions src/maxtext/kernels/megablox/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
import jax.numpy as jnp
from maxtext.kernels.megablox import backend
from maxtext.layers import quantizations
import qwix
import qwix.pallas as qpl
import tokamax
Expand Down Expand Up @@ -61,6 +62,7 @@ def gmm(
weight_gather_axes: List[Tuple[str, int]] | None = None,
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
qwix_rule: qwix.QtRule | None = None,
use_manual_quantization: bool = False,
):
"""Grouped matrix multiplication operation."""
quantization_rule = None
Expand All @@ -80,7 +82,7 @@ def gmm(
)

gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12))
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
return gmm_fwd_bwd(
lhs,
Expand All @@ -95,6 +97,7 @@ def gmm(
quantization_rule,
use_tokamax_backend,
weight_gather_axes,
use_manual_quantization,
)


Expand All @@ -121,6 +124,7 @@ def _gmm_fwd(
quantization_rule: qwix.QtRule | None = None,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
use_manual_quantization: bool = False,
) -> tuple[
jnp.ndarray,
tuple[
Expand All @@ -140,15 +144,20 @@ def _gmm_fwd(
calibration_method=quantization_rule.act_calibration_method,
)
if quantization_rule.weight_qtype and not isinstance(rhs, qpl.QArray):
rhs = qpl.quantize(
rhs,
quantization_rule.weight_qtype,
# If only considering the fwd pass, we could also enable channelwise
# axes for the group axis, i.e., [0, 1 or 2]. However, this makes the
# bwd pass unable to reuse the scale easily.
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else ([1] if transpose_rhs else [2]),
calibration_method=quantization_rule.weight_calibration_method,
)
if not use_manual_quantization:
rhs = qpl.quantize(
rhs,
quantization_rule.weight_qtype,
# If only considering the fwd pass, we could also enable channelwise
# axes for the group axis, i.e., [0, 1 or 2]. However, this makes the
# bwd pass unable to reuse the scale easily.
channelwise_axes=([] if quantization_rule.disable_channelwise_axes else ([1] if transpose_rhs else [2])),
calibration_method=quantization_rule.weight_calibration_method,
)
else:
rhs = quantizations.manual_quantize(
rhs, quantization_rule.weight_calibration_method, quantization_rule.weight_qtype
)
# QAG is only supported for following conditions
if use_tokamax_backend:
if quantization_rule and quantization_rule.bwd_qtype:
Expand All @@ -169,6 +178,9 @@ def _gmm_fwd(
preferred_element_type=preferred_element_type,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"]))
if use_manual_quantization
else None,
)
else:
out = backend.gmm(
Expand All @@ -195,6 +207,7 @@ def _gmm_bwd(
quantization_rule: qwix.QtRule | None,
use_tokamax_backend: bool,
weight_gather_axes: List[Tuple[str, int]] | None,
use_manual_quantization: bool,
residual: tuple[
jnp.ndarray | qpl.QArray,
jnp.ndarray | qpl.QArray,
Expand Down Expand Up @@ -257,6 +270,9 @@ def _gmm_bwd(
preferred_element_type=lhs_dtype,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"]))
if use_manual_quantization
else None,
)
drhs = tokamax.ragged_dot_general(
lhs=lhs,
Expand All @@ -267,6 +283,9 @@ def _gmm_bwd(
preferred_element_type=rhs_dtype,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["expert"]), unreduced=frozenset(["data", "fsdp"]))
if use_manual_quantization
else None,
)
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
# Scatter back in reverse order of gather
Expand Down
7 changes: 5 additions & 2 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,8 @@ def __call__(
# as detected by immutable params, use deepseek_batchsplit custom
# scan with initialized parameters.
if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"):
if cfg.use_qwix_quantization:
# old version of batch-split that fully uses qwix quantization.
if cfg.use_qwix_quantization and not cfg.use_manual_quantization:
y = deepseek_batchsplit_fp8.scan_batch_split_layers(
y,
self.variables["params"]["moe_layers"],
Expand All @@ -935,7 +936,9 @@ def __call__(
policy=policy,
)
else:
# bf16 code path
# bf16 and fp8 code path for pure-JAX batch-split.
# fp8 code path supports both manual quantization and qwix
# quantization.
y = deepseek_batchsplit.scan_batch_split_layers(
y,
self.variables["params"]["moe_layers"],
Expand Down
97 changes: 80 additions & 17 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import functools
import json
import qwix.pallas as qpl
import re
from typing import Tuple, Sequence, Callable
from dataclasses import dataclass
Expand Down Expand Up @@ -629,13 +630,15 @@ def get_quant_mode(quant_mode_str: str = "train"):
def configure_quantization(config: Config, quant_mode_str: str = "train"):
"""Configure quantization based on user config and quant mode."""
if config.use_batch_split_schedule and config.quantization:
if not (config.use_qwix_quantization and config.quantization == "fp8_full"):
raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`")
return QwixQuantization(
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
)
# The older version of batch-split that fully uses qwix quantization.
if config.quantization == "fp8_full" and not config.use_manual_quantization:
return QwixQuantization(
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
)
# The pure JAX version of batch-split that uses manual quantization.
return None

if config.use_qwix_quantization:
return None
Expand Down Expand Up @@ -764,8 +767,7 @@ def get_quantization_rule(config: Config):
weight_qtype=jnp.int4,
act_qtype=jnp.int4,
bwd_qtype=jnp.int4,
bwd_weight_grad_tile_size=1
/ config.quantization_local_shard_count,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
Expand All @@ -776,8 +778,7 @@ def get_quantization_rule(config: Config):
weight_qtype=jnp.int8,
act_qtype=jnp.int8,
bwd_qtype=jnp.int8,
bwd_weight_grad_tile_size=1
/ config.quantization_local_shard_count,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
Expand All @@ -788,8 +789,7 @@ def get_quantization_rule(config: Config):
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1
/ config.quantization_local_shard_count,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
Expand All @@ -802,8 +802,7 @@ def get_quantization_rule(config: Config):
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1
/ config.quantization_local_shard_count,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
Expand All @@ -814,8 +813,7 @@ def get_quantization_rule(config: Config):
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1
/ config.quantization_local_shard_count,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
Expand Down Expand Up @@ -851,6 +849,71 @@ def maybe_quantize_model(model, config):
return model


def _cast_reduced_from(arr, reduced_arr):
aval = jax.typeof(reduced_arr)
# In shard map
if aval.sharding.mesh.axis_types[0] == jax.sharding.AxisType.Manual:
for axis in aval.mat.reduced:
arr = jax.lax.pcast(arr, axis, to="reduced")
return arr
# Outside shard map
return jax.reshard(arr, aval.sharding)


def _make_scale_tensor(scale, arr):
scale_tensor = jnp.full_like(arr, scale, dtype=jnp.bfloat16)
return _cast_reduced_from(scale_tensor, arr)


def _get_max_min(target_dtype):
if target_dtype in (jnp.int4, jnp.int8):
return jnp.iinfo(target_dtype).max, jnp.iinfo(target_dtype).min
else:
return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16)


def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn):
"""Manually quantizes a tensor based on a fixed calibration method.

Args:
tensor: The tensor to quantize.
calibration_method: A string specifying the calibration method. Expected
format is "fixed,{scale},{max_val}".

Returns:
A qwix.QArray containing the quantized value and the scale.

Raises:
ValueError: If calibration_method is None or has an unexpected format.
"""
calib_method = calibration_method
if calib_method is None:
raise ValueError("calibration_method cannot be None for manual quantization")
if not calib_method.startswith("fixed"):
raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}")

parts = calib_method.split(",")
if len(parts) != 3:
raise ValueError(f"Unexpected format for weight calibration method: {calib_method}")

dtype_max, dtype_min = _get_max_min(dtype)
max_val = float(parts[2])
scale = max_val / dtype_max
scale = jnp.where(scale == 0, 1.0, scale)
# scale must be converted to a tensor because grad has reduced axes.
scale_tensor = _make_scale_tensor(scale, tensor)
min_bound = _make_scale_tensor(dtype_min, tensor)
max_bound = _make_scale_tensor(dtype_max, tensor)
q_tensor = jnp.clip(tensor / scale_tensor, min_bound, max_bound).astype(dtype)

# get scale for QArray
scale_shape = [1] * tensor.ndim
# It must stay fully replicated for the backward pass and Pallas.
scale_tensor_qpl = jnp.full(scale_shape, scale, dtype=tensor.dtype)
# wrap in QArray
return qpl.QArray(qvalue=q_tensor, scale=scale_tensor_qpl)


class TransformerEngineQuantization(Quantization):
"""Class for TransformerEngine quantization recipes."""

Expand Down
7 changes: 5 additions & 2 deletions src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ def __call__(
# That is also why we can split/merge activations here as well as
# in `Decoder`, since they will never be executed together.
if self.config.use_batch_split_schedule:
if self.config.use_qwix_quantization:
# The older version of batch-split that fully uses qwix quantization.
if self.config.use_qwix_quantization and not self.config.use_manual_quantization:
activation_pspec = jax.sharding.PartitionSpec(
("data", "fsdp", "fsdp_transpose", "expert", "context"),
None,
Expand Down Expand Up @@ -490,7 +491,9 @@ def __call__(
)(outputs)
return outputs, None

# bf16 code path
# bf16 and fp8 code path for pure-JAX batch-split.
# fp8 code path supports both manual quantization and qwix
# quantization.
input_sharding = jax.typeof(inputs).sharding
activation_pspec = jax.sharding.PartitionSpec(
("data", "fsdp", "expert"),
Expand Down
Loading
Loading