-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathquantizer.cpp
More file actions
71 lines (58 loc) · 2.64 KB
/
quantizer.cpp
File metadata and controls
71 lines (58 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <torch/extension.h>
#include "quant.h"
torch::Tensor quantize_q2_k(torch::Tensor& input) {
// Row-major quantization (equivalent to block size [1, 256])
// of input tensor using Q2_K scheme.
TORCH_CHECK(input.ndimension() == 2, "input must be 2D");
TORCH_CHECK(input.size(1) % QK_K == 0, "ncols must be divisible by QK_K");
TORCH_CHECK(input.dtype() == torch::kFloat32, "input must be float32");
if (!input.is_contiguous()) {
input = input.contiguous();
}
const int64_t nrows = input.size(0);
const int64_t ncols = input.size(1);
const int64_t blocks_per_row = ncols / QK_K;
const int64_t block_size = sizeof(block_q2_K);
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto output = torch::empty({nrows, blocks_per_row * block_size}, options);
const float* input_ptr = input.data_ptr<float>();
uint8_t* output_ptr = output.data_ptr<uint8_t>();
// Parallelize over rows
#pragma omp parallel for
for (int64_t row = 0; row < nrows; row++) {
const float* row_input = input_ptr + row * ncols;
block_q2_K* row_output = reinterpret_cast<block_q2_K*>(output_ptr + row * blocks_per_row * block_size);
quantize_row_q2_K_ref(row_input, row_output, ncols);
}
return output;
}
torch::Tensor quantize_q3_k(torch::Tensor& input) {
// Row-major quantization (equivalent to block size [1, 256])
// of input tensor using Q3_K scheme.
TORCH_CHECK(input.ndimension() == 2, "input must be 2D");
TORCH_CHECK(input.size(1) % QK_K == 0, "ncols must be divisible by QK_K");
TORCH_CHECK(input.dtype() == torch::kFloat32, "input must be float32");
if (!input.is_contiguous()) {
input = input.contiguous();
}
const int64_t nrows = input.size(0);
const int64_t ncols = input.size(1);
const int64_t blocks_per_row = ncols / QK_K;
const int64_t block_size = sizeof(block_q3_K);
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto output = torch::empty({nrows, blocks_per_row * block_size}, options);
const float* input_ptr = input.data_ptr<float>();
uint8_t* output_ptr = output.data_ptr<uint8_t>();
// Parallelize over rows
#pragma omp parallel for
for (int64_t row = 0; row < nrows; row++) {
const float* row_input = input_ptr + row * ncols;
block_q3_K* row_output = reinterpret_cast<block_q3_K*>(output_ptr + row * blocks_per_row * block_size);
quantize_row_q3_K_ref(row_input, row_output, ncols);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("quantize_q2_k", &quantize_q2_k, "Quantize a tensor to Q2_K format");
m.def("quantize_q3_k", &quantize_q3_k, "Quantize a tensor to Q3_K format");
}