-
Notifications
You must be signed in to change notification settings - Fork 401
cuda::device::warp_match_any
#9243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fbusato
wants to merge
9
commits into
NVIDIA:main
Choose a base branch
from
fbusato:warp_match_any
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+296
−7
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
53ab696
implementation
fbusato a5064cf
header
fbusato efeaa27
unit test
fbusato d5c0bdb
documentation
fbusato feb923e
Update libcudacxx/test/libcudacxx/cuda/warp/warp_match_any.pass.cpp
fbusato c0fe5cf
Update libcudacxx/test/libcudacxx/cuda/warp/warp_match_any.pass.cpp
fbusato faa21ec
formatting
fbusato a17e799
Merge branch 'warp_match_any' of https://github.com/fbusato/cccl into…
fbusato de51be0
avoid launching a kernel
fbusato File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| .. _libcudacxx-extended-api-warp-warp-match-any: | ||
|
|
||
| ``cuda::device::warp_match_any`` | ||
| ================================ | ||
|
|
||
| Defined in ``<cuda/warp>`` header. | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| namespace cuda::device { | ||
|
|
||
| template <typename T> | ||
| [[nodiscard]] __device__ lane_mask | ||
| warp_match_any(const T& data, lane_mask = lane_mask::all()); | ||
|
|
||
| } // namespace cuda::device | ||
|
|
||
| The functionality provides a generalized and safe alternative to CUDA warp match any intrinsic ``__match_any_sync``. | ||
| The function allows bitwise comparison of any data size, including raw arrays, pointers, and structs. | ||
|
|
||
| .. note:: | ||
|
|
||
| The underlying CUDA intrinsic does not provide memory ordering. | ||
|
|
||
| **Parameters** | ||
|
|
||
| - ``data``: data to compare. | ||
| - ``lane_mask``: mask of the active lanes. | ||
|
|
||
| **Return value** | ||
|
|
||
| - A ``lane_mask`` representing the non-exited lanes in ``lane_mask`` that have the same bitwise value for ``data`` as the calling lane. | ||
|
|
||
| **Constraints** | ||
|
|
||
| - ``T`` shall be trivially copyable, see :ref:`cuda::is_trivially_copyable <libcudacxx-extended-api-type_traits-is_trivially_copyable>`. | ||
| - When ``__builtin_clear_padding`` is not supported, ``T`` shall have no padding bits, that is, ``T``'s value representation shall be identical to its object representation. | ||
|
|
||
| **Preconditions** | ||
|
|
||
| - The functionality is only supported on ``SM >= 70``. | ||
| - ``lane_mask`` must be non-zero. | ||
|
|
||
| **Undefined Behavior** | ||
|
|
||
| - ``lane_mask`` must represent a subset of the active lanes. | ||
| - All non-exited lanes specified by ``lane_mask`` must execute the function with the same ``lane_mask`` value. | ||
|
|
||
| **Performance considerations** | ||
|
|
||
| - The function calls the PTX instruction ``match.sync`` :math:`ceil\left(\frac{sizeof(data)}{4}\right)` times. | ||
| - The function is faster when called with a mask representing all active lanes in a warp (default value of the second parameter ``lane_mask``). | ||
|
|
||
| **References** | ||
|
|
||
| - `CUDA match_any Intrinsics <https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html#warp-match-functions>`_ | ||
| - `PTX match.sync instruction <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-match-sync>`_ | ||
|
|
||
| Example | ||
| ------- | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| #include <cuda/std/array> | ||
| #include <cuda/std/cassert> | ||
| #include <cuda/warp> | ||
|
|
||
| struct MyStruct { | ||
| double x; // 8 bytes | ||
| int y; // 4 bytes | ||
| }; // 4 bytes of padding | ||
|
|
||
| __global__ void warp_match_kernel() { | ||
| { | ||
| auto mask = cuda::device::warp_match_any(threadIdx.x / 4); | ||
| auto expected = cuda::device::lane_mask{0b1111 << ((threadIdx.x / 4) * 4)}; | ||
| assert(mask == expected); | ||
| } | ||
| { | ||
| auto mask = cuda::device::warp_match_any(2); | ||
| auto expected = cuda::device::lane_mask{0xFFFFFFFF}; | ||
| assert(mask == expected); | ||
| } | ||
| { | ||
| // compile error, except when __builtin_clear_padding is supported | ||
| auto mask = cuda::device::warp_match_any(MyStruct{1.0, 3}); | ||
| auto expected = cuda::device::lane_mask{0xFFFFFFFF}; | ||
| assert(mask == expected); | ||
| } | ||
| } | ||
|
|
||
| int main() { | ||
| warp_match_kernel<<<1, 32>>>(); | ||
| cudaDeviceSynchronize(); | ||
| return 0; | ||
| } | ||
|
|
||
| `See it on Godbolt 🔗 <https://godbolt.org/z/Ys1McG8nv>`_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // Part of libcu++, the C++ Standard Library for your entire system, | ||
| // under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef _CUDA___WARP_WARP_MATCH_ANY_H | ||
| #define _CUDA___WARP_WARP_MATCH_ANY_H | ||
|
|
||
| #include <cuda/std/detail/__config> | ||
|
|
||
| #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
| # pragma GCC system_header | ||
| #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
| # pragma clang system_header | ||
| #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
| # pragma system_header | ||
| #endif // no system header | ||
|
|
||
| #if _CCCL_CUDA_COMPILATION() | ||
|
|
||
| # include <cuda/__cmath/ceil_div.h> | ||
| # include <cuda/__type_traits/is_bitwise_comparable.h> | ||
| # include <cuda/__type_traits/is_trivially_copyable.h> | ||
| # include <cuda/__warp/lane_mask.h> | ||
| # include <cuda/std/__cstring/memcpy.h> | ||
| # include <cuda/std/__memory/addressof.h> | ||
| # include <cuda/std/cstdint> | ||
|
|
||
| # include <cuda/std/__cccl/prologue.h> | ||
|
|
||
| _CCCL_BEGIN_NAMESPACE_CUDA_DEVICE | ||
|
|
||
| extern "C" _CCCL_DEVICE void __cuda__match_any_sync_is_not_supported_before_SM_70__(); | ||
|
|
||
| //! @brief Returns the mask of lanes with the same bitwise value as the calling lane. | ||
| //! | ||
| //! @param[in] __data The data to compare across lanes. | ||
| //! @param[in] __lane_mask The mask of participating lanes. | ||
| //! | ||
| //! @return A lane mask containing lanes in `__lane_mask` whose `__data` matches the calling lane's data. | ||
| template <class _Tp> | ||
| [[nodiscard]] _CCCL_DEVICE_API lane_mask | ||
| warp_match_any(const _Tp& __data, const lane_mask __lane_mask = lane_mask::all()) noexcept | ||
| { | ||
| static_assert(is_trivially_copyable_v<_Tp>, "data must be trivially copyable"); | ||
| _CCCL_ASSERT(__lane_mask != lane_mask::none(), "lane_mask must be non-zero"); | ||
|
|
||
| constexpr int __ratio = ::cuda::ceil_div(sizeof(_Tp), sizeof(::cuda::std::uint32_t)); | ||
| ::cuda::std::uint32_t __array[__ratio]{}; | ||
|
|
||
| # if defined(_CCCL_BUILTIN_CLEAR_PADDING) | ||
| auto __data_copy = __data; | ||
| _CCCL_BUILTIN_CLEAR_PADDING(&__data_copy); | ||
| const auto __data_ptr = ::cuda::std::addressof(__data_copy); | ||
| # else // ^^^ _CCCL_BUILTIN_CLEAR_PADDING ^^^ / vvv !_CCCL_BUILTIN_CLEAR_PADDING vvv | ||
| static_assert(is_bitwise_comparable_v<_Tp>, "data must be bitwise comparable"); | ||
| const auto __data_ptr = ::cuda::std::addressof(__data); | ||
| # endif // _CCCL_BUILTIN_CLEAR_PADDING | ||
| ::cuda::std::memcpy(__array, __data_ptr, sizeof(_Tp)); | ||
|
|
||
| lane_mask __ret = __lane_mask; | ||
| _CCCL_PRAGMA_UNROLL_FULL() | ||
| for (int i = 0; i < __ratio; ++i) | ||
| { | ||
| ::cuda::std::uint32_t __match_any_result = 0; | ||
| NV_IF_ELSE_TARGET(NV_PROVIDES_SM_70, | ||
| (__match_any_result = ::__match_any_sync(__lane_mask.value(), __array[i]);), | ||
| (::cuda::device::__cuda__match_any_sync_is_not_supported_before_SM_70__();)); | ||
| __ret &= lane_mask{__match_any_result}; | ||
| } | ||
| return __ret; | ||
| } | ||
|
|
||
| _CCCL_END_NAMESPACE_CUDA_DEVICE | ||
|
|
||
| # include <cuda/std/__cccl/epilogue.h> | ||
|
|
||
| #endif // _CCCL_CUDA_COMPILATION() | ||
| #endif // _CUDA___WARP_WARP_MATCH_ANY_H | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
libcudacxx/test/libcudacxx/cuda/warp/warp_match_any.pass.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // Part of the libcu++ Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // UNSUPPORTED: pre-sm-70 | ||
|
|
||
| // UNSUPPORTED: enable-tile | ||
| // error: asm statement is unsupported in tile code | ||
|
|
||
| #include <cuda/std/array> | ||
| #include <cuda/std/cassert> | ||
| #include <cuda/std/cstdint> | ||
| #include <cuda/warp> | ||
|
|
||
| #include "test_macros.h" | ||
|
|
||
| TEST_DEVICE_FUNC uint32_t make_low_mask(unsigned count) | ||
| { | ||
| return count == 32 ? 0xFFFFFFFF : ((1u << count) - 1); | ||
| } | ||
|
|
||
| TEST_DEVICE_FUNC uint32_t make_stride_mask(unsigned count, unsigned step, unsigned remainder) | ||
| { | ||
| uint32_t mask = 0; | ||
| for (unsigned lane = 0; lane < count; ++lane) | ||
| { | ||
| if ((lane % step) == remainder) | ||
| { | ||
| mask |= uint32_t{1} << lane; | ||
| } | ||
| } | ||
| return mask; | ||
| } | ||
|
|
||
| template <typename T> | ||
| TEST_DEVICE_FUNC void test_all_equal(T value = T{}) | ||
| { | ||
| for (unsigned i = 1; i <= 32; ++i) | ||
| { | ||
| auto mask = cuda::device::lane_mask{make_low_mask(i)}; | ||
| if (threadIdx.x < i) | ||
| { | ||
| assert(cuda::device::warp_match_any(value, mask) == mask); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // two different groups of lanes | ||
| template <typename T> | ||
| TEST_DEVICE_FUNC void test_grouped(T valueA = T{}, T valueB = T{1}) | ||
| { | ||
| for (unsigned i = 2; i <= 32; ++i) | ||
| { | ||
| auto mask = cuda::device::lane_mask{make_low_mask(i)}; | ||
| if (threadIdx.x < i) | ||
| { | ||
| auto value = threadIdx.x % 2 == 0 ? valueA : valueB; | ||
| auto expected = cuda::device::lane_mask{make_stride_mask(i, 2, threadIdx.x % 2)}; | ||
| assert(cuda::device::warp_match_any(value, mask) == expected); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| TEST_DEVICE_FUNC void test() | ||
| { | ||
| using array_t = cuda::std::array<char, 6>; | ||
| test_all_equal<uint8_t>(); | ||
| test_all_equal<uint16_t>(); | ||
| test_all_equal<uint32_t>(); | ||
| test_all_equal<uint64_t>(); | ||
| #if _CCCL_HAS_INT128() | ||
| test_all_equal<__uint128_t>(); | ||
| #endif // _CCCL_HAS_INT128() | ||
| test_all_equal(char3{0, 0, 0}); | ||
| test_all_equal(array_t{0, 0, 0, 0, 0, 0}); | ||
|
|
||
| test_grouped<uint8_t>(); | ||
| test_grouped<uint16_t>(); | ||
| test_grouped<uint32_t>(); | ||
| test_grouped<uint64_t>(); | ||
| #if _CCCL_HAS_INT128() | ||
| test_grouped<__uint128_t>(); | ||
| #endif // _CCCL_HAS_INT128() | ||
| test_grouped(char3{0, 0, 0}, char3{1, 1, 1}); | ||
| test_grouped(array_t{0, 0, 0, 0, 0, 0}, array_t{1, 1, 1, 1, 1, 1}); | ||
| } | ||
|
|
||
| int main(int, char**) | ||
| { | ||
| NV_DISPATCH_TARGET(NV_IS_HOST, (cuda_thread_count = 32;), NV_IS_DEVICE, (test();)) | ||
| return 0; | ||
| } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Important: This introduces a needless copy. I believe this should only copy if
is_bitwise_comparable_vis falseUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'm aware of this issue. However,
is_bitwise_comparable_valso checks padding. I need to introduce another (internal) traits for that. I will open a second PRThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine by me, although I would like to have an integer overload that does not do any of that and just forwards to the builtin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is needed. Direct call and
cuda::device::warp_match_anyproduce identical code as expected https://godbolt.org/z/sKaWz5dv1