Skip to content

SM120 (Blackwell) FA3 Attention Optimization #191

@m96-chan

Description

@m96-chan

Overview

The FA3 TMA attention kernel is working correctly with deterministic output. This issue tracks further SM120-specific optimizations.

Current Performance (RTX 5090, SM120a)

seq_len Kernel-only E2E cached
1024 51.21 TFLOPS 46.49 TFLOPS
2048 59.86 TFLOPS 57.39 TFLOPS
  • Correctness: PASS
  • Determinism: PASS (fixed in a7c814c)
  • TMA cache hit rate: 99.4%

Optimization Opportunities

High Priority

  • WGMMA Migration - Replace WMMA with Warp Group MMA (SM90+ instruction)

    • Expected: +20-30% throughput
    • Requires PTX inline assembly or CUTLASS 3.x
    • Reference: CUTLASS wgmma.mma_async examples
  • FP8 Attention - E4M3/E5M2 format support

    • Expected: +50-100% (memory bandwidth bound cases)
    • Requires FP8 Q/K/V inputs and accumulator handling
    • SM120 supports native FP8 Tensor Core ops

Medium Priority

  • Tile Size Tuning - Optimize for SM120's 99KB smem limit

    • Current: TILE_Q=32, TILE_KV=64, ~96KB smem
    • Could increase TILE_Q to 64 with careful layout
  • Pipeline Depth - Increase from 2-stage to 3-stage

    • Need to balance smem usage vs occupancy

Low Priority

  • Swizzled Layout - Bank conflict reduction
  • Epilogue Fusion - Fuse output scaling/casting

Profiling TODO

  • Run NCU to identify current bottlenecks
  • Compare with cuDNN/FlashAttention-3 reference

References

  • FlashAttention-3 paper (Dao et al., 2024)
  • CUTLASS 3.x Hopper/Blackwell examples
  • NVIDIA Blackwell Architecture Whitepaper

Related

  • Commit a7c814c: Fixed non-determinism bug (union race condition)
  • TMA descriptor cache implementation

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions