Skip to content

perf(diffusion): FLUX.1 transformer performance optimization #187

@m96-chan

Description

@m96-chan

Problem

Current FLUX.1 implementation is significantly slower than diffusers reference (~140x slower).

Performance Analysis

Metric PyGPUkit Diffusers Gap
Total time (4 steps) ~420s ~3s 140x
Per-block time ~517ms ~5ms ~100x

Root Cause

Excessive H2D/D2H transfers due to numpy fallbacks:

Operation Time (ms) Issue
gpu_batched_matmul 58 Loop fallback on SM120
gpu_layer_norm 15 Numpy fallback
gated_residual 11 Numpy fallback (broadcast)
gpu_modulate 9 Numpy fallback (broadcast)

Total ~58 to_numpy() calls per forward pass causing GPU sync overhead.

Required Optimizations

Phase 1: GPU-native broadcast operations (High Priority)

  • gpu_modulate(x, scale, shift) - AdaLN modulation
  • gpu_gated_residual(x, gate, attn_out) - Gated addition
  • gpu_add_broadcast - Element-wise add with broadcasting

Phase 2: Batched matmul optimization (Medium Priority)

  • Fix SM120 batched matmul (currently uses loop fallback)
  • Single cuBLASLt call for all batches

Phase 3: Fused operations (Low Priority)

  • Fused QKV projection
  • Fused gate+residual
  • Fused AdaLN (norm + modulate)

Expected Improvement

Phase Expected Speedup
Phase 1 10-20x
Phase 2 2-3x
Phase 3 1.5-2x

Target: < 10s for 4-step generation (comparable to diffusers)

References

  • FLUX forward pass: src/pygpukit/diffusion/models/flux/model.py
  • GPU ops: src/pygpukit/diffusion/models/flux/ops.py
  • Attention: src/pygpukit/diffusion/models/flux/attention.py

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions