diff --git a/cub/cub/detail/warpspeed/look_ahead.cuh b/cub/cub/detail/warpspeed/look_ahead.cuh index 9e91e09acb3..855fadcffde 100644 --- a/cub/cub/detail/warpspeed/look_ahead.cuh +++ b/cub/cub/detail/warpspeed/look_ahead.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -261,6 +262,88 @@ template return aggrExclusiveCtaCur; // must only be valid in lane_0 } +// Deterministic version of warpIncrementalLookahead that returns the same aggrExclusiveCta. The difference is that it +// always starts the lookahead from a tile index that is a multiple of 32: it shifts the left pointer (idxTilePrev) down +// to the nearest multiple of 32 and reduces from there. Because every reduction begins at the same fixed tiles, no +// matter which tiles happened to finish first, the order in which values are summed is always the same and the result +// is identical on every run. idxTilePrev/aggrExclusiveCtaPrev are updated by reference to the last multiple of 32. +template +[[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE AccumT warpIncrementalLookaheadStable( + SpecialRegisters specialRegisters, + tile_state_t* ptrTileStates, + int& idxTilePrev, + AccumT& aggrExclusiveCtaPrev, + const int idxTileNext, + ScanOpT& scan_op) +{ + const int laneIdx = specialRegisters.laneIdx; + const ::cuda::std::uint32_t lanemaskEq = ::cuda::ptx::get_sreg_lanemask_eq(); + + // Adjust the left pointer down to the nearest 32-multiple so we do batched sums + int idxTileCur = (idxTilePrev / 32) * 32; + AccumT aggrExclusiveCtaCur = aggrExclusiveCtaPrev; + + using warp_reduce_t = WarpReduce; + static_assert(sizeof(typename warp_reduce_t::TempStorage) <= 4, + "WarpReduce with non-trivial temporary storage is not supported yet in this kernel."); + [[maybe_unused]] typename warp_reduce_t::TempStorage temp_storage; + + using warp_reduce_or_t = WarpReduce<::cuda::std::uint32_t>; + typename warp_reduce_or_t::TempStorage temp_storage_or; + warp_reduce_or_t warp_reduce_or{temp_storage_or}; + constexpr ::cuda::std::bit_or<::cuda::std::uint32_t> or_op{}; + + while (idxTileCur < idxTileNext) + { + tile_state_t regTmpStates[numTileStatesPerThread]; + warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext); + + for (int idx = 0; idx < numTileStatesPerThread; ++idx) + { + // Bitmask with a 1 bit in the position of the current lane if current lane has a tile aggregate + const ::cuda::std::uint32_t lane_has_aggregate = + lanemaskEq * (regTmpStates[idx].state == scan_state::tile_aggregate); + + // Bitmask with 1 bits indicating which lane has a tile aggregate + const ::cuda::std::uint32_t warp_has_aggregate_mask = warp_reduce_or.Reduce(lane_has_aggregate, or_op); + + // Bitmask with 1 bits for all rightmost lanes having a tile aggregate + const ::cuda::std::uint32_t warp_right_aggregates_mask = warp_has_aggregate_mask & (~warp_has_aggregate_mask - 1); + + const ::cuda::std::uint32_t warp_right_aggregates_count = ::cuda::std::popcount(warp_right_aggregates_mask); + + // Only reduce once a fixed number of contiguous tile aggregates are available, so the reduction order is fixed. + const ::cuda::std::uint32_t expected_count = + static_cast<::cuda::std::uint32_t>(::cuda::std::min(32, idxTileNext - idxTileCur)); + if (warp_right_aggregates_count < expected_count) + { + break; + } + + const bool use_value = lanemaskEq & warp_right_aggregates_mask; + const AccumT value = use_value ? regTmpStates[idx].value : cuda::identity_element(); + const AccumT local_aggr = warp_reduce_t{temp_storage}.Reduce(value, scan_op); + + if (expected_count == 32) + { + aggrExclusiveCtaCur = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr); + idxTileCur += 32; + } + else + { + const AccumT full_aggr = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr); + idxTilePrev = idxTileCur; + aggrExclusiveCtaPrev = aggrExclusiveCtaCur; + return full_aggr; + } + } + } + + idxTilePrev = idxTileNext; + aggrExclusiveCtaPrev = aggrExclusiveCtaCur; + return aggrExclusiveCtaCur; // must only be valid in lane_0 +} + #endif // __cccl_ptx_isa >= 860 } // namespace detail::warpspeed diff --git a/cub/cub/device/dispatch/kernels/kernel_scan.cuh b/cub/cub/device/dispatch/kernels/kernel_scan.cuh index ad959dee9dd..f6a9e91be9b 100644 --- a/cub/cub/device/dispatch/kernels/kernel_scan.cuh +++ b/cub/cub/device/dispatch/kernels/kernel_scan.cuh @@ -205,12 +205,12 @@ __launch_bounds__(device_scan_launch_bounds, 1) _CCCL_KERNEL_ATT if constexpr (active_policy.algorithm == scan_algorithm::warpspeed) { #if _CCCL_CUDACC_AT_LEAST(12, 8) - NV_IF_TARGET( - NV_PROVIDES_SM_100, ({ - auto scan_params = scanKernelParams, it_value_t, AccumT>{ - d_in, d_out, tile_state.warpspeed, num_items, num_stages}; - device_scan_warpspeed_body(scan_params, scan_op, init_value); - })); + NV_IF_TARGET(NV_PROVIDES_SM_100, ({ + auto scan_params = scanKernelParams, it_value_t, AccumT>{ + d_in, d_out, tile_state.warpspeed, num_items, num_stages}; + device_scan_warpspeed_body( + scan_params, scan_op, init_value); + })); #else static_assert(sizeof(d_in) == 0, "Implementation bug: Tuning policy selected warpspeed, but CUDA compiler does not support it"); diff --git a/cub/cub/device/dispatch/kernels/kernel_scan_warpspeed.cuh b/cub/cub/device/dispatch/kernels/kernel_scan_warpspeed.cuh index d1ef2a31ba0..a2b800a40ef 100644 --- a/cub/cub/device/dispatch/kernels/kernel_scan_warpspeed.cuh +++ b/cub/cub/device/dispatch/kernels/kernel_scan_warpspeed.cuh @@ -259,7 +259,8 @@ template + bool ForceInclusive, + bool StableReductionOrder = false> struct warpspeed_scan_closure { static constexpr scan_warpspeed_policy policy = current_policy().warpspeed; @@ -327,14 +328,27 @@ struct warpspeed_scan_closure if (!is_first_tile) { - AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookahead( - specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op); - if (squad.isLeaderThread()) + if constexpr (StableReductionOrder) + { + // The stable-order version updates idxTilePrev/AggrExclusiveCtaPrev itself + AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookaheadStable( + specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op); + if (squad.isLeaderThread()) + { + refAggrExclusiveCtaW.data() = regAggrExclusiveCta; + } + } + else { - refAggrExclusiveCtaW.data() = regAggrExclusiveCta; + AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookahead( + specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op); + if (squad.isLeaderThread()) + { + refAggrExclusiveCtaW.data() = regAggrExclusiveCta; + } + AggrExclusiveCtaPrev = regAggrExclusiveCta; + idxTilePrev = idxTile; } - AggrExclusiveCtaPrev = regAggrExclusiveCta; - idxTilePrev = idxTile; } } @@ -825,6 +839,7 @@ struct warpspeed_scan_closure template ; + using closure_t = warpspeed_scan_closure< + PolicySelector, + InputT, + OutputT, + AccumT, + ScanOpT, + RealInitValueT, + ForceInclusive, + StableReductionOrder>; warpspeed::squadDispatch( specialRegisters, closure_t::scanSquads, [&](warpspeed::Squad squad) _CCCL_FORCEINLINE_LAMBDA { // we load the initial value after the squad dispatch, so only the squads needing it emit an LDG diff --git a/cub/cub/device/dispatch/tuning/tuning_scan.cuh b/cub/cub/device/dispatch/tuning/tuning_scan.cuh index 78823b83033..67d35672766 100644 --- a/cub/cub/device/dispatch/tuning/tuning_scan.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_scan.cuh @@ -1035,7 +1035,9 @@ struct policy_selector [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator()(::cuda::compute_capability cc) const -> scan_policy { // we first try to get the valid warpspeed implementation. if we can't run it, fall back to the old scan impl. - if (!require_stable_reduction_order) + // For stable reduction order (fp + plus), warpspeed can only be used on sm_100+, Older arches fall back to classic + // lookback stable reduction order verison below. + if (!require_stable_reduction_order || cc >= ::cuda::compute_capability{10, 0}) { const auto warpspeed_policy_opt = get_warpspeed_policy(cc); if (warpspeed_policy_opt && can_use_warpspeed(cc, *warpspeed_policy_opt))