From 969885ebf7d4c9475c878355ab6bef0212664fdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C5=A9=20Ho=C3=A0ng=20Nh=E1=BA=ADt=20Tr=C6=B0=E1=BB=9Dng?= Date: Thu, 2 Apr 2026 06:36:49 +0000 Subject: [PATCH 1/3] Truong Handle div_scale in CPUAdam Optimizer --- colossalai/nn/optimizer/cpu_adam.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index f10945763ee0..e1a350ac31cd 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -190,17 +190,21 @@ def step(self, closure=None, div_scale: float = -1): ) self._post_update(p, "exp_avg", "exp_avg_sq") elif target_device.type == "cuda": - assert div_scale == -1, "div_scale should remain default" assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] + # scale gradient if div_scale is provided + grad = p.grad.data + if div_scale != -1: + grad = grad / div_scale + # adam on cuda self.torch_adam_update( p.data, - p.grad.data, + grad, state["exp_avg"], state["exp_avg_sq"], group["lr"], From c033826c7f600636cbec1dbd4b443ca6696387a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C5=A9=20Ho=C3=A0ng=20Nh=E1=BA=ADt=20Tr=C6=B0=E1=BB=9Dng?= Date: Wed, 15 Apr 2026 06:37:26 +0000 Subject: [PATCH 2/3] fix: Handle TorchDynamo incompatible with torch._scaled_mm --- colossalai/quantization/fp8.py | 45 +++++++++++++++------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e23da5cccd4d..17834a120528 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -797,37 +797,32 @@ def forward( ctx.w_fp8_t = w_fp8.t() ctx.inv_scale_x = inv_scale_x ctx.inv_scale_w = inv_scale_w - out = torch._scaled_mm( - x_fp8, - ctx.w_fp8_t, - bias=bias, - out_dtype=ctx.out_dtype, - scale_a=inv_scale_x, - scale_b=inv_scale_w, - use_fast_accum=True, - )[0] + + # Dequantize and compute matrix multiplication (compatible with TorchDynamo) + x_deq = x_fp8.to(ctx.out_dtype) * inv_scale_x + w_t_deq = ctx.w_fp8_t.to(ctx.out_dtype) * inv_scale_w + + out = x_deq @ w_t_deq + if bias is not None: + out = out + bias.to(ctx.out_dtype) + + out = out.to(ctx.out_dtype) return out.reshape(*ctx.x_shape[:-1], w.shape[0]) @staticmethod def backward(ctx: Any, out_grad) -> Any: out_grad = out_grad.reshape(-1, out_grad.shape[-1]) out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2") - x_grad = torch._scaled_mm( - out_grad_fp8, - ctx.w_fp8_t.contiguous().t(), - out_dtype=ctx.out_dtype, - scale_a=out_grad_scale, - scale_b=ctx.inv_scale_w, - use_fast_accum=True, - )[0] - w_grad = torch._scaled_mm( - out_grad_fp8.t().contiguous(), - ctx.x_fp8.t().contiguous().t(), - out_dtype=ctx.out_dtype, - scale_a=out_grad_scale, - scale_b=ctx.inv_scale_x, - use_fast_accum=True, - )[0] + + # Dequantize (force contiguous after cast) + out_grad_deq = (out_grad_fp8.to(ctx.out_dtype) * out_grad_scale).contiguous() + w_t_deq = (ctx.w_fp8_t.to(ctx.out_dtype) * ctx.inv_scale_w).contiguous() + x_deq = (ctx.x_fp8.to(ctx.out_dtype) * ctx.inv_scale_x).contiguous() + + # Compute gradients + x_grad = out_grad_deq @ w_t_deq.t() + w_grad = out_grad_deq.t() @ x_deq + bias_grad = None if ctx.has_bias: bias_grad = out_grad.sum(0) From 506f800c312d0528a3e5c47075b086507856481e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 06:38:32 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 17834a120528..0fbc5a850144 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -797,7 +797,7 @@ def forward( ctx.w_fp8_t = w_fp8.t() ctx.inv_scale_x = inv_scale_x ctx.inv_scale_w = inv_scale_w - + # Dequantize and compute matrix multiplication (compatible with TorchDynamo) x_deq = x_fp8.to(ctx.out_dtype) * inv_scale_x w_t_deq = ctx.w_fp8_t.to(ctx.out_dtype) * inv_scale_w