Skip to content

FA4: Flash Attention 4 with WGMMA for SM120 #192

@m96-chan

Description

@m96-chan

Flash Attention 4 - SM120 (RTX 5090) Implementation

⚠️ Important: SM100 vs SM120

Modal Blog FA4 is for SM100 (B100/B200 datacenter), NOT SM120!

Feature SM100 (Datacenter) SM120 (RTX 5090)
MMA Instruction tcgen05.mma mma.sync.block_scale
Tensor Memory 256KB TMEM None
Cluster Up to 16 SM 1x1x1 only

Goal

Create the fastest attention kernel for SM120 (RTX 5090) using:

  • mma.sync.aligned.block_scale (NOT tcgen05)
  • NVFP4 format (4x throughput vs Ada FP8!)
  • TMA async loads
  • Smart softmax optimizations from FA4 paper

Target Performance

FA3 (Current) FA4 SM120 (Target)
MMA WMMA 16x16x16 block_scale 64x64x64
Data Type BF16 NVFP4
Throughput ~60 TFLOPS 200+ TFLOPS (FP4)

Implementation Phases

  • Phase 1: BF16 baseline with WMMA
  • Phase 2: Block-scaled FP8 (mma.sync.block_scale.e4m3)
  • Phase 3: NVFP4 (maximum throughput)
  • Phase 4: Optimization (polynomial exp, smart rescale)

Key Advantages of SM120 Approach

  1. NVFP4 throughput: 4x vs Ada FP8
  2. Smaller memory footprint: 3-stage pipeline fits in 99KB smem
  3. TMA available: Async bulk loads work on SM120

Applicable Ideas from Modal FA4 Blog

  • ✅ Warp specialization pattern
  • ✅ Cubic polynomial exp approximation
  • ✅ Smart rescaling (10x fewer corrections)
  • ✅ Deep K/V buffering with TMA
  • ❌ tcgen05.mma (SM100 only)
  • ❌ Tensor Memory (SM100 only)
  • ❌ Multi-CTA cluster (SM100 only)

Research Notes

See .serena/memories/fa4_sm120_research.md

References

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions