Skip to content

[KDA] Add intra-card CP for chunk_delta_h forward in SM100#70

Open
cherhh wants to merge 18 commits into
inclusionAI:mainfrom
cherhh:dev.kcp.v2
Open

[KDA] Add intra-card CP for chunk_delta_h forward in SM100#70
cherhh wants to merge 18 commits into
inclusionAI:mainfrom
cherhh:dev.kcp.v2

Conversation

@cherhh
Copy link
Copy Markdown
Collaborator

@cherhh cherhh commented May 17, 2026

📌 Description

The serial bottleneck

chunk_gated_delta_rule_fwd_h runs a strictly sequential chunk recurrence inside each sequence
(h_t = decay_t · h_{t-1} + k_t^T @ v_t), so within one sequence the work cannot parallelize across
chunks, only across the (NUM_V_BLOCKS, H × num_seqs) grid.

This becomes a bottleneck when both of the following hold:

  • H × num_seqs is small — the baseline grid under‑utilizes the SMs. A single long sequence at H=4
    occupies only 2 × 4 × 1 = 8 SM “units” on a 152‑SM B200.
  • The varlen batch is highly skewed with a long‑tail sequence — the long seq’s serial recurrence
    dominates wall time, while short seqs finish early and leave SMs idle waiting on the one long chain.

Approach

Inspired by FLA’s intra‑card CP design (README), this PR splits long sequences into sub‑sequences on
the same card and parallelizes the recurrence across them via a 3‑stage pipeline:

  1. Pre‑scan — per sub‑seq, compute packed (he, m): exit state [K, V] + decay matrix [K, K].
  2. Merge — prefix‑scan across sub‑seqs of the same original seq to produce init states for
    non‑first sub‑seqs.
  3. Forward H — run the existing chunk_gated_delta_rule_fwd_h on the split sub‑seqs with the
    merged init states.

Activation strategy

Auto‑dispatch is controlled by a set of entry conditions and guards so that workloads that don’t
benefit pay essentially zero overhead.

Entry conditions (in chunk_gated_delta_rule_fwd_h):

  • CULA_INTRACARD_CP is set to a value other than "0".
  • _no_cp=False (recursive calls pass _no_cp=True to prevent re‑entry).
  • cu_seqlens is not None (varlen mode).
  • g is None (scalar gate disabled; gk key‑gate is fine).
  • torch.is_inference_mode_enabled().

Pre‑split guards (in should_use_intracard_cp):

  • Guard 0: baseline already saturates SMs (2·H·num_seqs ≥ SM).
  • Guard 1: longest seq too short to amortize CP (max_chunks < 256< 16K tokens).
  • Guard 2: existing parallelism already high (Be·H > 10, where Be = Σchunks / max(chunks) is
    a length‑weighted effective batch size — Be → 1 means one dominant seq, Be → num_seqs means
    balanced).
  • Guard 3: expected sub‑seq too short after H‑scaled check
    (expected_subseq_chunks < 12·H).

Post‑split guard (in intracard_fwd_h, after prepare_subseq_cu_seqlens):

  • Guard 4: total_subseqs · NUM_V_BLOCKS · H > SM → fall back to baseline.

In short, CP is enabled when H × num_seqs is small and Be is close to 1 — i.e., when the batch
contains a long‑tail sequence that would otherwise serialize the card.

The thresholds (MIN_SUBSEQ_CHUNKS=16, MIN_LONG_SEQ_CHUNKS=256, MAX_BE_H=10,
NUM_V_BLOCKS=2, MIN_SUBSEQ_CHUNKS_PER_HEAD=12) are manually tuned on B200 SM100 based on bench
sweeps and are isolated at the top of cula/ops/cp/chunk_delta_h.py.

🔍 Related Issues

Closes #20

🧪 Tests

python -m pytest tests/test_intracard_cp.py -v
Running 29 items in this shard: tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens0-4-False], tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens1-4-True], tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens2-4-True], tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens3-8-True], tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens4-4-True], tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens5-4-False], tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens6-4-True], tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens7-8-True], tests/test_intracard_cp.py::test_cp_autodispatch_with_h0[seq_lens0-4], tests/test_intracard_cp.py::test_cp_autodispatch_with_h0[seq_lens1-4], tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[32768-4], tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[65536-4], tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[32768-8], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens0-4-False-False], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens1-4-True-True], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens2-4-True-True], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens3-4-True-False], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens4-4-False-True], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens5-4-True-True], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens6-4-True-False], tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens7-8-True-True], tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens0-4-False-False], tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens1-4-True-True], tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens2-8-True-True], tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens3-4-True-True], tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens4-4-True-False], tests/test_intracard_cp.py::test_intracard_cp_stress_repeat[single-64K-H4-gk-h0], tests/test_intracard_cp.py::test_intracard_cp_stress_repeat[multi-64K+4K-H4-gk-h0], tests/test_intracard_cp.py::test_intracard_cp_h0_none_equiv_h0_zeros

tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens0-4-False] PASSED                                                                                   [  3%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens1-4-True] PASSED                                                                                    [  6%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens2-4-True] PASSED                                                                                    [ 10%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens3-8-True] PASSED                                                                                    [ 13%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens4-4-True] PASSED                                                                                    [ 17%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens5-4-False] PASSED                                                                                   [ 20%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens6-4-True] PASSED                                                                                    [ 24%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens7-8-True] PASSED                                                                                    [ 27%]
tests/test_intracard_cp.py::test_cp_autodispatch_with_h0[seq_lens0-4] PASSED                                                                                                  [ 31%]
tests/test_intracard_cp.py::test_cp_autodispatch_with_h0[seq_lens1-4] PASSED                                                                                                  [ 34%]
tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[32768-4] PASSED                                                                                                       [ 37%]
tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[65536-4] PASSED                                                                                                       [ 41%]
tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[32768-8] PASSED                                                                                                       [ 44%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens0-4-False-False] PASSED                                                                                  [ 48%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens1-4-True-True] PASSED                                                                                    [ 51%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens2-4-True-True] PASSED                                                                                    [ 55%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens3-4-True-False] PASSED                                                                                   [ 58%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens4-4-False-True] PASSED                                                                                   [ 62%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens5-4-True-True] PASSED                                                                                    [ 65%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens6-4-True-False] PASSED                                                                                   [ 68%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens7-8-True-True] PASSED                                                                                    [ 72%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens0-4-False-False] PASSED                                                                             [ 75%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens1-4-True-True] PASSED                                                                               [ 79%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens2-8-True-True] PASSED                                                                               [ 82%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens3-4-True-True] PASSED                                                                               [ 86%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens4-4-True-False] PASSED                                                                              [ 89%]
tests/test_intracard_cp.py::test_intracard_cp_stress_repeat[single-64K-H4-gk-h0] PASSED                                                                                       [ 93%]
tests/test_intracard_cp.py::test_intracard_cp_stress_repeat[multi-64K+4K-H4-gk-h0] PASSED                                                                                     [ 96%]
tests/test_intracard_cp.py::test_intracard_cp_h0_none_equiv_h0_zeros PASSED                                                                                                   [100%]

⚡ Performance

python benchmarks/bench_intracard_cp.py
====================================================================================================
 Intracard CP Benchmark: CP-on vs CP-off
====================================================================================================


==============================================================================================================
                       BENCHMARK REPORT: Intracard CP
                       CP-on vs CP-off (same kernel, different code paths)
                       D=128  dtype=bf16  safe_gate=True
                       Warmup=10  Iters=100
==============================================================================================================

  [H=4]
  ───────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub  │  CP_off(ms)   CP_on(ms)   Speedup
  ───────────────────────────────────────────────────────────────────────────────────────────────
  T=4K                        4096     N     0  │      0.3301      0.3301     1.00x
  T=8K                        8192     N     0  │      0.3828      0.3803     1.01x
  T=32K                      32768     N     0  │      0.8398      0.8437     1.00x
  T=64K                      65536     Y    19  │      1.5202      0.9777     1.55x
  T=128K                    131072     Y    19  │      2.9095      1.4970     1.94x
  8x4K                       32768     N     0  │      0.3959      0.3947     1.00x
  4x8K                       32768     N     0  │      0.4562      0.4568     1.00x
  2x16K                      32768     N     0  │      0.5850      0.5853     1.00x
  16K+16K                    32768     N     0  │      0.5870      0.5869     1.00x
  24K+8K                     32768     N     0  │      0.7168      0.7119     1.01x
  28K+4K                     32768     N     0  │      0.7783      0.7759     1.00x
  32K+256+256                33280     N     0  │      0.8420      0.8409     1.00x
  40K+1K+8K                  50176     N     0  │      1.0565      1.0541     1.00x
  64K+512+256+128            66432     Y    19  │      1.5299      1.0031     1.53x
  128K+1K                   132096     Y    19  │      2.9176      1.5270     1.91x
  128K+2x1K                 133120     Y    19  │      3.0040      1.5288     1.96x
  128K+5x1K                 136192     Y    19  │      2.9353      1.6150     1.82x
  128K+10x1K                141312     Y    19  │      2.9706      1.6973     1.75x
  ───────────────────────────────────────────────────────────────────────────────────────────────

  [H=8]
  ───────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub  │  CP_off(ms)   CP_on(ms)   Speedup
  ───────────────────────────────────────────────────────────────────────────────────────────────
  T=4K                        4096     N     0  │      0.3341      0.3327     1.00x
  T=8K                        8192     N     0  │      0.4019      0.4045     0.99x
  T=32K                      32768     N     0  │      1.0120      1.0104     1.00x
  T=64K                      65536     Y     9  │      1.8864      1.5261     1.24x
  T=128K                    131072     Y     9  │      3.6368      2.5798     1.41x
  8x4K                       32768     N     0  │      0.5672      0.5670     1.00x
  4x8K                       32768     N     0  │      0.6285      0.6281     1.00x
  2x16K                      32768     N     0  │      0.7592      0.7573     1.00x
  16K+16K                    32768     N     0  │      0.7573      0.7586     1.00x
  24K+8K                     32768     N     0  │      0.8854      0.8848     1.00x
  28K+4K                     32768     N     0  │      0.9480      0.9498     1.00x
  32K+256+256                33280     N     0  │      1.0220      1.0224     1.00x
  40K+1K+8K                  50176     N     0  │      1.3359      1.3390     1.00x
  64K+512+256+128            66432     Y     9  │      1.8987      1.6763     1.13x
  128K+1K                   132096     Y     9  │      3.6473      2.6425     1.38x
  128K+2x1K                 133120     Y     9  │      3.6638      2.7529     1.33x
  128K+5x1K                 136192     Y     9  │      3.6969      2.9505     1.25x
  128K+10x1K                141312     N     0  │      3.7646      3.7645     1.00x
  ───────────────────────────────────────────────────────────────────────────────────────────────

  CP triggered (13 configs): geo-mean=1.53x  best=1.96x  worst=1.13x
  CP bypassed  (23 configs): mean overhead=0.999x  max=1.006x  (1.00 = no regression)

==============================================================================================================

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, adding specialized kernels for pre-scan and merge stages along with an auto-dispatch mechanism for long sequences. Feedback focuses on optimizing the pre-scan kernel by gating the key gate (gk) loading and decay logic when unused, which would prevent potential out-of-memory errors from large zero-tensor allocations and reduce memory traffic.

Comment thread cula/ops/cp/pre_scan.py Outdated
Comment thread cula/ops/cp/pre_scan.py Outdated
Comment thread cula/ops/cp/pre_scan.py Outdated
@cherhh cherhh changed the title Dev.kcp.v2 [KDA] Add intra-card CP for chunk_delta_h forward in SM100 May 17, 2026
@cherhh cherhh requested review from KevinZeng08 and icavan May 17, 2026 17:16
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, featuring a three-stage pipeline (pre-scan, merge, and forward) to optimize long sequence processing. It includes specialized CuTeDSL kernels for Blackwell architectures, an auto-dispatch heuristic, and comprehensive benchmarks and tests. Reviewers identified a potential cache key collision in multi-GPU setups and recommended using torch.cuda.current_device() for safer indexing. Additionally, a function-level import of the math module should be moved to the top of the file to adhere to PEP 8 standards.

Comment thread cula/ops/cp/merge.py Outdated
Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, introducing a three-stage pipeline (pre-scan, merge, and forward H) to parallelize the processing of long sequences. The implementation includes specialized CuTeDSL kernels for fused pre-scan and prefix-scan merge operations, along with an auto-dispatch heuristic that triggers CP based on sequence length and SM occupancy. A new benchmark script and comprehensive accuracy tests are also provided. The reviewer noted a potential performance bottleneck where the auto-dispatch logic might trigger a synchronous GPU-to-CPU transfer if cu_seqlens_cpu is not provided, suggesting that this behavior should be documented or the heuristic refined to avoid unnecessary synchronization.

Comment thread cula/ops/chunk_delta_h.py Outdated
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an intra-card context-parallel (CP) implementation for the chunk_delta_h operator, designed to accelerate the processing of long sequences on a single GPU. Key additions include a fused pre-scan kernel optimized for Blackwell architectures, a prefix-scan merge kernel, and an auto-dispatch heuristic that dynamically engages the CP path. The PR also includes a new benchmark suite and comprehensive accuracy tests. Reviewer feedback identifies opportunities to simplify the CP predicate API by removing unused parameters, optimize performance by moving local imports to the top level, and enhance robustness by replacing uninitialized dummy tensors with zero-initialized ones.

Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/merge.py Outdated
Comment thread cula/ops/cp/pre_scan.py Outdated
cherhh and others added 3 commits May 18, 2026 11:40
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, optimizing long sequence processing through a three-stage pipeline consisting of pre-scan, prefix-scan merge, and forward H stages. The implementation leverages CuTeDSL for high-performance kernels and includes an auto-dispatch heuristic to balance CP overhead against potential speedups. Additionally, the PR adds a benchmark suite and extensive tests to ensure numerical consistency with the non-CP baseline and PyTorch references. Reviewer feedback highlights opportunities to improve CPU efficiency by optimizing tensor construction in _prepare_chunk_indices and streamlining the metadata caching logic in the merge step.

Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/merge.py Outdated
cherhh and others added 3 commits May 18, 2026 12:55
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Replace _get_meta_tensors LRU cache with direct torch.tensor calls.
The secondary cache required converting lists to tuples for hashing
on every merge_fwd call. Since the metadata lists are tiny (1-4 int32
values), the hashing overhead outweighs any benefit.

Each torch.tensor call produces an independently aligned allocation
required by from_dlpack(assumed_align=16). A single merged tensor with
slice views would not guarantee per-slice 16-byte alignment.
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Intra-Card Context Parallel (CP) for the chunk_delta_h operation, which improves performance for long sequences on a single GPU by parallelizing across the sequence dimension. The implementation includes new CuTeDSL kernels for pre-scan and merge stages, an auto-dispatch heuristic, and comprehensive benchmarking and testing tools. A review comment correctly identified that the global _intracard_cache is not thread-safe and recommended implementing a threading.Lock to prevent race conditions when the function is called concurrently.

Comment thread cula/ops/cp/chunk_delta_h.py
Comment thread benchmarks/bench_intracard_cp.py Outdated
Comment thread cula/ops/cp/chunk_delta_h.py
Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/pre_scan.py
Copy link
Copy Markdown
Collaborator

@icavan icavan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread cula/ops/cp/pre_scan.py
Comment thread third_party/flash-linear-attention
KevinZeng08 and others added 3 commits May 22, 2026 11:46
* upgrade fla and update b200 bench, update readme and fix lightning test param

* update h200 bench result with fla bug fixed

* update b200 bench

* update b200 bench

* fix readme

* remove useless repeat_interleave for fla

* fix readme

---------

Co-authored-by: boyu.zbw <boyu.zbw@antgroup.com>
1. compute_subseq_len: remove power-of-2 snap, use floor division for
   target_splits. This ensures ceil(seq_chunks / target_splits) <= target_splits
   so Guard 3 in intracard_fwd_h no longer fires spuriously for the common
   single-long-seq case.

2. should_use_intracard_cp: add Guard 3 (expected sub-seq length check).
   CP merge work scales with H; require expected_subseq_c >= 12*H chunks.
   Add MIN_SUBSEQ_CHUNKS_PER_HEAD=12 constant. Restructure Guard 2 from
   an inline return to an explicit if so Guard 3 can follow.

   This fixes worst-case degradation (28K+4K H=8: 0.87x ??? 1.00x). All
   previously-degraded configs now correctly bypass CP.

3. tests: update ACCURACY_CONFIGS, FINAL_STATE_CONFIGS, STRESS and
   h0_none_equiv to use 65536 (instead of 32768) as the long sequence so
   they still exercise the CP kernel path under the new Guard 3 threshold.

29/29 tests pass. Worst triggered speedup: 1.11x. Max bypass overhead: 1.016x.
Copy link
Copy Markdown
Collaborator

@KevinZeng08 KevinZeng08 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@KevinZeng08
Copy link
Copy Markdown
Collaborator

resolve file conflicts and refactor the benchmark according to the main branch like bench_kda.py, then we can merge it

@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 25, 2026 via email

@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 27, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an intra-card context parallel (CP) path for chunk_gated_delta_rule_fwd_h to unlock sequence-dimension parallelism on a single GPU. It includes a three-stage pipeline (pre-scan, prefix-scan merge, and forward H) along with auto-dispatch heuristics, benchmarks, tests, and documentation. The review comments provide highly actionable feedback, pointing out performance improvements by avoiding .item() inside loops, identifying an unused parameter in _precompute_intracard_indices, suggesting code simplifications in the benchmark suite, and recommending a NotImplementedError for the unimplemented SM90 pre-scan path to prevent a ModuleNotFoundError.

Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/chunk_delta_h.py
Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread benchmarks/bench_intracard_cp.py
Comment thread cula/utils.py Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Identify and implement CUDA optimization opportunities for Intracard CP (single-card sequence splitting)

3 participants