ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator#5
Open
Geramy wants to merge 40 commits intoNripeshN:rocm-supportfrom
Open
ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator#5Geramy wants to merge 40 commits intoNripeshN:rocm-supportfrom
Geramy wants to merge 40 commits intoNripeshN:rocm-supportfrom
Conversation
Integrates ROCm/AMD GPU backend from NripeshN/mlx (rocm-support branch). Original ROCm work by @NripeshN — see https://github.com/NripeshN/mlx/tree/rocm-support Conflicts resolved favoring ROCm-compatible code paths with upstream device_count() checks preserved for non-ROCm builds.
- Add gfx1150, gfx1151, gfx1152 (RDNA 3.5) and gfx1200, gfx1201 (RDNA 4) to default HIP architecture list - Use --parallel-jobs with auto-detected CPU count for hipcc so offload compilations for multiple architectures run in parallel
Ninja already parallelizes across HIP files, so using all CPUs per hipcc invocation causes oversubscription.
- Add gpu::init() to eval.cpp to initialize HIP runtime - Add SliceUpdate NO_GPU stub to primitives.cpp to fix linker errors
- Add compiled HIP kernel for slice update with reduce ops (Sum/Prod/Max/Min) - ReduceType::None delegates to copy_gpu_inplace (no kernel needed) - Kernel templated on dtype, Op, contiguity flags, and NWORK for perf - Supports all 12 dtypes and all 4 reduce operations - Remove NO_GPU(SliceUpdate) stub from primitives.cpp
- Fix dtype_to_hip_type: return "hip_bfloat16" not "__hip_bfloat16" (hiprtc doesn't recognize the double-underscore variant) - Fix all JIT preamble unary ops (Sigmoid, Exp, Log, etc.) to promote half/bfloat16 to float before math, use native ops for float/double - Fix binary ops (ArcTan2, Remainder, FloorDivide, LogAddExp) similarly
hiprtc lacks <type_traits> so std::is_same_v is unavailable. Use unconditional float promotion for all unary/binary math ops since static_cast<float>(float) is a no-op anyway.
KernelArgs::append(array) was using a.data<void>() which returns the CPU-side pointer. Changed to gpu_ptr<void>(a) which returns the actual GPU device pointer via the RocmBuffer, matching the CUDA backend's implementation. This caused "illegal memory access" crashes on all JIT fused kernels since the GPU tried to read/write CPU memory addresses.
Stock ROCm packages don't include Tensile kernels for RDNA 3.5 (gfx115x) or RDNA 4 (gfx120x). When rocBLAS can't find the kernel, it crashes the GPU with "illegal memory access" instead of failing gracefully. Fall back to naive_gemm for these GPUs.
rocBLAS crashes the GPU with "illegal memory access" when a specific Tensile kernel variant isn't available for the target architecture (e.g., bfloat16 GEMM on gfx1151). Instead of crashing, check the rocblas_status return value and fall back to naive_gemm. Also fix all GEMM call sites to use gpu_ptr<T>() instead of array::data<T>() to get proper GPU device pointers.
rocBLAS returns success from the API but crashes the GPU asynchronously when the Tensile .co kernel files are corrupt or missing specific bf16 GEMM variants (seen on gfx1151). Fix: at device init, run a tiny 4x4 bf16 GEMM probe. If it crashes, reset the GPU, mark bf16 as unavailable, and route all subsequent bf16 GEMM calls to naive_gemm instead of rocBLAS. Also use gpu_ptr<T>() consistently in all GEMM call sites.
rocBLAS Tensile .co files for bf16 are corrupt on gfx1151 — the optimized kernel functions can't be loaded, causing GPU memory faults. Small-matrix probes don't catch this because they use fallback kernels that work, while larger inference-sized GEMMs hit the corrupt optimized paths. Route all bf16 GEMM to naive_gemm unconditionally. This is correct for all architectures. Performance optimization for bf16 GEMM can be added later with custom HIP kernels that don't depend on Tensile.
Bug fixes: - ArgReduce: add bfloat16 dispatch (was crashing with "Unsupported type") - QMM: fix unsigned affine dequantization (uint8_t, no sign extension) - Sort: add bounds check + rocprim radix sort for arrays > 4096 elements - JIT: hash long kernel names to avoid 255-byte filesystem limit Performance: - Add optimized warp-cooperative GEMV kernel (qmv_kernel.hip) - Coalesced uint32 global loads (adjacent threads read adjacent words) - LDS for x vector sharing across 8 warps per block - Warp shuffle reduction (no shared memory needed for reduction) - 33x speedup for token generation (0.45 → 15 tok/s on Qwen3-8B-4bit) - 18x speedup for prompt processing - Shared dequantization utilities in qdequant.hpp
- JIT compiled fused ops (Add, Subtract, Multiply, Divide) now promote half/bfloat16 through float to reduce precision loss compounding across 28-36 transformer layers - Restore gfx1151 in rocBLAS supported list (ROCm 7.x has proper support) - Keep bf16 naive_gemm bypass (Tensile bf16 may still have issues)
- GatherQMM eval_gpu: copy non-contiguous indices to contiguous before passing to GPU kernel (broadcast indices from gather_qmm ops have non-trivial strides that cause OOB when accessed as flat arrays) - SDPA: add head_dim=256 to supported vector configs (needed for Qwen3-Next which uses 256-dim attention heads)
…s indices - SDPA: use_fallback returns true for unsupported configs (head_dim or seq_len), framework decomposes into matmul+softmax+matmul GPU ops - All matmul dtypes routed through naive_gemm (avoids rocBLAS Tensile init being affected by pending GPU errors from gather_qmm) - GatherQMM: ensure indices are contiguous before GPU kernel (broadcast indices can have non-trivial strides) - SDPA head_dim=256 support in optimized vector kernel
QMM output quality: - Match Metal's qdot() accumulation pattern: separate integer dot product from scale/bias application. Instead of per-element `x*(scale*q+bias)`, compute `scale * dot(x, q_int) + bias * sum(x)` per group. Mathematically equivalent but matches Metal's bf16 rounding behavior that models are quantized against. JIT compilation: - Add StderrSuppressor RAII class to suppress AMD comgr preprocessed source dumps during hiprtcCompileProgram (thousands of lines of compiler defines were flooding terminal) - Add tail_lines() to truncate error logs to last 60 lines on failure - Include module name in compilation error messages
Root cause: ensure_row_contiguous_matrix only checked last 2 dimensions. Arrays from expand_dims (SwitchGLU MoE path) had non-contiguous batch strides that passed the check but caused OOB when the kernel used flat pointer arithmetic (x + lhs_idx * M * K). Fix: - GatherQMM::eval_gpu: use ensure_row_contiguous (full contiguity check) for all inputs, not just ensure_row_contiguous_matrix (last-2-dims) - Add LHS_B parameter (valid x batch count) to both gather kernels - Add bounds clamping: lhs_idx < LHS_B, rhs_idx < E - QuantizedMatmul (non-gather) unchanged — no batch indirection
RMSNorm (called 72x per forward pass): - Replace rsqrtf() hardware approximation with 1.0f/sqrtf() for IEEE compliance (Metal uses precise::rsqrt) - Match Metal's weight application order: truncate to T between normalization and weight multiply (intermediate rounding step) - Same fix applied to LayerNorm Sort/ArgSort: - Add is_sort_floating_v trait that includes __half and hip_bfloat16 (std::is_floating_point_v is false for these, skipping NaN handling) - Fix NaN comparison and sentinel values for half types - Add __half nan_value specialization SDPA: - Fix max_score initialization: use Limits<U>::finite_min (-FLT_MAX) instead of -1e9f (matches Metal) - Fix zero-sum normalization edge case Standalone ops (binary_ops.hpp, unary_ops.hpp): - Promote __half and hip_bfloat16 through float for Add, Subtract, Multiply, Divide (Metal auto-promotes, ROCm doesn't) - Add float promotion for unary ops with __half inputs JIT preamble (compiled.cpp): - Remove redundant float promotion for Add/Subtract/Multiply/Divide (already promoted in previous commit, clean up duplicate logic)
The non-uniform-stride batch loop in gemm_and_bias() called rocBLAS directly (bypassing the naive_gemm wrapper that was patched earlier) and only handled float32/float64 — bfloat16 and float16 matmuls silently did nothing, leaving the output buffer uninitialized. This caused non-deterministic SDPA results for any GQA model (where n_q_heads != n_kv_heads) at sequence lengths >= 4, with progressively worse corruption (NaN/Inf at L >= 7). The SDPA fallback decomposition reshapes Q via unflatten and K/V via expand_dims for GQA broadcasting, which produces non-uniform batch strides that hit this code path. Fix: always use naive_gemm_with_offset for the non-uniform-stride batch loop, matching the approach already used by the single-GEMM and strided-batched paths.
The supports_sdpa_vector() function listed head_dim=256 as supported, but the sdpa_vector() dispatch only had cases for D=64, 96, 128. For D=256, no kernel was launched, leaving the output buffer uninitialized — causing non-deterministic results for models using head_dim=256 (e.g. Qwen3-Next) at sequence lengths 1-3.
Merges goniz/rocm-support-fixes: flash attention kernel, allocator redesign for integrated GPUs, bfloat16 math overloads, QMV vectorization, depthwise conv1d, event sync improvements, rocBLAS solution-index dispatch, and upstream main (CUDA, docs, quantization). Conflicts resolved preferring upstream for most ROCm backend files, keeping our SliceUpdate kernel and float-promotion JIT approach.
The gather_qmv_warp_shared_kernel (wave-cooperative, shared memory tiling, vectorized 4-bit unpacking) was only dispatched for 6-bit and 8-bit quantization. 4-bit fell through to the naive gather_qmv_kernel (1 thread per output, sequential K loop), which was 18.6x slower. Add bits==4 to the fast dispatch condition. The kernel already handles 4-bit internally with 8-element vectorized unpacking. Profiled impact (Qwen3-Next 4-bit MoE): gather_qmv_kernel: 5193 μs/call → (removed) gather_qmv_warp_shared_kernel: N/A → 279 μs/call (18.6x)
Key changes for Strix Halo / RDNA 3.5 integrated GPU: 1. raw_ptr(): Use hipStreamSynchronize(nullptr) instead of hipDeviceSynchronize() for unified memory buffers. Only waits on the default stream instead of all streams. Skips the expensive move_to_unified_memory() since integrated GPU memory is already CPU-accessible (device==-1). 2. malloc(): Integrated GPU path now goes through rocm_unified_malloc() which sets device=-1, so raw_ptr() takes the fast path. 3. rocm_unified_malloc(): Integrated GPUs try hipExtMallocWithFlags (fine-grained coherent) first, falling back to hipMallocManaged. Profiled impact on Qwen3-Next 4-bit MoE: Generation: 12.0 tok/s → 18.9 tok/s (58% faster) Prompt: 2.5 tok/s → 5.2 tok/s (2x faster)
The noshared QMV kernel reads x from global memory redundantly per warp (each warp reloads the same x vector). The shared variant caches x in LDS and is significantly faster for decode-sized (M<=8) shapes. Disable the alignment-based noshared path selection; always use the shared variant unless K is tiny. This reduces redundant global memory traffic for dense quantized projections.
For MoE prefill (M>1) with sorted rhs_indices, consecutive batch elements map to the same expert. The existing gather_qmv_warp_shared kernel launches B independent blocks that each load the same expert weights from global memory — 60-75x redundant weight traffic. New gather_qmv_prefill_kernel groups batch elements into contiguous runs of same-expert assignments. Each block handles one (run, row, col) and iterates over all batch elements in the run, reading weights once. Grid z-dimension = num_runs (~8-10 unique experts) instead of B (~600). Supports 4-bit and 8-bit affine quantization with vectorized unpacking (8 elements per iteration for 4-bit, 4 for 8-bit) and fmaf accumulation. Profiled impact (Qwen3-Next 4-bit MoE, 40-token prompt): Prompt: 1.8 tok/s → 6.1 tok/s (3.4x faster) gather_qmv total: 502ms → ~150ms
New gather_qmv_wmma_prefill_kernel uses rocWMMA 16x16x16 bf16→f32 tiles for matrix multiply-accumulate during MoE prefill. Each wave32 handles a 16x16 output tile, dequantizing 4-bit weights into shared memory and using rocwmma::mma_sync for the reduction. Enabled for gfx11 (RDNA 3/3.5) and gfx12 (RDNA 4) when M >= 16 and dimensions are 16-aligned. Falls back to scalar kernel otherwise. Guarded by ROCM_HAS_WMMA macro so gfx9/gfx10 builds are unaffected. Also restores hipExtMallocWithFlags as primary allocator for APU (reverts hipMallocManaged experiment — fine-grained coherent gives better GPU kernel bandwidth). Profiled impact (Qwen3-Coder-Next 4-bit, Strix Halo gfx1151): Prompt (40 tok): 84 tok/s → 117 tok/s (39% faster) Qwen3-8B prompt: 33 tok/s → 44 tok/s (33% faster) Generation: unchanged at ~18 tok/s
- Remove M%16 alignment requirement: kernel now bounds-checks rows, padding with zero for tile positions beyond M. - Remove right_sorted_ requirement from prefill dispatch: CPU-side sort creates sorted index arrays and output permutation for any index order. - Add out_perm parameter to both WMMA and scalar prefill kernels to scatter results back to original batch positions after sorted dispatch. - Add <algorithm> and <numeric> includes for std::sort/std::iota. NOTE: MLX's MoE layer (SwitchGLU) currently expands all tokens to individual M=1 calls via gather_qmm. The prefill kernels (M>1) will activate when upstream changes batch tokens per-expert. The 4-bit fast gather_qmv_warp_shared dispatch handles the current M=1 path.
Author
|
Let me know if you can get this merged in, it really changes the performance a LOT. it goes from complete garbage to great results. some algorithms are also fixed to create good output. |
New gather_qmv_expert_batched_kernel finds expert run boundaries on-GPU via binary search of sorted rhs_indices. Each block handles one (expert, column) pair and iterates over all tokens for that expert, loading weights once per expert. Dispatch condition: E <= 64 and B/E >= 4 (low expert count with many tokens per expert). For high-expert models (E=512 like Qwen3-Next), the warp_shared kernel remains faster since most runs have only 1-4 tokens and the per-block run-finding overhead isn't justified.
hipBLASLt provides architecture-tuned GEMM kernels via Tensile, typically outperforming rocBLAS for bf16/fp16 on RDNA 3.5 and CDNA. New hipblaslt_gemm() and hipblaslt_gemm_batched() functions with: - Per-device handle cache (thread-safe, lazily initialized) - Algorithm heuristic selection (best-of-1 from hipBLASLt) - RAII guards for all descriptor types - Persistent workspace allocation (up to 32MB, grown as needed) - fp32 accumulation for bf16/fp16 inputs matmul.cpp tries hipBLASLt first for bf16/fp16, falls back to rocBLAS silently on failure. Float32/64 GEMMs unchanged.
The dequant+GEMM path in QuantizedMatmul now tries hipBLASLt before rocBLAS for bf16 GEMMs. hipBLASLt selects architecture-tuned kernels via heuristic algorithm search, significantly outperforming rocBLAS once the algorithm cache is warm. New hipblaslt_gemm_raw() allows calling from inside kernel lambdas with pre-swapped column-major parameters, matching the rocBLAS pattern. Warm prompt (Qwen3-Coder-Next 4-bit, Strix Halo): 80 tok/s → 207 tok/s (2.6x faster) First-call overhead from algorithm search is amortized by the application warmup pass.
- hipblaslt_gemm_raw() for calling from inside kernel lambdas with pre-swapped col-major params. Used in QMM bf16 dequant+GEMM path. - Warm prompt: 80→207 tok/s with hipBLASLt algorithm cache primed. - CommandEncoder graph capture API (begin_capture, end_capture, replay, reset_graph) using hipStreamBeginCapture/EndCapture/GraphLaunch. Infrastructure for future decode acceleration (18→34 tok/s potential). Not yet active due to MLX lazy eval incompatibility with capture mode.
Replace the 5-operation copy chain (2 allocs + 2 hipMemcpyAsync + 1 kernel) with single-dispatch strided copy kernels for non-contiguous arrays. New kernels: - strided_row_copy_kernel: inner-contiguous with outer stride gap (common pattern from take/gather_sort). Uses 4-byte word copies when aligned. - strided_general_copy_kernel: arbitrary strides, shapes/strides passed as by-value structs (zero device allocation). Tiered dispatch in ensure_row_contiguous_matrix: 1. Already contiguous → return (fast path, unchanged) 2. Inner-contiguous outer gap → strided_row_copy_kernel (1 dispatch) 3. General non-contiguous → strided_general_copy_kernel (1 dispatch) 4. ndim > 10 → old contiguous_copy_gpu fallback Net: each non-contiguous copy drops from 5 GPU operations to 1.
Coarser size buckets for large allocations improve buffer cache hit rate during LLM decode. Without this, slightly different allocation sizes (e.g., 1.01MB vs 1.02MB) miss the cache and trigger hipExtMallocWithFlags at ~7ms each. Previous: page-aligned (16KB granularity) for all sizes >= 16KB New: page-aligned for 16KB-1MB, power-of-2 for >= 1MB Trades up to 2x memory waste for large buffers in exchange for dramatically fewer cache misses during steady-state decode.
The power-of-2 rounding for >= 1MB allocations caused OOM by doubling large allocations that exceeded the 2GB device-local VRAM on iGPU. Reverted to page-aligned (16KB) rounding for all large sizes. hipExtMallocWithFlags remains the primary path for iGPU (best GPU bandwidth via fine-grained coherent access). Falls back to hipMallocManaged for allocations that exceed VRAM capacity, accessing the full system RAM (126GB on Strix Halo).
6a6031a to
f26c802
Compare
|
Support for gfx1103 (RDNA3)? |
Author
I don't see why not? I did a lot of work supporting 3, 3.5 and 4 I can verify it's in the cmake for you. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Comprehensive ROCm performance optimizations for LLM inference on all GPU architectures, with particular gains for MoE models on iGPU (Strix Halo gfx1151).
Kernel Optimizations
hipBLASLt Integration
Allocator Optimizations
Performance (Strix Halo gfx1151)
Qwen3-Coder-Next 4-bit (MoE, 512 experts):
Qwen3-8B 4-bit:
Architecture Support