Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 88 additions & 40 deletions crates/core_arch/src/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2778,8 +2778,17 @@ pub fn _mm256_sign_epi8(a: __m256i, b: __m256i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsllw))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_sll_epi16(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(psllw(a.as_i16x16(), count.as_i16x8())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_sll_epi16(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0];
unsafe {
if shift >= 16 {
_mm256_setzero_si256()
} else {
// SAFETY: We checked above that the shift is less than 16 bits.
simd_shl(a.as_u16x16(), u16x16::splat(shift as u16)).as_m256i()
}
}
}

/// Shifts packed 32-bit integers in `a` left by `count` while
Expand All @@ -2790,8 +2799,17 @@ pub fn _mm256_sll_epi16(a: __m256i, count: __m128i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpslld))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_sll_epi32(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(pslld(a.as_i32x8(), count.as_i32x4())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_sll_epi32(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0];
unsafe {
if shift >= 32 {
_mm256_setzero_si256()
} else {
// SAFETY: We checked above that the shift is less than 32 bits.
simd_shl(a.as_u32x8(), u32x8::splat(shift as u32)).as_m256i()
}
}
}

/// Shifts packed 64-bit integers in `a` left by `count` while
Expand All @@ -2802,8 +2820,17 @@ pub fn _mm256_sll_epi32(a: __m256i, count: __m128i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsllq))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_sll_epi64(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(psllq(a.as_i64x4(), count.as_i64x2())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_sll_epi64(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0];
unsafe {
if shift >= 64 {
_mm256_setzero_si256()
} else {
// SAFETY: We checked above that the shift is less than 64 bits.
simd_shl(a.as_u64x4(), u64x4::splat(shift as u64)).as_m256i()
}
}
}

/// Shifts packed 16-bit integers in `a` left by `IMM8` while
Expand Down Expand Up @@ -3030,8 +3057,13 @@ pub const fn _mm256_sllv_epi64(a: __m256i, count: __m256i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsraw))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_sra_epi16(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(psraw(a.as_i16x16(), count.as_i16x8())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_sra_epi16(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0].min(15);
unsafe {
// SAFETY: We checked above that the shift is less than 16 bits.
simd_shr(a.as_i16x16(), i16x16::splat(shift as i16)).as_m256i()
}
}

/// Shifts packed 32-bit integers in `a` right by `count` while
Expand All @@ -3042,8 +3074,13 @@ pub fn _mm256_sra_epi16(a: __m256i, count: __m128i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrad))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_sra_epi32(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(psrad(a.as_i32x8(), count.as_i32x4())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_sra_epi32(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0].min(31);
unsafe {
// SAFETY: We checked above that the shift is less than 32 bits.
simd_shr(a.as_i32x8(), i32x8::splat(shift as i32)).as_m256i()
}
}

/// Shifts packed 16-bit integers in `a` right by `IMM8` while
Expand Down Expand Up @@ -3197,8 +3234,17 @@ pub const fn _mm256_bsrli_epi128<const IMM8: i32>(a: __m256i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrlw))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_srl_epi16(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(psrlw(a.as_i16x16(), count.as_i16x8())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_srl_epi16(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0];
unsafe {
if shift >= 16 {
_mm256_setzero_si256()
} else {
// SAFETY: We checked above that the shift is less than 16 bits.
simd_shr(a.as_u16x16(), u16x16::splat(shift as u16)).as_m256i()
}
}
}

/// Shifts packed 32-bit integers in `a` right by `count` while shifting in
Expand All @@ -3209,8 +3255,17 @@ pub fn _mm256_srl_epi16(a: __m256i, count: __m128i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrld))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_srl_epi32(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(psrld(a.as_i32x8(), count.as_i32x4())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_srl_epi32(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0];
unsafe {
if shift >= 32 {
_mm256_setzero_si256()
} else {
// SAFETY: We checked above that the shift is less than 32 bits.
simd_shr(a.as_u32x8(), u32x8::splat(shift as u32)).as_m256i()
}
}
}

/// Shifts packed 64-bit integers in `a` right by `count` while shifting in
Expand All @@ -3221,8 +3276,17 @@ pub fn _mm256_srl_epi32(a: __m256i, count: __m128i) -> __m256i {
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrlq))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub fn _mm256_srl_epi64(a: __m256i, count: __m128i) -> __m256i {
unsafe { transmute(psrlq(a.as_i64x4(), count.as_i64x2())) }
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_srl_epi64(a: __m256i, count: __m128i) -> __m256i {
let shift = count.as_u64x2().as_array()[0];
unsafe {
if shift >= 64 {
_mm256_setzero_si256()
} else {
// SAFETY: We checked above that the shift is less than 64 bits.
simd_shr(a.as_u64x4(), u64x4::splat(shift as u64)).as_m256i()
}
}
}

/// Shifts packed 16-bit integers in `a` right by `IMM8` while shifting in
Expand Down Expand Up @@ -3919,22 +3983,6 @@ unsafe extern "C" {
fn psignw(a: i16x16, b: i16x16) -> i16x16;
#[link_name = "llvm.x86.avx2.psign.d"]
fn psignd(a: i32x8, b: i32x8) -> i32x8;
#[link_name = "llvm.x86.avx2.psll.w"]
fn psllw(a: i16x16, count: i16x8) -> i16x16;
#[link_name = "llvm.x86.avx2.psll.d"]
fn pslld(a: i32x8, count: i32x4) -> i32x8;
#[link_name = "llvm.x86.avx2.psll.q"]
fn psllq(a: i64x4, count: i64x2) -> i64x4;
#[link_name = "llvm.x86.avx2.psra.w"]
fn psraw(a: i16x16, count: i16x8) -> i16x16;
#[link_name = "llvm.x86.avx2.psra.d"]
fn psrad(a: i32x8, count: i32x4) -> i32x8;
#[link_name = "llvm.x86.avx2.psrl.w"]
fn psrlw(a: i16x16, count: i16x8) -> i16x16;
#[link_name = "llvm.x86.avx2.psrl.d"]
fn psrld(a: i32x8, count: i32x4) -> i32x8;
#[link_name = "llvm.x86.avx2.psrl.q"]
fn psrlq(a: i64x4, count: i64x2) -> i64x4;
#[link_name = "llvm.x86.avx2.pshuf.b"]
fn pshufb(a: u8x32, b: u8x32) -> u8x32;
#[link_name = "llvm.x86.avx2.permd"]
Expand Down Expand Up @@ -5184,23 +5232,23 @@ mod tests {
}

#[simd_test(enable = "avx2")]
fn test_mm256_sll_epi16() {
const fn test_mm256_sll_epi16() {
let a = _mm256_set1_epi16(0xFF);
let b = _mm_insert_epi16::<0>(_mm_set1_epi16(0), 4);
let r = _mm256_sll_epi16(a, b);
assert_eq_m256i(r, _mm256_set1_epi16(0xFF0));
}

#[simd_test(enable = "avx2")]
fn test_mm256_sll_epi32() {
const fn test_mm256_sll_epi32() {
let a = _mm256_set1_epi32(0xFFFF);
let b = _mm_insert_epi32::<0>(_mm_set1_epi32(0), 4);
let r = _mm256_sll_epi32(a, b);
assert_eq_m256i(r, _mm256_set1_epi32(0xFFFF0));
}

#[simd_test(enable = "avx2")]
fn test_mm256_sll_epi64() {
const fn test_mm256_sll_epi64() {
let a = _mm256_set1_epi64x(0xFFFFFFFF);
let b = _mm_insert_epi64::<0>(_mm_set1_epi64x(0), 4);
let r = _mm256_sll_epi64(a, b);
Expand Down Expand Up @@ -5275,15 +5323,15 @@ mod tests {
}

#[simd_test(enable = "avx2")]
fn test_mm256_sra_epi16() {
const fn test_mm256_sra_epi16() {
let a = _mm256_set1_epi16(-1);
let b = _mm_setr_epi16(1, 0, 0, 0, 0, 0, 0, 0);
let r = _mm256_sra_epi16(a, b);
assert_eq_m256i(r, _mm256_set1_epi16(-1));
}

#[simd_test(enable = "avx2")]
fn test_mm256_sra_epi32() {
const fn test_mm256_sra_epi32() {
let a = _mm256_set1_epi32(-1);
let b = _mm_insert_epi32::<0>(_mm_set1_epi32(0), 1);
let r = _mm256_sra_epi32(a, b);
Expand Down Expand Up @@ -5345,23 +5393,23 @@ mod tests {
}

#[simd_test(enable = "avx2")]
fn test_mm256_srl_epi16() {
const fn test_mm256_srl_epi16() {
let a = _mm256_set1_epi16(0xFF);
let b = _mm_insert_epi16::<0>(_mm_set1_epi16(0), 4);
let r = _mm256_srl_epi16(a, b);
assert_eq_m256i(r, _mm256_set1_epi16(0xF));
}

#[simd_test(enable = "avx2")]
fn test_mm256_srl_epi32() {
const fn test_mm256_srl_epi32() {
let a = _mm256_set1_epi32(0xFFFF);
let b = _mm_insert_epi32::<0>(_mm_set1_epi32(0), 4);
let r = _mm256_srl_epi32(a, b);
assert_eq_m256i(r, _mm256_set1_epi32(0xFFF));
}

#[simd_test(enable = "avx2")]
fn test_mm256_srl_epi64() {
const fn test_mm256_srl_epi64() {
let a = _mm256_set1_epi64x(0xFFFFFFFF);
let b = _mm_setr_epi64x(4, 0);
let r = _mm256_srl_epi64(a, b);
Expand Down
Loading