Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
aaf2d08
feat: configure leading dimensions explicitly
sanjibansg Mar 2, 2026
9aa4b88
fix: set stream while initializing handle
sanjibansg Mar 2, 2026
2498197
feat: print requested workspace
sanjibansg Mar 2, 2026
901c5ae
fix: provide layout order during adding
sanjibansg Mar 2, 2026
8a27ffe
fix: layout order configuration
sanjibansg Mar 3, 2026
54dce63
fix: remove look for requested workspace
sanjibansg Mar 3, 2026
c5128e6
feat: matmul method for without bias matrix multiplication condition
sanjibansg Apr 17, 2026
9282ba9
feat: option to pass gpu pointers directly
sanjibansg Apr 17, 2026
fc90737
fix: function signature for matmul method with void pointers
sanjibansg Apr 17, 2026
b591c4a
fix: template types for arguments to function with pointer signatures
sanjibansg Apr 17, 2026
19b5b3f
fix: use float signatures for pointer arguments for blascuda
sanjibansg Apr 17, 2026
c995464
fix: use explicit data type for signatures with pointer arguments
sanjibansg Apr 17, 2026
75f75a4
fix: non transpose axis for matmul method
sanjibansg Apr 17, 2026
ed303fb
fix (experimental): layout shape order
sanjibansg Apr 17, 2026
54a9401
feat: cuda cleanup and cpu blas api
sanjibansg May 4, 2026
066fdae
feat: matmul method on raw cuda pointers
sanjibansg May 4, 2026
1f1c4c7
fix: remove extra template parameter in cuda matmul method on raw poi…
sanjibansg May 4, 2026
a867939
fix: method signature for matmul on cuda raw pointers
sanjibansg May 4, 2026
ad41e99
fix: dimension order for checkAddLayout method in cuda
sanjibansg May 4, 2026
f86d299
fix: correct transpose values while adding layouts for cuda gemm methods
sanjibansg May 4, 2026
7b2133c
feat: gemm apis for alpaka views
sanjibansg May 11, 2026
6f2a9e1
feat: gemm functions for cublas using raw pointers
sanjibansg May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,7 @@
*.dwo

# build files
**/build/
**/build/

# vscode settings
.vscode/
8 changes: 0 additions & 8 deletions include/.vscode/settings.json

This file was deleted.

190 changes: 180 additions & 10 deletions include/sofieBLAS/backends/cpu/sofieBLAS_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#endif

#include <cassert>
#include <cmath>
#include <stdexcept>

class BlasCpu {
public:
Expand All @@ -36,18 +38,186 @@ class BlasCpu {
}
}

// C = alpha * op(A) * op(B) + beta * C (no bias, leading dims inferred)
template <typename T, typename TIdx>
inline void matmul(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &B,
float beta,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &C) {
int lda = (transa == 'N' || transa == 'n') ? static_cast<int>(m)
: static_cast<int>(k);
int ldb = (transb == 'N' || transb == 'n') ? static_cast<int>(k)
: static_cast<int>(n);
cblas_sgemm(CblasColMajor, charToTranspose(transa), charToTranspose(transb),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
alpha, alpaka::getPtrNative(A), lda, alpaka::getPtrNative(B),
ldb, beta, alpaka::getPtrNative(C), static_cast<int>(m));
}

template <typename T, typename TIdx>
inline void matmul(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &B,
float beta,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> &C) {
int lda = (transa == 'N' || transa == 'n') ? static_cast<int>(m)
: static_cast<int>(k);
int ldb = (transb == 'N' || transb == 'n') ? static_cast<int>(k)
: static_cast<int>(n);
cblas_sgemm(CblasColMajor, charToTranspose(transa), charToTranspose(transb),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
alpha, alpaka::getPtrNative(A), lda, alpaka::getPtrNative(B),
ldb, beta, alpaka::getPtrNative(C), static_cast<int>(m));
}

// C = alpha * op(A) * op(B) + beta * bias + bias_vec (bias_vec broadcast per
// row) Matches the cuBLASLt EPILOGUE_BIAS semantics: the bias buffer serves
// as both the beta-scaled accumulator and provides the per-row bias vector
// (first m elements).
template <typename T, typename TIdx>
inline void
gemm(char transa, char transb, unsigned int m, unsigned int n, unsigned int k,
float alpha, alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &B, float beta,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &bias,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &C) {
int lda = (transa == 'N' || transa == 'n') ? static_cast<int>(m)
: static_cast<int>(k);
int ldb = (transb == 'N' || transb == 'n') ? static_cast<int>(k)
: static_cast<int>(n);
// Step 1: C = alpha * op(A) * op(B) (beta=0 so C is fully overwritten)
cblas_sgemm(CblasColMajor, charToTranspose(transa), charToTranspose(transb),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
alpha, alpaka::getPtrNative(A), lda, alpaka::getPtrNative(B),
ldb, 0.0f, alpaka::getPtrNative(C), static_cast<int>(m));
// Step 2: C += beta * bias_matrix + bias_vec (per-row broadcast)
float *c = alpaka::getPtrNative(C);
const float *b = alpaka::getPtrNative(bias);
for (unsigned int j = 0; j < n; ++j)
for (unsigned int i = 0; i < m; ++i)
c[j * m + i] += beta * b[j * m + i] + b[i];
}

template <typename T, typename TIdx>
inline void
gemm(char transa, char transb, const unsigned int m, const unsigned int n,
const unsigned int k, const float alpha,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &A, const int lda,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &B, const int ldb,
const float beta, alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &C,
const int ldc) {
CBLAS_TRANSPOSE TransA = charToTranspose(transa);
CBLAS_TRANSPOSE TransB = charToTranspose(transb);
cblas_sgemm(CblasColMajor, TransA, TransB, m, n, k, alpha, A.data(), lda,
B.data(), ldb, beta, C.data(), ldc);
gemm(char transa, char transb, unsigned int m, unsigned int n, unsigned int k,
float alpha,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &B,
float beta,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> &bias,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> &C) {
int lda = (transa == 'N' || transa == 'n') ? static_cast<int>(m)
: static_cast<int>(k);
int ldb = (transb == 'N' || transb == 'n') ? static_cast<int>(k)
: static_cast<int>(n);
cblas_sgemm(CblasColMajor, charToTranspose(transa), charToTranspose(transb),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
alpha, alpaka::getPtrNative(A), lda, alpaka::getPtrNative(B),
ldb, 0.0f, alpaka::getPtrNative(C), static_cast<int>(m));
float *c = alpaka::getPtrNative(C);
const float *b = alpaka::getPtrNative(bias);
for (unsigned int j = 0; j < n; ++j)
for (unsigned int i = 0; i < m; ++i)
c[j * m + i] += beta * b[j * m + i] + b[i];
}

// C = relu(alpha * op(A) * op(B) + beta * bias + bias_vec)
template <typename T, typename TIdx>
inline void gemmrelu(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &B,
float beta,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &bias,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &C) {
gemm(transa, transb, m, n, k, alpha, A, B, beta, bias, C);
float *c = alpaka::getPtrNative(C);
for (unsigned int i = 0; i < m * n; ++i)
c[i] = c[i] > 0.0f ? c[i] : 0.0f;
}

template <typename T, typename TIdx>
inline void gemmrelu(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &B,
float beta,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> &bias,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> &C) {
gemm(transa, transb, m, n, k, alpha, A, B, beta, bias, C);
float *c = alpaka::getPtrNative(C);
for (unsigned int i = 0; i < m * n; ++i)
c[i] = c[i] > 0.0f ? c[i] : 0.0f;
}

// C = gelu(alpha * op(A) * op(B) + beta * bias + bias_vec)
// Uses the standard GELU: x * 0.5 * (1 + erf(x / sqrt(2)))
template <typename T, typename TIdx>
inline void gemmgelu(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> const &B,
float beta,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &bias,
alpaka::BufCpu<T, alpaka::DimInt<1u>, TIdx> &C) {
gemm(transa, transb, m, n, k, alpha, A, B, beta, bias, C);
float *c = alpaka::getPtrNative(C);
constexpr float kInvSqrt2 = 0.7071067811865476f;
for (unsigned int i = 0; i < m * n; ++i)
c[i] *= 0.5f * (1.0f + std::erff(c[i] * kInvSqrt2));
}

template <typename T, typename TIdx>
inline void gemmgelu(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &A,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> const &B,
float beta,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> &bias,
alpaka::ViewPlainPtr<alpaka::DevCpu, T, alpaka::DimInt<1u>, TIdx> &C) {
gemm(transa, transb, m, n, k, alpha, A, B, beta, bias, C);
float *c = alpaka::getPtrNative(C);
constexpr float kInvSqrt2 = 0.7071067811865476f;
for (unsigned int i = 0; i < m * n; ++i)
c[i] *= 0.5f * (1.0f + std::erff(c[i] * kInvSqrt2));
}

// Raw-pointer overloads: accept T const*/T* from any BufXxx or ViewPlainPtr via getPtrNative()
template <typename T>
inline void gemm(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha, T const *A, T const *B,
float beta, T *bias, T *C) {
int lda = (transa == 'N' || transa == 'n') ? static_cast<int>(m) : static_cast<int>(k);
int ldb = (transb == 'N' || transb == 'n') ? static_cast<int>(k) : static_cast<int>(n);
cblas_sgemm(CblasColMajor, charToTranspose(transa), charToTranspose(transb),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
alpha, A, lda, B, ldb, 0.0f, C, static_cast<int>(m));
for (unsigned int j = 0; j < n; ++j)
for (unsigned int i = 0; i < m; ++i)
C[j * m + i] += beta * bias[j * m + i] + bias[i];
}

template <typename T>
inline void gemmrelu(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha, T const *A, T const *B,
float beta, T *bias, T *C) {
gemm(transa, transb, m, n, k, alpha, A, B, beta, bias, C);
for (unsigned int i = 0; i < m * n; ++i)
C[i] = C[i] > 0.0f ? C[i] : 0.0f;
}

template <typename T>
inline void gemmgelu(char transa, char transb, unsigned int m, unsigned int n,
unsigned int k, float alpha, T const *A, T const *B,
float beta, T *bias, T *C) {
gemm(transa, transb, m, n, k, alpha, A, B, beta, bias, C);
constexpr float kInvSqrt2 = 0.7071067811865476f;
for (unsigned int i = 0; i < m * n; ++i)
C[i] *= 0.5f * (1.0f + std::erff(C[i] * kInvSqrt2));
}
};

Expand Down
Loading