Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/apis.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
:orphan:

=========
NeMo APIs
=========
Expand Down Expand Up @@ -43,4 +45,3 @@ Alternatively, you can jump straight to the documentation for the individual col
* :doc:`Audio Processing <../audio/intro>`

* :doc:`SpeechLM2 <../speechlm2/intro>`

2 changes: 2 additions & 0 deletions docs/source/collections.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
:orphan:

================
NeMo Collections
================
Expand Down
38 changes: 38 additions & 0 deletions docs/source/speechlm2/training_and_scaling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,44 @@ For distributed inference, launch with ``torchrun``:
inputs=path/to/manifest \
ep_size=2

FP8 Training (SALMAutomodel)
""""""""""""""""""""""""""""

``SALMAutomodel`` supports two FP8 modes through NeMo Automodel. Configure only
one mode at a time.

For Hopper MoE backbones, prefer Transformer Engine FP8 because it applies to
TE Linear / GroupedLinear kernels used by the MoE path:

.. code-block:: yaml

model:
automodel_backend:
linear: te
experts: te
dispatcher: deepep
te_fp8:
recipe: block # or "current"

TorchAO FP8 is a separate dense-linear path. It should be paired with
``torch.compile`` for speedup and requires Hopper-class GPUs unless
``emulate=true`` is set for testing:

.. code-block:: yaml

model:
compile:
enabled: true
dynamic: true
fp8:
enabled: true
recipe_name: tensorwise
enable_fsdp_float8_all_gather: true
precompute_float8_dynamic_scale_for_fsdp: true
force_recompute_fp8_weight_in_bwd: true
filter_fqns: ["lm_head"]
emulate: false

Packed Sequences (THD)
""""""""""""""""""""""

Expand Down
2 changes: 2 additions & 0 deletions docs/source/tools/intro.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
:orphan:

Speech AI Tools
===============

Expand Down
20 changes: 17 additions & 3 deletions examples/speechlm2/conf/salm_automodel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,39 @@ model:
# backend: null # Compilation backend (null = inductor)
# dynamo_cache_size_limit: 256 # Triton compilation cache limit

# TorchAO FP8 training. This is separate from Transformer Engine FP8 below;
# configure only one FP8 mode at a time. TorchAO FP8 should be paired with
# torch.compile for speedup and requires Hopper-class GPUs unless emulate=true.
# fp8:
# enabled: true
# recipe_name: tensorwise # "tensorwise" | "rowwise" | "rowwise_with_gw_hp"
# enable_fsdp_float8_all_gather: true
# precompute_float8_dynamic_scale_for_fsdp: true
# force_recompute_fp8_weight_in_bwd: true
# filter_fqns: ["lm_head"]
# emulate: false

# Automodel backend dispatch. Selects the kernel/backend for each major module
# in the LLM (attention, linear, rms_norm, MoE experts/dispatcher). Defaults
# come from Automodel's BackendConfig and auto-select TE/DeepEP when available;
# override here to pin a specific backend (e.g. attn=sdpa to bypass TE).
# For Hopper MoE FP8 training, prefer Transformer Engine FP8 by using TE
# linear/expert backends and setting te_fp8. Do not combine with model.fp8.
# automodel_backend:
# attn: te # "te" | "sdpa" | "flex"
# linear: te # "torch" | "te"
# rms_norm: torch_fp32 # "torch" | "torch_fp32" | "te"
# rope_fusion: true # Fused RoPE (requires TE)
# experts: torch_mm # MoE expert GEMM: "torch" | "te" | "gmm" | "torch_mm"
# experts: te # MoE expert GEMM: "torch" | "te" | "gmm" | "torch_mm"
# dispatcher: deepep # MoE token dispatcher: "torch" | "deepep" | "hybridep" | "uccl_ep"
# dispatcher_num_sms: 20 # SM count for DeepEP/UCCL-EP kernels
# fake_balanced_gate: false # Replace learned Gate with balanced fake gate (debug/bench)
# fake_gate_noise: 0.0 # [0, 1] — noise for FakeBalancedGate routing
# enable_hf_state_dict_adapter: true
# enable_fsdp_optimizations: false
# gate_precision: null # e.g. "float32" to force fp32 gate compute
# te_fp8: null # {recipe: "current"} or {recipe: "block"} to enable TE FP8
# # (requires linear=te or experts=te)
# te_fp8:
# recipe: block # "current" or "block"; requires linear=te or experts=te

# Pin the SDPA kernel list used when automodel_backend.attn=sdpa. Accepts
# strings from: "flash_attention", "efficient_attention", "math", "cudnn_attention".
Expand Down
62 changes: 51 additions & 11 deletions nemo/collections/speechlm2/models/salm_automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
from nemo.collections.speechlm2.models.salm import _resolve_audios_in_prompt, replace_placeholders_and_build_targets
from nemo.collections.speechlm2.parts.automodel_lora import ensure_lora_trainable, make_peft_config, maybe_install_lora
from nemo.collections.speechlm2.parts.encoder_chunking import encode_audio_with_optional_chunking
from nemo.collections.speechlm2.parts.fp8 import (
make_fp8_config,
maybe_apply_te_patches,
maybe_pad_bshd_inputs_for_te_fp8,
maybe_precompute_float8_dynamic_scale_for_fsdp,
te_fp8_context,
trim_fp8_padded_logits,
validate_fp8_config,
)
from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin
from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen
from nemo.collections.speechlm2.parts.pretrained import (
Expand Down Expand Up @@ -166,19 +175,32 @@ def forward(
# (the THD shape mirrors Automodel's _shard_thd_chunk_for_te output —
# the model squeezes 3D inputs internally when qkv_format=="thd", so
# passing 2D directly skips that hop)
out = self.llm(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
past_key_values=cache,
use_cache=cache is not None,
return_dict=True,
**llm_kwargs,
)
automodel_backend_config = self.cfg.get("automodel_backend", None)
te_fp8_config = (automodel_backend_config or {}).get("te_fp8", None)
original_seq_len = input_embeds.shape[1] if input_embeds.dim() == 3 else input_embeds.shape[0]
if cache is None and llm_kwargs.get("qkv_format", None) != "thd":
tp_size = self.device_mesh["tp"].size() if self._use_tp else 1
input_embeds, attention_mask, llm_kwargs, original_seq_len = maybe_pad_bshd_inputs_for_te_fp8(
te_fp8_config,
input_embeds,
attention_mask,
llm_kwargs,
tp_size=tp_size,
)
with te_fp8_context(automodel_backend_config):
out = self.llm(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
past_key_values=cache,
use_cache=cache is not None,
return_dict=True,
**llm_kwargs,
)
if not isinstance(out, dict):
# NeMo Automodel doesn't respect return_dict=True yet
ans = {"logits": out}
ans = {"logits": trim_fp8_padded_logits(out, original_seq_len)}
else:
ans = {"logits": out['logits']} # (B, T, text_vocab_size)
ans = {"logits": trim_fp8_padded_logits(out['logits'], original_seq_len)} # (B, T, text_vocab_size)
if cache is not None:
ans["cache"] = out["past_key_values"]
return ans
Expand Down Expand Up @@ -218,6 +240,7 @@ def prepare_inputs(self, batch: dict):
if self.cfg.get("packed_sequences", False):
from nemo.collections.speechlm2.parts.packed_sequences import prepare_packed_llm_inputs

automodel_backend_config = self.cfg.get("automodel_backend", None)
return prepare_packed_llm_inputs(
input_ids=batch["input_ids"],
text_embs=text_embs,
Expand All @@ -226,6 +249,7 @@ def prepare_inputs(self, batch: dict):
padding_id=self.text_pad_id,
placeholder_id=self.audio_locator_tag_id,
device_mesh=getattr(self, "_device_mesh", None),
te_fp8_config=(automodel_backend_config or {}).get("te_fp8", None),
)

input_embs, target_ids, attention_mask = replace_placeholders_and_build_targets(
Expand Down Expand Up @@ -443,6 +467,14 @@ def backward(self, *args, **kwargs):
with loss_parallel():
super().backward(*args, **kwargs)

def on_before_zero_grad(self, optimizer) -> None:
maybe_precompute_float8_dynamic_scale_for_fsdp(
self.cfg,
self.llm,
getattr(self, "_device_mesh", None),
self._use_fsdp,
)

def _setup_moe_fsdp_sync(self):
"""Configure MoE FSDP gradient sync for gradient accumulation.

Expand Down Expand Up @@ -751,6 +783,10 @@ def configure_model(
activation_checkpointing_llm: bool | None = None,
activation_checkpointing_perception: bool | None = None,
) -> None:
validate_fp8_config(self.cfg)
automodel_backend_config = self.cfg.get("automodel_backend", None)
maybe_apply_te_patches(automodel_backend_config)

# Use provided device_mesh, or fall back to LightningModule property
if device_mesh is not None:
self._device_mesh = device_mesh
Expand Down Expand Up @@ -828,9 +864,13 @@ def configure_model(
compile_dict = dict(compile_cfg)
automodel_kwargs["compile_config"] = CompileConfig(**compile_dict)

fp8_config = make_fp8_config(self.cfg)
if fp8_config is not None:
automodel_kwargs["fp8_config"] = fp8_config

# Pass backend through to automodel — lets YAML pick attn/linear/rms_norm/MoE
# dispatcher backends (e.g. set attn=sdpa to bypass TransformerEngine).
backend_cfg = self.cfg.get("automodel_backend", None)
backend_cfg = automodel_backend_config
if backend_cfg is not None:
from nemo_automodel.components.models.common import BackendConfig

Expand Down
Loading
Loading