Skip to content

Add CUB cooperative collectives#9266

Draft
tpn wants to merge 4 commits into
NVIDIA:mainfrom
tpn:dev/cuda-coop-cub-primitives
Draft

Add CUB cooperative collectives#9266
tpn wants to merge 4 commits into
NVIDIA:mainfrom
tpn:dev/cuda-coop-cub-primitives

Conversation

@tpn
Copy link
Copy Markdown
Contributor

@tpn tpn commented Jun 5, 2026

Description

closes

Adds CUB cooperative collective adapters with result-placement semantics used by downstream cuda.coop frontends:

  • cub::WarpReduceBroadcast for warp reductions whose aggregate is returned to every logical lane.
  • cub::WarpReduceBatchedBroadcast for batched all-lane warp reductions.
  • cub::BlockReduceBroadcast for block reductions whose aggregate is returned to every thread.
  • cub::BlockRowReduce and cub::BlockRowReduceWarpBroadcast for row-shaped block reductions used by norm-style kernels.

The implementation now lives in public CUB headers under the cub:: namespace:

  • cub/warp/warp_reduce_broadcast.cuh
  • cub/warp/warp_reduce_batched_broadcast.cuh
  • cub/block/block_reduce_broadcast.cuh
  • cub/block/block_row_reduce.cuh

This also adds focused Catch2 coverage and an nvbench target that compares the new adapters with equivalent handwritten or existing-CUB reduction idioms.

Style/guideline review notes:

  • Followed the repository AGENTS.md / CONTRIBUTING.md guidance: pre-commit-managed formatting, targeted CUB build/test, and local consistency with nearby CUB warp/block collectives.
  • Used CCCL device API macros, existing temporary-storage conventions, CUB namespace macros, SPDX headers, and CUB-style RST/Doxygen comments.

Performance and adoption map

The table below combines the focused C++ primitive benchmarks with the cuda.coop CUTE/Numba wrapper experiments. The main point is that these primitives are not all sold the same way: owner-lane batched reduction is a clear speedup; all-lane and block-broadcast forms are mostly about replacing bespoke shuffle/shared-memory boilerplate with a named primitive while preserving performance.

Shape / likely consumer Handwritten or existing baseline Primitive / downstream spelling Measured result Why it matters
Full-warp all-lane sum Manual shuffle allreduce + broadcast cub::WarpReduceBroadcast<T>::Sum(x) 32.739 us -> 28.829 us, 1.136x Tangible speedup and removes a hand-written shuffle tree for the common “every lane needs the aggregate” contract.
Four independent reductions, owner-lane result Four serial CUB warp reductions cub::WarpReduceBatched<T, 4>::Sum(items); Numba wrapper: coop.warp.batched_sum(inputs, threads_in_warp=32) 96.818 us -> 27.187 us, 3.561x Strongest perf case. The repeated-kernel SASS count drops from 2560 SHFL instructions to 768. Numba reaches the intended codegen shape too: owner batched has 0 SHFL and 4 REDUX instances in the wrapper experiment.
FlashAttention-style width-4 allreduce Recursive Allreduce<4> / inline shuffle helper cub::WarpReduceBroadcast<T, 4>::Sum(x) 14.309 us -> 14.483 us, 0.988x This is near parity, not a speedup, but it gives a standard CUB spelling for code that projects currently hand-roll with recursive templates. It also fixes the earlier wrong-shape generic coop result that was 0.109x.
Four all-lane reductions Four direct width-4/all-lane reductions cub::WarpReduceBatchedBroadcast<T, 4, Width>::Sum(items); Numba wrapper: coop.warp.batched_sum_broadcast(inputs, outputs, ...); CUTE wrapper shape: coop.batched_sum(items) 86.668 us -> 88.848 us, 0.975x Not the algorithmic speedup case. The sell is code shape: one primitive call instead of four handwritten allreduces, with roughly parity C++ performance.
Block reduction where every thread needs the result CUB BlockReduce, shared scalar, if (threadIdx.x == 0), two barriers cub::BlockReduceBroadcast<T, BlockDimX>::Sum(x); CUTE-facing shape: coop.sum(value) 24.553 us -> 24.555 us, 1.000x in the isolated C++ bench; vLLM-like RMSNorm is parity for large H and 0.964x at H=4096 Replaces fragile shared-memory broadcast boilerplate with one semantic primitive. For RMSNorm, global memory traffic dominates at larger hidden sizes, so this is mainly maintainability unless fused kernels reduce memory traffic.
FlashInfer CUTE RMSNorm anchor FlashInfer CUTE utilities: warp_reduce, block_reduce, row_reduce_sum CUTE-facing coop.sum(...) / row-reduction provider path 28.56 us -> 29.07 us, 0.982x Good code-reduction target: FlashInfer carries roughly 50+ lines of CUTE reduction utility code before RMSNorm logic. Current CUTE provider path is still slightly slower even after the provider boundary is inlined away.
CUTE wrapper microbench Direct inline CUTE shuffle tree cuda.coop.cutlass.cute.sum / cuda.coop.cutlass.cute.batched_sum Scalar: 3.520 us -> 5.440 us, 0.647x. Batched x4: 3.968 us -> 11.936 us, 0.332x on the stale non-inlined provider path This was the old blocker. The current DKG LTO-IR caller fix removes the provider call boundary, but CUTE still needs richer return/value plumbing before batched wrappers can reliably inherit the C++ primitive quality.

Takeaway: the C++ primitive layer already has one clear speedup case and several parity-with-less-boilerplate cases. The downstream Python story is credible as an API simplification path today, but CUTE still needs richer batched return plumbing before we should promise broad end-to-end speedups there.

Validation

  • pre-commit run --files cub/cub/warp/warp_reduce_broadcast.cuh cub/cub/warp/warp_reduce_batched_broadcast.cuh cub/cub/block/block_reduce_broadcast.cuh cub/cub/block/block_row_reduce.cuh cub/test/catch2_test_coop_collectives.cu cub/benchmarks/bench/collectives/coop_collectives.cu cub/cub/cub.cuh
  • git diff --check
  • ninja -C build/cub-cpp20 cub.test.coop_collectives
  • ctest --test-dir build/cub-cpp20 -R '^cub.test.coop_collectives$' --output-on-failure
  • ninja -C build/cub-benchmark cub.bench.collectives.coop_collectives.base

Checklist

  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Draft note: the new public headers include API comments and test snippets, but the CUB docs index/API page still needs a public-facing docs pass before this should be marked ready for review.

tpn added 3 commits May 30, 2026 20:03
Add experimental adapters and examples for warp and block collective result-placement semantics that are useful to cuda.coop frontends. The examples cover broadcast warp reduction, batched warp reduction, broadcast block reduction, row/segmented reductions, and scan/broadcast composition.

Co-Authored-By: GPT-5.5 xhigh, Codex v0.130.0
Signed-off-by: Trent Nelson <trent@trent.me>
Add NVBench coverage for the experimental cooperative collective adapters and the direct CUB patterns they replace. The benchmark compares broadcast warp reduction, batched warp reduction, broadcast block reduction, row reduction, and scan/broadcast composition.

Co-Authored-By: GPT-5.5 xhigh, Codex v0.130.0
Signed-off-by: Trent Nelson <trent@trent.me>
Add a shuffle allreduce fast path for warp reduce broadcast sums, covering both full warp and tiny logical-warp reductions without routing through owner-lane storage.

Add a batched all-lane broadcast adapter for commutative reductions and benchmark it against serial batched reductions and the owner-lane WarpReduceBatched primitive.

Co-Authored-By: GPT-5.5 xhigh, Codex v0.130.0
Signed-off-by: Trent Nelson <trent@trent.me>
@copy-pr-bot
Copy link
Copy Markdown
Contributor

copy-pr-bot Bot commented Jun 5, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Progress in CCCL Jun 5, 2026
Signed-off-by: Trent Nelson <trent@trent.me>
@tpn tpn changed the title Add experimental CUB cooperative collectives Add CUB cooperative collectives Jun 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

1 participant