diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index f10945763ee0..e1a350ac31cd 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -190,17 +190,21 @@ def step(self, closure=None, div_scale: float = -1): ) self._post_update(p, "exp_avg", "exp_avg_sq") elif target_device.type == "cuda": - assert div_scale == -1, "div_scale should remain default" assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] + # scale gradient if div_scale is provided + grad = p.grad.data + if div_scale != -1: + grad = grad / div_scale + # adam on cuda self.torch_adam_update( p.data, - p.grad.data, + grad, state["exp_avg"], state["exp_avg_sq"], group["lr"],