diff --git a/apps/bench/src/main.rs b/apps/bench/src/main.rs index 714c4b0..e5906a7 100644 --- a/apps/bench/src/main.rs +++ b/apps/bench/src/main.rs @@ -211,6 +211,19 @@ fn main() { let _c = yscv_kernels::conv2d_nhwc(&conv_input, &conv_kernel, None, 1, 1).unwrap(); }); + let yolo_stem_input = Tensor::zeros(vec![1, 640, 640, 3]).unwrap(); + let yolo_stem_kernel = Tensor::zeros(vec![3, 3, 3, 32]).unwrap(); + bench_n("conv2d_nhwc_yolo_640x640_3x3_32", 100, || { + let _c = + yscv_kernels::conv2d_nhwc(&yolo_stem_input, &yolo_stem_kernel, None, 1, 1).unwrap(); + }); + + let yolo_p3_input = Tensor::zeros(vec![1, 80, 80, 128]).unwrap(); + let yolo_p3_kernel = Tensor::zeros(vec![3, 3, 128, 256]).unwrap(); + bench_n("conv2d_nhwc_yolo_p3_80x80_3x3_256", 100, || { + let _c = yscv_kernels::conv2d_nhwc(&yolo_p3_input, &yolo_p3_kernel, None, 1, 1).unwrap(); + }); + // --- Activations (vs PyTorch) --- bench_n("sigmoid_1M", 100, || { let _s = yscv_kernels::sigmoid(&a1m); diff --git a/crates/yscv-kernels/benches/kernels_cpu_ops.rs b/crates/yscv-kernels/benches/kernels_cpu_ops.rs index 5a0ef59..94cd420 100644 --- a/crates/yscv-kernels/benches/kernels_cpu_ops.rs +++ b/crates/yscv-kernels/benches/kernels_cpu_ops.rs @@ -4,9 +4,9 @@ use criterion::{Criterion, black_box, criterion_group, criterion_main}; use yscv_kernels::{ Backend, BatchNorm2dParams, LayerNormLastDimParams, ParallelElementwiseConfig, ParallelMatmulConfig, SeparableConv2dParams, ThreadedCpuBackend, ThreadedCpuBackendConfig, add, - avg_pool2d_nhwc, batch_norm2d_nhwc, conv2d_nhwc, depthwise_conv2d_nhwc, layer_norm_last_dim, - log_softmax_last_dim, logsumexp_last_dim, matmul_2d, matmul_2d_sequential, max_pool2d_nhwc, - relu, separable_conv2d_nhwc, sigmoid, softmax_last_dim, + avg_pool2d_nhwc, batch_norm2d_nhwc, conv2d_nhwc, conv2d_nhwc_padded, depthwise_conv2d_nhwc, + layer_norm_last_dim, log_softmax_last_dim, logsumexp_last_dim, matmul_2d, matmul_2d_sequential, + max_pool2d_nhwc, relu, separable_conv2d_nhwc, sigmoid, softmax_last_dim, }; use yscv_tensor::Tensor; @@ -282,6 +282,55 @@ fn bench_conv_modes(c: &mut Criterion) { group.finish(); } +fn bench_winograd_conv_modes(c: &mut Criterion) { + let small_input = build_tensor(&[1, 32, 32, 8], 0.43); + let small_kernel = build_tensor(&[3, 3, 8, 16], 0.87); + let small_bias = build_tensor(&[16], 0.21); + + let yolo_p3_input = build_tensor(&[1, 80, 80, 128], 0.49); + let yolo_p3_kernel = build_tensor(&[3, 3, 128, 256], 0.83); + let yolo_p3_bias = build_tensor(&[256], 0.27); + + let mut group = c.benchmark_group("kernels_winograd_conv_modes"); + group.bench_function("winograd_3x3_s1_32x32x8_to16", |b| { + b.iter(|| { + let out = conv2d_nhwc_padded( + black_box(&small_input), + black_box(&small_kernel), + Some(black_box(&small_bias)), + 1, + 1, + 0, + 0, + 0, + 0, + yscv_kernels::Activation::Relu, + ) + .expect("winograd conv2d small"); + black_box(out); + }); + }); + group.bench_function("winograd_3x3_s1_yolo_p3_80x80x128_to256", |b| { + b.iter(|| { + let out = conv2d_nhwc_padded( + black_box(&yolo_p3_input), + black_box(&yolo_p3_kernel), + Some(black_box(&yolo_p3_bias)), + 1, + 1, + 0, + 0, + 0, + 0, + yscv_kernels::Activation::Relu, + ) + .expect("winograd conv2d yolo p3"); + black_box(out); + }); + }); + group.finish(); +} + fn bench_depthwise_conv_modes(c: &mut Criterion) { let input = build_tensor(&[1, 32, 32, 8], 0.28); let kernel = build_tensor(&[3, 3, 8, 2], 0.74); @@ -593,6 +642,7 @@ criterion_group!( bench_elementwise_modes, bench_pool_modes, bench_conv_modes, + bench_winograd_conv_modes, bench_depthwise_conv_modes, bench_separable_conv_modes, bench_batch_norm_modes, diff --git a/crates/yscv-kernels/src/ops/conv/gemm_conv.rs b/crates/yscv-kernels/src/ops/conv/gemm_conv.rs index 0eae647..f969ecf 100644 --- a/crates/yscv-kernels/src/ops/conv/gemm_conv.rs +++ b/crates/yscv-kernels/src/ops/conv/gemm_conv.rs @@ -372,6 +372,37 @@ fn winograd_transform_weights_f32(kernel: &[f32], c_in: usize, c_out: usize) -> #[cfg(all(feature = "blas", not(target_os = "macos")))] #[inline] fn winograd_input_tile(d: &[f32; 16], out: &mut [f32; 16]) { + #[cfg(target_arch = "aarch64")] + { + // SAFETY: `d` and `out` reference 16 contiguous f32 values; aarch64 implies NEON. + #[allow(unsafe_code)] + unsafe { + winograd_input_tile_neon(d, out); + } + } + + #[cfg(target_arch = "x86_64")] + { + // SAFETY: `d` and `out` reference 16 contiguous f32 values; x86_64 implies SSE. + #[allow(unsafe_code)] + unsafe { + winograd_input_tile_sse(d, out); + } + } + + #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] + { + winograd_input_tile_scalar(d, out); + } +} + +#[cfg(all( + feature = "blas", + not(target_os = "macos"), + any(test, not(any(target_arch = "aarch64", target_arch = "x86_64"))) +))] +#[inline] +fn winograd_input_tile_scalar(d: &[f32; 16], out: &mut [f32; 16]) { // B^T * d → 4×4 intermediate (rows transformed) let mut bd = [0.0f32; 16]; for col in 0..4 { @@ -390,6 +421,101 @@ fn winograd_input_tile(d: &[f32; 16], out: &mut [f32; 16]) { } } +#[cfg(all(feature = "blas", not(target_os = "macos"), target_arch = "aarch64"))] +#[inline] +#[allow(unsafe_code, unsafe_op_in_unsafe_fn)] +unsafe fn transpose4x4_f32_neon( + r0: std::arch::aarch64::float32x4_t, + r1: std::arch::aarch64::float32x4_t, + r2: std::arch::aarch64::float32x4_t, + r3: std::arch::aarch64::float32x4_t, +) -> ( + std::arch::aarch64::float32x4_t, + std::arch::aarch64::float32x4_t, + std::arch::aarch64::float32x4_t, + std::arch::aarch64::float32x4_t, +) { + use std::arch::aarch64::{vcombine_f32, vget_high_f32, vget_low_f32, vtrnq_f32}; + + let t0 = vtrnq_f32(r0, r1); + let t1 = vtrnq_f32(r2, r3); + + let c0 = vcombine_f32(vget_low_f32(t0.0), vget_low_f32(t1.0)); + let c1 = vcombine_f32(vget_low_f32(t0.1), vget_low_f32(t1.1)); + let c2 = vcombine_f32(vget_high_f32(t0.0), vget_high_f32(t1.0)); + let c3 = vcombine_f32(vget_high_f32(t0.1), vget_high_f32(t1.1)); + + (c0, c1, c2, c3) +} + +#[cfg(all(feature = "blas", not(target_os = "macos"), target_arch = "aarch64"))] +#[inline] +#[allow(unsafe_code, unsafe_op_in_unsafe_fn)] +unsafe fn winograd_input_tile_neon(d: &[f32; 16], out: &mut [f32; 16]) { + use std::arch::aarch64::{vaddq_f32, vld1q_f32, vst1q_f32, vsubq_f32}; + + let r0 = vld1q_f32(d.as_ptr().add(0)); + let r1 = vld1q_f32(d.as_ptr().add(4)); + let r2 = vld1q_f32(d.as_ptr().add(8)); + let r3 = vld1q_f32(d.as_ptr().add(12)); + + // B^T * d → 4×4 intermediate (rows transformed) + let b0 = vsubq_f32(r0, r2); + let b1 = vaddq_f32(r1, r2); + let b2 = vsubq_f32(r2, r1); + let b3 = vsubq_f32(r1, r3); + + let (c0, c1, c2, c3) = transpose4x4_f32_neon(b0, b1, b2, b3); + + // (B^T * d) * B → 4×4 output (columns transformed) + let out0 = vsubq_f32(c0, c2); + let out1 = vaddq_f32(c1, c2); + let out2 = vsubq_f32(c2, c1); + let out3 = vsubq_f32(c1, c3); + + let (r_out0, r_out1, r_out2, r_out3) = transpose4x4_f32_neon(out0, out1, out2, out3); + + vst1q_f32(out.as_mut_ptr().add(0), r_out0); + vst1q_f32(out.as_mut_ptr().add(4), r_out1); + vst1q_f32(out.as_mut_ptr().add(8), r_out2); + vst1q_f32(out.as_mut_ptr().add(12), r_out3); +} + +#[cfg(all(feature = "blas", not(target_os = "macos"), target_arch = "x86_64"))] +#[inline] +#[allow(unsafe_code, unsafe_op_in_unsafe_fn)] +unsafe fn winograd_input_tile_sse(d: &[f32; 16], out: &mut [f32; 16]) { + use std::arch::x86_64::{ + _MM_TRANSPOSE4_PS, _mm_add_ps, _mm_loadu_ps, _mm_storeu_ps, _mm_sub_ps, + }; + + let r0 = _mm_loadu_ps(d.as_ptr().add(0)); + let r1 = _mm_loadu_ps(d.as_ptr().add(4)); + let r2 = _mm_loadu_ps(d.as_ptr().add(8)); + let r3 = _mm_loadu_ps(d.as_ptr().add(12)); + + // B^T * d → 4×4 intermediate (rows transformed) + let mut b0 = _mm_sub_ps(r0, r2); + let mut b1 = _mm_add_ps(r1, r2); + let mut b2 = _mm_sub_ps(r2, r1); + let mut b3 = _mm_sub_ps(r1, r3); + + _MM_TRANSPOSE4_PS(&mut b0, &mut b1, &mut b2, &mut b3); + + // (B^T * d) * B → 4×4 output (columns transformed) + let mut out0 = _mm_sub_ps(b0, b2); + let mut out1 = _mm_add_ps(b1, b2); + let mut out2 = _mm_sub_ps(b2, b1); + let mut out3 = _mm_sub_ps(b1, b3); + + _MM_TRANSPOSE4_PS(&mut out0, &mut out1, &mut out2, &mut out3); + + _mm_storeu_ps(out.as_mut_ptr().add(0), out0); + _mm_storeu_ps(out.as_mut_ptr().add(4), out1); + _mm_storeu_ps(out.as_mut_ptr().add(8), out2); + _mm_storeu_ps(out.as_mut_ptr().add(12), out3); +} + /// Winograd output transform: A^T * m * A, yielding 2×2 output from 4×4 product. /// /// A^T = [[1,1,1,0],[0,1,-1,-1]] @@ -958,3 +1084,50 @@ fn im2col_nhwc_padded_tile( } } } + +#[cfg(all( + test, + feature = "blas", + not(target_os = "macos"), + any(target_arch = "aarch64", target_arch = "x86_64") +))] +mod tests { + use super::*; + + #[test] + fn winograd_input_tile_vectorized_matches_scalar() { + let cases = [ + [0.0f32; 16], + core::array::from_fn(|i| i as f32), + core::array::from_fn(|i| i as f32 * 0.25 - 2.0), + core::array::from_fn(|i| if i % 2 == 0 { i as f32 } else { -(i as f32) }), + ]; + + for input in cases { + let mut scalar_output = [0.0f32; 16]; + let mut vectorized_output = [0.0f32; 16]; + + winograd_input_tile_scalar(&input, &mut scalar_output); + + #[cfg(target_arch = "aarch64")] + { + // SAFETY: test arrays are exactly 16 contiguous f32 values. + #[allow(unsafe_code)] + unsafe { + winograd_input_tile_neon(&input, &mut vectorized_output); + } + } + + #[cfg(target_arch = "x86_64")] + { + // SAFETY: test arrays are exactly 16 contiguous f32 values. + #[allow(unsafe_code)] + unsafe { + winograd_input_tile_sse(&input, &mut vectorized_output); + } + } + + assert_eq!(scalar_output, vectorized_output); + } + } +}