In FP-Quant/inference_lib/src/fp_quant/module/linear_fns.py, the forward_quantize function receives a global_scale argument. However, fused_quantize_mx_op does not actually apply this global_scale during quantization.
def forward_quantize(
x: torch.Tensor,
hadamard_matrix: torch.Tensor,
global_scale: torch.Tensor,
dtype: FPQuantDtype,
forward_method: str,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if dtype == FPQuantDtype.MXFP4:
qweight, scales, mask = fused_quantize_mx_op(
x.to(torch.bfloat16),
hadamard_matrix.to(torch.bfloat16),
forward_method,
forward_method == "quest" and x.requires_grad,
)
return qweight, scales, mask
Later, in forward_gemm function, the same global_scale is used as the alpha scaling factor for the GEMM operation.
def forward_gemm(x_q, w_q, x_scales, w_scales, alpha, dtype: FPQuantDtype):
if dtype == FPQuantDtype.MXFP4:
if False and x_q.shape[0] <= 64: # TODO: remove when ada alpha is fixed
return matmul_ada_mxf4_bf16_tn_op(
x_q, w_q, x_scales, w_scales, alpha.float()
)
else:
return matmul_mxf4_bf16_tn_op(x_q, w_q, x_scales, w_scales, alpha.float())
Shouldn't the global_scale be applied during quantization to ensure equivalence with the GEMM's alpha usage? Or is this intentional behavior?
Looking forward to your clarification :)