diff --git a/docs/issues/native-macos-accelerate-simd.md b/docs/issues/native-macos-accelerate-simd.md new file mode 100644 index 00000000..b0317c92 --- /dev/null +++ b/docs/issues/native-macos-accelerate-simd.md @@ -0,0 +1,138 @@ +# Native macOS SIMD acceleration via Apple Accelerate framework + +## Problem + +The `skainet-backend-cpu` module on Kotlin/Native macOS (macosArm64) uses plain scalar loops +for all tensor operations (`DefaultCpuOps`). On JVM, the same module uses the JDK Vector API +for SIMD-accelerated matmul, elementwise ops, and reductions (`DefaultCpuOpsJvm`), which gives +a significant performance advantage. + +When running LLM inference benchmarks via the `llm-performance` native binary, the CPU backend +is 5-10x slower than it needs to be because every matmul is a triple-nested scalar loop +(`DefaultCpuOps.kt:264-272`). + +## Proposed solution + +Add an Accelerate-backed `TensorOps` implementation for the macOS native target, mirroring +how the JVM target has `DefaultCpuOpsJvm`. Apple's Accelerate framework provides +hardware-optimized BLAS and vector DSP routines that leverage ARM NEON and AMX under the hood. + +### Architecture + +``` +PlatformCpuOpsFactory + ├── jvmMain → DefaultCpuOpsJvm (Vector API + optional BLAS) ← exists + ├── nativeMain → DefaultCpuOps (scalar fallback) ← exists + ├── macosMain → AccelerateCpuOps (Accelerate framework via cinterop) ← NEW + └── linuxMain → DefaultCpuOps (scalar, or OpenBLAS in future) ← unchanged +``` + +### Key changes + +**1. Cinterop definition** — `src/nativeInterop/cinterop/accelerate.def` + +```def +package = platform.accelerate +language = C +headers = Accelerate/Accelerate.h +compilerOpts = -framework Accelerate +linkerOpts = -framework Accelerate +``` + +**2. New class** — `src/macosMain/kotlin/.../AccelerateCpuOps.kt` + +Extends `DefaultCpuOps` and overrides hot-path operations with Accelerate calls: + +| Priority | Operation | Accelerate function | Impact | +|----------|-----------|---------------------|--------| +| P0 | `matmul` | `cblas_sgemm` | Dominant cost in LLM inference (~90% of forward pass) | +| P1 | `add` | `vDSP_vadd` | Elementwise add (residual connections) | +| P1 | `multiply` | `vDSP_vmul` | Elementwise multiply (gates, scaling) | +| P1 | `subtract` | `vDSP_vsub` | Elementwise subtract | +| P1 | `divide` | `vDSP_vdiv` | Elementwise divide | +| P2 | `sum` (global) | `vDSP_sve` | Reduction for normalization | +| P2 | `mean` (global) | `vDSP_meanv` | Reduction for normalization | +| P2 | `softmax` | `vDSP_vse` + manual | Attention weights | +| P3 | `relu` | `vDSP_vthres` / `vDSP_vthr` | Activation function | +| P3 | `silu` | manual vectorized loop | Activation function (SiLU = x * sigmoid(x)) | +| P3 | `transpose` | `vDSP_mtrans` | Matrix transpose | + +**3. Platform factory** — update `PlatformCpuOpsFactory` for macOS + +```kotlin +// src/macosMain/kotlin/.../PlatformCpuOpsFactory.macos.kt +internal actual fun platformDefaultCpuOpsFactory(): (TensorDataFactory) -> TensorOps { + println("[SKaiNET] Using Accelerate-backed CPU operations (ARM NEON + AMX)") + return { factory -> AccelerateCpuOps(factory) } +} +``` + +This requires splitting the current `nativeMain` expect/actual into separate +`macosMain` and `linuxMain` actuals (the `macosMain` source set already exists in +`build.gradle.kts`). + +**4. Build changes** — `build.gradle.kts` + +Add cinterop configuration for macosArm64 (and optionally iosArm64/iosSimulatorArm64): + +```kotlin +macosArm64 { + compilations["main"].cinterops { + val accelerate by creating { + defFile("src/nativeInterop/cinterop/accelerate.def") + } + } +} +``` + +Add linker opts for the Accelerate framework to all macOS/iOS binaries. + +### Implementation notes + +- `AccelerateCpuOps` should extend `DefaultCpuOps` and override only the operations above. + Non-accelerated operations fall through to the scalar implementation. +- The `matmul` override should handle 2D FP32 tensors with `cblas_sgemm` and delegate + batched/non-float cases to `super.matmul()`. +- `vDSP_*` functions operate on contiguous `FloatArray` buffers. Tensors backed by + `FloatArrayTensorData` can be passed directly; others need a `toFloatArray()` copy. +- Broadcasting logic (e.g., bias add, scalar multiply) should remain in the Kotlin layer + and only dispatch the contiguous inner loop to Accelerate. +- The same approach works for iOS targets (`iosArm64`, `iosSimulatorArm64`) since + Accelerate is available on all Apple platforms. + +### Testing + +- Existing `DefaultCpuOps` tests in `commonTest` should pass unchanged (numerical equivalence). +- Add macOS-specific tests verifying Accelerate dispatch actually occurs (e.g., check log output + or add a query method). +- Benchmark comparison: run `llm-performance` native benchmark with the current scalar backend + vs Accelerate backend on the same model. + +### Expected impact + +Based on JVM BLAS vs scalar measurements and Apple's published Accelerate performance data: + +- **matmul**: 10-50x speedup (NEON + AMX vs scalar loop) +- **elementwise**: 4-8x speedup (NEON vectorization) +- **reductions**: 4-8x speedup (NEON vectorization) +- **overall LLM inference**: 5-20x speedup on native macOS CPU backend + +### Files to create/modify + +``` +skainet-backends/skainet-backend-cpu/ +├── build.gradle.kts # add cinterop +├── src/nativeInterop/cinterop/accelerate.def # NEW +├── src/macosMain/kotlin/.../AccelerateCpuOps.kt # NEW +├── src/macosMain/kotlin/.../PlatformCpuOpsFactory.macos.kt # NEW +├── src/linuxMain/kotlin/.../PlatformCpuOpsFactory.linux.kt # NEW (move from nativeMain) +└── src/nativeMain/kotlin/.../PlatformCpuOpsFactory.native.kt # REMOVE (split to platform-specific) +``` + +### References + +- JVM SIMD implementation: `src/jvmMain/kotlin/.../DefaultCpuOpsJvm.kt` +- JVM BLAS integration: `src/jvmMain/kotlin/.../JvmBlas.kt` +- Apple Accelerate docs: https://developer.apple.com/documentation/accelerate +- CBLAS reference: https://developer.apple.com/documentation/accelerate/blas +- vDSP reference: https://developer.apple.com/documentation/accelerate/vdsp diff --git a/skainet-backends/skainet-backend-cpu/build.gradle.kts b/skainet-backends/skainet-backend-cpu/build.gradle.kts index 54d14a7c..3864033b 100644 --- a/skainet-backends/skainet-backend-cpu/build.gradle.kts +++ b/skainet-backends/skainet-backend-cpu/build.gradle.kts @@ -70,16 +70,20 @@ kotlin { dependsOn(commonMain) } + val appleMain by creating { + dependsOn(nativeMain) + } + val linuxMain by creating { dependsOn(nativeMain) } val iosMain by creating { - dependsOn(nativeMain) + dependsOn(appleMain) } val macosMain by creating { - dependsOn(nativeMain) + dependsOn(appleMain) } val iosArm64Main by getting { diff --git a/skainet-backends/skainet-backend-cpu/src/appleMain/kotlin/sk/ainet/exec/tensor/ops/AccelerateCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/appleMain/kotlin/sk/ainet/exec/tensor/ops/AccelerateCpuOps.kt new file mode 100644 index 00000000..e8e5540b --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/appleMain/kotlin/sk/ainet/exec/tensor/ops/AccelerateCpuOps.kt @@ -0,0 +1,344 @@ +@file:OptIn(kotlinx.cinterop.ExperimentalForeignApi::class) + +package sk.ainet.exec.tensor.ops + +import kotlinx.cinterop.addressOf +import kotlinx.cinterop.usePinned +import platform.Accelerate.CblasNoTrans +import platform.Accelerate.CblasRowMajor +import platform.Accelerate.cblas_sgemm +import platform.Accelerate.vDSP_vadd +import platform.Accelerate.vDSP_vsub +import platform.Accelerate.vDSP_vmul +import platform.Accelerate.vDSP_vdiv +import platform.Accelerate.vDSP_sve +import platform.Accelerate.vDSP_meanv +import platform.Accelerate.vDSP_mtrans +import platform.Accelerate.vDSP_vthres +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.tensor.data.TensorDataFactory +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP32 + +/** + * CPU operations accelerated by Apple's Accelerate framework. + * Overrides hot-path operations (matmul, elementwise, reductions) with + * hardware-optimized routines that leverage ARM NEON and AMX. + * + * Falls through to [DefaultCpuOpsBase] for non-FP32, non-contiguous, + * or complex broadcasting cases. + */ +public class AccelerateCpuOps( + dataFactory: TensorDataFactory, +) : DefaultCpuOpsBase(dataFactory) { + + // ── matmul ────────────────────────────────────────────────────────── + + override fun matmul(a: Tensor, b: Tensor): Tensor { + if (a.rank == 2 && b.rank == 2 + && a.dtype == FP32::class + && a.data is FloatArrayTensorData<*> + && b.data is FloatArrayTensorData<*> + ) { + val aBuf = (a.data as FloatArrayTensorData<*>).buffer + val bBuf = (b.data as FloatArrayTensorData<*>).buffer + val m = a.shape[0] + val k = a.shape[1] + val n = b.shape[1] + require(k == b.shape[0]) { "matmul shape mismatch: ${a.shape} vs ${b.shape}" } + + val out = FloatArray(m * n) + // cblas_sgemm: C = alpha * A * B + beta * C + aBuf.usePinned { aPin -> + bBuf.usePinned { bPin -> + out.usePinned { cPin -> + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, CblasNoTrans, + m, n, k, + 1.0f, // alpha + aPin.addressOf(0), k, // A, lda + bPin.addressOf(0), n, // B, ldb + 0.0f, // beta + cPin.addressOf(0), n, // C, ldc + ) + } + } + } + + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(Shape(m, n), a.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, a.dtype, a, b) + } + + return super.matmul(a, b) + } + + // ── elementwise binary ops ────────────────────────────────────────── + + override fun add(a: Tensor, b: Tensor): Tensor { + val result = tryVdspBinary(a, b, ::vdspAdd) + return result ?: super.add(a, b) + } + + override fun subtract(a: Tensor, b: Tensor): Tensor { + val result = tryVdspBinary(a, b, ::vdspSub) + return result ?: super.subtract(a, b) + } + + override fun multiply(a: Tensor, b: Tensor): Tensor { + val result = tryVdspBinary(a, b, ::vdspMul) + return result ?: super.multiply(a, b) + } + + override fun divide(a: Tensor, b: Tensor): Tensor { + val result = tryVdspBinary(a, b, ::vdspDiv) + return result ?: super.divide(a, b) + } + + // ── reductions ────────────────────────────────────────────────────── + + override fun sum(tensor: Tensor, dim: Int?): Tensor { + if (dim == null + && tensor.dtype == FP32::class + && tensor.data is FloatArrayTensorData<*> + ) { + val buf = (tensor.data as FloatArrayTensorData<*>).buffer + val n = buf.size + if (n > 0) { + val result = FloatArray(1) + buf.usePinned { pin -> + result.usePinned { rPin -> + vDSP_sve(pin.addressOf(0), 1, rPin.addressOf(0), n.toULong()) + } + } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(Shape(), tensor.dtype, floatArrayOf(result[0])) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, tensor.dtype, tensor) + } + } + return super.sum(tensor, dim) + } + + override fun mean(tensor: Tensor, dim: Int?): Tensor { + if (dim == null + && tensor.dtype == FP32::class + && tensor.data is FloatArrayTensorData<*> + ) { + val buf = (tensor.data as FloatArrayTensorData<*>).buffer + val n = buf.size + if (n > 0) { + val result = FloatArray(1) + buf.usePinned { pin -> + result.usePinned { rPin -> + vDSP_meanv(pin.addressOf(0), 1, rPin.addressOf(0), n.toULong()) + } + } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(Shape(), tensor.dtype, floatArrayOf(result[0])) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, tensor.dtype, tensor) + } + } + return super.mean(tensor, dim) + } + + // ── activations ───────────────────────────────────────────────────── + + override fun relu(tensor: Tensor): Tensor { + if (tensor.dtype == FP32::class && tensor.data is FloatArrayTensorData<*>) { + val buf = (tensor.data as FloatArrayTensorData<*>).buffer + val n = buf.size + val out = FloatArray(n) + buf.usePinned { pin -> + out.usePinned { oPin -> + val threshold = FloatArray(1) { 0.0f } + threshold.usePinned { tPin -> + vDSP_vthres(pin.addressOf(0), 1, tPin.addressOf(0), oPin.addressOf(0), 1, n.toULong()) + } + } + } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(tensor.shape, tensor.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, tensor.dtype, tensor) + } + return super.relu(tensor) + } + + override fun silu(tensor: Tensor): Tensor { + if (tensor.dtype == FP32::class && tensor.data is FloatArrayTensorData<*>) { + val buf = (tensor.data as FloatArrayTensorData<*>).buffer + val n = buf.size + val out = FloatArray(n) + for (i in 0 until n) { + val x = buf[i] + out[i] = x / (1.0f + kotlin.math.exp(-x)) + } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(tensor.shape, tensor.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, tensor.dtype, tensor) + } + return super.silu(tensor) + } + + // ── transpose ─────────────────────────────────────────────────────── + + override fun transpose(tensor: Tensor): Tensor { + if (tensor.rank == 2 + && tensor.dtype == FP32::class + && tensor.data is FloatArrayTensorData<*> + ) { + val buf = (tensor.data as FloatArrayTensorData<*>).buffer + val rows = tensor.shape[0] + val cols = tensor.shape[1] + val out = FloatArray(rows * cols) + buf.usePinned { pin -> + out.usePinned { oPin -> + vDSP_mtrans( + pin.addressOf(0), 1, + oPin.addressOf(0), 1, + cols.toULong(), rows.toULong(), + ) + } + } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(Shape(cols, rows), tensor.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, tensor.dtype, tensor) + } + return super.transpose(tensor) + } + + // ── vDSP binary helpers ───────────────────────────────────────────── + + /** + * Attempt to dispatch a binary elementwise op to vDSP. + * Returns null if the tensors are not eligible (non-FP32, non-contiguous, + * complex broadcasting). + */ + private fun tryVdspBinary( + a: Tensor, + b: Tensor, + op: (FloatArray, FloatArray, FloatArray, Int) -> Unit, + ): Tensor? { + if (a.dtype != FP32::class) return null + if (a.data !is FloatArrayTensorData<*> || b.data !is FloatArrayTensorData<*>) return null + + val aBuf = (a.data as FloatArrayTensorData<*>).buffer + val bBuf = (b.data as FloatArrayTensorData<*>).buffer + + // Same shape: straightforward vectorized op + if (a.shape == b.shape) { + val n = aBuf.size + val out = FloatArray(n) + op(aBuf, bBuf, out, n) + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(a.shape, a.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, a.dtype, a, b) + } + + // Scalar broadcast: b is a single element + if (bBuf.size == 1) { + val n = aBuf.size + val expanded = FloatArray(n) { bBuf[0] } + val out = FloatArray(n) + op(aBuf, expanded, out, n) + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(a.shape, a.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, a.dtype, a, b) + } + + // Scalar broadcast: a is a single element + if (aBuf.size == 1) { + val n = bBuf.size + val expanded = FloatArray(n) { aBuf[0] } + val out = FloatArray(n) + op(expanded, bBuf, out, n) + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(b.shape, a.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, a.dtype, a, b) + } + + // Last-dim broadcast: b has shape [1, ..., 1, N] matching a's last dim + // Common for bias add: [batch, features] + [features] + if (b.rank <= a.rank) { + val bDims = b.shape.dimensions + val aDims = a.shape.dimensions + val offset = aDims.size - bDims.size + var isBiasBroadcast = true + for (i in bDims.indices) { + if (i < bDims.size - 1 && bDims[i] != 1) { isBiasBroadcast = false; break } + if (i == bDims.size - 1 && bDims[i] != aDims[offset + i]) { isBiasBroadcast = false; break } + } + if (isBiasBroadcast && bDims.last() > 1) { + val lastDim = bDims.last() + val batches = aBuf.size / lastDim + val out = FloatArray(aBuf.size) + for (batch in 0 until batches) { + val aSlice = FloatArray(lastDim) + aBuf.copyInto(aSlice, 0, batch * lastDim, (batch + 1) * lastDim) + val oSlice = FloatArray(lastDim) + op(aSlice, bBuf, oSlice, lastDim) + oSlice.copyInto(out, batch * lastDim) + } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(a.shape, a.dtype, out) + as sk.ainet.lang.tensor.data.TensorData + return newTensor(outData, a.dtype, a, b) + } + } + + return null // fall through to scalar + } + + private fun vdspAdd(a: FloatArray, b: FloatArray, out: FloatArray, n: Int) { + a.usePinned { aPin -> + b.usePinned { bPin -> + out.usePinned { oPin -> + vDSP_vadd(aPin.addressOf(0), 1, bPin.addressOf(0), 1, oPin.addressOf(0), 1, n.toULong()) + } + } + } + } + + private fun vdspSub(a: FloatArray, b: FloatArray, out: FloatArray, n: Int) { + // vDSP_vsub computes out = B - A (reversed!), so swap args + a.usePinned { aPin -> + b.usePinned { bPin -> + out.usePinned { oPin -> + vDSP_vsub(bPin.addressOf(0), 1, aPin.addressOf(0), 1, oPin.addressOf(0), 1, n.toULong()) + } + } + } + } + + private fun vdspMul(a: FloatArray, b: FloatArray, out: FloatArray, n: Int) { + a.usePinned { aPin -> + b.usePinned { bPin -> + out.usePinned { oPin -> + vDSP_vmul(aPin.addressOf(0), 1, bPin.addressOf(0), 1, oPin.addressOf(0), 1, n.toULong()) + } + } + } + } + + private fun vdspDiv(a: FloatArray, b: FloatArray, out: FloatArray, n: Int) { + // vDSP_vdiv computes out = B / A (reversed!), so swap args + a.usePinned { aPin -> + b.usePinned { bPin -> + out.usePinned { oPin -> + vDSP_vdiv(bPin.addressOf(0), 1, aPin.addressOf(0), 1, oPin.addressOf(0), 1, n.toULong()) + } + } + } + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/appleMain/kotlin/sk/ainet/exec/tensor/ops/PlatformCpuOpsFactory.apple.kt b/skainet-backends/skainet-backend-cpu/src/appleMain/kotlin/sk/ainet/exec/tensor/ops/PlatformCpuOpsFactory.apple.kt new file mode 100644 index 00000000..6db5adc9 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/appleMain/kotlin/sk/ainet/exec/tensor/ops/PlatformCpuOpsFactory.apple.kt @@ -0,0 +1,9 @@ +package sk.ainet.exec.tensor.ops + +import sk.ainet.lang.tensor.data.TensorDataFactory +import sk.ainet.lang.tensor.ops.TensorOps + +internal actual fun platformDefaultCpuOpsFactory(): (TensorDataFactory) -> TensorOps { + println("[SKaiNET] Using Accelerate-backed CPU operations (ARM NEON + AMX)") + return { factory -> AccelerateCpuOps(factory) } +} diff --git a/skainet-backends/skainet-backend-cpu/src/nativeMain/kotlin/sk/ainet/exec/tensor/ops/PlatformCpuOpsFactory.native.kt b/skainet-backends/skainet-backend-cpu/src/linuxMain/kotlin/sk/ainet/exec/tensor/ops/PlatformCpuOpsFactory.linux.kt similarity index 100% rename from skainet-backends/skainet-backend-cpu/src/nativeMain/kotlin/sk/ainet/exec/tensor/ops/PlatformCpuOpsFactory.native.kt rename to skainet-backends/skainet-backend-cpu/src/linuxMain/kotlin/sk/ainet/exec/tensor/ops/PlatformCpuOpsFactory.linux.kt