From 969885ebf7d4c9475c878355ab6bef0212664fdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C5=A9=20Ho=C3=A0ng=20Nh=E1=BA=ADt=20Tr=C6=B0=E1=BB=9Dng?= Date: Thu, 2 Apr 2026 06:36:49 +0000 Subject: [PATCH] Truong Handle div_scale in CPUAdam Optimizer --- colossalai/nn/optimizer/cpu_adam.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index f10945763ee0..e1a350ac31cd 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -190,17 +190,21 @@ def step(self, closure=None, div_scale: float = -1): ) self._post_update(p, "exp_avg", "exp_avg_sq") elif target_device.type == "cuda": - assert div_scale == -1, "div_scale should remain default" assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] + # scale gradient if div_scale is provided + grad = p.grad.data + if div_scale != -1: + grad = grad / div_scale + # adam on cuda self.torch_adam_update( p.data, - p.grad.data, + grad, state["exp_avg"], state["exp_avg_sq"], group["lr"],