Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions apps/bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,19 @@ fn main() {
let _c = yscv_kernels::conv2d_nhwc(&conv_input, &conv_kernel, None, 1, 1).unwrap();
});

let yolo_stem_input = Tensor::zeros(vec![1, 640, 640, 3]).unwrap();
let yolo_stem_kernel = Tensor::zeros(vec![3, 3, 3, 32]).unwrap();
bench_n("conv2d_nhwc_yolo_640x640_3x3_32", 100, || {
let _c =
yscv_kernels::conv2d_nhwc(&yolo_stem_input, &yolo_stem_kernel, None, 1, 1).unwrap();
});

let yolo_p3_input = Tensor::zeros(vec![1, 80, 80, 128]).unwrap();
let yolo_p3_kernel = Tensor::zeros(vec![3, 3, 128, 256]).unwrap();
bench_n("conv2d_nhwc_yolo_p3_80x80_3x3_256", 100, || {
let _c = yscv_kernels::conv2d_nhwc(&yolo_p3_input, &yolo_p3_kernel, None, 1, 1).unwrap();
});

// --- Activations (vs PyTorch) ---
bench_n("sigmoid_1M", 100, || {
let _s = yscv_kernels::sigmoid(&a1m);
Expand Down
40 changes: 40 additions & 0 deletions crates/yscv-kernels/benches/kernels_cpu_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,45 @@ fn bench_conv_modes(c: &mut Criterion) {
group.finish();
}

fn bench_winograd_conv_modes(c: &mut Criterion) {
let small_input = build_tensor(&[1, 32, 32, 8], 0.43);
let small_kernel = build_tensor(&[3, 3, 8, 16], 0.87);
let small_bias = build_tensor(&[16], 0.21);

let yolo_p3_input = build_tensor(&[1, 80, 80, 128], 0.49);
let yolo_p3_kernel = build_tensor(&[3, 3, 128, 256], 0.83);
let yolo_p3_bias = build_tensor(&[256], 0.27);

let mut group = c.benchmark_group("kernels_winograd_conv_modes");
group.bench_function("winograd_3x3_s1_32x32x8_to16", |b| {
b.iter(|| {
let out = conv2d_nhwc(
black_box(&small_input),
black_box(&small_kernel),
Some(black_box(&small_bias)),
1,
1,
)
.expect("winograd conv2d small");
black_box(out);
});
});
group.bench_function("winograd_3x3_s1_yolo_p3_80x80x128_to256", |b| {
b.iter(|| {
let out = conv2d_nhwc(
black_box(&yolo_p3_input),
black_box(&yolo_p3_kernel),
Some(black_box(&yolo_p3_bias)),
1,
1,
)
.expect("winograd conv2d yolo p3");
black_box(out);
});
});
group.finish();
}

fn bench_depthwise_conv_modes(c: &mut Criterion) {
let input = build_tensor(&[1, 32, 32, 8], 0.28);
let kernel = build_tensor(&[3, 3, 8, 2], 0.74);
Expand Down Expand Up @@ -593,6 +632,7 @@ criterion_group!(
bench_elementwise_modes,
bench_pool_modes,
bench_conv_modes,
bench_winograd_conv_modes,
bench_depthwise_conv_modes,
bench_separable_conv_modes,
bench_batch_norm_modes,
Expand Down
173 changes: 173 additions & 0 deletions crates/yscv-kernels/src/ops/conv/gemm_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,37 @@ fn winograd_transform_weights_f32(kernel: &[f32], c_in: usize, c_out: usize) ->
#[cfg(all(feature = "blas", not(target_os = "macos")))]
#[inline]
fn winograd_input_tile(d: &[f32; 16], out: &mut [f32; 16]) {
#[cfg(target_arch = "aarch64")]
{
// SAFETY: `d` and `out` reference 16 contiguous f32 values; aarch64 implies NEON.
#[allow(unsafe_code)]
unsafe {
winograd_input_tile_neon(d, out);
}
}

#[cfg(target_arch = "x86_64")]
{
// SAFETY: `d` and `out` reference 16 contiguous f32 values; x86_64 implies SSE.
#[allow(unsafe_code)]
unsafe {
winograd_input_tile_sse(d, out);
}
}

#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
winograd_input_tile_scalar(d, out);
}
}

#[cfg(all(
feature = "blas",
not(target_os = "macos"),
any(test, not(any(target_arch = "aarch64", target_arch = "x86_64")))
))]
#[inline]
fn winograd_input_tile_scalar(d: &[f32; 16], out: &mut [f32; 16]) {
// B^T * d → 4×4 intermediate (rows transformed)
let mut bd = [0.0f32; 16];
for col in 0..4 {
Expand All @@ -390,6 +421,101 @@ fn winograd_input_tile(d: &[f32; 16], out: &mut [f32; 16]) {
}
}

#[cfg(all(feature = "blas", not(target_os = "macos"), target_arch = "aarch64"))]
#[inline]
#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
unsafe fn transpose4x4_f32_neon(
r0: std::arch::aarch64::float32x4_t,
r1: std::arch::aarch64::float32x4_t,
r2: std::arch::aarch64::float32x4_t,
r3: std::arch::aarch64::float32x4_t,
) -> (
std::arch::aarch64::float32x4_t,
std::arch::aarch64::float32x4_t,
std::arch::aarch64::float32x4_t,
std::arch::aarch64::float32x4_t,
) {
use std::arch::aarch64::{vcombine_f32, vget_high_f32, vget_low_f32, vtrnq_f32};

let t0 = vtrnq_f32(r0, r1);
let t1 = vtrnq_f32(r2, r3);

let c0 = vcombine_f32(vget_low_f32(t0.0), vget_low_f32(t1.0));
let c1 = vcombine_f32(vget_low_f32(t0.1), vget_low_f32(t1.1));
let c2 = vcombine_f32(vget_high_f32(t0.0), vget_high_f32(t1.0));
let c3 = vcombine_f32(vget_high_f32(t0.1), vget_high_f32(t1.1));

(c0, c1, c2, c3)
}

#[cfg(all(feature = "blas", not(target_os = "macos"), target_arch = "aarch64"))]
#[inline]
#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
unsafe fn winograd_input_tile_neon(d: &[f32; 16], out: &mut [f32; 16]) {
use std::arch::aarch64::{vaddq_f32, vld1q_f32, vst1q_f32, vsubq_f32};

let r0 = vld1q_f32(d.as_ptr().add(0));
let r1 = vld1q_f32(d.as_ptr().add(4));
let r2 = vld1q_f32(d.as_ptr().add(8));
let r3 = vld1q_f32(d.as_ptr().add(12));

// B^T * d → 4×4 intermediate (rows transformed)
let b0 = vsubq_f32(r0, r2);
let b1 = vaddq_f32(r1, r2);
let b2 = vsubq_f32(r2, r1);
let b3 = vsubq_f32(r1, r3);

let (c0, c1, c2, c3) = transpose4x4_f32_neon(b0, b1, b2, b3);

// (B^T * d) * B → 4×4 output (columns transformed)
let out0 = vsubq_f32(c0, c2);
let out1 = vaddq_f32(c1, c2);
let out2 = vsubq_f32(c2, c1);
let out3 = vsubq_f32(c1, c3);

let (r_out0, r_out1, r_out2, r_out3) = transpose4x4_f32_neon(out0, out1, out2, out3);

vst1q_f32(out.as_mut_ptr().add(0), r_out0);
vst1q_f32(out.as_mut_ptr().add(4), r_out1);
vst1q_f32(out.as_mut_ptr().add(8), r_out2);
vst1q_f32(out.as_mut_ptr().add(12), r_out3);
}

#[cfg(all(feature = "blas", not(target_os = "macos"), target_arch = "x86_64"))]
#[inline]
#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
unsafe fn winograd_input_tile_sse(d: &[f32; 16], out: &mut [f32; 16]) {
use std::arch::x86_64::{
_MM_TRANSPOSE4_PS, _mm_add_ps, _mm_loadu_ps, _mm_storeu_ps, _mm_sub_ps,
};

let r0 = _mm_loadu_ps(d.as_ptr().add(0));
let r1 = _mm_loadu_ps(d.as_ptr().add(4));
let r2 = _mm_loadu_ps(d.as_ptr().add(8));
let r3 = _mm_loadu_ps(d.as_ptr().add(12));

// B^T * d → 4×4 intermediate (rows transformed)
let mut b0 = _mm_sub_ps(r0, r2);
let mut b1 = _mm_add_ps(r1, r2);
let mut b2 = _mm_sub_ps(r2, r1);
let mut b3 = _mm_sub_ps(r1, r3);

_MM_TRANSPOSE4_PS(&mut b0, &mut b1, &mut b2, &mut b3);

// (B^T * d) * B → 4×4 output (columns transformed)
let mut out0 = _mm_sub_ps(b0, b2);
let mut out1 = _mm_add_ps(b1, b2);
let mut out2 = _mm_sub_ps(b2, b1);
let mut out3 = _mm_sub_ps(b1, b3);

_MM_TRANSPOSE4_PS(&mut out0, &mut out1, &mut out2, &mut out3);

_mm_storeu_ps(out.as_mut_ptr().add(0), out0);
_mm_storeu_ps(out.as_mut_ptr().add(4), out1);
_mm_storeu_ps(out.as_mut_ptr().add(8), out2);
_mm_storeu_ps(out.as_mut_ptr().add(12), out3);
}

/// Winograd output transform: A^T * m * A, yielding 2×2 output from 4×4 product.
///
/// A^T = [[1,1,1,0],[0,1,-1,-1]]
Expand Down Expand Up @@ -958,3 +1084,50 @@ fn im2col_nhwc_padded_tile(
}
}
}

#[cfg(all(
test,
feature = "blas",
not(target_os = "macos"),
any(target_arch = "aarch64", target_arch = "x86_64")
))]
mod tests {
use super::*;

#[test]
fn winograd_input_tile_vectorized_matches_scalar() {
let cases = [
[0.0f32; 16],
core::array::from_fn(|i| i as f32),
core::array::from_fn(|i| i as f32 * 0.25 - 2.0),
core::array::from_fn(|i| if i % 2 == 0 { i as f32 } else { -(i as f32) }),
];

for input in cases {
let mut scalar_output = [0.0f32; 16];
let mut vectorized_output = [0.0f32; 16];

winograd_input_tile_scalar(&input, &mut scalar_output);

#[cfg(target_arch = "aarch64")]
{
// SAFETY: test arrays are exactly 16 contiguous f32 values.
#[allow(unsafe_code)]
unsafe {
winograd_input_tile_neon(&input, &mut vectorized_output);
}
}

#[cfg(target_arch = "x86_64")]
{
// SAFETY: test arrays are exactly 16 contiguous f32 values.
#[allow(unsafe_code)]
unsafe {
winograd_input_tile_sse(&input, &mut vectorized_output);
}
}

assert_eq!(scalar_output, vectorized_output);
}
}
}
Loading