From f6d4fe703b0148c0ffe45eea7f9290be2ce32257 Mon Sep 17 00:00:00 2001 From: Murray Steele Date: Thu, 26 Mar 2026 12:10:52 +0000 Subject: [PATCH] Fix incorrect cast from BF16 to FP32 in SBGEMM This change fixes a regression in SBGEMM where C is assumed to be BF16, and so unconditionally casts the output to FP32 resulting in incorrect outputs when beta=1. --- kernel/generic/gemmkernel_2x2.c | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index 07da2cbc87..94dcaea5f3 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -30,6 +30,12 @@ #include "conversion_macros.h" +#ifdef BGEMM +#define C_TO_F32 TO_F32 +#else +#define C_TO_F32 +#endif + int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL ,BLASLONG offset @@ -108,13 +114,13 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1); + C0[1] = TO_OUTPUT(C_TO_F32(C0[1])+res1); res2 = res2*ALPHA; - C1[0] = TO_OUTPUT(TO_F32(C1[0])+res2); + C1[0] = TO_OUTPUT(C_TO_F32(C1[0])+res2); res3 = res3*ALPHA; - C1[1] = TO_OUTPUT(TO_F32(C1[1])+res3); + C1[1] = TO_OUTPUT(C_TO_F32(C1[1])+res3); C0 = C0+2; C1 = C1+2; } @@ -134,9 +140,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C1[0] = TO_OUTPUT(TO_F32(C1[0])+res1); + C1[0] = TO_OUTPUT(C_TO_F32(C1[0])+res1); C0 = C0+1; C1 = C1+1; } @@ -165,9 +171,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1); + C0[1] = TO_OUTPUT(C_TO_F32(C0[1])+res1); C0 = C0+2; } for (i=0; i<(bm&1); i+=1) @@ -183,7 +189,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); C0 = C0+1; } k = (bk<<0);