diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 5e59a0f4be..8e0ba28379 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -143,6 +143,7 @@ save_quantized_params_path: "" # accepted values are "inference" model_call_mode: "" use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix. +use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is True. # quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80 weight_quantization_calibration_method: "absmax" act_quantization_calibration_method: "absmax" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b463516945..ac1736ce08 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -422,6 +422,10 @@ class Quantization(BaseModel): kv_quant_dtype: Literal["int8", "int4"] = Field("int8", description="Data type for KV cache quantization.") quantization_local_shard_count: int = Field(-1, description="Shards the range finding operation for quantization.") use_qwix_quantization: bool = Field(False, description="Whether to use qwix for quantization.") + use_manual_quantization: bool = Field( + False, + description="Whether to use manual quantization for batch split. Only used if use_batch_split_schedule is True.", + ) weight_quantization_calibration_method: str = Field( "absmax", description="Quantization calibration method used for weights.", @@ -2727,8 +2731,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. " f"Supported options are: {list(supported_decoders)}." ) - if self.quantization: - raise ValueError("Quantization is not supported with 'explicit' sharding.") if self.context_sharding not in ("context", "expert"): raise ValueError(f"Assigned context_sharding f{self.context_sharding} is not supported.") if ( @@ -2835,10 +2837,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.use_grpo = False if self.use_batch_split_schedule: - if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"): - raise ValueError( - "Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`" - ) + if self.quantization and not self.quantization == "fp8_full": + raise ValueError("Batch split quantization only supports `quantization=fp8_full`") if self.opt_type == "muon" and self.decoder_block not in [ DecoderBlockType.DEEPSEEK, diff --git a/src/maxtext/kernels/megablox/ops.py b/src/maxtext/kernels/megablox/ops.py index d65db0648b..0250e713a5 100644 --- a/src/maxtext/kernels/megablox/ops.py +++ b/src/maxtext/kernels/megablox/ops.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from maxtext.kernels.megablox import backend +from maxtext.layers import quantizations import qwix import qwix.pallas as qpl import tokamax @@ -61,6 +62,7 @@ def gmm( weight_gather_axes: List[Tuple[str, int]] | None = None, # TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature qwix_rule: qwix.QtRule | None = None, + use_manual_quantization: bool = False, ): """Grouped matrix multiplication operation.""" quantization_rule = None @@ -80,7 +82,7 @@ def gmm( ) gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001 - gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11)) + gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12)) gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype)) return gmm_fwd_bwd( lhs, @@ -95,6 +97,7 @@ def gmm( quantization_rule, use_tokamax_backend, weight_gather_axes, + use_manual_quantization, ) @@ -121,6 +124,7 @@ def _gmm_fwd( quantization_rule: qwix.QtRule | None = None, use_tokamax_backend: bool = False, weight_gather_axes: List[Tuple[str, int]] | None = None, + use_manual_quantization: bool = False, ) -> tuple[ jnp.ndarray, tuple[ @@ -140,15 +144,20 @@ def _gmm_fwd( calibration_method=quantization_rule.act_calibration_method, ) if quantization_rule.weight_qtype and not isinstance(rhs, qpl.QArray): - rhs = qpl.quantize( - rhs, - quantization_rule.weight_qtype, - # If only considering the fwd pass, we could also enable channelwise - # axes for the group axis, i.e., [0, 1 or 2]. However, this makes the - # bwd pass unable to reuse the scale easily. - channelwise_axes=[] if quantization_rule.disable_channelwise_axes else ([1] if transpose_rhs else [2]), - calibration_method=quantization_rule.weight_calibration_method, - ) + if not use_manual_quantization: + rhs = qpl.quantize( + rhs, + quantization_rule.weight_qtype, + # If only considering the fwd pass, we could also enable channelwise + # axes for the group axis, i.e., [0, 1 or 2]. However, this makes the + # bwd pass unable to reuse the scale easily. + channelwise_axes=([] if quantization_rule.disable_channelwise_axes else ([1] if transpose_rhs else [2])), + calibration_method=quantization_rule.weight_calibration_method, + ) + else: + rhs = quantizations.manual_quantize( + rhs, quantization_rule.weight_calibration_method, quantization_rule.weight_qtype + ) # QAG is only supported for following conditions if use_tokamax_backend: if quantization_rule and quantization_rule.bwd_qtype: @@ -169,6 +178,9 @@ def _gmm_fwd( preferred_element_type=preferred_element_type, group_offset=group_offset, implementation="mosaic", + manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])) + if use_manual_quantization + else None, ) else: out = backend.gmm( @@ -195,6 +207,7 @@ def _gmm_bwd( quantization_rule: qwix.QtRule | None, use_tokamax_backend: bool, weight_gather_axes: List[Tuple[str, int]] | None, + use_manual_quantization: bool, residual: tuple[ jnp.ndarray | qpl.QArray, jnp.ndarray | qpl.QArray, @@ -257,6 +270,9 @@ def _gmm_bwd( preferred_element_type=lhs_dtype, group_offset=group_offset, implementation="mosaic", + manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])) + if use_manual_quantization + else None, ) drhs = tokamax.ragged_dot_general( lhs=lhs, @@ -267,6 +283,9 @@ def _gmm_bwd( preferred_element_type=rhs_dtype, group_offset=group_offset, implementation="mosaic", + manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["expert"]), unreduced=frozenset(["data", "fsdp"])) + if use_manual_quantization + else None, ) if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes: # Scatter back in reverse order of gather diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 42644ab262..608f86fa79 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -922,7 +922,8 @@ def __call__( # as detected by immutable params, use deepseek_batchsplit custom # scan with initialized parameters. if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"): - if cfg.use_qwix_quantization: + # old version of batch-split that fully uses qwix quantization. + if cfg.use_qwix_quantization and not cfg.use_manual_quantization: y = deepseek_batchsplit_fp8.scan_batch_split_layers( y, self.variables["params"]["moe_layers"], @@ -935,7 +936,9 @@ def __call__( policy=policy, ) else: - # bf16 code path + # bf16 and fp8 code path for pure-JAX batch-split. + # fp8 code path supports both manual quantization and qwix + # quantization. y = deepseek_batchsplit.scan_batch_split_layers( y, self.variables["params"]["moe_layers"], diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index dad0f6179e..248333fa26 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -16,6 +16,7 @@ import functools import json +import qwix.pallas as qpl import re from typing import Tuple, Sequence, Callable from dataclasses import dataclass @@ -629,13 +630,15 @@ def get_quant_mode(quant_mode_str: str = "train"): def configure_quantization(config: Config, quant_mode_str: str = "train"): """Configure quantization based on user config and quant mode.""" if config.use_batch_split_schedule and config.quantization: - if not (config.use_qwix_quantization and config.quantization == "fp8_full"): - raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`") - return QwixQuantization( - weight_calibration_method=config.weight_quantization_calibration_method, - act_calibration_method=config.act_quantization_calibration_method, - bwd_calibration_method=config.bwd_quantization_calibration_method, - ) + # The older version of batch-split that fully uses qwix quantization. + if config.quantization == "fp8_full" and not config.use_manual_quantization: + return QwixQuantization( + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, + ) + # The pure JAX version of batch-split that uses manual quantization. + return None if config.use_qwix_quantization: return None @@ -764,8 +767,7 @@ def get_quantization_rule(config: Config): weight_qtype=jnp.int4, act_qtype=jnp.int4, bwd_qtype=jnp.int4, - bwd_weight_grad_tile_size=1 - / config.quantization_local_shard_count, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] @@ -776,8 +778,7 @@ def get_quantization_rule(config: Config): weight_qtype=jnp.int8, act_qtype=jnp.int8, bwd_qtype=jnp.int8, - bwd_weight_grad_tile_size=1 - / config.quantization_local_shard_count, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] @@ -788,8 +789,7 @@ def get_quantization_rule(config: Config): weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 - / config.quantization_local_shard_count, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] @@ -802,8 +802,7 @@ def get_quantization_rule(config: Config): weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 - / config.quantization_local_shard_count, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] @@ -814,8 +813,7 @@ def get_quantization_rule(config: Config): weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 - / config.quantization_local_shard_count, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] @@ -851,6 +849,71 @@ def maybe_quantize_model(model, config): return model +def _cast_reduced_from(arr, reduced_arr): + aval = jax.typeof(reduced_arr) + # In shard map + if aval.sharding.mesh.axis_types[0] == jax.sharding.AxisType.Manual: + for axis in aval.mat.reduced: + arr = jax.lax.pcast(arr, axis, to="reduced") + return arr + # Outside shard map + return jax.reshard(arr, aval.sharding) + + +def _make_scale_tensor(scale, arr): + scale_tensor = jnp.full_like(arr, scale, dtype=jnp.bfloat16) + return _cast_reduced_from(scale_tensor, arr) + + +def _get_max_min(target_dtype): + if target_dtype in (jnp.int4, jnp.int8): + return jnp.iinfo(target_dtype).max, jnp.iinfo(target_dtype).min + else: + return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16) + + +def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn): + """Manually quantizes a tensor based on a fixed calibration method. + + Args: + tensor: The tensor to quantize. + calibration_method: A string specifying the calibration method. Expected + format is "fixed,{scale},{max_val}". + + Returns: + A qwix.QArray containing the quantized value and the scale. + + Raises: + ValueError: If calibration_method is None or has an unexpected format. + """ + calib_method = calibration_method + if calib_method is None: + raise ValueError("calibration_method cannot be None for manual quantization") + if not calib_method.startswith("fixed"): + raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}") + + parts = calib_method.split(",") + if len(parts) != 3: + raise ValueError(f"Unexpected format for weight calibration method: {calib_method}") + + dtype_max, dtype_min = _get_max_min(dtype) + max_val = float(parts[2]) + scale = max_val / dtype_max + scale = jnp.where(scale == 0, 1.0, scale) + # scale must be converted to a tensor because grad has reduced axes. + scale_tensor = _make_scale_tensor(scale, tensor) + min_bound = _make_scale_tensor(dtype_min, tensor) + max_bound = _make_scale_tensor(dtype_max, tensor) + q_tensor = jnp.clip(tensor / scale_tensor, min_bound, max_bound).astype(dtype) + + # get scale for QArray + scale_shape = [1] * tensor.ndim + # It must stay fully replicated for the backward pass and Pallas. + scale_tensor_qpl = jnp.full(scale_shape, scale, dtype=tensor.dtype) + # wrap in QArray + return qpl.QArray(qvalue=q_tensor, scale=scale_tensor_qpl) + + class TransformerEngineQuantization(Quantization): """Class for TransformerEngine quantization recipes.""" diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 7ae5c46b19..0980b78599 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -451,7 +451,8 @@ def __call__( # That is also why we can split/merge activations here as well as # in `Decoder`, since they will never be executed together. if self.config.use_batch_split_schedule: - if self.config.use_qwix_quantization: + # The older version of batch-split that fully uses qwix quantization. + if self.config.use_qwix_quantization and not self.config.use_manual_quantization: activation_pspec = jax.sharding.PartitionSpec( ("data", "fsdp", "fsdp_transpose", "expert", "context"), None, @@ -490,7 +491,9 @@ def __call__( )(outputs) return outputs, None - # bf16 code path + # bf16 and fp8 code path for pure-JAX batch-split. + # fp8 code path supports both manual quantization and qwix + # quantization. input_sharding = jax.typeof(inputs).sharding activation_pspec = jax.sharding.PartitionSpec( ("data", "fsdp", "expert"), diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index e6c3bbd5ef..6b7b92310d 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -33,10 +33,11 @@ import jax import jax.numpy as jnp -from maxtext.kernels import attention, sort_activations - from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask +from maxtext.kernels import attention, sort_activations, megablox +from maxtext.layers import quantizations + def scheduling_group(group_id) -> contextlib.AbstractContextManager[None]: return jax.experimental.xla_metadata.set_xla_metadata(_scheduling_group_id=group_id) @@ -991,6 +992,7 @@ def batch_split_schedule( norm_mla_ws, positions, mesh=mesh, + config=cfg, splash_kernel=splash_kernel, normalization_layer_epsilon=cfg.normalization_layer_epsilon, kv_lora_rank=cfg.kv_lora_rank, @@ -1041,6 +1043,7 @@ def batch_split_schedule_bwd( norm_mla_ws, positions, mesh=mesh, + config=cfg, splash_kernel=splash_kernel, normalization_layer_epsilon=cfg.normalization_layer_epsilon, kv_lora_rank=cfg.kv_lora_rank, @@ -1099,6 +1102,7 @@ def mla_with_norms( yarn_freqs, *, mesh, + config, splash_kernel, normalization_layer_epsilon, kv_lora_rank, @@ -1141,6 +1145,7 @@ def fn(args): pairwise_swap_and_negate_mask=pairwise_swap_and_negate_mask, dtype=dtype, mscale=mscale, + config=config, splash_kernel=splash_kernel, mesh=mesh, activation_pspec=activation_pspec, @@ -1159,6 +1164,7 @@ def mla_with_norms_remat( yarn_freqs, *, mesh, + config, splash_kernel, normalization_layer_epsilon, kv_lora_rank, @@ -1205,6 +1211,7 @@ def remat_fn(args): pairwise_swap_and_negate_mask=pairwise_swap_and_negate_mask, dtype=dtype, mscale=mscale, + config=config, splash_kernel=splash_kernel, mesh=mesh, activation_pspec=activation_pspec, @@ -1257,6 +1264,7 @@ def mla( original_max_position_embeddings, rope_factor, mscale, + config, splash_kernel, pairwise_swap_and_negate_mask, dtype, @@ -1288,6 +1296,7 @@ def mla( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, mscale=mscale, + config=config, mesh=mesh, activation_pspec=activation_pspec, ) @@ -1303,6 +1312,7 @@ def mla( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, + config=config, mesh=mesh, activation_pspec=activation_pspec, ) @@ -1333,6 +1343,7 @@ def mla_remat( original_max_position_embeddings, rope_factor, mscale, + config, splash_kernel, pairwise_swap_and_negate_mask, dtype, @@ -1362,6 +1373,7 @@ def mla_remat( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, mscale=mscale, + config=config, mesh=mesh, activation_pspec=activation_pspec, ), @@ -1380,6 +1392,7 @@ def mla_remat( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, + config=config, mesh=mesh, activation_pspec=activation_pspec, ), @@ -1422,7 +1435,6 @@ def mla_bwd( """Performs the backward pass for the mla function.""" query_projection_bwd, kv_projection_bwd, attn_op_bwd, out_projection_bwd = bwds attn_out_grad, out_weights_grad = out_projection_bwd(out_grad) - # query_grad, key_grad, value_grad, _ = attention_op_bwd(attn_out_grad) query_grad, key_grad, value_grad = attn_op_bwd(attn_out_grad) inputs_grad_from_kv, _, wkv_a_weights_grad, wkv_b_weights_grad, kv_norm_scale_weights_grad = kv_projection_bwd( (key_grad, value_grad) @@ -1457,6 +1469,7 @@ def query_projection( pairwise_swap_and_negate_mask, dtype, mscale, + config, mesh, activation_pspec, ): @@ -1504,6 +1517,7 @@ def kv_projection( dtype, qk_nope_head_dim, num_query_heads, + config, mesh, activation_pspec, ): @@ -1644,6 +1658,7 @@ def shared_expert_and_route( routed_scaling_factor, expert_axis_name, use_gather_mosaic_kernel, + config, normalization_layer_epsilon, dtype, ): @@ -1820,35 +1835,83 @@ def unroute_impl_bwd(expert_axis_name, use_gather_mosaic_kernel, res, grad): unroute_impl.defvjp(unroute_impl_fwd, unroute_impl_bwd) -def compute_gating(x, w0, w1, group_sizes, *, dtype): +def gmm( + inputs, + kernel, + group_sizes, + preferred_element_type, + config, +): + """Performs a Grouped Matrix Multiplication (GMM). + + This function can use either a quantized Megablox kernel or a standard + jax.lax.ragged_dot for the GMM operation, based on the configuration. + + Args: + inputs: The left-hand side operand of the GMM. + kernel: The right-hand side operand (kernel) of the GMM. + group_sizes: An array indicating the size of each group. + preferred_element_type: The preferred element type for the computation. + config: Configuration object containing model settings, including + `use_qwix_quantization` and `merge_gating_gmm`. + + Returns: + The result of the grouped matrix multiplication. + """ + if config.quantization: + output = megablox.gmm( + lhs=inputs, + rhs=kernel, + group_sizes=group_sizes, + preferred_element_type=preferred_element_type, + use_qwix_quantization=True, + use_tokamax_backend=True, + qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[0], + use_manual_quantization=True, + ) + else: + output = jax.lax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + ) + return output + + +def compute_gating(x, w0, w1, group_sizes, *, dtype, config): """Computes the gating GMMs.""" - layer_w0 = jax.lax.ragged_dot( - x, - w0, + gmm_fn = functools.partial( + gmm, group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=dtype, - ) - layer_w1 = jax.lax.ragged_dot( - x, - w1, - group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, preferred_element_type=dtype, + config=config, ) + if config.merge_gating_gmm: + w01 = jnp.concatenate([w0, w1], axis=-1) + layer_w01 = gmm_fn(x, w01) + layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1) + else: + layer_w0 = gmm_fn(x, w0) + layer_w1 = gmm_fn(x, w1) + return layer_w0, layer_w1 -def compute_linear(layer_w0, layer_w1, wo, group_sizes, weights, *, dtype): +def compute_linear(layer_w0, layer_w1, wo, group_sizes, weights, *, dtype, config): """Combines the outputs of the gating GMMs and computes the final GMM.""" intermediate_layer = jax.nn.silu(layer_w0) * layer_w1 intermediate_layer *= weights[:, None] - layer_wo = jax.lax.ragged_dot( - intermediate_layer, - wo, + gmm_fn = functools.partial( + gmm, group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, preferred_element_type=dtype, + config=config, + ) + layer_wo = gmm_fn( + intermediate_layer, + wo, ) return layer_wo @@ -1864,6 +1927,7 @@ def route_compute_unroute( use_gather_mosaic_kernel, normalization_layer_epsilon, dtype, + config, ): """Routes, processes, and unroutes activations.""" target_length = xs[0].shape[1] @@ -1888,6 +1952,7 @@ def route_fn(inputs): routed_scaling_factor=routed_scaling_factor, expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, + config=config, normalization_layer_epsilon=normalization_layer_epsilon, dtype=dtype, ) @@ -1907,6 +1972,7 @@ def compute_gating_fn(inputs): routed_w1, group_sizes, dtype=dtype, + config=config, ) return layer_w0, layer_w1 @@ -1919,6 +1985,7 @@ def compute_linear_fn(inputs): group_sizes, weights, dtype=dtype, + config=config, ) return x @@ -2095,6 +2162,7 @@ def route_compute_unroute_bwd( use_gather_mosaic_kernel, normalization_layer_epsilon, dtype, + config, ): """Performs the backward pass for route_compute_unroute.""" xs = residuals.pop("mla_out") @@ -2114,6 +2182,7 @@ def route_fn_remat(inputs): routed_scaling_factor=routed_scaling_factor, expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, + config=config, normalization_layer_epsilon=normalization_layer_epsilon, dtype=dtype, ), @@ -2176,6 +2245,7 @@ def compute_gating_fn_remat(inputs): functools.partial( compute_gating, dtype=dtype, + config=config, ), x, routed_w0, @@ -2190,6 +2260,7 @@ def compute_linear_fn_remat(inputs): functools.partial( compute_linear, dtype=dtype, + config=config, ), layer_w0, layer_w1, @@ -2267,6 +2338,7 @@ def moe( use_gather_mosaic_kernel=use_gather_mosaic_kernel, normalization_layer_epsilon=normalization_layer_epsilon, dtype=dtype, + config=config, ), mesh=mesh, in_specs=( @@ -2335,6 +2407,7 @@ def moe_bwd( use_gather_mosaic_kernel=use_gather_mosaic_kernel, normalization_layer_epsilon=normalization_layer_epsilon, dtype=dtype, + config=config, ), mesh=mesh, in_specs=( diff --git a/src/maxtext/models/deepseek_batchsplit_fp8.py b/src/maxtext/models/deepseek_batchsplit_fp8.py index 2d55536440..cef7c0646f 100644 --- a/src/maxtext/models/deepseek_batchsplit_fp8.py +++ b/src/maxtext/models/deepseek_batchsplit_fp8.py @@ -959,7 +959,7 @@ def gmm( use_qwix_quantization=config.use_qwix_quantization, use_tokamax_backend=config.use_tokamax_gmm, weight_gather_axes=weight_gather_axes, - qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config), + qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[0], ) else: output = tokamax.ragged_dot(