diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e23da5cccd4d..0fbc5a850144 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -797,37 +797,32 @@ def forward( ctx.w_fp8_t = w_fp8.t() ctx.inv_scale_x = inv_scale_x ctx.inv_scale_w = inv_scale_w - out = torch._scaled_mm( - x_fp8, - ctx.w_fp8_t, - bias=bias, - out_dtype=ctx.out_dtype, - scale_a=inv_scale_x, - scale_b=inv_scale_w, - use_fast_accum=True, - )[0] + + # Dequantize and compute matrix multiplication (compatible with TorchDynamo) + x_deq = x_fp8.to(ctx.out_dtype) * inv_scale_x + w_t_deq = ctx.w_fp8_t.to(ctx.out_dtype) * inv_scale_w + + out = x_deq @ w_t_deq + if bias is not None: + out = out + bias.to(ctx.out_dtype) + + out = out.to(ctx.out_dtype) return out.reshape(*ctx.x_shape[:-1], w.shape[0]) @staticmethod def backward(ctx: Any, out_grad) -> Any: out_grad = out_grad.reshape(-1, out_grad.shape[-1]) out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2") - x_grad = torch._scaled_mm( - out_grad_fp8, - ctx.w_fp8_t.contiguous().t(), - out_dtype=ctx.out_dtype, - scale_a=out_grad_scale, - scale_b=ctx.inv_scale_w, - use_fast_accum=True, - )[0] - w_grad = torch._scaled_mm( - out_grad_fp8.t().contiguous(), - ctx.x_fp8.t().contiguous().t(), - out_dtype=ctx.out_dtype, - scale_a=out_grad_scale, - scale_b=ctx.inv_scale_x, - use_fast_accum=True, - )[0] + + # Dequantize (force contiguous after cast) + out_grad_deq = (out_grad_fp8.to(ctx.out_dtype) * out_grad_scale).contiguous() + w_t_deq = (ctx.w_fp8_t.to(ctx.out_dtype) * ctx.inv_scale_w).contiguous() + x_deq = (ctx.x_fp8.to(ctx.out_dtype) * ctx.inv_scale_x).contiguous() + + # Compute gradients + x_grad = out_grad_deq @ w_t_deq.t() + w_grad = out_grad_deq.t() @ x_deq + bias_grad = None if ctx.has_bias: bias_grad = out_grad.sum(0)