Skip to content

Global_scale is passed but not applied in MXFP quantization, yet used in GEMM, could this cause numerical inequivalence? #21

@manyizhang

Description

@manyizhang

In FP-Quant/inference_lib/src/fp_quant/module/linear_fns.py, the forward_quantize function receives a global_scale argument. However, fused_quantize_mx_op does not actually apply this global_scale during quantization.

def forward_quantize(
    x: torch.Tensor,
    hadamard_matrix: torch.Tensor,
    global_scale: torch.Tensor,
    dtype: FPQuantDtype,
    forward_method: str,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if dtype == FPQuantDtype.MXFP4:
        qweight, scales, mask = fused_quantize_mx_op(
            x.to(torch.bfloat16),
            hadamard_matrix.to(torch.bfloat16),
            forward_method,
            forward_method == "quest" and x.requires_grad,
        )
        return qweight, scales, mask

Later, in forward_gemm function, the same global_scale is used as the alpha scaling factor for the GEMM operation.

def forward_gemm(x_q, w_q, x_scales, w_scales, alpha, dtype: FPQuantDtype):
    if dtype == FPQuantDtype.MXFP4:
        if False and x_q.shape[0] <= 64:  # TODO: remove when ada alpha is fixed
            return matmul_ada_mxf4_bf16_tn_op(
                x_q, w_q, x_scales, w_scales, alpha.float()
            )
        else:
            return matmul_mxf4_bf16_tn_op(x_q, w_q, x_scales, w_scales, alpha.float())

Shouldn't the global_scale be applied during quantization to ensure equivalence with the GEMM's alpha usage? Or is this intentional behavior?
Looking forward to your clarification :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions