Add FP8 support for SALMAutomodel#15754
Conversation
|
/ok to test 614e4b4 |
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
6c20005 to
23761ed
Compare
|
/ok to test 23761ed |
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
|
/ok to test 329f23a |
KunalDhawan
left a comment
There was a problem hiding this comment.
Great work @pzelasko, added minor comments below, other than that LGTM
| """Return the minimal sequence-length multiple so B*T is divisible by 8.""" | ||
| if batch_size <= 0: | ||
| raise ValueError(f"batch_size must be positive; got {batch_size}.") | ||
| return 8 // gcd(batch_size, 8) |
There was a problem hiding this comment.
This only ensures B*T % 8 == 0 and ignores tp_size, which is inconsistent with the THD helper (8 * cp_size * tp_size). Under BSHD + TP + TE-FP8 I think that breaks in two ways:
prepare_inputstruncates the seq dim to a multiple oftp_sizeso sequence parallelism doesn't silently reshape the input (salm_automodel.py ~L269), but thenforwardappendspad = (-T) % seq_multipletokens, so the padded length is no longer guaranteed divisible bytp_size→ SP shape break.- With SP the local TE Linear sees
M = B*T/tp_size, so FP8 actually needsB*T % (8*tp_size) == 0, not just% 8.
Could we either thread tp_size through here (note 8*tp_size alone isn't enough — e.g. B=16, tp=4 → multiple of 2, still not divisible by 4 — so probably needs an explicit lcm(tp_size, ...)), or add a validate_fp8_config rejection for BSHD + TP + TE-FP8 pointing folks at the THD packed path? A BSHD analogue of test_maybe_pad_thd_..._accounts_for_cp_and_tp would lock it down. This combo wasn't in the 2-GPU run (dp=2 ep=2, no TP), so it's currently untested.
| def backward(self, *args, **kwargs): | ||
| self._setup_moe_fsdp_sync() | ||
| with loss_parallel(): | ||
| with loss_parallel(), te_fp8_context(self.cfg.get("automodel_backend", None)): |
There was a problem hiding this comment.
Quick question on wrapping backward in te_fp8_context too: standard TE usage only wraps the forward, and the backward consumes the FP8 metadata captured during the forward's fp8_autocast. Re-entering fp8_autocast here can, for history/delayed-scaling recipes, trigger an extra amax/scale update (and a second amax all-reduce) at context exit. Probably harmless for block/current (which is what the run used), but it's an easy source of subtle scale drift on other recipes. Is it deliberate / needed for something specific? If not, dropping it from backward seems safer.
|
/ok to test 17e203f |
|
Have you gotten any speedups from this? I tried FP8 with Fastconformer and it was slower than BF16 |
|
I'm still debugging a few issues before this is ready, I'll report when I have the numbers. The expected speedup from TransformerEngine's FP8 is about 2x on Hopper, but you have to make sure the training is compute bound to get that speedup. If your matmuls are on very small problem sizes, the overhead of scaling etc can be greater than the speedup of an already very tiny kernel. Two easiest ways to get the speedup are using larger models or larger batch sizes. |
Summary
Testing