Skip to content

perf(llm): Batch decode missing CUDA Graph and zero-alloc optimizations #181

@m96-chan

Description

@m96-chan

Description

The LLM batch decode path (forward_batch_zero_alloc) has several TODO items for performance optimization that are not yet implemented.

Locations

1. CUDA Graph capture (model.py:916)

# TODO: CUDA Graph capture can be added once this path is validated.

The M=1 decode has CUDA Graph support (DecodeM1Graph), but batch decode (M>1) does not.

2. Zero-alloc Attention (model.py:965)

# TODO: Add forward_fixed_cache_batch_zero_alloc to Attention class
attn_out = block.attn.forward_fixed_cache_batch(
    norm_out_buf, start_position, context_len
)

The batch attention path still allocates intermediate buffers.

3. Zero-alloc MLP (model.py:985)

# TODO: Add zero-alloc MLP path
mlp_out = block.mlp(norm_out_buf)

The MLP layer allocates new buffers for gate_proj, up_proj, down_proj outputs.

Impact

  • Batch decode performance: Currently allocates buffers per layer per token
  • Memory fragmentation: Many small allocations stress the memory pool
  • CUDA Graph benefits: Batch decode would benefit from reduced kernel launch overhead

Current Performance (v0.2.11)

Batch Size Per Token (us) Throughput
1 381,303 2.6 tok/s
8 55,845 17.9 tok/s

With zero-alloc + CUDA Graph, expect ~20-30% improvement.

Required Work

  1. Add forward_fixed_cache_batch_zero_alloc to Attention class
  2. Add zero-alloc variants to MLP (reuse gate/up/down buffers)
  3. Implement CUDA Graph capture for batch decode loop
  4. Update DecodeBuffers with batch-specific pre-allocated buffers

Related

  • M=1 CUDA Graph: src/pygpukit/llm/decode/m1_graph.py
  • Batch decode: src/pygpukit/llm/decode/batch.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestv0.2Scheduler: memory + bandwidth guarantees

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions