-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Overview
The FA3 TMA attention kernel is working correctly with deterministic output. This issue tracks further SM120-specific optimizations.
Current Performance (RTX 5090, SM120a)
| seq_len | Kernel-only | E2E cached |
|---|---|---|
| 1024 | 51.21 TFLOPS | 46.49 TFLOPS |
| 2048 | 59.86 TFLOPS | 57.39 TFLOPS |
- Correctness: PASS
- Determinism: PASS (fixed in a7c814c)
- TMA cache hit rate: 99.4%
Optimization Opportunities
High Priority
-
WGMMA Migration - Replace WMMA with Warp Group MMA (SM90+ instruction)
- Expected: +20-30% throughput
- Requires PTX inline assembly or CUTLASS 3.x
- Reference: CUTLASS
wgmma.mma_asyncexamples
-
FP8 Attention - E4M3/E5M2 format support
- Expected: +50-100% (memory bandwidth bound cases)
- Requires FP8 Q/K/V inputs and accumulator handling
- SM120 supports native FP8 Tensor Core ops
Medium Priority
-
Tile Size Tuning - Optimize for SM120's 99KB smem limit
- Current: TILE_Q=32, TILE_KV=64, ~96KB smem
- Could increase TILE_Q to 64 with careful layout
-
Pipeline Depth - Increase from 2-stage to 3-stage
- Need to balance smem usage vs occupancy
Low Priority
- Swizzled Layout - Bank conflict reduction
- Epilogue Fusion - Fuse output scaling/casting
Profiling TODO
- Run NCU to identify current bottlenecks
- Compare with cuDNN/FlashAttention-3 reference
References
- FlashAttention-3 paper (Dao et al., 2024)
- CUTLASS 3.x Hopper/Blackwell examples
- NVIDIA Blackwell Architecture Whitepaper
Related
- Commit a7c814c: Fixed non-determinism bug (union race condition)
- TMA descriptor cache implementation
Metadata
Metadata
Assignees
Labels
No labels