diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index 07da2cbc87..94dcaea5f3 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -30,6 +30,12 @@ #include "conversion_macros.h" +#ifdef BGEMM +#define C_TO_F32 TO_F32 +#else +#define C_TO_F32 +#endif + int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL ,BLASLONG offset @@ -108,13 +114,13 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1); + C0[1] = TO_OUTPUT(C_TO_F32(C0[1])+res1); res2 = res2*ALPHA; - C1[0] = TO_OUTPUT(TO_F32(C1[0])+res2); + C1[0] = TO_OUTPUT(C_TO_F32(C1[0])+res2); res3 = res3*ALPHA; - C1[1] = TO_OUTPUT(TO_F32(C1[1])+res3); + C1[1] = TO_OUTPUT(C_TO_F32(C1[1])+res3); C0 = C0+2; C1 = C1+2; } @@ -134,9 +140,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C1[0] = TO_OUTPUT(TO_F32(C1[0])+res1); + C1[0] = TO_OUTPUT(C_TO_F32(C1[0])+res1); C0 = C0+1; C1 = C1+1; } @@ -165,9 +171,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1); + C0[1] = TO_OUTPUT(C_TO_F32(C0[1])+res1); C0 = C0+2; } for (i=0; i<(bm&1); i+=1) @@ -183,7 +189,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*ALPHA; - C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); + C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0); C0 = C0+1; } k = (bk<<0);