From fa892c02bbd8da04c049012a4da65d229b47b642 Mon Sep 17 00:00:00 2001 From: "Ankit.Ahlawat@ibm.com" Date: Wed, 11 Feb 2026 14:31:42 +0530 Subject: [PATCH 1/6] Make argsort const-correct --- README.md | 2 +- lib/x86simdsort-avx2.cpp | 2 +- lib/x86simdsort-internal.h | 2 +- lib/x86simdsort-scalar.h | 2 +- lib/x86simdsort-skx.cpp | 2 +- lib/x86simdsort.cpp | 4 ++-- lib/x86simdsort.h | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index eb4e1e35..b79e44ac 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ data types. ## Arg sort routines on arrays ```cpp -std::vector arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending); +std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending); std::vector arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index c00591e4..656cf128 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -23,7 +23,7 @@ } \ template <> \ std::vector argsort( \ - type *arr, size_t arrsize, bool hasnan, bool descending) \ + const type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index a9ded641..f8a14c05 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -44,7 +44,7 @@ bool hasnan = false, \ bool descending = false); \ template \ - XSS_HIDE_SYMBOL std::vector argsort(T *arr, \ + XSS_HIDE_SYMBOL std::vector argsort(const T *arr, \ size_t arrsize, \ bool hasnan = false, \ bool descending = false); \ diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 3dc737ca..95fab42a 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -71,7 +71,7 @@ namespace scalar { } template std::vector - argsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed) { UNUSED(hasnan); std::vector arg(arrsize); diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 7d9d5aa4..f4c41255 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -23,7 +23,7 @@ } \ template <> \ std::vector argsort( \ - type *arr, size_t arrsize, bool hasnan, bool descending) \ + const type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 8ef9aadb..7aecbead 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -88,11 +88,11 @@ namespace x86simdsort { #define DECLARE_INTERNAL_argsort(TYPE) \ static std::vector (*internal_argsort##TYPE)( \ - TYPE *, size_t, bool, bool) \ + const TYPE *, size_t, bool, bool) \ = NULL; \ template <> \ std::vector argsort( \ - TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ } diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index e1402fe0..4f17b167 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -36,7 +36,7 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, // argsort template XSS_EXPORT_SYMBOL std::vector -argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); +argsort(const T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // argselect template From 1b2cacc467ef46dc51b3f0c12debe7cb405c3886 Mon Sep 17 00:00:00 2001 From: "Ankit.Ahlawat@ibm.com" Date: Wed, 11 Feb 2026 20:13:38 +0530 Subject: [PATCH 2/6] Propagate const-correct argsort to static API and docs --- src/README.md | 2 +- src/x86simdsort-static-incl.h | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/README.md b/src/README.md index 87757b2a..2e52a451 100644 --- a/src/README.md +++ b/src/README.md @@ -63,7 +63,7 @@ Equivalent to `np.argsort` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html). ```cpp -void x86simdsortStatic::argsort(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 519abe6c..26252a41 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -26,11 +26,11 @@ X86_SIMD_SORT_FINLINE void partial_qsort(T *arr, template X86_SIMD_SORT_FINLINE std::vector -argsort(T *arr, size_t size, bool hasnan = false, bool descending = false); +argsort(const T *arr, size_t size, bool hasnan = false, bool descending = false); /* argsort API required by NumPy: */ template -X86_SIMD_SORT_FINLINE void argsort(T *arr, +X86_SIMD_SORT_FINLINE void argsort(const T *arr, size_t *arg, size_t size, bool hasnan = false, @@ -91,13 +91,13 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort( \ - T *arr, size_t *arg, size_t size, bool hasnan, bool descending) \ + const T *arr, size_t *arg, size_t size, bool hasnan, bool descending) \ { \ ISA##_argsort(arr, arg, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argsort( \ - T *arr, size_t size, bool hasnan, bool descending) \ + const T *arr, size_t size, bool hasnan, bool descending) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ @@ -211,4 +211,4 @@ XSS_METHODS(avx2) #error "x86simdsortStatic methods needs to be compiled with avx512/avx2 specific flags" #endif // (__AVX512VL__ && __AVX512DQ__) || AVX2 -#endif // X86_SIMD_SORT_STATIC_METHODS +#endif // X86_SIMD_SORT_STATIC_METHODS \ No newline at end of file From bc06c20f9bd64d8fecd5414f270312723930daea Mon Sep 17 00:00:00 2001 From: "Ankit.Ahlawat@ibm.com" Date: Thu, 12 Feb 2026 00:00:53 +0530 Subject: [PATCH 3/6] argsort const is dropped only for SIMD type --- src/xss-common-argsort.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 6c071c2d..ad209fc1 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -675,8 +675,10 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, bool hasnan = false, bool descending = false) { - xss_argsort( - arr, arg, arrsize, hasnan, descending); + // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argsort( + const_cast(arr), arg, arrsize, hasnan, descending); } template @@ -686,8 +688,10 @@ X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, bool hasnan = false, bool descending = false) { - xss_argsort( - arr, arg, arrsize, hasnan, descending); + // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argsort( + const_cast(arr), arg, arrsize, hasnan, descending); } /* argselect methods for 32-bit and 64-bit dtypes */ From 7b49ea24e774b156a2a7f1c4a1f7de6dfad6d866 Mon Sep 17 00:00:00 2001 From: "Ankit.Ahlawat@ibm.com" Date: Fri, 13 Feb 2026 16:28:06 +0530 Subject: [PATCH 4/6] resolve clang-format issues --- lib/x86simdsort-avx2.cpp | 2 +- lib/x86simdsort.h | 6 ++++-- src/x86simdsort-static-incl.h | 15 ++++++++++----- src/xss-common-keyvaluesort.hpp | 18 ++++++------------ 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 74dfccae..1e0761e0 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -23,7 +23,7 @@ } \ template <> \ std::vector argsort( \ - const type *arr, size_t arrsize, bool hasnan, bool descending) \ + const type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 4f17b167..c4109187 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -35,8 +35,10 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, // argsort template -XSS_EXPORT_SYMBOL std::vector -argsort(const T *arr, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL std::vector argsort(const T *arr, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argselect template diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 26252a41..852da1f1 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -25,8 +25,10 @@ X86_SIMD_SORT_FINLINE void partial_qsort(T *arr, bool descending = false); template -X86_SIMD_SORT_FINLINE std::vector -argsort(const T *arr, size_t size, bool hasnan = false, bool descending = false); +X86_SIMD_SORT_FINLINE std::vector argsort(const T *arr, + size_t size, + bool hasnan = false, + bool descending = false); /* argsort API required by NumPy: */ template @@ -90,14 +92,17 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, ISA##_partial_qsort(arr, k, size, hasnan, descending); \ } \ template \ - X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort( \ - const T *arr, size_t *arg, size_t size, bool hasnan, bool descending) \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort(const T *arr, \ + size_t *arg, \ + size_t size, \ + bool hasnan, \ + bool descending) \ { \ ISA##_argsort(arr, arg, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argsort( \ - const T *arr, size_t size, bool hasnan, bool descending) \ + const T *arr, size_t size, bool hasnan, bool descending) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 3a07e01b..dda03f32 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -563,10 +563,8 @@ X86_SIMD_SORT_INLINE void kvselect_(type1_t *keys, template - typename full_vector, - template - typename half_vector> + template typename full_vector, + template typename half_vector> X86_SIMD_SORT_INLINE void xss_qsort_kv( T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan, bool descending) { @@ -654,10 +652,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( template - typename full_vector, - template - typename half_vector> + template typename full_vector, + template typename half_vector> X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, @@ -719,10 +715,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, template - typename full_vector, - template - typename half_vector> + template typename full_vector, + template typename half_vector> X86_SIMD_SORT_INLINE void xss_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, From f1da85b1f20c7022b1a605c78a40502f0c9e2d7c Mon Sep 17 00:00:00 2001 From: "Ankit.Ahlawat@ibm.com" Date: Sat, 14 Feb 2026 10:38:18 +0530 Subject: [PATCH 5/6] Normalize line endings and fix clang-format --- src/xss-common-argsort.h | 1512 +++++++++++++++++++------------------- 1 file changed, 754 insertions(+), 758 deletions(-) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index ad209fc1..dffefb0a 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -1,758 +1,754 @@ -/******************************************************************* - * Copyright (C) 2022 Intel Corporation - * SPDX-License-Identifier: BSD-3-Clause - * Authors: Raghuveer Devulapalli - * ****************************************************************/ - -#ifndef XSS_COMMON_ARGSORT -#define XSS_COMMON_ARGSORT - -#include "xss-network-keyvaluesort.hpp" -#include - -template -X86_SIMD_SORT_INLINE void std_argselect_withnan( - T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) -{ - std::nth_element(arg + left, - arg + k, - arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - // sort indices according to corresponding array element - return arr[left] < arr[right]; - }); -} - -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - argtype::mask_compressstoreu( - arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu( - arg + right - amount_gt_pivot, gt_mask, arg_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - return amount_gt_pivot; -} - -/* - * Parition one AVX2 register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); - typename argtype::opmask_t ge_mask - = resize_mask(ge_mask_vtype); - - auto l_store = arg + left; - auto r_store = arg + right - vtype::numlanes; - - int amount_ge_pivot - = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); - - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - - return amount_ge_pivot; -} - -template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - if constexpr (vtype::vec_type == simd_type::AVX512) { - return partition_vec_avx512(arg, - left, - right, - arg_vec, - curr_vec, - pivot_vec, - smallest_vec, - biggest_vec); - } - else if constexpr (vtype::vec_type == simd_type::AVX2) { - return partition_vec_avx2(arg, - left, - right, - arg_vec, - curr_vec, - pivot_vec, - smallest_vec, - biggest_vec); - } - else { - static_assert(sizeof(argreg_t) == 0, "Should not get here"); - } -} - -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE arrsize_t argpartition(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - using argreg_t = typename argtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - argreg_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot - = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - argreg_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left = vtype::i64gather(arr, arg + left); - argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec; - reg_t curr_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::i64gather(arr, arg + right); - } - else { - arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::i64gather(arr, arg + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); - ; - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - if (right - left <= 8 * num_unroll * vtype::numlanes) { - return argpartition( - arr, arg, left, right, pivot, smallest, biggest); - } - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - using argreg_t = typename argtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - // first and last vtype::numlanes values are partitioned at the end - reg_t vec_left[num_unroll], vec_right[num_unroll]; - argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); - argvec_right[ii] = argtype::loadu( - arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::i64gather( - arr, arg + (right - vtype::numlanes * (num_unroll - ii))); - } - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += num_unroll * vtype::numlanes; - right -= num_unroll * vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec[num_unroll]; - reg_t curr_vec[num_unroll]; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll * vtype::numlanes; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] - = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + right + ii * vtype::numlanes); - } - } - else { - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + left + ii * vtype::numlanes); - } - left += num_unroll * vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - } - - /* partition and save vec_left and vec_right */ - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - arrsize_t *arg, - const arrsize_t left, - const arrsize_t right) -{ - if constexpr (vtype::numlanes == 8) { - if (right - left >= vtype::numlanes) { - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]], - arr[arg[left + 5 * size]], - arr[arg[left + 6 * size]], - arr[arg[left + 7 * size]], - arr[arg[left + 8 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[4]; - } - else { - return arr[arg[right]]; - } - } - else if constexpr (vtype::numlanes == 4) { - if (right - left >= vtype::numlanes) { - // median of 4 - arrsize_t size = (right - left) / 4; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[2]; - } - else { - return arr[arg[right]]; - } - } -} - -template -X86_SIMD_SORT_INLINE void argsort_(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters, - arrsize_t task_threshold) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = argpartition_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); -#ifdef XSS_COMPILE_OPENMP - if (pivot != smallest) { - bool parallel_left = (pivot_index - left) > task_threshold; - if (parallel_left) { -#pragma omp task - argsort_(arr, - arg, - left, - pivot_index - 1, - max_iters - 1, - task_threshold); - } - else { - argsort_(arr, - arg, - left, - pivot_index - 1, - max_iters - 1, - task_threshold); - } - } - if (pivot != biggest) { - bool parallel_right = (right - pivot_index) > task_threshold; - - if (parallel_right) { -#pragma omp task - argsort_(arr, - arg, - pivot_index, - right, - max_iters - 1, - task_threshold); - } - else { - argsort_(arr, - arg, - pivot_index, - right, - max_iters - 1, - task_threshold); - } - } -#else - UNUSED(task_threshold); - if (pivot != smallest) - argsort_( - arr, arg, left, pivot_index - 1, max_iters - 1, 0); - if (pivot != biggest) - argsort_( - arr, arg, pivot_index, right, max_iters - 1, 0); -#endif -} - -template -X86_SIMD_SORT_INLINE void argselect_(type_t *arr, - arrsize_t *arg, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = argpartition_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - argselect_( - arr, arg, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_( - arr, arg, pos, pivot_index, right, max_iters - 1); -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template - typename full_vector, - template - typename half_vector> -X86_SIMD_SORT_INLINE void xss_argsort(T *arr, - arrsize_t *arg, - arrsize_t arrsize, - bool hasnan = false, - bool descending = false) -{ - - using vectype = typename std::conditional, - full_vector>::type; - - using argtype = - typename std::conditional, - full_vector>::type; - - if (arrsize > 1) { - /* simdargsort does not work for float/double arrays with nan */ - if constexpr (xss::fp::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - - if (descending) { std::reverse(arg, arg + arrsize); } - - return; - } - } - UNUSED(hasnan); - - /* early exit for already sorted arrays: float/double with nan never reach here*/ - auto comp = descending ? Comparator::STDSortComparator - : Comparator::STDSortComparator; - if (std::is_sorted(arr, arr + arrsize, comp)) { return; } - -#ifdef XSS_COMPILE_OPENMP - - bool use_parallel = arrsize > 10000; - - if (use_parallel) { - int thread_count = xss_get_num_threads(); - arrsize_t task_threshold - = std::max((arrsize_t)10000, arrsize / 100); - - // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ - // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems - // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays -#pragma omp parallel num_threads(thread_count) -#pragma omp single - argsort_(arr, - arg, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize), - task_threshold); -#pragma omp taskwait - } - else { - argsort_(arr, - arg, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize), - std::numeric_limits::max()); - } -#else - argsort_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); -#endif - - if (descending) { std::reverse(arg, arg + arrsize); } - } - -#ifdef __MMX__ - // Workaround for compiler bug generating MMX instructions without emms - _mm_empty(); -#endif -} - -template -X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, - arrsize_t *arg, - arrsize_t arrsize, - bool hasnan = false, - bool descending = false) -{ - // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation - using base_t = std::remove_const_t; - xss_argsort( - const_cast(arr), arg, arrsize, hasnan, descending); -} - -template -X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, - arrsize_t *arg, - arrsize_t arrsize, - bool hasnan = false, - bool descending = false) -{ - // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation - using base_t = std::remove_const_t; - xss_argsort( - const_cast(arr), arg, arrsize, hasnan, descending); -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template - typename full_vector, - template - typename half_vector> -X86_SIMD_SORT_INLINE void xss_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ - using vectype = typename std::conditional, - full_vector>::type; - - using argtype = - typename std::conditional, - full_vector>::type; - - if (arrsize > 1) { - if constexpr (xss::fp::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } - -#ifdef __MMX__ - // Workaround for compiler bug generating MMX instructions without emms - _mm_empty(); -#endif -} - -template -X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - xss_argselect(arr, arg, k, arrsize, hasnan); -} - -template -X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - xss_argselect( - arr, arg, k, arrsize, hasnan); -} - -#endif // XSS_COMMON_ARGSORT +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef XSS_COMMON_ARGSORT +#define XSS_COMMON_ARGSORT + +#include "xss-network-keyvaluesort.hpp" +#include + +template +X86_SIMD_SORT_INLINE void std_argselect_withnan( + T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](arrsize_t a, arrsize_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + // sort indices according to corresponding array element + return arr[left] < arr[right]; + }); +} + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + argtype::mask_compressstoreu( + arg + left, vtype::knot_opmask(gt_mask), arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_gt_pivot; +} + +/* + * Parition one AVX2 register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); + typename argtype::opmask_t ge_mask + = resize_mask(ge_mask_vtype); + + auto l_store = arg + left; + auto r_store = arg + right - vtype::numlanes; + + int amount_ge_pivot + = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); + + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + + return amount_ge_pivot; +} + +template +X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + if constexpr (vtype::vec_type == simd_type::AVX512) { + return partition_vec_avx512(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else if constexpr (vtype::vec_type == simd_type::AVX2) { + return partition_vec_avx2(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else { + static_assert(sizeof(argreg_t) == 0, "Should not get here"); + } +} + +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t argpartition(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + argreg_t argvec = argtype::loadu(arg + left); + reg_t vec = vtype::i64gather(arr, arg + left); + int32_t amount_gt_pivot + = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + argreg_t argvec_left = argtype::loadu(arg + left); + reg_t vec_left = vtype::i64gather(arr, arg + left); + argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec; + reg_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + arg_vec = argtype::loadu(arg + right); + curr_vec = vtype::i64gather(arr, arg + right); + } + else { + arg_vec = argtype::loadu(arg + left); + curr_vec = vtype::i64gather(arr, arg + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); + ; + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 8 * num_unroll * vtype::numlanes) { + return argpartition( + arr, arg, left, right, pivot, smallest, biggest); + } + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + // first and last vtype::numlanes values are partitioned at the end + reg_t vec_left[num_unroll], vec_right[num_unroll]; + argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); + argvec_right[ii] = argtype::loadu( + arg + (right - vtype::numlanes * (num_unroll - ii))); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec[num_unroll]; + reg_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); + } + } + else { + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + } + + /* partition and save vec_left and vec_right */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + arrsize_t *arg, + const arrsize_t left, + const arrsize_t right) +{ + if constexpr (vtype::numlanes == 8) { + if (right - left >= vtype::numlanes) { + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } + } + else if constexpr (vtype::numlanes == 4) { + if (right - left >= vtype::numlanes) { + // median of 4 + arrsize_t size = (right - left) / 4; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[2]; + } + else { + return arr[arg[right]]; + } + } +} + +template +X86_SIMD_SORT_INLINE void argsort_(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters, + arrsize_t task_threshold) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = argpartition_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); +#ifdef XSS_COMPILE_OPENMP + if (pivot != smallest) { + bool parallel_left = (pivot_index - left) > task_threshold; + if (parallel_left) { +#pragma omp task + argsort_(arr, + arg, + left, + pivot_index - 1, + max_iters - 1, + task_threshold); + } + else { + argsort_(arr, + arg, + left, + pivot_index - 1, + max_iters - 1, + task_threshold); + } + } + if (pivot != biggest) { + bool parallel_right = (right - pivot_index) > task_threshold; + + if (parallel_right) { +#pragma omp task + argsort_(arr, + arg, + pivot_index, + right, + max_iters - 1, + task_threshold); + } + else { + argsort_(arr, + arg, + pivot_index, + right, + max_iters - 1, + task_threshold); + } + } +#else + UNUSED(task_threshold); + if (pivot != smallest) + argsort_( + arr, arg, left, pivot_index - 1, max_iters - 1, 0); + if (pivot != biggest) + argsort_( + arr, arg, pivot_index, right, max_iters - 1, 0); +#endif +} + +template +X86_SIMD_SORT_INLINE void argselect_(type_t *arr, + arrsize_t *arg, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = argpartition_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_( + arr, arg, pos, pivot_index, right, max_iters - 1); +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template typename full_vector, + template typename half_vector> +X86_SIMD_SORT_INLINE void xss_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) +{ + + using vectype = typename std::conditional, + full_vector>::type; + + using argtype = + typename std::conditional, + full_vector>::type; + + if (arrsize > 1) { + /* simdargsort does not work for float/double arrays with nan */ + if constexpr (xss::fp::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + + if (descending) { std::reverse(arg, arg + arrsize); } + + return; + } + } + UNUSED(hasnan); + + /* early exit for already sorted arrays: float/double with nan never reach here*/ + auto comp = descending ? Comparator::STDSortComparator + : Comparator::STDSortComparator; + if (std::is_sorted(arr, arr + arrsize, comp)) { return; } + +#ifdef XSS_COMPILE_OPENMP + + bool use_parallel = arrsize > 10000; + + if (use_parallel) { + int thread_count = xss_get_num_threads(); + arrsize_t task_threshold + = std::max((arrsize_t)10000, arrsize / 100); + + // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ + // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems + // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays +#pragma omp parallel num_threads(thread_count) +#pragma omp single + argsort_(arr, + arg, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize), + task_threshold); +#pragma omp taskwait + } + else { + argsort_(arr, + arg, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize), + std::numeric_limits::max()); + } +#else + argsort_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); +#endif + + if (descending) { std::reverse(arg, arg + arrsize); } + } + +#ifdef __MMX__ + // Workaround for compiler bug generating MMX instructions without emms + _mm_empty(); +#endif +} + +template +X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) +{ + // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argsort( + const_cast(arr), arg, arrsize, hasnan, descending); +} + +template +X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) +{ + // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argsort( + const_cast(arr), arg, arrsize, hasnan, descending); +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template typename full_vector, + template typename half_vector> +X86_SIMD_SORT_INLINE void xss_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ + using vectype = typename std::conditional, + full_vector>::type; + + using argtype = + typename std::conditional, + full_vector>::type; + + if (arrsize > 1) { + if constexpr (xss::fp::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + +#ifdef __MMX__ + // Workaround for compiler bug generating MMX instructions without emms + _mm_empty(); +#endif +} + +template +X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + xss_argselect(arr, arg, k, arrsize, hasnan); +} + +template +X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + xss_argselect( + arr, arg, k, arrsize, hasnan); +} + +#endif // XSS_COMMON_ARGSORT From 5b86d5ed260b7fbaa7b730da78218f9ba6b79f1e Mon Sep 17 00:00:00 2001 From: "Ankit.Ahlawat@ibm.com" Date: Sun, 15 Feb 2026 13:55:09 +0530 Subject: [PATCH 6/6] Format templates using clang-format-18 to match CI --- src/xss-common-argsort.h | 12 ++++++++---- src/xss-common-keyvaluesort.hpp | 18 ++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index dffefb0a..9af9e709 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -586,8 +586,10 @@ X86_SIMD_SORT_INLINE void argselect_(type_t *arr, /* argsort methods for 32-bit and 64-bit dtypes */ template typename full_vector, - template typename half_vector> + template + typename full_vector, + template + typename half_vector> X86_SIMD_SORT_INLINE void xss_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, @@ -694,8 +696,10 @@ X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, /* argselect methods for 32-bit and 64-bit dtypes */ template typename full_vector, - template typename half_vector> + template + typename full_vector, + template + typename half_vector> X86_SIMD_SORT_INLINE void xss_argselect(T *arr, arrsize_t *arg, arrsize_t k, diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index dda03f32..3a07e01b 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -563,8 +563,10 @@ X86_SIMD_SORT_INLINE void kvselect_(type1_t *keys, template typename full_vector, - template typename half_vector> + template + typename full_vector, + template + typename half_vector> X86_SIMD_SORT_INLINE void xss_qsort_kv( T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan, bool descending) { @@ -652,8 +654,10 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( template typename full_vector, - template typename half_vector> + template + typename full_vector, + template + typename half_vector> X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, @@ -715,8 +719,10 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, template typename full_vector, - template typename half_vector> + template + typename full_vector, + template + typename half_vector> X86_SIMD_SORT_INLINE void xss_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k,