diff --git a/README.md b/README.md index eb4e1e3..b79e44a 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ data types. ## Arg sort routines on arrays ```cpp -std::vector arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending); +std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending); std::vector arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 58305dc..1e0761e 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -23,7 +23,7 @@ } \ template <> \ std::vector argsort( \ - type *arr, size_t arrsize, bool hasnan, bool descending) \ + const type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index a9ded64..f8a14c0 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -44,7 +44,7 @@ bool hasnan = false, \ bool descending = false); \ template \ - XSS_HIDE_SYMBOL std::vector argsort(T *arr, \ + XSS_HIDE_SYMBOL std::vector argsort(const T *arr, \ size_t arrsize, \ bool hasnan = false, \ bool descending = false); \ diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 3dc737c..95fab42 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -71,7 +71,7 @@ namespace scalar { } template std::vector - argsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed) { UNUSED(hasnan); std::vector arg(arrsize); diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 7d9d5aa..f4c4125 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -23,7 +23,7 @@ } \ template <> \ std::vector argsort( \ - type *arr, size_t arrsize, bool hasnan, bool descending) \ + const type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 8ef9aad..7aecbea 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -88,11 +88,11 @@ namespace x86simdsort { #define DECLARE_INTERNAL_argsort(TYPE) \ static std::vector (*internal_argsort##TYPE)( \ - TYPE *, size_t, bool, bool) \ + const TYPE *, size_t, bool, bool) \ = NULL; \ template <> \ std::vector argsort( \ - TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ } diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index e1402fe..c410918 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -35,8 +35,10 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, // argsort template -XSS_EXPORT_SYMBOL std::vector -argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL std::vector argsort(const T *arr, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argselect template diff --git a/src/README.md b/src/README.md index 87757b2..2e52a45 100644 --- a/src/README.md +++ b/src/README.md @@ -63,7 +63,7 @@ Equivalent to `np.argsort` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html). ```cpp -void x86simdsortStatic::argsort(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false); +void x86simdsortStatic::argsort(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 519abe6..852da1f 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -25,12 +25,14 @@ X86_SIMD_SORT_FINLINE void partial_qsort(T *arr, bool descending = false); template -X86_SIMD_SORT_FINLINE std::vector -argsort(T *arr, size_t size, bool hasnan = false, bool descending = false); +X86_SIMD_SORT_FINLINE std::vector argsort(const T *arr, + size_t size, + bool hasnan = false, + bool descending = false); /* argsort API required by NumPy: */ template -X86_SIMD_SORT_FINLINE void argsort(T *arr, +X86_SIMD_SORT_FINLINE void argsort(const T *arr, size_t *arg, size_t size, bool hasnan = false, @@ -90,14 +92,17 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, ISA##_partial_qsort(arr, k, size, hasnan, descending); \ } \ template \ - X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort( \ - T *arr, size_t *arg, size_t size, bool hasnan, bool descending) \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort(const T *arr, \ + size_t *arg, \ + size_t size, \ + bool hasnan, \ + bool descending) \ { \ ISA##_argsort(arr, arg, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argsort( \ - T *arr, size_t size, bool hasnan, bool descending) \ + const T *arr, size_t size, bool hasnan, bool descending) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ @@ -211,4 +216,4 @@ XSS_METHODS(avx2) #error "x86simdsortStatic methods needs to be compiled with avx512/avx2 specific flags" #endif // (__AVX512VL__ && __AVX512DQ__) || AVX2 -#endif // X86_SIMD_SORT_STATIC_METHODS +#endif // X86_SIMD_SORT_STATIC_METHODS \ No newline at end of file diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 7b805fa..9af9e70 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -675,8 +675,10 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, bool hasnan = false, bool descending = false) { - xss_argsort( - arr, arg, arrsize, hasnan, descending); + // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argsort( + const_cast(arr), arg, arrsize, hasnan, descending); } template @@ -686,8 +688,10 @@ X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, bool hasnan = false, bool descending = false) { - xss_argsort( - arr, arg, arrsize, hasnan, descending); + // Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argsort( + const_cast(arr), arg, arrsize, hasnan, descending); } /* argselect methods for 32-bit and 64-bit dtypes */