diff --git a/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh b/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh index e7f8dd30cf..7a1c6b106c 100644 --- a/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh +++ b/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh @@ -2,9 +2,137 @@ #pragma once +#include #include #include +namespace cutlass { + +using uint3b_t = integer_subbyte<3, false>; +using uint5b_t = integer_subbyte<5, false>; + +template +struct NumericArrayConverter { + static_assert(N % 8 == 0); + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(const source_type& source) { + result_type result; + auto* s_base = reinterpret_cast(&source); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + auto* s = s_base + i * 3; + result[i * 8] = T(s[0] & 0x07); + result[i * 8 + 1] = T((s[0] & 0x38) >> 3); + result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2); + result[i * 8 + 3] = T((s[1] & 0x0e) >> 1); + result[i * 8 + 4] = T((s[1] & 0x70) >> 4); + result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1); + result[i * 8 + 6] = T((s[2] & 0x1c) >> 2); + result[i * 8 + 7] = T((s[2] & 0xe0) >> 5); + } + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(const source_type& s) const { + return convert(s); + } +}; + +template +struct NumericArrayConverter { + static_assert(N % 8 == 0); + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(const source_type& source) { + result_type result; + auto* s_base = reinterpret_cast(&source); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + auto* s = s_base + i * 5; + result[i * 8] = T(s[0] & 0x1f); + result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3); + result[i * 8 + 2] = T((s[1] & 0x7c) >> 2); + result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1); + result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4); + result[i * 8 + 5] = T((s[3] & 0x3e) >> 1); + result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2); + result[i * 8 + 7] = T((s[4] & 0xf8) >> 3); + } + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(const source_type& s) const { + return convert(s); + } +}; + +template +struct NumericArrayConverter { + static_assert(N % 4 == 0); + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(const source_type& source) { + result_type result; + auto* s_base = reinterpret_cast(&source); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + auto* s = s_base + i * 3; + result[i * 4] = T(s[0] & 0x3f); + result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2); + result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4); + result[i * 4 + 3] = T((s[2] >> 2) & 0x3f); + } + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(const source_type& s) const { + return convert(s); + } +}; + +} // namespace cutlass + +namespace cute { + +// Required by tiled copy for 3/5/6-bit weights. +struct uint24_t { + std::array bytes; +}; +struct uint40_t { + std::array bytes; +}; +struct uint48_t { + std::array bytes; +}; + +template <> +struct uint_bit<24> { + using type = uint24_t; +}; +template <> +struct uint_bit<40> { + using type = uint40_t; +}; +template <> +struct uint_bit<48> { + using type = uint48_t; +}; + +} // namespace cute + namespace cutlass_gemm { // Whether the quant type is affine quantization. diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index 97a8f7f422..e6f7fc3215 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -80,7 +80,7 @@ void qmm_sm90( qmm_impl_sm90( x, w, scales, biases, out, bits, group_size, encoder, s); }; - int m = out.shape(-2); + int m = out.ndim() > 1 ? out.shape(-2) : 1; if (m <= 16) { dispatch.template operator()<128, 16, 1>(); } else if (m <= 32) { @@ -163,7 +163,7 @@ void qmm_sm80( qmm_impl_sm80( x, w, scales, biases, out, bits, group_size, mode, encoder); }; - int m = out.shape(-2); + int m = out.ndim() > 1 ? out.shape(-2) : 1; if (m <= 16) { dispatch.template operator()<16>(); } else if (m <= 32) { @@ -208,9 +208,6 @@ bool supports_qmm_naive( if (biases && !biases->flags().row_contiguous) { return false; } - if (bits != 2 && bits != 4 && bits != 8) { - return false; - } return true; } @@ -230,7 +227,7 @@ void qmm_naive( x, w, scales, biases, out, bits, group_size, mode, encoder); }; dispatch_bool(transpose, [&](auto k_major) { - int m = out.shape(-2); + int m = out.ndim() > 1 ? out.shape(-2) : 1; if (m <= 16) { dispatch.template operator()<16, k_major.value>(); } else if (m <= 32) { diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh index fb187aa92d..cc387cfc8e 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh @@ -385,8 +385,14 @@ inline void dispatch_quant_types( dispatch_groups(group_size, tag, [&]() { if (bits == 2) { f.template operator()(); + } else if (bits == 3) { + f.template operator()(); } else if (bits == 4) { f.template operator()(); + } else if (bits == 5) { + f.template operator()(); + } else if (bits == 6) { + f.template operator()(); } else if (bits == 8) { f.template operator()(); } else { @@ -409,7 +415,7 @@ void qmm_impl_naive( QuantizationMode mode, cu::CommandEncoder& encoder) { const char* tag = "[quantized_matmul]"; - int m = out.shape(-2); + int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index 895cdfdb5c..679e4dadea 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -435,7 +435,7 @@ void qmm_impl_sm80( QuantizationMode mode, cu::CommandEncoder& encoder) { const char* tag = "[quantized_matmul]"; - int m = out.shape(-2); + int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh index bb29cdafc5..ce79dbceeb 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh @@ -189,7 +189,7 @@ void qmm_impl_sm90( cu::CommandEncoder& encoder, Stream s) { const char* tag = "[quantized_matmul]"; - int m = out.shape(-2); + int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index c43e783b4e..ceda750a74 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -1,112 +1,12 @@ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/dtype_utils.h" #include #include -#include -#include - -namespace cutlass { - -using uint3b_t = integer_subbyte<3, false>; -using uint5b_t = integer_subbyte<5, false>; - -template -struct NumericArrayConverter { - static_assert(N % 8 == 0); - - using result_type = Array; - using source_type = Array; - - CUTLASS_HOST_DEVICE - static result_type convert(const source_type& source) { - result_type result; - auto* s_base = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - auto* s = s_base + i * 3; - result[i * 8] = T(s[0] & 0x07); - result[i * 8 + 1] = T((s[0] & 0x38) >> 3); - result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2); - result[i * 8 + 3] = T((s[1] & 0x0e) >> 1); - result[i * 8 + 4] = T((s[1] & 0x70) >> 4); - result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1); - result[i * 8 + 6] = T((s[2] & 0x1c) >> 2); - result[i * 8 + 7] = T((s[2] & 0xe0) >> 5); - } - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(const source_type& s) const { - return convert(s); - } -}; - -template -struct NumericArrayConverter { - static_assert(N % 8 == 0); - - using result_type = Array; - using source_type = Array; - - CUTLASS_HOST_DEVICE - static result_type convert(const source_type& source) { - result_type result; - auto* s_base = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - auto* s = s_base + i * 5; - result[i * 8] = T(s[0] & 0x1f); - result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3); - result[i * 8 + 2] = T((s[1] & 0x7c) >> 2); - result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1); - result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4); - result[i * 8 + 5] = T((s[3] & 0x3e) >> 1); - result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2); - result[i * 8 + 7] = T((s[4] & 0xf8) >> 3); - } - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(const source_type& s) const { - return convert(s); - } -}; - -template -struct NumericArrayConverter { - static_assert(N % 4 == 0); - - using result_type = Array; - using source_type = Array; - - CUTLASS_HOST_DEVICE - static result_type convert(const source_type& source) { - result_type result; - auto* s_base = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - auto* s = s_base + i * 3; - result[i * 4] = T(s[0] & 0x3f); - result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2); - result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4); - result[i * 4 + 3] = T((s[2] >> 2) & 0x3f); - } - return result; - } - - CUTLASS_HOST_DEVICE - result_type operator()(const source_type& s) const { - return convert(s); - } -}; - -} // namespace cutlass namespace mlx::core { diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 31cd5038c2..507e7f1384 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -75,7 +75,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { } }; - int M = out.shape(-2); + int M = out.ndim() > 1 ? out.shape(-2) : 1; int N = out.shape(-1); int K = x.shape(-1); int B = out.size() / (M * N); diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index e666ec3389..8f91199b9b 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -27,7 +27,6 @@ "TestQuantized.test_qmm_shapes", "TestQuantized.test_fp_qvm", "TestQuantized.test_qvm", - "TestQuantized.test_qvm_splitk", "TestQuantized.test_qmv_small_non_multiples", "TestQuantized.test_small_matrix", "TestExportImport.test_export_quantized_model",