Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ data types.

## Arg sort routines on arrays
```cpp
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending);
std::vector<size_t> arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending);
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
Comment thread
AnkitAhlawat7742 marked this conversation as resolved.
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double,
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort-avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
} \
template <> \
std::vector<size_t> argsort( \
type *arr, size_t arrsize, bool hasnan, bool descending) \
const type *arr, size_t arrsize, bool hasnan, bool descending) \
{ \
return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \
} \
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
bool hasnan = false, \
bool descending = false); \
template <typename T> \
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, \
XSS_HIDE_SYMBOL std::vector<size_t> argsort(const T *arr, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort-scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace scalar {
}
template <typename T>
std::vector<size_t>
argsort(T *arr, size_t arrsize, bool hasnan, bool reversed)
argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed)
{
UNUSED(hasnan);
std::vector<size_t> arg(arrsize);
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
} \
template <> \
std::vector<size_t> argsort( \
type *arr, size_t arrsize, bool hasnan, bool descending) \
const type *arr, size_t arrsize, bool hasnan, bool descending) \
{ \
return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \
} \
Expand Down
4 changes: 2 additions & 2 deletions lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ namespace x86simdsort {

#define DECLARE_INTERNAL_argsort(TYPE) \
static std::vector<size_t> (*internal_argsort##TYPE)( \
TYPE *, size_t, bool, bool) \
const TYPE *, size_t, bool, bool) \
= NULL; \
template <> \
std::vector<size_t> argsort( \
TYPE *arr, size_t arrsize, bool hasnan, bool descending) \
const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \
{ \
return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \
}
Expand Down
6 changes: 4 additions & 2 deletions lib/x86simdsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr,

// argsort
template <typename T>
XSS_EXPORT_SYMBOL std::vector<size_t>
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
XSS_EXPORT_SYMBOL std::vector<size_t> argsort(const T *arr,
size_t arrsize,
bool hasnan = false,
bool descending = false);

// argselect
template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Equivalent to `np.argsort` in
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html).

```cpp
void x86simdsortStatic::argsort<T>(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false);
void x86simdsortStatic::argsort<T>(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false);
```
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
`double`.
Expand Down
19 changes: 12 additions & 7 deletions src/x86simdsort-static-incl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ X86_SIMD_SORT_FINLINE void partial_qsort(T *arr,
bool descending = false);

template <typename T>
X86_SIMD_SORT_FINLINE std::vector<size_t>
argsort(T *arr, size_t size, bool hasnan = false, bool descending = false);
X86_SIMD_SORT_FINLINE std::vector<size_t> argsort(const T *arr,
size_t size,
bool hasnan = false,
bool descending = false);

/* argsort API required by NumPy: */
template <typename T>
X86_SIMD_SORT_FINLINE void argsort(T *arr,
X86_SIMD_SORT_FINLINE void argsort(const T *arr,
size_t *arg,
size_t size,
bool hasnan = false,
Expand Down Expand Up @@ -90,14 +92,17 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key,
ISA##_partial_qsort(arr, k, size, hasnan, descending); \
} \
template <typename T> \
X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort( \
T *arr, size_t *arg, size_t size, bool hasnan, bool descending) \
X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort(const T *arr, \
size_t *arg, \
size_t size, \
bool hasnan, \
bool descending) \
{ \
ISA##_argsort(arr, arg, size, hasnan, descending); \
} \
template <typename T> \
X86_SIMD_SORT_FINLINE std::vector<size_t> x86simdsortStatic::argsort( \
T *arr, size_t size, bool hasnan, bool descending) \
const T *arr, size_t size, bool hasnan, bool descending) \
{ \
std::vector<size_t> indices(size); \
std::iota(indices.begin(), indices.end(), 0); \
Expand Down Expand Up @@ -211,4 +216,4 @@ XSS_METHODS(avx2)
#error "x86simdsortStatic methods needs to be compiled with avx512/avx2 specific flags"
#endif // (__AVX512VL__ && __AVX512DQ__) || AVX2

#endif // X86_SIMD_SORT_STATIC_METHODS
#endif // X86_SIMD_SORT_STATIC_METHODS
12 changes: 8 additions & 4 deletions src/xss-common-argsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,10 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
bool hasnan = false,
bool descending = false)
{
xss_argsort<T, zmm_vector, ymm_vector>(
arr, arg, arrsize, hasnan, descending);
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
using base_t = std::remove_const_t<T>;
xss_argsort<base_t, zmm_vector, ymm_vector>(
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
}

template <typename T>
Expand All @@ -686,8 +688,10 @@ X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
bool hasnan = false,
bool descending = false)
{
xss_argsort<T, avx2_vector, avx2_half_vector>(
arr, arg, arrsize, hasnan, descending);
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
Comment thread
AnkitAhlawat7742 marked this conversation as resolved.
using base_t = std::remove_const_t<T>;
xss_argsort<base_t, avx2_vector, avx2_half_vector>(
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
}

/* argselect methods for 32-bit and 64-bit dtypes */
Expand Down
Loading