Skip to content

[WIP] Persistent + grouped tile MoE batched GEMM #19219

Draft
Gasoonjia wants to merge 1 commit intohoist-activation-quantfrom
persistent-on-hoist
Draft

[WIP] Persistent + grouped tile MoE batched GEMM #19219
Gasoonjia wants to merge 1 commit intohoist-activation-quantfrom
persistent-on-hoist

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

No description provided.

…-quant

Rewrites the four batched MoE GEMM kernels (BF16/INT8 GEMM1 + GEMM2) from
one-CTA-per-tile to a persistent grid: launch min(NUM_SMS, num_tiles) programs,
each loops over its share of (expert_block, n_block) tiles via
`tl.range(start_pid, num_tiles, NUM_SMS)`.

Tiles are visited in column-major-within-group order (Triton tutorial style)
so consecutive M-blocks of the same expert reuse B[expert, n_block, *] via L2.
Since moe_align_block_size already sorts blocks by expert, this gives clean
weight reuse across the group.

Adds NUM_SMS (constexpr) and GROUP_SIZE_M (autotuned over {8, 16, 32}) to
all four batched kernels and their wrappers. NUM_SMS is queried once per
device via _num_sms() with lru_cache.

Benchmark on A100 80GB, prefill=1642, decode=512, --cuda_graph
(persistent_on_hoist vs hoist baseline):

  hoist baseline:       6171 prefill (5941-6313),   99.0 decode tok/s
  persistent + grouped: 6127 prefill (6037-6315),   84.5 decode tok/s

Best-case prefill is essentially tied with the hoist baseline at this
prefill length on A100 (108 SMs). The hoist optimization had already
shortened the K-loop by removing per-tile activation quantization, so the
persistent kernel's L2 reuse gain doesn't show on top of it for M=1642.
The optimization should matter more for short prefills (M <= 512) and on
GPUs with fewer SMs / bigger L2 (e.g. RTX 4090 with 128 SMs and 72MB L2)
where wave quantization is more severe.

Behavior preserved: kernels are mathematically identical to the hoist
baseline; only the tile traversal order and grid launch shape changed.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19219

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures, 3 Cancelled Jobs, 3 Pending, 1 Unrelated Failure

As of commit a67881e with merge base cb4e5ae (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2026
@Gasoonjia Gasoonjia force-pushed the persistent-on-hoist branch from 015476d to a67881e Compare April 30, 2026 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant