[KDA] Add intra-card CP for chunk_delta_h forward in SM100#70
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, adding specialized kernels for pre-scan and merge stages along with an auto-dispatch mechanism for long sequences. Feedback focuses on optimizing the pre-scan kernel by gating the key gate (gk) loading and decay logic when unused, which would prevent potential out-of-memory errors from large zero-tensor allocations and reduce memory traffic.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, featuring a three-stage pipeline (pre-scan, merge, and forward) to optimize long sequence processing. It includes specialized CuTeDSL kernels for Blackwell architectures, an auto-dispatch heuristic, and comprehensive benchmarks and tests. Reviewers identified a potential cache key collision in multi-GPU setups and recommended using torch.cuda.current_device() for safer indexing. Additionally, a function-level import of the math module should be moved to the top of the file to adhere to PEP 8 standards.
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, introducing a three-stage pipeline (pre-scan, merge, and forward H) to parallelize the processing of long sequences. The implementation includes specialized CuTeDSL kernels for fused pre-scan and prefix-scan merge operations, along with an auto-dispatch heuristic that triggers CP based on sequence length and SM occupancy. A new benchmark script and comprehensive accuracy tests are also provided. The reviewer noted a potential performance bottleneck where the auto-dispatch logic might trigger a synchronous GPU-to-CPU transfer if cu_seqlens_cpu is not provided, suggesting that this behavior should be documented or the heuristic refined to avoid unnecessary synchronization.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request introduces an intra-card context-parallel (CP) implementation for the chunk_delta_h operator, designed to accelerate the processing of long sequences on a single GPU. Key additions include a fused pre-scan kernel optimized for Blackwell architectures, a prefix-scan merge kernel, and an auto-dispatch heuristic that dynamically engages the CP path. The PR also includes a new benchmark suite and comprehensive accuracy tests. Reviewer feedback identifies opportunities to simplify the CP predicate API by removing unused parameters, optimize performance by moving local imports to the top level, and enhance robustness by replacing uninitialized dummy tensors with zero-initialized ones.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, optimizing long sequence processing through a three-stage pipeline consisting of pre-scan, prefix-scan merge, and forward H stages. The implementation leverages CuTeDSL for high-performance kernels and includes an auto-dispatch heuristic to balance CP overhead against potential speedups. Additionally, the PR adds a benchmark suite and extensive tests to ensure numerical consistency with the non-CP baseline and PyTorch references. Reviewer feedback highlights opportunities to improve CPU efficiency by optimizing tensor construction in _prepare_chunk_indices and streamlining the metadata caching logic in the merge step.
Replace _get_meta_tensors LRU cache with direct torch.tensor calls. The secondary cache required converting lists to tuples for hashing on every merge_fwd call. Since the metadata lists are tiny (1-4 int32 values), the hashing overhead outweighs any benefit. Each torch.tensor call produces an independently aligned allocation required by from_dlpack(assumed_align=16). A single merged tensor with slice views would not guarantee per-slice 16-byte alignment.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request introduces Intra-Card Context Parallel (CP) for the chunk_delta_h operation, which improves performance for long sequences on a single GPU by parallelizing across the sequence dimension. The implementation includes new CuTeDSL kernels for pre-scan and merge stages, an auto-dispatch heuristic, and comprehensive benchmarking and testing tools. A review comment correctly identified that the global _intracard_cache is not thread-safe and recommended implementing a threading.Lock to prevent race conditions when the function is called concurrently.
* upgrade fla and update b200 bench, update readme and fix lightning test param * update h200 bench result with fla bug fixed * update b200 bench * update b200 bench * fix readme * remove useless repeat_interleave for fla * fix readme --------- Co-authored-by: boyu.zbw <boyu.zbw@antgroup.com>
1. compute_subseq_len: remove power-of-2 snap, use floor division for target_splits. This ensures ceil(seq_chunks / target_splits) <= target_splits so Guard 3 in intracard_fwd_h no longer fires spuriously for the common single-long-seq case. 2. should_use_intracard_cp: add Guard 3 (expected sub-seq length check). CP merge work scales with H; require expected_subseq_c >= 12*H chunks. Add MIN_SUBSEQ_CHUNKS_PER_HEAD=12 constant. Restructure Guard 2 from an inline return to an explicit if so Guard 3 can follow. This fixes worst-case degradation (28K+4K H=8: 0.87x ??? 1.00x). All previously-degraded configs now correctly bypass CP. 3. tests: update ACCURACY_CONFIGS, FINAL_STATE_CONFIGS, STRESS and h0_none_equiv to use 65536 (instead of 32768) as the long sequence so they still exercise the CP kernel path under the new Guard 3 threshold. 29/29 tests pass. Worst triggered speedup: 1.11x. Max bypass overhead: 1.016x.
|
resolve file conflicts and refactor the benchmark according to the main branch like |
|
Okay. Aside from resolving the conflicts and aligning the benchmark with bench_kda.py, I'll also fix the remaining issues discussed earlier.
|
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request introduces an intra-card context parallel (CP) path for chunk_gated_delta_rule_fwd_h to unlock sequence-dimension parallelism on a single GPU. It includes a three-stage pipeline (pre-scan, prefix-scan merge, and forward H) along with auto-dispatch heuristics, benchmarks, tests, and documentation. The review comments provide highly actionable feedback, pointing out performance improvements by avoiding .item() inside loops, identifying an unused parameter in _precompute_intracard_indices, suggesting code simplifications in the benchmark suite, and recommending a NotImplementedError for the unimplemented SM90 pre-scan path to prevent a ModuleNotFoundError.
…otImplementedError for SM90 pre_scan
📌 Description
The serial bottleneck
chunk_gated_delta_rule_fwd_hruns a strictly sequential chunk recurrence inside each sequence(
h_t = decay_t · h_{t-1} + k_t^T @ v_t), so within one sequence the work cannot parallelize acrosschunks, only across the
(NUM_V_BLOCKS, H × num_seqs)grid.This becomes a bottleneck when both of the following hold:
H × num_seqsis small — the baseline grid under‑utilizes the SMs. A single long sequence atH=4occupies only
2 × 4 × 1 = 8SM “units” on a 152‑SM B200.dominates wall time, while short seqs finish early and leave SMs idle waiting on the one long chain.
Approach
Inspired by FLA’s intra‑card CP design (README), this PR splits long sequences into sub‑sequences on
the same card and parallelizes the recurrence across them via a 3‑stage pipeline:
(he, m): exit state[K, V]+ decay matrix[K, K].non‑first sub‑seqs.
chunk_gated_delta_rule_fwd_hon the split sub‑seqs with themerged init states.
Activation strategy
Auto‑dispatch is controlled by a set of entry conditions and guards so that workloads that don’t
benefit pay essentially zero overhead.
Entry conditions (in
chunk_gated_delta_rule_fwd_h):CULA_INTRACARD_CPis set to a value other than"0"._no_cp=False(recursive calls pass_no_cp=Trueto prevent re‑entry).cu_seqlens is not None(varlen mode).g is None(scalar gate disabled;gkkey‑gate is fine).torch.is_inference_mode_enabled().Pre‑split guards (in
should_use_intracard_cp):2·H·num_seqs ≥ SM).max_chunks < 256↔< 16Ktokens).Be·H > 10, whereBe = Σchunks / max(chunks)isa length‑weighted effective batch size —
Be → 1means one dominant seq,Be → num_seqsmeansbalanced).
(
expected_subseq_chunks < 12·H).Post‑split guard (in
intracard_fwd_h, afterprepare_subseq_cu_seqlens):total_subseqs · NUM_V_BLOCKS · H > SM→ fall back to baseline.In short, CP is enabled when
H × num_seqsis small andBeis close to 1 — i.e., when the batchcontains a long‑tail sequence that would otherwise serialize the card.
The thresholds (
MIN_SUBSEQ_CHUNKS=16,MIN_LONG_SEQ_CHUNKS=256,MAX_BE_H=10,NUM_V_BLOCKS=2,MIN_SUBSEQ_CHUNKS_PER_HEAD=12) are manually tuned on B200 SM100 based on benchsweeps and are isolated at the top of
cula/ops/cp/chunk_delta_h.py.🔍 Related Issues
Closes #20
🧪 Tests
⚡ Performance