Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1158,8 +1158,8 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: False
pure_nnx_decoder: False
enable_nnx: True
pure_nnx_decoder: True
pure_nnx: False

################################## Qwen3-Next Specific Configs ##################################
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
out_sharding = metadata["sharding"]

if out_sharding is not None:
if nnx.PARTITION_NAME in metadata:
partition_name = metadata[nnx.PARTITION_NAME]
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0

sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
if partition_name not in sharding_list:
sharding_list.insert(scan_axis, partition_name)

out_sharding = tuple(sharding_list)

return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
variable.value,
out_sharding, # type: ignore[arg-type]
Expand Down
572 changes: 385 additions & 187 deletions src/maxtext/layers/nnx_decoders.py

Large diffs are not rendered by default.

93 changes: 67 additions & 26 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,37 @@ def dot_general(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)

from qwix._src import flax_util
module = flax_util.get_current_module()
from flax import nnx
from src.maxtext.layers.nnx_wrappers import ToLinen, ToNNX
if isinstance(module, nnx.Module) and not isinstance(module, ToLinen):
if not hasattr(module, op_id):
op = ToNNX(nn.Fp8DirectDotGeneralOp(name=op_id))
op.lazy_init(*args, **kwargs)
setattr(module, op_id, op)
return getattr(module, op_id)(*args, **kwargs)
else:
return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)

def einsum(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("einsum")
if rule is None:
return jnp.einsum(*args, **kwargs)
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)

from qwix._src import flax_util
module = flax_util.get_current_module()
from flax import nnx
from src.maxtext.layers.nnx_wrappers import ToLinen, ToNNX
if isinstance(module, nnx.Module) and not isinstance(module, ToLinen):
if not hasattr(module, op_id):
op = ToNNX(nn.Fp8Einsum(name=op_id))
op.lazy_init(*args, **kwargs)
setattr(module, op_id, op)
return getattr(module, op_id)(*args, **kwargs)
else:
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)


class NANOOFp8Provider(qwix.QtProvider):
Expand All @@ -728,31 +752,37 @@ def dot_general(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)

from qwix._src import flax_util
module = flax_util.get_current_module()
from flax import nnx
from src.maxtext.layers.nnx_wrappers import ToLinen, ToNNX
if isinstance(module, nnx.Module) and not isinstance(module, ToLinen):
if not hasattr(module, op_id):
op = ToNNX(nn.NANOOFp8DotGeneralOp(name=op_id))
op.lazy_init(*args, **kwargs)
setattr(module, op_id, op)
return getattr(module, op_id)(*args, **kwargs)
else:
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)

def get_fp8_full_qwix_rule_w_sparsity(config: Config):
sparsity_rule = None
if config.weight_sparsity_n and config.weight_sparsity_m:
sparsity_rule = sparsity.SparsityRule(
weight_sparsity_n=config.weight_sparsity_n,
weight_sparsity_m=config.weight_sparsity_m,
weight_sparsity_update_step=config.weight_sparsity_update_step,
weight_sparsity_start_step=config.weight_sparsity_start_step,
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
additional_qt_config={"sparsity_rule": sparsity_rule},
op_names=("dot_general", "gmm", "ragged_dot"),
),
]
def einsum(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("einsum")
if rule is None:
return jnp.einsum(*args, **kwargs)
# NANOOFp8 doesn't have an Einsum op, so we fall back to Fp8Einsum
from qwix._src import flax_util
module = flax_util.get_current_module()
from flax import nnx
from src.maxtext.layers.nnx_wrappers import ToLinen, ToNNX
if isinstance(module, nnx.Module) and not isinstance(module, ToLinen):
if not hasattr(module, op_id):
op = ToNNX(nn.Fp8Einsum(name=op_id))
op.lazy_init(*args, **kwargs)
setattr(module, op_id, op)
return getattr(module, op_id)(*args, **kwargs)
else:
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)


def get_quantization_rule(config: Config):
Expand Down Expand Up @@ -847,7 +877,18 @@ def maybe_quantize_model(model, config):
if config.use_qwix_quantization and not config.use_batch_split_schedule:
quantization_provider = get_qt_provider(config)
if quantization_provider:
model = qwix.quantize_model(model, quantization_provider)
from flax import nnx
from src.maxtext.layers.nnx_wrappers import ToLinen
if isinstance(model, nnx.Module) and not isinstance(model, ToLinen):
import jax.numpy as jnp
batch_size = config.global_batch_size_to_train_on
seq_len = config.max_target_length
ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
decoder_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
decoder_positions = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32), (batch_size, 1))
model = qwix.quantize_model(model, quantization_provider, ids, decoder_positions, decoder_segment_ids, enable_dropout=False)
else:
model = qwix.quantize_model(model, quantization_provider)
return model


Expand Down
17 changes: 2 additions & 15 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from maxtext.layers.decoders import Decoder
from maxtext.layers.embeddings import Embed, embed_as_linen
from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen
from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen
from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.multimodal import processor as mm_processor
from maxtext.utils import max_utils
Expand Down Expand Up @@ -386,25 +386,12 @@ def __init__(
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
mtp_layer = layer_types[-1]
mtp_block_linen = multi_token_prediction_block_as_linen(
self.mtp_block = MultiTokenPredictionBlock(
config=self.config,
mesh=self.mesh,
transformer_layer_module=mtp_layer,
decoder=self.decoder,
rngs=rngs,
name="mtp_block",
)
self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs)

self.mtp_block.lazy_init(
shared_embedding=self.token_embedder,
main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype),
input_ids=jnp.ones((1, 1), dtype=jnp.int32),
target_ids=jnp.ones((1, 1), dtype=jnp.int32),
target_mask=jnp.ones((1, 1), dtype=jnp.int32),
position_ids=jnp.ones((1, 1), dtype=jnp.int32),
decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32),
deterministic=True,
)

def no_op(self, *args, **kwargs):
Expand Down
Loading
Loading