Skip to content

ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator#5

Open
Geramy wants to merge 40 commits intoNripeshN:rocm-supportfrom
lemonade-sdk:rocm-optimizations
Open

ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator#5
Geramy wants to merge 40 commits intoNripeshN:rocm-supportfrom
lemonade-sdk:rocm-optimizations

Conversation

@Geramy
Copy link
Copy Markdown

@Geramy Geramy commented Mar 27, 2026

Summary

Comprehensive ROCm performance optimizations for LLM inference on all GPU architectures, with particular gains for MoE models on iGPU (Strix Halo gfx1151).

Kernel Optimizations

  • 4-bit fast gather QMV dispatch: 18.6x faster MoE expert dispatch
  • WMMA prefill kernel: rocWMMA 16x16x16 bf16 tiles on RDNA 3/3.5/4
  • Expert-grouped prefill kernel: groups batch elements by expert
  • GPU-only expert-batched kernel: binary search run boundaries on-GPU
  • Shared-memory QMV preferred: over noshared variant for decode
  • QMV 16 threads/col: better memory-level parallelism on RDNA 3.5
  • Strided copy kernels: 1 dispatch instead of 5 per non-contiguous copy

hipBLASLt Integration

  • hipBLASLt GEMM for bf16/fp16 in both Matmul and QMM dequant paths
  • Architecture-tuned kernels via heuristic algorithm selection
  • 2.6x prompt speedup on warm second call

Allocator Optimizations

  • hipStreamSynchronize instead of hipDeviceSynchronize for iGPU
  • hipExtMallocWithFlags primary, hipMallocManaged fallback for iGPU

Performance (Strix Halo gfx1151)

Qwen3-Coder-Next 4-bit (MoE, 512 experts):

Metric Before After
Prompt (warm) N/A (broken) 228 tok/s
Generation N/A (broken) 21.1 tok/s

Qwen3-8B 4-bit:

Metric Before After
Prompt (warm) 8 tok/s 76 tok/s
Generation 17 tok/s 14.3 tok/s

Architecture Support

  • WMMA: gfx11 (RDNA 3/3.5), gfx12 (RDNA 4), gfx9 (CDNA)
  • hipBLASLt: all ROCm architectures
  • Allocator: all (iGPU fast path auto-detected)
  • QMV tuning: all wave32 architectures

Geramy added 30 commits March 25, 2026 12:56
Integrates ROCm/AMD GPU backend from NripeshN/mlx (rocm-support branch).
Original ROCm work by @NripeshN — see https://github.com/NripeshN/mlx/tree/rocm-support

Conflicts resolved favoring ROCm-compatible code paths with upstream
device_count() checks preserved for non-ROCm builds.
- Add gfx1150, gfx1151, gfx1152 (RDNA 3.5) and gfx1200, gfx1201 (RDNA 4)
  to default HIP architecture list
- Use --parallel-jobs with auto-detected CPU count for hipcc so offload
  compilations for multiple architectures run in parallel
Ninja already parallelizes across HIP files, so using all CPUs per
hipcc invocation causes oversubscription.
- Add gpu::init() to eval.cpp to initialize HIP runtime
- Add SliceUpdate NO_GPU stub to primitives.cpp to fix linker errors
- Add compiled HIP kernel for slice update with reduce ops (Sum/Prod/Max/Min)
- ReduceType::None delegates to copy_gpu_inplace (no kernel needed)
- Kernel templated on dtype, Op, contiguity flags, and NWORK for perf
- Supports all 12 dtypes and all 4 reduce operations
- Remove NO_GPU(SliceUpdate) stub from primitives.cpp
- Fix dtype_to_hip_type: return "hip_bfloat16" not "__hip_bfloat16"
  (hiprtc doesn't recognize the double-underscore variant)
- Fix all JIT preamble unary ops (Sigmoid, Exp, Log, etc.) to promote
  half/bfloat16 to float before math, use native ops for float/double
- Fix binary ops (ArcTan2, Remainder, FloorDivide, LogAddExp) similarly
hiprtc lacks <type_traits> so std::is_same_v is unavailable.
Use unconditional float promotion for all unary/binary math ops
since static_cast<float>(float) is a no-op anyway.
KernelArgs::append(array) was using a.data<void>() which returns
the CPU-side pointer. Changed to gpu_ptr<void>(a) which returns
the actual GPU device pointer via the RocmBuffer, matching the
CUDA backend's implementation.

This caused "illegal memory access" crashes on all JIT fused
kernels since the GPU tried to read/write CPU memory addresses.
Stock ROCm packages don't include Tensile kernels for RDNA 3.5
(gfx115x) or RDNA 4 (gfx120x). When rocBLAS can't find the
kernel, it crashes the GPU with "illegal memory access" instead
of failing gracefully. Fall back to naive_gemm for these GPUs.
rocBLAS crashes the GPU with "illegal memory access" when a specific
Tensile kernel variant isn't available for the target architecture
(e.g., bfloat16 GEMM on gfx1151). Instead of crashing, check the
rocblas_status return value and fall back to naive_gemm.

Also fix all GEMM call sites to use gpu_ptr<T>() instead of
array::data<T>() to get proper GPU device pointers.
rocBLAS returns success from the API but crashes the GPU
asynchronously when the Tensile .co kernel files are corrupt or
missing specific bf16 GEMM variants (seen on gfx1151).

Fix: at device init, run a tiny 4x4 bf16 GEMM probe. If it crashes,
reset the GPU, mark bf16 as unavailable, and route all subsequent
bf16 GEMM calls to naive_gemm instead of rocBLAS.

Also use gpu_ptr<T>() consistently in all GEMM call sites.
rocBLAS Tensile .co files for bf16 are corrupt on gfx1151 — the
optimized kernel functions can't be loaded, causing GPU memory
faults. Small-matrix probes don't catch this because they use
fallback kernels that work, while larger inference-sized GEMMs
hit the corrupt optimized paths.

Route all bf16 GEMM to naive_gemm unconditionally. This is correct
for all architectures. Performance optimization for bf16 GEMM can
be added later with custom HIP kernels that don't depend on
Tensile.
Bug fixes:
- ArgReduce: add bfloat16 dispatch (was crashing with "Unsupported type")
- QMM: fix unsigned affine dequantization (uint8_t, no sign extension)
- Sort: add bounds check + rocprim radix sort for arrays > 4096 elements
- JIT: hash long kernel names to avoid 255-byte filesystem limit

Performance:
- Add optimized warp-cooperative GEMV kernel (qmv_kernel.hip)
  - Coalesced uint32 global loads (adjacent threads read adjacent words)
  - LDS for x vector sharing across 8 warps per block
  - Warp shuffle reduction (no shared memory needed for reduction)
  - 33x speedup for token generation (0.45 → 15 tok/s on Qwen3-8B-4bit)
  - 18x speedup for prompt processing
- Shared dequantization utilities in qdequant.hpp
- JIT compiled fused ops (Add, Subtract, Multiply, Divide) now promote
  half/bfloat16 through float to reduce precision loss compounding
  across 28-36 transformer layers
- Restore gfx1151 in rocBLAS supported list (ROCm 7.x has proper support)
- Keep bf16 naive_gemm bypass (Tensile bf16 may still have issues)
- GatherQMM eval_gpu: copy non-contiguous indices to contiguous before
  passing to GPU kernel (broadcast indices from gather_qmm ops have
  non-trivial strides that cause OOB when accessed as flat arrays)
- SDPA: add head_dim=256 to supported vector configs (needed for
  Qwen3-Next which uses 256-dim attention heads)
…s indices

- SDPA: use_fallback returns true for unsupported configs (head_dim or
  seq_len), framework decomposes into matmul+softmax+matmul GPU ops
- All matmul dtypes routed through naive_gemm (avoids rocBLAS Tensile
  init being affected by pending GPU errors from gather_qmm)
- GatherQMM: ensure indices are contiguous before GPU kernel
  (broadcast indices can have non-trivial strides)
- SDPA head_dim=256 support in optimized vector kernel
QMM output quality:
- Match Metal's qdot() accumulation pattern: separate integer dot product
  from scale/bias application. Instead of per-element `x*(scale*q+bias)`,
  compute `scale * dot(x, q_int) + bias * sum(x)` per group.
  Mathematically equivalent but matches Metal's bf16 rounding behavior
  that models are quantized against.

JIT compilation:
- Add StderrSuppressor RAII class to suppress AMD comgr preprocessed
  source dumps during hiprtcCompileProgram (thousands of lines of
  compiler defines were flooding terminal)
- Add tail_lines() to truncate error logs to last 60 lines on failure
- Include module name in compilation error messages
Root cause: ensure_row_contiguous_matrix only checked last 2 dimensions.
Arrays from expand_dims (SwitchGLU MoE path) had non-contiguous batch
strides that passed the check but caused OOB when the kernel used flat
pointer arithmetic (x + lhs_idx * M * K).

Fix:
- GatherQMM::eval_gpu: use ensure_row_contiguous (full contiguity check)
  for all inputs, not just ensure_row_contiguous_matrix (last-2-dims)
- Add LHS_B parameter (valid x batch count) to both gather kernels
- Add bounds clamping: lhs_idx < LHS_B, rhs_idx < E
- QuantizedMatmul (non-gather) unchanged — no batch indirection
RMSNorm (called 72x per forward pass):
- Replace rsqrtf() hardware approximation with 1.0f/sqrtf() for IEEE
  compliance (Metal uses precise::rsqrt)
- Match Metal's weight application order: truncate to T between
  normalization and weight multiply (intermediate rounding step)
- Same fix applied to LayerNorm

Sort/ArgSort:
- Add is_sort_floating_v trait that includes __half and hip_bfloat16
  (std::is_floating_point_v is false for these, skipping NaN handling)
- Fix NaN comparison and sentinel values for half types
- Add __half nan_value specialization

SDPA:
- Fix max_score initialization: use Limits<U>::finite_min (-FLT_MAX)
  instead of -1e9f (matches Metal)
- Fix zero-sum normalization edge case

Standalone ops (binary_ops.hpp, unary_ops.hpp):
- Promote __half and hip_bfloat16 through float for Add, Subtract,
  Multiply, Divide (Metal auto-promotes, ROCm doesn't)
- Add float promotion for unary ops with __half inputs

JIT preamble (compiled.cpp):
- Remove redundant float promotion for Add/Subtract/Multiply/Divide
  (already promoted in previous commit, clean up duplicate logic)
The non-uniform-stride batch loop in gemm_and_bias() called rocBLAS
directly (bypassing the naive_gemm wrapper that was patched earlier)
and only handled float32/float64 — bfloat16 and float16 matmuls
silently did nothing, leaving the output buffer uninitialized.

This caused non-deterministic SDPA results for any GQA model (where
n_q_heads != n_kv_heads) at sequence lengths >= 4, with progressively
worse corruption (NaN/Inf at L >= 7). The SDPA fallback decomposition
reshapes Q via unflatten and K/V via expand_dims for GQA broadcasting,
which produces non-uniform batch strides that hit this code path.

Fix: always use naive_gemm_with_offset for the non-uniform-stride
batch loop, matching the approach already used by the single-GEMM
and strided-batched paths.
The supports_sdpa_vector() function listed head_dim=256 as supported,
but the sdpa_vector() dispatch only had cases for D=64, 96, 128.
For D=256, no kernel was launched, leaving the output buffer
uninitialized — causing non-deterministic results for models using
head_dim=256 (e.g. Qwen3-Next) at sequence lengths 1-3.
Merges goniz/rocm-support-fixes: flash attention kernel, allocator
redesign for integrated GPUs, bfloat16 math overloads, QMV
vectorization, depthwise conv1d, event sync improvements, rocBLAS
solution-index dispatch, and upstream main (CUDA, docs, quantization).

Conflicts resolved preferring upstream for most ROCm backend files,
keeping our SliceUpdate kernel and float-promotion JIT approach.
The gather_qmv_warp_shared_kernel (wave-cooperative, shared memory
tiling, vectorized 4-bit unpacking) was only dispatched for 6-bit and
8-bit quantization. 4-bit fell through to the naive gather_qmv_kernel
(1 thread per output, sequential K loop), which was 18.6x slower.

Add bits==4 to the fast dispatch condition. The kernel already handles
4-bit internally with 8-element vectorized unpacking.

Profiled impact (Qwen3-Next 4-bit MoE):
  gather_qmv_kernel:             5193 μs/call → (removed)
  gather_qmv_warp_shared_kernel: N/A          → 279 μs/call (18.6x)
Key changes for Strix Halo / RDNA 3.5 integrated GPU:

1. raw_ptr(): Use hipStreamSynchronize(nullptr) instead of
   hipDeviceSynchronize() for unified memory buffers. Only waits on
   the default stream instead of all streams. Skips the expensive
   move_to_unified_memory() since integrated GPU memory is already
   CPU-accessible (device==-1).

2. malloc(): Integrated GPU path now goes through rocm_unified_malloc()
   which sets device=-1, so raw_ptr() takes the fast path.

3. rocm_unified_malloc(): Integrated GPUs try hipExtMallocWithFlags
   (fine-grained coherent) first, falling back to hipMallocManaged.

Profiled impact on Qwen3-Next 4-bit MoE:
  Generation: 12.0 tok/s → 18.9 tok/s (58% faster)
  Prompt:     2.5 tok/s → 5.2 tok/s (2x faster)
The noshared QMV kernel reads x from global memory redundantly per
warp (each warp reloads the same x vector). The shared variant caches
x in LDS and is significantly faster for decode-sized (M<=8) shapes.

Disable the alignment-based noshared path selection; always use the
shared variant unless K is tiny. This reduces redundant global memory
traffic for dense quantized projections.
For MoE prefill (M>1) with sorted rhs_indices, consecutive batch
elements map to the same expert. The existing gather_qmv_warp_shared
kernel launches B independent blocks that each load the same expert
weights from global memory — 60-75x redundant weight traffic.

New gather_qmv_prefill_kernel groups batch elements into contiguous
runs of same-expert assignments. Each block handles one (run, row, col)
and iterates over all batch elements in the run, reading weights once.
Grid z-dimension = num_runs (~8-10 unique experts) instead of B (~600).

Supports 4-bit and 8-bit affine quantization with vectorized unpacking
(8 elements per iteration for 4-bit, 4 for 8-bit) and fmaf accumulation.

Profiled impact (Qwen3-Next 4-bit MoE, 40-token prompt):
  Prompt: 1.8 tok/s → 6.1 tok/s (3.4x faster)
  gather_qmv total: 502ms → ~150ms
New gather_qmv_wmma_prefill_kernel uses rocWMMA 16x16x16 bf16→f32
tiles for matrix multiply-accumulate during MoE prefill. Each wave32
handles a 16x16 output tile, dequantizing 4-bit weights into shared
memory and using rocwmma::mma_sync for the reduction.

Enabled for gfx11 (RDNA 3/3.5) and gfx12 (RDNA 4) when M >= 16 and
dimensions are 16-aligned. Falls back to scalar kernel otherwise.
Guarded by ROCM_HAS_WMMA macro so gfx9/gfx10 builds are unaffected.

Also restores hipExtMallocWithFlags as primary allocator for APU
(reverts hipMallocManaged experiment — fine-grained coherent gives
better GPU kernel bandwidth).

Profiled impact (Qwen3-Coder-Next 4-bit, Strix Halo gfx1151):
  Prompt (40 tok): 84 tok/s → 117 tok/s (39% faster)
  Qwen3-8B prompt: 33 tok/s → 44 tok/s (33% faster)
  Generation: unchanged at ~18 tok/s
- Remove M%16 alignment requirement: kernel now bounds-checks rows,
  padding with zero for tile positions beyond M.
- Remove right_sorted_ requirement from prefill dispatch: CPU-side sort
  creates sorted index arrays and output permutation for any index order.
- Add out_perm parameter to both WMMA and scalar prefill kernels to
  scatter results back to original batch positions after sorted dispatch.
- Add <algorithm> and <numeric> includes for std::sort/std::iota.

NOTE: MLX's MoE layer (SwitchGLU) currently expands all tokens to
individual M=1 calls via gather_qmm. The prefill kernels (M>1) will
activate when upstream changes batch tokens per-expert. The 4-bit
fast gather_qmv_warp_shared dispatch handles the current M=1 path.
@Geramy
Copy link
Copy Markdown
Author

Geramy commented Mar 27, 2026

Let me know if you can get this merged in, it really changes the performance a LOT. it goes from complete garbage to great results. some algorithms are also fixed to create good output.

Geramy added 9 commits March 27, 2026 15:45
New gather_qmv_expert_batched_kernel finds expert run boundaries
on-GPU via binary search of sorted rhs_indices. Each block handles
one (expert, column) pair and iterates over all tokens for that expert,
loading weights once per expert.

Dispatch condition: E <= 64 and B/E >= 4 (low expert count with many
tokens per expert). For high-expert models (E=512 like Qwen3-Next),
the warp_shared kernel remains faster since most runs have only 1-4
tokens and the per-block run-finding overhead isn't justified.
hipBLASLt provides architecture-tuned GEMM kernels via Tensile,
typically outperforming rocBLAS for bf16/fp16 on RDNA 3.5 and CDNA.

New hipblaslt_gemm() and hipblaslt_gemm_batched() functions with:
- Per-device handle cache (thread-safe, lazily initialized)
- Algorithm heuristic selection (best-of-1 from hipBLASLt)
- RAII guards for all descriptor types
- Persistent workspace allocation (up to 32MB, grown as needed)
- fp32 accumulation for bf16/fp16 inputs

matmul.cpp tries hipBLASLt first for bf16/fp16, falls back to
rocBLAS silently on failure. Float32/64 GEMMs unchanged.
The dequant+GEMM path in QuantizedMatmul now tries hipBLASLt before
rocBLAS for bf16 GEMMs. hipBLASLt selects architecture-tuned kernels
via heuristic algorithm search, significantly outperforming rocBLAS
once the algorithm cache is warm.

New hipblaslt_gemm_raw() allows calling from inside kernel lambdas
with pre-swapped column-major parameters, matching the rocBLAS pattern.

Warm prompt (Qwen3-Coder-Next 4-bit, Strix Halo):
  80 tok/s → 207 tok/s (2.6x faster)

First-call overhead from algorithm search is amortized by the
application warmup pass.
- hipblaslt_gemm_raw() for calling from inside kernel lambdas with
  pre-swapped col-major params. Used in QMM bf16 dequant+GEMM path.
- Warm prompt: 80→207 tok/s with hipBLASLt algorithm cache primed.

- CommandEncoder graph capture API (begin_capture, end_capture, replay,
  reset_graph) using hipStreamBeginCapture/EndCapture/GraphLaunch.
  Infrastructure for future decode acceleration (18→34 tok/s potential).
  Not yet active due to MLX lazy eval incompatibility with capture mode.
Replace the 5-operation copy chain (2 allocs + 2 hipMemcpyAsync + 1 kernel)
with single-dispatch strided copy kernels for non-contiguous arrays.

New kernels:
- strided_row_copy_kernel: inner-contiguous with outer stride gap (common
  pattern from take/gather_sort). Uses 4-byte word copies when aligned.
- strided_general_copy_kernel: arbitrary strides, shapes/strides passed
  as by-value structs (zero device allocation).

Tiered dispatch in ensure_row_contiguous_matrix:
1. Already contiguous → return (fast path, unchanged)
2. Inner-contiguous outer gap → strided_row_copy_kernel (1 dispatch)
3. General non-contiguous → strided_general_copy_kernel (1 dispatch)
4. ndim > 10 → old contiguous_copy_gpu fallback

Net: each non-contiguous copy drops from 5 GPU operations to 1.
Coarser size buckets for large allocations improve buffer cache hit
rate during LLM decode. Without this, slightly different allocation
sizes (e.g., 1.01MB vs 1.02MB) miss the cache and trigger
hipExtMallocWithFlags at ~7ms each.

Previous: page-aligned (16KB granularity) for all sizes >= 16KB
New: page-aligned for 16KB-1MB, power-of-2 for >= 1MB

Trades up to 2x memory waste for large buffers in exchange for
dramatically fewer cache misses during steady-state decode.
The power-of-2 rounding for >= 1MB allocations caused OOM by doubling
large allocations that exceeded the 2GB device-local VRAM on iGPU.
Reverted to page-aligned (16KB) rounding for all large sizes.

hipExtMallocWithFlags remains the primary path for iGPU (best GPU
bandwidth via fine-grained coherent access). Falls back to
hipMallocManaged for allocations that exceed VRAM capacity,
accessing the full system RAM (126GB on Strix Halo).
@Geramy Geramy force-pushed the rocm-optimizations branch from 6a6031a to f26c802 Compare March 28, 2026 00:51
@Geramy Geramy changed the title ROCm: WMMA prefill, 4-bit MoE dispatch, APU allocator optimizations ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator Mar 28, 2026
@chimezie
Copy link
Copy Markdown

Support for gfx1103 (RDNA3)?

@Geramy
Copy link
Copy Markdown
Author

Geramy commented Mar 28, 2026

Support for gfx1103 (RDNA3)?

I don't see why not? I did a lot of work supporting 3, 3.5 and 4 I can verify it's in the cmake for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants