-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Flash Attention 4 - SM120 (RTX 5090) Implementation
⚠️ Important: SM100 vs SM120
Modal Blog FA4 is for SM100 (B100/B200 datacenter), NOT SM120!
| Feature | SM100 (Datacenter) | SM120 (RTX 5090) |
|---|---|---|
| MMA Instruction | tcgen05.mma |
mma.sync.block_scale |
| Tensor Memory | 256KB TMEM | None |
| Cluster | Up to 16 SM | 1x1x1 only |
Goal
Create the fastest attention kernel for SM120 (RTX 5090) using:
mma.sync.aligned.block_scale(NOT tcgen05)- NVFP4 format (4x throughput vs Ada FP8!)
- TMA async loads
- Smart softmax optimizations from FA4 paper
Target Performance
| FA3 (Current) | FA4 SM120 (Target) | |
|---|---|---|
| MMA | WMMA 16x16x16 | block_scale 64x64x64 |
| Data Type | BF16 | NVFP4 |
| Throughput | ~60 TFLOPS | 200+ TFLOPS (FP4) |
Implementation Phases
- Phase 1: BF16 baseline with WMMA
- Phase 2: Block-scaled FP8 (mma.sync.block_scale.e4m3)
- Phase 3: NVFP4 (maximum throughput)
- Phase 4: Optimization (polynomial exp, smart rescale)
Key Advantages of SM120 Approach
- NVFP4 throughput: 4x vs Ada FP8
- Smaller memory footprint: 3-stage pipeline fits in 99KB smem
- TMA available: Async bulk loads work on SM120
Applicable Ideas from Modal FA4 Blog
- ✅ Warp specialization pattern
- ✅ Cubic polynomial exp approximation
- ✅ Smart rescaling (10x fewer corrections)
- ✅ Deep K/V buffering with TMA
- ❌ tcgen05.mma (SM100 only)
- ❌ Tensor Memory (SM100 only)
- ❌ Multi-CTA cluster (SM100 only)
Research Notes
See .serena/memories/fa4_sm120_research.md
References
- CUTLASS example 79 - SM120 GEMM
- Modal FA4 Blog - SM100 reference
- CUTLASS Issue #2186 - SM120 support
Related
- SM120 (Blackwell) FA3 Attention Optimization #191 - SM120 FA3 Optimization
- Commit a7c814c - FA3 TMA determinism fix
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request