Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions mlx/backend/cuda/quantized/qmm/cute_dequant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,137 @@

#pragma once

#include <cute/numeric/numeric_types.hpp>
#include <cute/tensor.hpp>
#include <cutlass/numeric_conversion.h>

namespace cutlass {

using uint3b_t = integer_subbyte<3, false>;
using uint5b_t = integer_subbyte<5, false>;

template <typename T, int N, FloatRoundStyle Round>
struct NumericArrayConverter<T, uint3b_t, N, Round> {
static_assert(N % 8 == 0);

using result_type = Array<T, N>;
using source_type = Array<uint3b_t, N>;

CUTLASS_HOST_DEVICE
static result_type convert(const source_type& source) {
result_type result;
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
auto* s = s_base + i * 3;
result[i * 8] = T(s[0] & 0x07);
result[i * 8 + 1] = T((s[0] & 0x38) >> 3);
result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2);
result[i * 8 + 3] = T((s[1] & 0x0e) >> 1);
result[i * 8 + 4] = T((s[1] & 0x70) >> 4);
result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1);
result[i * 8 + 6] = T((s[2] & 0x1c) >> 2);
result[i * 8 + 7] = T((s[2] & 0xe0) >> 5);
}
return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(const source_type& s) const {
return convert(s);
}
};

template <typename T, int N, FloatRoundStyle Round>
struct NumericArrayConverter<T, uint5b_t, N, Round> {
static_assert(N % 8 == 0);

using result_type = Array<T, N>;
using source_type = Array<uint5b_t, N>;

CUTLASS_HOST_DEVICE
static result_type convert(const source_type& source) {
result_type result;
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
auto* s = s_base + i * 5;
result[i * 8] = T(s[0] & 0x1f);
result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3);
result[i * 8 + 2] = T((s[1] & 0x7c) >> 2);
result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1);
result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4);
result[i * 8 + 5] = T((s[3] & 0x3e) >> 1);
result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2);
result[i * 8 + 7] = T((s[4] & 0xf8) >> 3);
}
return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(const source_type& s) const {
return convert(s);
}
};

template <typename T, int N, FloatRoundStyle Round>
struct NumericArrayConverter<T, uint6b_t, N, Round> {
static_assert(N % 4 == 0);

using result_type = Array<T, N>;
using source_type = Array<uint6b_t, N>;

CUTLASS_HOST_DEVICE
static result_type convert(const source_type& source) {
result_type result;
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
auto* s = s_base + i * 3;
result[i * 4] = T(s[0] & 0x3f);
result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2);
result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4);
result[i * 4 + 3] = T((s[2] >> 2) & 0x3f);
}
return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(const source_type& s) const {
return convert(s);
}
};

} // namespace cutlass

namespace cute {

// Required by tiled copy for 3/5/6-bit weights.
struct uint24_t {
std::array<std::uint8_t, 3> bytes;
};
struct uint40_t {
std::array<std::uint8_t, 5> bytes;
};
struct uint48_t {
std::array<std::uint8_t, 6> bytes;
};

template <>
struct uint_bit<24> {
using type = uint24_t;
};
template <>
struct uint_bit<40> {
using type = uint40_t;
};
template <>
struct uint_bit<48> {
using type = uint48_t;
};

} // namespace cute

namespace cutlass_gemm {

// Whether the quant type is affine quantization.
Expand Down
9 changes: 3 additions & 6 deletions mlx/backend/cuda/quantized/qmm/qmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void qmm_sm90(
qmm_impl_sm90<TileShapeMN, ClusterShape>(
x, w, scales, biases, out, bits, group_size, encoder, s);
};
int m = out.shape(-2);
int m = out.ndim() > 1 ? out.shape(-2) : 1;
if (m <= 16) {
dispatch.template operator()<128, 16, 1>();
} else if (m <= 32) {
Expand Down Expand Up @@ -163,7 +163,7 @@ void qmm_sm80(
qmm_impl_sm80<TileM>(
x, w, scales, biases, out, bits, group_size, mode, encoder);
};
int m = out.shape(-2);
int m = out.ndim() > 1 ? out.shape(-2) : 1;
if (m <= 16) {
dispatch.template operator()<16>();
} else if (m <= 32) {
Expand Down Expand Up @@ -208,9 +208,6 @@ bool supports_qmm_naive(
if (biases && !biases->flags().row_contiguous) {
return false;
}
if (bits != 2 && bits != 4 && bits != 8) {
return false;
}
return true;
}

Expand All @@ -230,7 +227,7 @@ void qmm_naive(
x, w, scales, biases, out, bits, group_size, mode, encoder);
};
dispatch_bool(transpose, [&](auto k_major) {
int m = out.shape(-2);
int m = out.ndim() > 1 ? out.shape(-2) : 1;
if (m <= 16) {
dispatch.template operator()<16, k_major.value>();
} else if (m <= 32) {
Expand Down
8 changes: 7 additions & 1 deletion mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,14 @@ inline void dispatch_quant_types(
dispatch_groups(group_size, tag, [&]<int group_size>() {
if (bits == 2) {
f.template operator()<cutlass::uint2b_t, T, group_size>();
} else if (bits == 3) {
f.template operator()<cutlass::uint3b_t, T, group_size>();
} else if (bits == 4) {
f.template operator()<cutlass::uint4b_t, T, group_size>();
} else if (bits == 5) {
f.template operator()<cutlass::uint5b_t, T, group_size>();
} else if (bits == 6) {
f.template operator()<cutlass::uint6b_t, T, group_size>();
} else if (bits == 8) {
f.template operator()<uint8_t, T, group_size>();
} else {
Expand All @@ -409,7 +415,7 @@ void qmm_impl_naive(
QuantizationMode mode,
cu::CommandEncoder& encoder) {
const char* tag = "[quantized_matmul]";
int m = out.shape(-2);
int m = out.ndim() > 1 ? out.shape(-2) : 1;
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ void qmm_impl_sm80(
QuantizationMode mode,
cu::CommandEncoder& encoder) {
const char* tag = "[quantized_matmul]";
int m = out.shape(-2);
int m = out.ndim() > 1 ? out.shape(-2) : 1;
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ void qmm_impl_sm90(
cu::CommandEncoder& encoder,
Stream s) {
const char* tag = "[quantized_matmul]";
int m = out.shape(-2);
int m = out.ndim() > 1 ? out.shape(-2) : 1;
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
Expand Down
102 changes: 1 addition & 101 deletions mlx/backend/cuda/quantized/qmm/qmv.cu
Original file line number Diff line number Diff line change
@@ -1,112 +1,12 @@
// Copyright © 2026 Apple Inc.

#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh"
#include "mlx/backend/cuda/quantized/qmm/qmm.h"
#include "mlx/dtype_utils.h"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cute/numeric/numeric_types.hpp>
#include <cutlass/numeric_conversion.h>

namespace cutlass {

using uint3b_t = integer_subbyte<3, false>;
using uint5b_t = integer_subbyte<5, false>;

template <typename T, int N, FloatRoundStyle Round>
struct NumericArrayConverter<T, uint3b_t, N, Round> {
static_assert(N % 8 == 0);

using result_type = Array<T, N>;
using source_type = Array<uint3b_t, N>;

CUTLASS_HOST_DEVICE
static result_type convert(const source_type& source) {
result_type result;
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
auto* s = s_base + i * 3;
result[i * 8] = T(s[0] & 0x07);
result[i * 8 + 1] = T((s[0] & 0x38) >> 3);
result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2);
result[i * 8 + 3] = T((s[1] & 0x0e) >> 1);
result[i * 8 + 4] = T((s[1] & 0x70) >> 4);
result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1);
result[i * 8 + 6] = T((s[2] & 0x1c) >> 2);
result[i * 8 + 7] = T((s[2] & 0xe0) >> 5);
}
return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(const source_type& s) const {
return convert(s);
}
};

template <typename T, int N, FloatRoundStyle Round>
struct NumericArrayConverter<T, uint5b_t, N, Round> {
static_assert(N % 8 == 0);

using result_type = Array<T, N>;
using source_type = Array<uint5b_t, N>;

CUTLASS_HOST_DEVICE
static result_type convert(const source_type& source) {
result_type result;
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
auto* s = s_base + i * 5;
result[i * 8] = T(s[0] & 0x1f);
result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3);
result[i * 8 + 2] = T((s[1] & 0x7c) >> 2);
result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1);
result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4);
result[i * 8 + 5] = T((s[3] & 0x3e) >> 1);
result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2);
result[i * 8 + 7] = T((s[4] & 0xf8) >> 3);
}
return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(const source_type& s) const {
return convert(s);
}
};

template <typename T, int N, FloatRoundStyle Round>
struct NumericArrayConverter<T, uint6b_t, N, Round> {
static_assert(N % 4 == 0);

using result_type = Array<T, N>;
using source_type = Array<uint6b_t, N>;

CUTLASS_HOST_DEVICE
static result_type convert(const source_type& source) {
result_type result;
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
auto* s = s_base + i * 3;
result[i * 4] = T(s[0] & 0x3f);
result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2);
result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4);
result[i * 4 + 3] = T((s[2] >> 2) & 0x3f);
}
return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(const source_type& s) const {
return convert(s);
}
};

} // namespace cutlass

namespace mlx::core {

Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
};

int M = out.shape(-2);
int M = out.ndim() > 1 ? out.shape(-2) : 1;
int N = out.shape(-1);
int K = x.shape(-1);
int B = out.size() / (M * N);
Expand Down
1 change: 0 additions & 1 deletion python/tests/cuda_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_fp_qvm",
"TestQuantized.test_qvm",
"TestQuantized.test_qvm_splitk",
"TestQuantized.test_qmv_small_non_multiples",
"TestQuantized.test_small_matrix",
"TestExportImport.test_export_quantized_model",
Expand Down
Loading