Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/avx2-32bit-half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct avx2_half_vector<int32_t> {
return _mm_mask_i32gather_epi32(
src, (const int *)base, index, mask, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
}
Expand Down Expand Up @@ -237,7 +237,7 @@ struct avx2_half_vector<uint32_t> {
return _mm_mask_i32gather_epi32(
src, (const int *)base, index, mask, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
}
Expand Down Expand Up @@ -421,7 +421,7 @@ struct avx2_half_vector<float> {
return _mm_mask_i32gather_ps(
src, (const float *)base, index, _mm_castsi128_ps(mask), scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
}
Expand Down
6 changes: 3 additions & 3 deletions src/avx2-64bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct avx2_vector<int64_t> {
return _mm256_mask_i32gather_epi64(
src, (const long long int *)base, index, mask, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
}
Expand Down Expand Up @@ -269,7 +269,7 @@ struct avx2_vector<uint64_t> {
return _mm256_mask_i32gather_epi64(
src, (const long long int *)base, index, mask, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
}
Expand Down Expand Up @@ -499,7 +499,7 @@ struct avx2_vector<double> {
scale);
;
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
}
Expand Down
12 changes: 6 additions & 6 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct ymm_vector<float> {
{
return _mm256_mmask_i32gather_ps(src, mask, index, base, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[7]],
arr[ind[6]],
Expand Down Expand Up @@ -293,7 +293,7 @@ struct ymm_vector<uint32_t> {
{
return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[7]],
arr[ind[6]],
Expand Down Expand Up @@ -481,7 +481,7 @@ struct ymm_vector<int32_t> {
{
return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[7]],
arr[ind[6]],
Expand Down Expand Up @@ -680,7 +680,7 @@ struct zmm_vector<int64_t> {
{
return _mm512_mask_i32gather_epi64(src, mask, index, base, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[7]],
arr[ind[6]],
Expand Down Expand Up @@ -843,7 +843,7 @@ struct zmm_vector<uint64_t> {
{
return _mm512_mask_i32gather_epi64(src, mask, index, base, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[7]],
arr[ind[6]],
Expand Down Expand Up @@ -1062,7 +1062,7 @@ struct zmm_vector<double> {
{
return _mm512_mask_i32gather_pd(src, mask, index, base, scale);
}
static reg_t i64gather(type_t *arr, arrsize_t *ind)
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
{
return set(arr[ind[7]],
arr[ind[6]],
Expand Down
60 changes: 28 additions & 32 deletions src/xss-common-argsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
#include <numeric>

template <typename T>
X86_SIMD_SORT_INLINE void std_argselect_withnan(
T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right)
X86_SIMD_SORT_INLINE void std_argselect_withnan(const T *arr,
arrsize_t *arg,
arrsize_t k,
arrsize_t left,
arrsize_t right)
{
std::nth_element(arg + left,
arg + k,
Expand All @@ -32,8 +35,10 @@ X86_SIMD_SORT_INLINE void std_argselect_withnan(

/* argsort using std::sort */
template <typename T>
X86_SIMD_SORT_INLINE void
std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
X86_SIMD_SORT_INLINE void std_argsort_withnan(const T *arr,
arrsize_t *arg,
arrsize_t left,
arrsize_t right)
{
std::sort(arg + left,
arg + right,
Expand All @@ -53,7 +58,7 @@ std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
/* argsort using std::sort */
template <typename T>
X86_SIMD_SORT_INLINE void
std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
std_argsort(const T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
{
std::sort(arg + left,
arg + right,
Expand Down Expand Up @@ -172,7 +177,7 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg,
* last element that is less than equal to the pivot.
*/
template <typename vtype, typename argtype, typename type_t>
X86_SIMD_SORT_INLINE arrsize_t argpartition(type_t *arr,
X86_SIMD_SORT_INLINE arrsize_t argpartition(const type_t *arr,
arrsize_t *arg,
arrsize_t left,
arrsize_t right,
Expand Down Expand Up @@ -291,7 +296,7 @@ template <typename vtype,
typename argtype,
int num_unroll,
typename type_t = typename vtype::type_t>
X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr,
X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(const type_t *arr,
arrsize_t *arg,
arrsize_t left,
arrsize_t right,
Expand Down Expand Up @@ -422,7 +427,7 @@ X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr,
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(const type_t *arr,
arrsize_t *arg,
const arrsize_t left,
const arrsize_t right)
Expand Down Expand Up @@ -468,7 +473,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
}

template <typename vtype, typename argtype, typename type_t>
X86_SIMD_SORT_INLINE void argsort_(type_t *arr,
X86_SIMD_SORT_INLINE void argsort_(const type_t *arr,
arrsize_t *arg,
arrsize_t left,
arrsize_t right,
Expand Down Expand Up @@ -549,7 +554,7 @@ X86_SIMD_SORT_INLINE void argsort_(type_t *arr,
}

template <typename vtype, typename argtype, typename type_t>
X86_SIMD_SORT_INLINE void argselect_(type_t *arr,
X86_SIMD_SORT_INLINE void argselect_(const type_t *arr,
arrsize_t *arg,
arrsize_t pos,
arrsize_t left,
Expand Down Expand Up @@ -590,7 +595,7 @@ template <typename T,
typename full_vector,
template <typename...>
typename half_vector>
X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
X86_SIMD_SORT_INLINE void xss_argsort(const T *arr,
arrsize_t *arg,
arrsize_t arrsize,
bool hasnan = false,
Expand Down Expand Up @@ -669,29 +674,25 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
}

template <typename T>
X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
X86_SIMD_SORT_INLINE void avx512_argsort(const T *arr,
arrsize_t *arg,
arrsize_t arrsize,
bool hasnan = false,
bool descending = false)
{
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
using base_t = std::remove_const_t<T>;
xss_argsort<base_t, zmm_vector, ymm_vector>(
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
xss_argsort<T, zmm_vector, ymm_vector>(
arr, arg, arrsize, hasnan, descending);
}

template <typename T>
X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
X86_SIMD_SORT_INLINE void avx2_argsort(const T *arr,
arrsize_t *arg,
arrsize_t arrsize,
bool hasnan = false,
bool descending = false)
{
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
using base_t = std::remove_const_t<T>;
xss_argsort<base_t, avx2_vector, avx2_half_vector>(
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
xss_argsort<T, avx2_vector, avx2_half_vector>(
arr, arg, arrsize, hasnan, descending);
}

/* argselect methods for 32-bit and 64-bit dtypes */
Expand All @@ -700,7 +701,7 @@ template <typename T,
typename full_vector,
template <typename...>
typename half_vector>
X86_SIMD_SORT_INLINE void xss_argselect(T *arr,
X86_SIMD_SORT_INLINE void xss_argselect(const T *arr,
arrsize_t *arg,
arrsize_t k,
arrsize_t arrsize,
Expand Down Expand Up @@ -735,29 +736,24 @@ X86_SIMD_SORT_INLINE void xss_argselect(T *arr,
}

template <typename T>
X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
X86_SIMD_SORT_INLINE void avx512_argselect(const T *arr,
arrsize_t *arg,
arrsize_t k,
arrsize_t arrsize,
bool hasnan = false)
{
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
using base_t = std::remove_const_t<T>;
xss_argselect<base_t, zmm_vector, ymm_vector>(
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
xss_argselect<T, zmm_vector, ymm_vector>(arr, arg, k, arrsize, hasnan);
}

template <typename T>
X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
X86_SIMD_SORT_INLINE void avx2_argselect(const T *arr,
arrsize_t *arg,
arrsize_t k,
arrsize_t arrsize,
bool hasnan = false)
{
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
using base_t = std::remove_const_t<T>;
xss_argselect<base_t, avx2_vector, avx2_half_vector>(
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
xss_argselect<T, avx2_vector, avx2_half_vector>(
arr, arg, k, arrsize, hasnan);
}

#endif // XSS_COMMON_ARGSORT
2 changes: 1 addition & 1 deletion src/xss-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size)
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size)
X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size)
{
using opmask_t = typename vtype::opmask_t;
using reg_t = typename vtype::reg_t;
Expand Down
4 changes: 2 additions & 2 deletions src/xss-network-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ bitonic_fullmerge_n_vec(typename keyType::reg_t *keys,

template <typename keyType, typename indexType, int numVecs>
X86_SIMD_SORT_INLINE void
argsort_n_vec(typename keyType::type_t *keys, arrsize_t *indices, int N)
argsort_n_vec(const typename keyType::type_t *keys, arrsize_t *indices, int N)
{
using kreg_t = typename keyType::reg_t;
using ireg_t = typename indexType::reg_t;
Expand Down Expand Up @@ -354,7 +354,7 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys,

template <typename keyType, typename indexType, int maxN>
X86_SIMD_SORT_INLINE void
argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N)
argsort_n(const typename keyType::type_t *keys, arrsize_t *indices, int N)
{
static_assert(keyType::numlanes == indexType::numlanes,
"invalid pairing of value/index types");
Expand Down
Loading