diff --git a/docs/source/apis.rst b/docs/source/apis.rst index 5bee094ac894..46662d8ef3f4 100644 --- a/docs/source/apis.rst +++ b/docs/source/apis.rst @@ -1,3 +1,5 @@ +:orphan: + ========= NeMo APIs ========= @@ -43,4 +45,3 @@ Alternatively, you can jump straight to the documentation for the individual col * :doc:`Audio Processing <../audio/intro>` * :doc:`SpeechLM2 <../speechlm2/intro>` - diff --git a/docs/source/collections.rst b/docs/source/collections.rst index 304478fdabc0..fae87e693809 100644 --- a/docs/source/collections.rst +++ b/docs/source/collections.rst @@ -1,3 +1,5 @@ +:orphan: + ================ NeMo Collections ================ diff --git a/docs/source/speechlm2/training_and_scaling.rst b/docs/source/speechlm2/training_and_scaling.rst index 4e213319538a..1e2bdb6a7607 100644 --- a/docs/source/speechlm2/training_and_scaling.rst +++ b/docs/source/speechlm2/training_and_scaling.rst @@ -183,6 +183,44 @@ For distributed inference, launch with ``torchrun``: inputs=path/to/manifest \ ep_size=2 +FP8 Training (SALMAutomodel) +"""""""""""""""""""""""""""" + +``SALMAutomodel`` supports two FP8 modes through NeMo Automodel. Configure only +one mode at a time. + +For Hopper MoE backbones, prefer Transformer Engine FP8 because it applies to +TE Linear / GroupedLinear kernels used by the MoE path: + +.. code-block:: yaml + + model: + automodel_backend: + linear: te + experts: te + dispatcher: deepep + te_fp8: + recipe: block # or "current" + +TorchAO FP8 is a separate dense-linear path. It should be paired with +``torch.compile`` for speedup and requires Hopper-class GPUs unless +``emulate=true`` is set for testing: + +.. code-block:: yaml + + model: + compile: + enabled: true + dynamic: true + fp8: + enabled: true + recipe_name: tensorwise + enable_fsdp_float8_all_gather: true + precompute_float8_dynamic_scale_for_fsdp: true + force_recompute_fp8_weight_in_bwd: true + filter_fqns: ["lm_head"] + emulate: false + Packed Sequences (THD) """""""""""""""""""""" diff --git a/docs/source/tools/intro.rst b/docs/source/tools/intro.rst index 5a08d05f3405..3a1c8eb376da 100644 --- a/docs/source/tools/intro.rst +++ b/docs/source/tools/intro.rst @@ -1,3 +1,5 @@ +:orphan: + Speech AI Tools =============== diff --git a/examples/speechlm2/conf/salm_automodel.yaml b/examples/speechlm2/conf/salm_automodel.yaml index 0d08c1e82113..0f5f30a6a0e0 100644 --- a/examples/speechlm2/conf/salm_automodel.yaml +++ b/examples/speechlm2/conf/salm_automodel.yaml @@ -73,16 +73,30 @@ model: # backend: null # Compilation backend (null = inductor) # dynamo_cache_size_limit: 256 # Triton compilation cache limit + # TorchAO FP8 training. This is separate from Transformer Engine FP8 below; + # configure only one FP8 mode at a time. TorchAO FP8 should be paired with + # torch.compile for speedup and requires Hopper-class GPUs unless emulate=true. + # fp8: + # enabled: true + # recipe_name: tensorwise # "tensorwise" | "rowwise" | "rowwise_with_gw_hp" + # enable_fsdp_float8_all_gather: true + # precompute_float8_dynamic_scale_for_fsdp: true + # force_recompute_fp8_weight_in_bwd: true + # filter_fqns: ["lm_head"] + # emulate: false + # Automodel backend dispatch. Selects the kernel/backend for each major module # in the LLM (attention, linear, rms_norm, MoE experts/dispatcher). Defaults # come from Automodel's BackendConfig and auto-select TE/DeepEP when available; # override here to pin a specific backend (e.g. attn=sdpa to bypass TE). + # For Hopper MoE FP8 training, prefer Transformer Engine FP8 by using TE + # linear/expert backends and setting te_fp8. Do not combine with model.fp8. # automodel_backend: # attn: te # "te" | "sdpa" | "flex" # linear: te # "torch" | "te" # rms_norm: torch_fp32 # "torch" | "torch_fp32" | "te" # rope_fusion: true # Fused RoPE (requires TE) - # experts: torch_mm # MoE expert GEMM: "torch" | "te" | "gmm" | "torch_mm" + # experts: te # MoE expert GEMM: "torch" | "te" | "gmm" | "torch_mm" # dispatcher: deepep # MoE token dispatcher: "torch" | "deepep" | "hybridep" | "uccl_ep" # dispatcher_num_sms: 20 # SM count for DeepEP/UCCL-EP kernels # fake_balanced_gate: false # Replace learned Gate with balanced fake gate (debug/bench) @@ -90,8 +104,8 @@ model: # enable_hf_state_dict_adapter: true # enable_fsdp_optimizations: false # gate_precision: null # e.g. "float32" to force fp32 gate compute - # te_fp8: null # {recipe: "current"} or {recipe: "block"} to enable TE FP8 - # # (requires linear=te or experts=te) + # te_fp8: + # recipe: block # "current" or "block"; requires linear=te or experts=te # Pin the SDPA kernel list used when automodel_backend.attn=sdpa. Accepts # strings from: "flash_attention", "efficient_attention", "math", "cudnn_attention". diff --git a/nemo/collections/speechlm2/models/salm_automodel.py b/nemo/collections/speechlm2/models/salm_automodel.py index 516134849946..3dde3c8def6d 100644 --- a/nemo/collections/speechlm2/models/salm_automodel.py +++ b/nemo/collections/speechlm2/models/salm_automodel.py @@ -31,6 +31,15 @@ from nemo.collections.speechlm2.models.salm import _resolve_audios_in_prompt, replace_placeholders_and_build_targets from nemo.collections.speechlm2.parts.automodel_lora import ensure_lora_trainable, make_peft_config, maybe_install_lora from nemo.collections.speechlm2.parts.encoder_chunking import encode_audio_with_optional_chunking +from nemo.collections.speechlm2.parts.fp8 import ( + make_fp8_config, + maybe_apply_te_patches, + maybe_pad_bshd_inputs_for_te_fp8, + maybe_precompute_float8_dynamic_scale_for_fsdp, + te_fp8_context, + trim_fp8_padded_logits, + validate_fp8_config, +) from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen from nemo.collections.speechlm2.parts.pretrained import ( @@ -166,19 +175,32 @@ def forward( # (the THD shape mirrors Automodel's _shard_thd_chunk_for_te output — # the model squeezes 3D inputs internally when qkv_format=="thd", so # passing 2D directly skips that hop) - out = self.llm( - inputs_embeds=input_embeds, - attention_mask=attention_mask, - past_key_values=cache, - use_cache=cache is not None, - return_dict=True, - **llm_kwargs, - ) + automodel_backend_config = self.cfg.get("automodel_backend", None) + te_fp8_config = (automodel_backend_config or {}).get("te_fp8", None) + original_seq_len = input_embeds.shape[1] if input_embeds.dim() == 3 else input_embeds.shape[0] + if cache is None and llm_kwargs.get("qkv_format", None) != "thd": + tp_size = self.device_mesh["tp"].size() if self._use_tp else 1 + input_embeds, attention_mask, llm_kwargs, original_seq_len = maybe_pad_bshd_inputs_for_te_fp8( + te_fp8_config, + input_embeds, + attention_mask, + llm_kwargs, + tp_size=tp_size, + ) + with te_fp8_context(automodel_backend_config): + out = self.llm( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + past_key_values=cache, + use_cache=cache is not None, + return_dict=True, + **llm_kwargs, + ) if not isinstance(out, dict): # NeMo Automodel doesn't respect return_dict=True yet - ans = {"logits": out} + ans = {"logits": trim_fp8_padded_logits(out, original_seq_len)} else: - ans = {"logits": out['logits']} # (B, T, text_vocab_size) + ans = {"logits": trim_fp8_padded_logits(out['logits'], original_seq_len)} # (B, T, text_vocab_size) if cache is not None: ans["cache"] = out["past_key_values"] return ans @@ -218,6 +240,7 @@ def prepare_inputs(self, batch: dict): if self.cfg.get("packed_sequences", False): from nemo.collections.speechlm2.parts.packed_sequences import prepare_packed_llm_inputs + automodel_backend_config = self.cfg.get("automodel_backend", None) return prepare_packed_llm_inputs( input_ids=batch["input_ids"], text_embs=text_embs, @@ -226,6 +249,7 @@ def prepare_inputs(self, batch: dict): padding_id=self.text_pad_id, placeholder_id=self.audio_locator_tag_id, device_mesh=getattr(self, "_device_mesh", None), + te_fp8_config=(automodel_backend_config or {}).get("te_fp8", None), ) input_embs, target_ids, attention_mask = replace_placeholders_and_build_targets( @@ -443,6 +467,14 @@ def backward(self, *args, **kwargs): with loss_parallel(): super().backward(*args, **kwargs) + def on_before_zero_grad(self, optimizer) -> None: + maybe_precompute_float8_dynamic_scale_for_fsdp( + self.cfg, + self.llm, + getattr(self, "_device_mesh", None), + self._use_fsdp, + ) + def _setup_moe_fsdp_sync(self): """Configure MoE FSDP gradient sync for gradient accumulation. @@ -751,6 +783,10 @@ def configure_model( activation_checkpointing_llm: bool | None = None, activation_checkpointing_perception: bool | None = None, ) -> None: + validate_fp8_config(self.cfg) + automodel_backend_config = self.cfg.get("automodel_backend", None) + maybe_apply_te_patches(automodel_backend_config) + # Use provided device_mesh, or fall back to LightningModule property if device_mesh is not None: self._device_mesh = device_mesh @@ -828,9 +864,13 @@ def configure_model( compile_dict = dict(compile_cfg) automodel_kwargs["compile_config"] = CompileConfig(**compile_dict) + fp8_config = make_fp8_config(self.cfg) + if fp8_config is not None: + automodel_kwargs["fp8_config"] = fp8_config + # Pass backend through to automodel — lets YAML pick attn/linear/rms_norm/MoE # dispatcher backends (e.g. set attn=sdpa to bypass TransformerEngine). - backend_cfg = self.cfg.get("automodel_backend", None) + backend_cfg = automodel_backend_config if backend_cfg is not None: from nemo_automodel.components.models.common import BackendConfig diff --git a/nemo/collections/speechlm2/parts/fp8.py b/nemo/collections/speechlm2/parts/fp8.py new file mode 100644 index 000000000000..cadd1f18229a --- /dev/null +++ b/nemo/collections/speechlm2/parts/fp8.py @@ -0,0 +1,285 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Mapping +from contextlib import nullcontext +from math import gcd, lcm +from typing import Any + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf + + +def get_config_value(cfg: Any, path: str, default: Any = None) -> Any: + """Read a dotted config path from ``dict``/OmegaConf/object configs.""" + if cfg is None: + return default + if isinstance(cfg, (DictConfig, ListConfig)): + return OmegaConf.select(cfg, path, default=default) + + current = cfg + for key in path.split("."): + if current is None: + return default + if isinstance(current, Mapping): + if key not in current: + return default + current = current[key] + else: + current = getattr(current, key, default) + if current is default: + return default + return current + + +def as_plain_container(cfg: Any) -> Any: + """Convert OmegaConf containers to plain Python containers.""" + if isinstance(cfg, (DictConfig, ListConfig)): + return OmegaConf.to_container(cfg, resolve=True) + return cfg + + +def has_torchao_fp8(cfg: Any) -> bool: + """Return whether TorchAO FP8 training is enabled in the model config.""" + return bool(get_config_value(cfg, "fp8.enabled", False)) + + +def has_te_fp8(automodel_backend_config: Any) -> bool: + """Return whether Transformer Engine FP8 is configured in an Automodel backend config.""" + return get_config_value(automodel_backend_config, "te_fp8", None) is not None + + +def is_te_fp8_enabled(te_fp8_config: Any) -> bool: + """Return whether a direct ``automodel_backend.te_fp8`` config is present.""" + return te_fp8_config is not None + + +def validate_fp8_config(cfg: Any) -> None: + """Validate model FP8 config combinations.""" + if has_torchao_fp8(cfg) and has_te_fp8(get_config_value(cfg, "automodel_backend", None)): + raise ValueError( + "only one FP8 mode may be configured at a time. Configure either " + "fp8 for TorchAO FP8 or automodel_backend.te_fp8 for " + "Transformer Engine FP8, but not both." + ) + + +def maybe_apply_te_patches(automodel_backend_config: Any) -> None: + """Apply Automodel's Transformer Engine runtime patches when TE FP8 is configured.""" + if not has_te_fp8(automodel_backend_config): + return + + from nemo_automodel.shared.te_patches import apply_te_patches + + apply_te_patches() + + +def make_fp8_config(cfg: Any) -> Any: + """Build Automodel's TorchAO FP8Config from SALMAutomodel config, or return ``None``.""" + if not has_torchao_fp8(cfg): + return None + + from nemo_automodel.components.quantization.fp8 import build_fp8_config + + return build_fp8_config(as_plain_container(get_config_value(cfg, "fp8", None))) + + +def te_fp8_context(automodel_backend_config: Any): + """Return a Transformer Engine FP8 autocast context for an Automodel backend config.""" + te_fp8_config = get_config_value(automodel_backend_config, "te_fp8", None) + if te_fp8_config is None: + return nullcontext() + + te_fp8_config = as_plain_container(te_fp8_config) + if hasattr(te_fp8_config, "maybe_te_autocast"): + return te_fp8_config.maybe_te_autocast() + + from nemo_automodel.components.models.common.utils import TEFp8Config + + if isinstance(te_fp8_config, Mapping): + te_fp8_kwargs = dict(te_fp8_config) + te_fp8_kwargs.pop("_target_", None) + return TEFp8Config(**te_fp8_kwargs).maybe_te_autocast() + if isinstance(te_fp8_config, str): + return TEFp8Config(recipe=te_fp8_config).maybe_te_autocast() + + raise TypeError( + "automodel_backend.te_fp8 must be null, a mapping, a recipe string, " + "or a TEFp8Config-like object with maybe_te_autocast()." + ) + + +def validate_te_fp8_hidden_size(te_fp8_config: Any, hidden_size: int) -> None: + """Validate TE FP8's GEMM alignment requirement for activation hidden size.""" + if is_te_fp8_enabled(te_fp8_config) and hidden_size % 16 != 0: + raise ValueError( + "Transformer Engine FP8 requires input hidden size to be divisible by 16; " + f"got hidden_size={hidden_size}." + ) + + +def get_te_fp8_bshd_sequence_multiple(batch_size: int, tp_size: int = 1) -> int: + """Return the minimal BSHD sequence multiple for local TE FP8 Linear inputs.""" + if batch_size <= 0: + raise ValueError(f"batch_size must be positive; got {batch_size}.") + if tp_size <= 0: + raise ValueError(f"tp_size must be positive; got {tp_size}.") + + fp8_multiple = (8 * tp_size) // gcd(batch_size, 8 * tp_size) + return lcm(tp_size, fp8_multiple) + + +def maybe_pad_bshd_inputs_for_te_fp8( + te_fp8_config: Any, + input_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + llm_kwargs: Mapping[str, Any] | None = None, + *, + tp_size: int = 1, +) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, Any], int]: + """Pad BSHD LLM inputs for TE FP8 and return the original sequence length. + + TE FP8 Linear requires the product of all input dimensions except the last + to be divisible by 8 and the last dimension to be divisible by 16. With + BSHD sequence parallelism, local TE Linear inputs see ``B * T / TP`` rows, + so padding must keep ``T`` divisible by ``TP`` and ``B * T / TP`` divisible + by 8. Padding is appended on the sequence dimension and can be trimmed from + logits after the LLM. + """ + llm_kwargs = dict(llm_kwargs or {}) + if input_embeds.dim() != 3: + return input_embeds, attention_mask, llm_kwargs, input_embeds.shape[0] + original_seq_len = input_embeds.shape[1] + if not is_te_fp8_enabled(te_fp8_config): + return input_embeds, attention_mask, llm_kwargs, original_seq_len + + batch_size, seq_len, hidden_size = input_embeds.shape + validate_te_fp8_hidden_size(te_fp8_config, hidden_size) + + seq_multiple = get_te_fp8_bshd_sequence_multiple(batch_size, tp_size=tp_size) + pad = (-seq_len) % seq_multiple + if pad == 0: + return input_embeds, attention_mask, llm_kwargs, original_seq_len + + pad_embeds = torch.zeros( + batch_size, + pad, + hidden_size, + dtype=input_embeds.dtype, + device=input_embeds.device, + ) + input_embeds = torch.cat([input_embeds, pad_embeds], dim=1) + + if attention_mask is not None: + # These are appended causal dummy tokens, not data-loader padding. + # Mark them valid so their query rows are finite; real tokens cannot + # attend to future dummy tokens through the causal mask. + pad_mask = torch.ones( + batch_size, + pad, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, pad_mask], dim=1) + + for key in ("position_ids", "cache_position"): + value = llm_kwargs.get(key, None) + if isinstance(value, torch.Tensor): + llm_kwargs[key] = pad_sequence_tensor(value, seq_len, pad) + + return input_embeds, attention_mask, llm_kwargs, original_seq_len + + +def pad_sequence_tensor(tensor: torch.Tensor, seq_len: int, pad: int, pad_value: int = 0) -> torch.Tensor: + """Right-pad a tensor when one of its sequence dimensions matches ``seq_len``.""" + if pad <= 0: + return tensor + if tensor.dim() == 1 and tensor.shape[0] == seq_len: + padding = torch.full((pad,), pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat([tensor, padding], dim=0) + if tensor.dim() >= 2 and tensor.shape[1] == seq_len: + pad_shape = list(tensor.shape) + pad_shape[1] = pad + padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat([tensor, padding], dim=1) + return tensor + + +def trim_fp8_padded_logits(logits: torch.Tensor, original_seq_len: int) -> torch.Tensor: + """Trim sequence padding introduced by ``maybe_pad_bshd_inputs_for_te_fp8``.""" + if logits.dim() >= 3 and logits.shape[1] > original_seq_len: + return logits[:, :original_seq_len] + if logits.dim() == 2 and logits.shape[0] > original_seq_len: + return logits[:original_seq_len] + return logits + + +def maybe_pad_thd_padded_lengths_for_te_fp8( + te_fp8_config: Any, + padded_lens: list[int], + *, + cp_size: int = 1, + tp_size: int = 1, +) -> list[int]: + """Pad THD packed sequence lengths so local TE FP8 Linear inputs are aligned. + + This must run before context-parallel THD partitioning because ``cu_seqlens`` + is global metadata and CP partitioning derives local token indices from it. + """ + if not is_te_fp8_enabled(te_fp8_config): + return padded_lens + if not padded_lens: + return padded_lens + + cp_size = max(int(cp_size), 1) + tp_size = max(int(tp_size), 1) + total_multiple = 8 * cp_size * tp_size + total_len = sum(padded_lens) + pad = (-total_len) % total_multiple + if pad == 0: + return padded_lens + + padded_lens = list(padded_lens) + padded_lens[-1] += pad + if cp_size > 1: + cp_multiple = 2 * cp_size + if padded_lens[-1] % cp_multiple != 0: + raise AssertionError( + "Internal error: TE FP8 THD padding did not preserve context-parallel " f"alignment to {cp_multiple}." + ) + return padded_lens + + +def maybe_precompute_float8_dynamic_scale_for_fsdp(cfg: Any, llm: Any, device_mesh: Any, use_fsdp: bool) -> None: + """Run TorchAO's FSDP FP8 scale precompute hook when the config and mesh require it.""" + if not has_torchao_fp8(cfg): + return + if not bool(get_config_value(cfg, "fp8.precompute_float8_dynamic_scale_for_fsdp", False)): + return + if llm is None or not use_fsdp or device_mesh is None: + return + + mesh_dim_names = getattr(device_mesh, "mesh_dim_names", ()) or () + if "dp_shard" not in mesh_dim_names: + return + try: + dp_shard_size = device_mesh["dp_shard"].size() + except (AttributeError, KeyError, RuntimeError, TypeError, ValueError): + return + if dp_shard_size <= 1: + return + + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp + + precompute_float8_dynamic_scale_for_fsdp(llm) diff --git a/nemo/collections/speechlm2/parts/packed_sequences.py b/nemo/collections/speechlm2/parts/packed_sequences.py index c6ef01564697..9d553cca9fbe 100644 --- a/nemo/collections/speechlm2/parts/packed_sequences.py +++ b/nemo/collections/speechlm2/parts/packed_sequences.py @@ -30,6 +30,7 @@ import torch from torch import Tensor +from nemo.collections.speechlm2.parts.fp8 import maybe_pad_thd_padded_lengths_for_te_fp8, validate_te_fp8_hidden_size from nemo.collections.speechlm2.parts.input_utils import _unpad_inputs @@ -43,6 +44,7 @@ def pack_audio_into_text_embeds( cp_size: int = 1, tp_size: int = 1, ignore_index: int = -100, + te_fp8_config: Any = None, ) -> dict[str, Tensor]: """Splice audio frames into per-utterance text embeddings and pack into THD. @@ -70,6 +72,11 @@ def pack_audio_into_text_embeds( ``2 * cp_size`` per-utterance alignment. ignore_index: label fill for audio-frame slots, padding slots, and the last position of every utterance. + te_fp8_config: Optional ``automodel_backend.te_fp8`` config. When + Transformer Engine FP8 is enabled, the final THD + lengths are additionally padded before CP sharding so + every rank's local TE Linear input satisfies FP8 GEMM + alignment. Returns a dict with: @@ -88,6 +95,7 @@ def pack_audio_into_text_embeds( H = embeds.shape[-1] device = embeds.device dtype = embeds.dtype + validate_te_fp8_hidden_size(te_fp8_config, H) # Strip left-padding so per-utt sequences are tight before splicing. ids_unpad, embs_unpad, tgts_unpad = _unpad_inputs(input_ids, embeds, target_ids, padding_id) @@ -163,6 +171,13 @@ def pack_audio_into_text_embeds( if rem != 0: padded_lens[-1] += tp_size - rem + padded_lens = maybe_pad_thd_padded_lengths_for_te_fp8( + te_fp8_config, + padded_lens, + cp_size=cp_size, + tp_size=tp_size, + ) + # Materialize the flat THD batch. flat_emb_segs: list[Tensor] = [] flat_lab_segs: list[Tensor] = [] @@ -243,6 +258,7 @@ def prepare_packed_llm_inputs( padding_id: int, placeholder_id: int, device_mesh: Optional[Any] = None, + te_fp8_config: Any = None, ) -> dict[str, Any]: """Pack a SALM minibatch and (optionally) shard it across CP ranks. @@ -282,6 +298,7 @@ def prepare_packed_llm_inputs( placeholder_id=placeholder_id, cp_size=cp_size, tp_size=tp_size, + te_fp8_config=te_fp8_config, ) if cp_mesh is not None: diff --git a/pyproject.toml b/pyproject.toml index fe88e5b79d77..897c60f677aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -465,7 +465,7 @@ no-build-isolation-package = [ # indexes when resolving the speechlm2 extra (which pulls nemo_automodel # from git as a source dependency — uv treats these as workspace members). [tool.uv.sources] -nemo_automodel = { git = "https://github.com/NVIDIA-NeMo/Automodel.git", rev = "9eccbb6102a260efd7cbdffa890fc57b94f94528" } +nemo_automodel = { git = "https://github.com/NVIDIA-NeMo/Automodel.git", rev = "2f7c7e5ad39601c8cd20cd6b747950ebe8355b12" } torch = [ { index = "pytorch-cpu", marker = "sys_platform != 'linux' and sys_platform != 'darwin'" }, { index = "pytorch-cu132", marker = "sys_platform == 'linux'" }, diff --git a/tests/collections/speechlm2/test_fp8.py b/tests/collections/speechlm2/test_fp8.py new file mode 100644 index 000000000000..02edc73209ca --- /dev/null +++ b/tests/collections/speechlm2/test_fp8.py @@ -0,0 +1,308 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import types +from contextlib import contextmanager +from importlib import import_module + +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.speechlm2.parts import fp8 + + +def install_fake_module(monkeypatch, name, module): + parts = name.split(".") + for idx in range(1, len(parts) + 1): + full_name = ".".join(parts[:idx]) + if idx == len(parts): + current_module = module + else: + current_module = sys.modules.get(full_name) + if current_module is None: + current_module = types.ModuleType(full_name) + current_module.__path__ = [] + monkeypatch.setitem(sys.modules, full_name, current_module) + + for idx in range(1, len(parts)): + parent_name = ".".join(parts[:idx]) + child_name = parts[idx] + child_full_name = ".".join(parts[: idx + 1]) + monkeypatch.setattr(sys.modules[parent_name], child_name, sys.modules[child_full_name], raising=False) + + +def test_install_fake_module_preserves_real_package_paths(): + nemo_automodel = pytest.importorskip("nemo_automodel") + original_path = list(nemo_automodel.__path__) + + with pytest.MonkeyPatch.context() as monkeypatch: + fake_te_patches = types.ModuleType("nemo_automodel.shared.te_patches") + fake_te_patches.apply_te_patches = lambda: None + install_fake_module(monkeypatch, "nemo_automodel.shared.te_patches", fake_te_patches) + + assert list(nemo_automodel.__path__) == original_path + import_module("nemo_automodel.components.distributed.config") + + assert list(nemo_automodel.__path__) == original_path + import_module("nemo_automodel.components.distributed.config") + + +@contextmanager +def recording_context(events, name): + events.append(f"{name}:enter") + try: + yield + finally: + events.append(f"{name}:exit") + + +def test_fp8_config_detection_and_validation(): + assert fp8.has_torchao_fp8(DictConfig({"fp8": {"enabled": True}})) + assert not fp8.has_torchao_fp8(DictConfig({"fp8": {"enabled": False}})) + assert fp8.has_te_fp8(DictConfig({"te_fp8": {"recipe": "block"}})) + assert not fp8.has_te_fp8(DictConfig({"te_fp8": None})) + + with pytest.raises(ValueError, match="only one FP8 mode"): + fp8.validate_fp8_config( + DictConfig({"fp8": {"enabled": True}, "automodel_backend": {"te_fp8": {"recipe": "block"}}}) + ) + + +def test_maybe_apply_te_patches_only_when_te_fp8_is_configured(monkeypatch): + calls = [] + te_patches_module = types.ModuleType("nemo_automodel.shared.te_patches") + te_patches_module.apply_te_patches = lambda: calls.append("patched") + install_fake_module(monkeypatch, "nemo_automodel.shared.te_patches", te_patches_module) + + fp8.maybe_apply_te_patches(DictConfig({})) + assert calls == [] + + fp8.maybe_apply_te_patches(DictConfig({"te_fp8": {"recipe": "block"}})) + assert calls == ["patched"] + + +def test_make_fp8_config_builds_automodel_fp8_config(monkeypatch): + sentinel = object() + seen = {} + + fp8_module = types.ModuleType("nemo_automodel.components.quantization.fp8") + + def fake_build_fp8_config(cfg): + seen["cfg"] = cfg + return sentinel + + fp8_module.build_fp8_config = fake_build_fp8_config + install_fake_module(monkeypatch, "nemo_automodel.components.quantization.fp8", fp8_module) + + cfg = DictConfig( + { + "fp8": { + "enabled": True, + "recipe_name": "tensorwise", + "filter_fqns": ["lm_head"], + } + } + ) + + assert fp8.make_fp8_config(cfg) is sentinel + assert seen["cfg"] == { + "enabled": True, + "recipe_name": "tensorwise", + "filter_fqns": ["lm_head"], + } + assert fp8.make_fp8_config(DictConfig({})) is None + + +def test_te_fp8_context_builds_te_config_and_strips_target(monkeypatch): + events = [] + seen = [] + + common_utils_module = types.ModuleType("nemo_automodel.components.models.common.utils") + + class FakeTEFp8Config: + def __init__(self, **kwargs): + seen.append(kwargs) + + def maybe_te_autocast(self): + return recording_context(events, "te_fp8") + + common_utils_module.TEFp8Config = FakeTEFp8Config + install_fake_module(monkeypatch, "nemo_automodel.components.models.common.utils", common_utils_module) + + automodel_backend_config = DictConfig( + { + "te_fp8": { + "_target_": "nemo_automodel.components.models.common.utils.TEFp8Config", + "recipe": "block", + } + } + ) + + with fp8.te_fp8_context(automodel_backend_config): + events.append("body") + + assert seen == [{"recipe": "block"}] + assert events == ["te_fp8:enter", "body", "te_fp8:exit"] + + with fp8.te_fp8_context(DictConfig({})): + events.append("no_te") + assert events[-1] == "no_te" + + +def test_maybe_pad_bshd_inputs_for_te_fp8_noops_without_te_fp8(): + input_embeds = torch.ones(1, 5, 16) + attention_mask = torch.ones(1, 5, dtype=torch.bool) + + padded, padded_mask, llm_kwargs, original_seq_len = fp8.maybe_pad_bshd_inputs_for_te_fp8( + None, + input_embeds, + attention_mask, + ) + + assert padded is input_embeds + assert padded_mask is attention_mask + assert llm_kwargs == {} + assert original_seq_len == 5 + + +@pytest.mark.parametrize( + ("batch_size", "tp_size", "expected_multiple"), + [ + (1, 1, 8), + (2, 1, 4), + (16, 4, 4), + (1, 4, 32), + (2, 4, 16), + (8, 4, 4), + ], +) +def test_get_te_fp8_bshd_sequence_multiple_accounts_for_tp(batch_size, tp_size, expected_multiple): + multiple = fp8.get_te_fp8_bshd_sequence_multiple(batch_size, tp_size=tp_size) + + assert multiple == expected_multiple + assert multiple % tp_size == 0 + assert (batch_size * multiple // tp_size) % 8 == 0 + + +def test_maybe_pad_bshd_inputs_for_te_fp8_pads_sequence_tensors(): + input_embeds = torch.ones(2, 5, 16) + attention_mask = torch.ones(2, 5, dtype=torch.bool) + position_ids = torch.arange(5).expand(2, -1) + + padded, padded_mask, llm_kwargs, original_seq_len = fp8.maybe_pad_bshd_inputs_for_te_fp8( + DictConfig({"recipe": "block"}), + input_embeds, + attention_mask, + {"position_ids": position_ids}, + ) + + assert original_seq_len == 5 + assert padded.shape == (2, 8, 16) + assert padded_mask.shape == (2, 8) + assert llm_kwargs["position_ids"].shape == (2, 8) + assert torch.equal(padded[:, :5], input_embeds) + assert (padded[:, 5:] == 0).all() + assert padded_mask.all() + assert (llm_kwargs["position_ids"][:, 5:] == 0).all() + + +def test_maybe_pad_bshd_inputs_for_te_fp8_accounts_for_tp(): + input_embeds = torch.ones(16, 5, 16) + attention_mask = torch.ones(16, 5, dtype=torch.bool) + + padded, padded_mask, llm_kwargs, original_seq_len = fp8.maybe_pad_bshd_inputs_for_te_fp8( + DictConfig({"recipe": "block"}), + input_embeds, + attention_mask, + tp_size=4, + ) + + assert original_seq_len == 5 + assert padded.shape == (16, 8, 16) + assert padded.shape[1] % 4 == 0 + assert (padded.shape[0] * padded.shape[1] // 4) % 8 == 0 + assert padded_mask.shape == (16, 8) + assert llm_kwargs == {} + + +def test_te_fp8_hidden_size_validation(): + te_fp8_config = DictConfig({"recipe": "block"}) + + with pytest.raises(ValueError, match="hidden size"): + fp8.maybe_pad_bshd_inputs_for_te_fp8(te_fp8_config, torch.ones(1, 5, 15), None) + + +def test_maybe_pad_thd_padded_lengths_for_te_fp8_preserves_cp_alignment(): + te_fp8_config = DictConfig({"recipe": "block"}) + + padded_lens = fp8.maybe_pad_thd_padded_lengths_for_te_fp8(te_fp8_config, [8, 4], cp_size=2, tp_size=1) + + assert padded_lens == [8, 8] + assert sum(padded_lens) % (8 * 2) == 0 + assert all(length % 4 == 0 for length in padded_lens) + assert (sum(padded_lens) // 2) % 8 == 0 + + +def test_maybe_pad_thd_padded_lengths_for_te_fp8_accounts_for_cp_and_tp(): + te_fp8_config = DictConfig({"recipe": "block"}) + + padded_lens = fp8.maybe_pad_thd_padded_lengths_for_te_fp8(te_fp8_config, [8, 4], cp_size=2, tp_size=3) + + assert padded_lens == [8, 40] + assert sum(padded_lens) % (8 * 2 * 3) == 0 + assert all(length % 4 == 0 for length in padded_lens) + + +def test_maybe_precompute_float8_dynamic_scale_for_fsdp_guards(monkeypatch): + calls = [] + torchao_float8_module = types.ModuleType("torchao.float8") + torchao_float8_module.precompute_float8_dynamic_scale_for_fsdp = lambda llm: calls.append(llm) + install_fake_module(monkeypatch, "torchao.float8", torchao_float8_module) + + class MeshDim: + def __init__(self, size): + self.value = size + + def size(self): + return self.value + + class DeviceMesh: + mesh_dim_names = ("dp_shard",) + + def __init__(self, dp_shard_size): + self.dp_shard_size = dp_shard_size + + def __getitem__(self, name): + assert name == "dp_shard" + return MeshDim(self.dp_shard_size) + + cfg = DictConfig( + { + "fp8": { + "enabled": True, + "precompute_float8_dynamic_scale_for_fsdp": True, + } + } + ) + llm = object() + + fp8.maybe_precompute_float8_dynamic_scale_for_fsdp(DictConfig({}), llm, DeviceMesh(2), True) + fp8.maybe_precompute_float8_dynamic_scale_for_fsdp(cfg, llm, DeviceMesh(2), False) + fp8.maybe_precompute_float8_dynamic_scale_for_fsdp(cfg, llm, DeviceMesh(1), True) + assert calls == [] + + fp8.maybe_precompute_float8_dynamic_scale_for_fsdp(cfg, llm, DeviceMesh(2), True) + assert calls == [llm] diff --git a/tests/collections/speechlm2/test_salm_automodel.py b/tests/collections/speechlm2/test_salm_automodel.py index c7fab4a2a059..952b20c82af3 100644 --- a/tests/collections/speechlm2/test_salm_automodel.py +++ b/tests/collections/speechlm2/test_salm_automodel.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from contextlib import contextmanager import pytest import torch from lhotse import CutSet, SupervisionSegment from lhotse.testing.dummies import dummy_cut, dummy_recording +from omegaconf import DictConfig from transformers import GenerationConfig from nemo.collections.common.data.lhotse import NeMoMultimodalConversation @@ -37,6 +39,21 @@ torch.set_default_device('cuda') +def is_te_fp8_runtime_available(): + if not torch.cuda.is_available(): + return False + try: + import nemo_automodel # noqa: F401 + import transformer_engine.pytorch # noqa: F401 + from transformer_engine.common.recipe import Float8BlockScaling # noqa: F401 + from transformer_engine.pytorch.quantization import autocast # noqa: F401 + except Exception: + return False + + capability = torch.cuda.get_device_capability() + return capability[0] >= 9 + + def resolve_pretrained_models(): if os.path.exists("/home/TestData/speechlm/pretrained_models"): # CI pre-cached paths: @@ -57,17 +74,46 @@ def resolve_pretrained_models(): @pytest.fixture(scope="session") -def model(): +def tiny_nemotronh_model_dir(tmp_path_factory): + from transformers import NemotronHConfig + + model_dir = tmp_path_factory.mktemp("tiny_nemotronh") + config = NemotronHConfig( + vocab_size=200000, + hidden_size=256, + num_hidden_layers=2, + layers_block_type=["attention", "attention"], + num_attention_heads=8, + num_key_value_heads=2, + head_dim=32, + max_position_embeddings=2048, + intermediate_size=512, + mlp_hidden_act="relu2", + mlp_bias=False, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + use_mamba_kernels=False, + ) + config.architectures = ["NemotronHForCausalLM"] + config.save_pretrained(model_dir) + return str(model_dir) + + +@pytest.fixture(scope="session") +def model(tiny_nemotronh_model_dir): if not torch.cuda.is_available(): pytest.skip("SALMAutomodel requires CUDA") + pretrained_models = resolve_pretrained_models() cfg = { - **resolve_pretrained_models(), + **pretrained_models, + "pretrained_llm": tiny_nemotronh_model_dir, + "tokenizer_path": pretrained_models["pretrained_llm"], "pretrained_weights": False, "prompt_format": PROMPT, "audio_locator_tag": AUDIO_LOCATOR_TAG, "perception": { "target": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule", - "output_dim": 2048, + "output_dim": 256, "encoder": { "_target_": "nemo.collections.asr.modules.ConformerEncoder", "att_context_size": [-1, -1], @@ -114,6 +160,13 @@ def model(): "optimizer": {"_target_": "torch.optim.AdamW"}, "torch_dtype": "bfloat16", } + if is_te_fp8_runtime_available(): + cfg["automodel_backend"] = { + "attn": "sdpa", + "linear": "te", + "rms_norm": "torch_fp32", + "rope_fusion": False, + } model = SALMAutomodel(cfg) model.configure_model() model.to("cuda") @@ -218,6 +271,148 @@ def fake_log(name, value, **kwargs): assert not model._partial_val_num_frames +def skip_unless_te_fp8_runtime_available(): + if not torch.cuda.is_available(): + pytest.skip("TE FP8 SALMAutomodel runtime tests require CUDA") + try: + import nemo_automodel # noqa: F401 + import transformer_engine.pytorch # noqa: F401 + from transformer_engine.common.recipe import Float8BlockScaling # noqa: F401 + from transformer_engine.pytorch.quantization import autocast # noqa: F401 + except Exception as exc: + pytest.skip(f"Automodel/Transformer Engine FP8 runtime is unavailable: {exc}") + + capability = torch.cuda.get_device_capability() + if capability[0] < 9: + pytest.skip(f"TE FP8 runtime tests require Hopper or newer GPUs; got sm_{capability[0]}{capability[1]}") + + +def skip_unless_model_uses_te_modules(model): + has_te_modules = any(type(module).__module__.startswith("transformer_engine") for module in model.llm.modules()) + if not has_te_modules: + pytest.skip("The SALMAutomodel fixture LLM was not constructed with Transformer Engine modules") + + +@contextmanager +def capture_te_fp8_linear_forwards(model): + from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + from transformer_engine.pytorch.module.linear import Linear as TELinear + + fp8_enabled_during_forward = [] + + def record_fp8_state(module, args): + fp8_enabled_during_forward.append(FP8GlobalStateManager.is_fp8_enabled()) + + handles = [ + module.register_forward_pre_hook(record_fp8_state) + for module in model.llm.modules() + if isinstance(module, TELinear) + ] + try: + yield fp8_enabled_during_forward + finally: + for handle in handles: + handle.remove() + + +@contextmanager +def te_fp8_config_enabled(model): + had_backend = "automodel_backend" in model.cfg + previous_backend = model.cfg.get("automodel_backend", None) + backend = DictConfig(previous_backend if previous_backend is not None else {}) + backend.te_fp8 = {"recipe": "block"} + model.cfg.automodel_backend = backend + try: + yield + finally: + if had_backend: + model.cfg.automodel_backend = previous_backend + else: + del model.cfg["automodel_backend"] + + +@requires_cuda +def test_salm_automodel_te_fp8_training_step_backward(request): + skip_unless_te_fp8_runtime_available() + model = request.getfixturevalue("model") + skip_unless_model_uses_te_modules(model) + dataset = request.getfixturevalue("dataset") + prompt_formatter = request.getfixturevalue("prompt_formatter") + training_cutset_batch = request.getfixturevalue("training_cutset_batch") + + was_training = model.training + try: + model.train() + model.zero_grad(set_to_none=True) + training_cutset_batch = training_cutset_batch.map( + lambda c: c.apply_prompt_format(prompt_formatter), apply_fn=None + ) + batch = dataset[training_cutset_batch] + batch = move_data_to_device(batch, device=model.device) + + with te_fp8_config_enabled(model), capture_te_fp8_linear_forwards(model) as fp8_enabled_during_forward: + results = model.training_step(batch, batch_idx=0) + loss = results["loss"] + assert torch.isfinite(loss) + + model.backward(loss) + grads = [param.grad for param in model.parameters() if param.grad is not None] + assert grads + model.zero_grad(set_to_none=True) + assert any(fp8_enabled_during_forward) + finally: + model.train(was_training) + + +@requires_cuda +def test_salm_automodel_te_fp8_validation_step(request): + skip_unless_te_fp8_runtime_available() + model = request.getfixturevalue("model") + skip_unless_model_uses_te_modules(model) + dataset = request.getfixturevalue("dataset") + prompt_formatter = request.getfixturevalue("prompt_formatter") + training_cutset_batch = request.getfixturevalue("training_cutset_batch") + + was_training = model.training + try: + model.eval() + model.on_validation_epoch_start() + training_cutset_batch = training_cutset_batch.map( + lambda c: c.apply_prompt_format(prompt_formatter), apply_fn=None + ) + batch = dataset[training_cutset_batch] + batch = move_data_to_device(batch, device=model.device) + + with te_fp8_config_enabled(model), capture_te_fp8_linear_forwards(model) as fp8_enabled_during_forward: + with torch.no_grad(): + results = model.validation_step({"dummy_val_set": batch}, batch_idx=0) + + assert results is None + assert model._partial_val_loss_sums["dummy_val_set"] + assert torch.isfinite(model._partial_val_loss_sums["dummy_val_set"][0]) + assert any(fp8_enabled_during_forward) + finally: + model.train(was_training) + + +@requires_cuda +def test_salm_automodel_te_fp8_forward_pads_and_trims_unaligned_bshd(request): + skip_unless_te_fp8_runtime_available() + model = request.getfixturevalue("model") + skip_unless_model_uses_te_modules(model) + + hidden_size = model.llm.config.hidden_size + dtype = next(model.llm.parameters()).dtype + input_embeds = torch.randn(1, 31, hidden_size, device=model.device, dtype=dtype) + attention_mask = torch.ones(1, 31, device=model.device, dtype=torch.bool) + + with te_fp8_config_enabled(model), capture_te_fp8_linear_forwards(model) as fp8_enabled_during_forward: + outputs = model(input_embeds, attention_mask=attention_mask) + + assert outputs["logits"].shape[:2] == (1, 31) + assert any(fp8_enabled_during_forward) + + @requires_cuda def test_salm_automodel_generation(model): answer = model.generate( diff --git a/tests/collections/speechlm2/test_salm_packed_sequences.py b/tests/collections/speechlm2/test_salm_packed_sequences.py index 4a5743f9213e..8a05406ca604 100644 --- a/tests/collections/speechlm2/test_salm_packed_sequences.py +++ b/tests/collections/speechlm2/test_salm_packed_sequences.py @@ -16,12 +16,24 @@ import pytest import torch +from omegaconf import DictConfig from nemo.collections.speechlm2.parts.packed_sequences import pack_audio_into_text_embeds, prepare_packed_llm_inputs PAD = 0 AUDIO = 100 REPO_ROOT = Path(__file__).parents[3] +TE_FP8_CONFIG = DictConfig({"recipe": "block"}) + + +@pytest.fixture(autouse=True) +def use_cpu_default_device_for_unit_tests(): + previous_device = torch.get_default_device() + torch.set_default_device("cpu") + try: + yield + finally: + torch.set_default_device(previous_device) def test_packed_sequences_does_not_import_speechlm2_models_globally(): @@ -318,6 +330,103 @@ def test_cu_seqlens_matches_padded_cumsum(): assert out["max_seqlen"].item() == max(out["seq_lens_padded"].squeeze(-1).tolist()) +def test_te_fp8_thd_padding_updates_metadata_and_labels(): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + loss_mask = torch.ones_like(input_ids, dtype=torch.bool) + embeds = torch.ones(1, 5, 16) + target_ids = input_ids.where(loss_mask, -100) + + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=[], + padding_id=PAD, + placeholder_id=AUDIO, + te_fp8_config=TE_FP8_CONFIG, + ) + + assert out["seq_lens"].squeeze(-1).tolist() == [5] + assert out["seq_lens_padded"].squeeze(-1).tolist() == [8] + assert out["inputs_embeds"].shape == (8, 16) + assert out["labels"].shape == (8,) + assert out["labels"][-3:].tolist() == [-100, -100, -100] + assert out["position_ids"].tolist() == list(range(8)) + assert out["cu_seqlens"].tolist() == [0, 8] + assert out["max_seqlen"].item() == 8 + + +def test_te_fp8_thd_padding_preserves_cp_partition_alignment(): + input_ids = torch.tensor([[1, 2, 3, 4, 5], [PAD, PAD, 6, 7, 8]]) + loss_mask = input_ids != PAD + embeds = torch.ones(2, 5, 16) + target_ids = input_ids.where(loss_mask, -100) + + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=[], + padding_id=PAD, + placeholder_id=AUDIO, + cp_size=2, + te_fp8_config=TE_FP8_CONFIG, + ) + + padded = out["seq_lens_padded"].squeeze(-1).tolist() + assert out["seq_lens"].squeeze(-1).tolist() == [5, 3] + assert padded == [8, 8] + assert all(length % 4 == 0 for length in padded) + assert sum(padded) % (8 * 2) == 0 + assert (sum(padded) // 2) % 8 == 0 + assert out["cu_seqlens"].tolist() == [0, 8, 16] + assert out["max_seqlen"].item() == 8 + + +def test_te_fp8_thd_padding_accounts_for_cp_and_tp(): + input_ids = torch.tensor([[1, 2, 3, 4, 5], [PAD, PAD, 6, 7, 8]]) + loss_mask = input_ids != PAD + embeds = torch.ones(2, 5, 16) + target_ids = input_ids.where(loss_mask, -100) + + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=[], + padding_id=PAD, + placeholder_id=AUDIO, + cp_size=2, + tp_size=3, + te_fp8_config=TE_FP8_CONFIG, + ) + + padded = out["seq_lens_padded"].squeeze(-1).tolist() + assert padded == [8, 40] + assert all(length % 4 == 0 for length in padded) + assert sum(padded) % (8 * 2 * 3) == 0 + assert out["cu_seqlens"].tolist() == [0, 8, 48] + assert out["max_seqlen"].item() == 40 + + +def test_te_fp8_thd_padding_requires_hidden_dim_multiple_of_16(): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + loss_mask = torch.ones_like(input_ids, dtype=torch.bool) + embeds = torch.ones(1, 5, 15) + target_ids = input_ids.where(loss_mask, -100) + + with pytest.raises(ValueError, match="hidden size"): + pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=[], + padding_id=PAD, + placeholder_id=AUDIO, + te_fp8_config=TE_FP8_CONFIG, + ) + + def test_loss_mask_propagates_to_minus_100(): """Positions where loss_mask=False end up as -100 in the shifted labels.""" input_ids = torch.tensor([[1, 2, 3, 4]]) diff --git a/uv.lock b/uv.lock index c720b1ac8e22..dbb555ec3504 100644 --- a/uv.lock +++ b/uv.lock @@ -3817,8 +3817,8 @@ wheels = [ [[package]] name = "nemo-automodel" -version = "0.4.0+9eccbb61" -source = { git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=9eccbb6102a260efd7cbdffa890fc57b94f94528#9eccbb6102a260efd7cbdffa890fc57b94f94528" } +version = "0.4.0+2f7c7e5a" +source = { git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=2f7c7e5ad39601c8cd20cd6b747950ebe8355b12#2f7c7e5ad39601c8cd20cd6b747950ebe8355b12" } dependencies = [ { name = "datasets" }, { name = "flashoptim" }, @@ -4266,9 +4266,9 @@ requires-dist = [ { name = "matplotlib", marker = "extra == 'audio'" }, { name = "matplotlib", marker = "extra == 'speechlm2'" }, { name = "matplotlib", marker = "extra == 'tts'" }, - { name = "nemo-automodel", marker = "extra == 'all'", git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=9eccbb6102a260efd7cbdffa890fc57b94f94528" }, - { name = "nemo-automodel", marker = "extra == 'speechlm2'", git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=9eccbb6102a260efd7cbdffa890fc57b94f94528" }, - { name = "nemo-automodel", marker = "extra == 'speechlm2-only'", git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=9eccbb6102a260efd7cbdffa890fc57b94f94528" }, + { name = "nemo-automodel", marker = "extra == 'all'", git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=2f7c7e5ad39601c8cd20cd6b747950ebe8355b12" }, + { name = "nemo-automodel", marker = "extra == 'speechlm2'", git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=2f7c7e5ad39601c8cd20cd6b747950ebe8355b12" }, + { name = "nemo-automodel", marker = "extra == 'speechlm2-only'", git = "https://github.com/NVIDIA-NeMo/Automodel.git?rev=2f7c7e5ad39601c8cd20cd6b747950ebe8355b12" }, { name = "nemo-text-processing", marker = "'aarch' not in platform_machine and 'arm' not in platform_machine and sys_platform != 'darwin' and extra == 'all'" }, { name = "nemo-text-processing", marker = "'aarch' not in platform_machine and 'arm' not in platform_machine and sys_platform != 'darwin' and extra == 'speechlm2'" }, { name = "nemo-text-processing", marker = "'aarch' not in platform_machine and 'arm' not in platform_machine and sys_platform != 'darwin' and extra == 'tts'" },