From 789d36a8fe4af4fdb46121cf68cd74adbf439316 Mon Sep 17 00:00:00 2001 From: Adam Bien Date: Sun, 19 Apr 2026 12:01:35 +0200 Subject: [PATCH 1/2] additional information / output added --- .../beehive/gpullama3/model/loader/DevstralModelLoader.java | 6 +----- .../org/beehive/gpullama3/model/loader/GraniteLoader.java | 6 +----- .../beehive/gpullama3/model/loader/LlamaModelLoader.java | 6 +----- .../beehive/gpullama3/model/loader/MistralModelLoader.java | 6 +----- .../org/beehive/gpullama3/model/loader/Phi3ModelLoader.java | 6 +----- .../beehive/gpullama3/model/loader/Qwen2ModelLoader.java | 6 +----- .../beehive/gpullama3/model/loader/Qwen3ModelLoader.java | 6 +----- 7 files changed, 7 insertions(+), 35 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java index 8c230b2f..4adc283d 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java @@ -143,11 +143,7 @@ protected Weights createStandardWeights(Map tensorEntri // @formatter:off @Override protected Weights createTornadoVMWeights(Map tensorEntries, DevstralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - GGMLType ggmlType = outputWeight.ggmlType(); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); - } + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); diff --git a/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java index c22491e7..cde4cc3b 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java @@ -136,11 +136,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - GGMLType ggmlType = outputWeight.ggmlType(); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); - } + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); // Validate supported types if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index e621b37a..bb58f03e 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -106,11 +106,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - GGMLType ggmlType = outputWeight.ggmlType(); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); - } + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); // Validate supported types if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { diff --git a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java index 52e0178d..577b103f 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java @@ -116,11 +116,7 @@ protected Weights createStandardWeights(Map tensorEntri // @formatter:off @Override protected Weights createTornadoVMWeights(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - GGMLType ggmlType = outputWeight.ggmlType(); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); - } + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); // Validate supported types if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java index a43c4095..2bdd8223 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java @@ -126,11 +126,7 @@ protected Weights createStandardWeights(Map tensorEntri // @formatter:off @Override protected Weights createTornadoVMWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - GGMLType ggmlType = outputWeight.ggmlType(); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); - } + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); // Validate supported types if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java index 2e3d8002..4c18aacd 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -126,11 +126,7 @@ protected Weights createStandardWeights(Map tensorEntri @Override protected Weights createTornadoVMWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - GGMLType ggmlType = outputWeight.ggmlType(); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); - } + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); // Validate supported types if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java index 57e833cf..fea962c8 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java @@ -129,11 +129,7 @@ protected Weights createStandardWeights(Map tensorEntri protected Weights createTornadoVMWeights(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); - } - - GGMLType ggmlType = outputWeight.ggmlType(); + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); final int nl = config.numberOfLayers(); From 58f7e2a4b8c5a2708ed89953901904d775e9495d Mon Sep 17 00:00:00 2001 From: Adam Bien Date: Sun, 19 Apr 2026 12:01:53 +0200 Subject: [PATCH 2/2] support for Q4 quantization added --- llamaTornado | 23 +++- .../model/loader/AbstractModelLoader.java | 36 +++++ .../gpullama3/model/loader/ModelLoader.java | 64 ++++++++- .../tensor/standard/Q4_KFloatTensor.java | 127 +++++++++++++++++ .../tensor/standard/Q5_KFloatTensor.java | 130 ++++++++++++++++++ .../tensor/standard/Q6_KFloatTensor.java | 123 +++++++++++++++++ 6 files changed, 500 insertions(+), 3 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tensor/standard/Q4_KFloatTensor.java create mode 100644 src/main/java/org/beehive/gpullama3/tensor/standard/Q5_KFloatTensor.java create mode 100644 src/main/java/org/beehive/gpullama3/tensor/standard/Q6_KFloatTensor.java diff --git a/llamaTornado b/llamaTornado index 6cc303d5..068c7946 100755 --- a/llamaTornado +++ b/llamaTornado @@ -12,7 +12,7 @@ record Config( double temperature, double topP, long seed, int maxTokens, boolean stream, boolean echo, boolean interactive, boolean instruct, boolean useGpu, Backend backend, String gpuMemory, - String heapMin, String heapMax, + String heapMin, String heapMax, String directMemory, boolean debug, boolean profiler, String profilerDumpDir, boolean printBytecodes, boolean threads, boolean printKernel, boolean fullDump, boolean verboseInit, @@ -37,6 +37,7 @@ Config parseArgs(String[] args) { String gpuMemory = "14GB"; String heapMin = "20g"; String heapMax = "20g"; + String directMemory = null; boolean debug = false; boolean profiler = false; String profilerDumpDir = null; @@ -71,6 +72,7 @@ Config parseArgs(String[] args) { case "--gpu-memory" -> gpuMemory = args[++i]; case "--heap-min" -> heapMin = args[++i]; case "--heap-max" -> heapMax = args[++i]; + case "--direct-memory" -> directMemory = args[++i]; case "--debug" -> debug = true; case "--profiler" -> profiler = true; case "--profiler-dump-dir" -> profilerDumpDir = args[++i]; @@ -101,12 +103,27 @@ Config parseArgs(String[] args) { profilerDumpDir = System.getenv("LLAMA_ROOT") + "/profiler-log.json"; } + // Default direct memory to 3x heap to accommodate K-quant dequantization + if (directMemory == null) { + directMemory = parseAndScale(heapMax, 3); + } + return new Config(modelPath, prompt, systemPrompt, temperature, topP, seed, maxTokens, - stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax, + stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax, directMemory, debug, profiler, profilerDumpDir, printBytecodes, threads, printKernel, fullDump, verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose); } +String parseAndScale(String memoryValue, int multiplier) { + var matcher = java.util.regex.Pattern.compile("(\\d+)([gGmM]?)").matcher(memoryValue); + if (matcher.matches()) { + long value = Long.parseLong(matcher.group(1)) * multiplier; + String suffix = matcher.group(2).isEmpty() ? "" : matcher.group(2); + return value + suffix; + } + return memoryValue; +} + void printUsage() { IO.println(""" Usage: %s --model [options] @@ -138,6 +155,7 @@ void printUsage() { --gpu-memory GPU memory allocation (default: 14GB) --heap-min Min JVM heap (default: 20g) --heap-max Max JVM heap (default: 20g) + --direct-memory Max direct buffer memory (default: 3x heap-max) Debug: --debug Enable debug output @@ -195,6 +213,7 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String "-XX:+EnableJVMCI", "-Xms" + cfg.heapMin(), "-Xmx" + cfg.heapMax(), + "-XX:MaxDirectMemorySize=" + cfg.directMemory(), "--enable-preview", "-Djava.library.path=" + tornadoSdk + "/lib", "-Djdk.module.showModuleResolution=false", diff --git a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java index c994c71d..9bbefcad 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model.loader; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.tensor.GGUF; import org.beehive.gpullama3.tensor.GGMLTensorEntry; import org.beehive.gpullama3.auxiliary.Pair; @@ -8,6 +9,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tokenizer.Vocabulary; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import java.io.IOException; import java.nio.channels.FileChannel; @@ -40,10 +42,39 @@ protected String getModelQuantization(Map metadata) { return switch (modelQuantizationAsInt) { case 1 -> "FP16"; case 7 -> "Q8_0"; + case 14, 15 -> "Q8_0"; // Q4_K_S, Q4_K_M (K-quants use Q8_0 activations) + case 16, 17 -> "Q8_0"; // Q5_K_S, Q5_K_M + case 18 -> "Q8_0"; // Q6_K default -> throw new UnsupportedOperationException("Unsupported quantization format: " + modelQuantizationAsInt + " (as int)."); }; } + /** + * Returns the effective GPU weight type for TornadoVM execution. + * K-quant types (Q4_K, Q5_K, Q6_K) are dequantized to Q8_0 at load time. + */ + protected static GGMLType effectiveGpuWeightType(GGMLType ggmlType) { + return switch (ggmlType) { + case F16, F32, Q8_0 -> ggmlType; + case Q4_K, Q5_K, Q6_K -> GGMLType.Q8_0; + default -> ggmlType; + }; + } + + private static String fileTypeName(int fileType) { + return switch (fileType) { + case 0 -> "F32"; + case 1 -> "F16"; + case 7 -> "Q8_0"; + case 14 -> "Q4_K_S"; + case 15 -> "Q4_K_M"; + case 16 -> "Q5_K_S"; + case 17 -> "Q5_K_M"; + case 18 -> "Q6_K"; + default -> "type_" + fileType; + }; + } + /** * Template method that defines the model loading workflow. Subclasses should not override this method. * @@ -123,6 +154,11 @@ public Weights loadWeights(Map tensorEntries, C config) // Delegate to specific implementation if (useTornadovm) { + GGMLType gpuType = effectiveGpuWeightType(outputWeight.ggmlType()); + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + int fileType = (int) gguf.getMetadata().get("general.file_type"); + System.out.println("Loading model weights in TornadoVM format (" + fileTypeName(fileType) + " -> " + gpuType + ")"); + } return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); } else { return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index 83b25987..656f035e 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -15,6 +15,7 @@ import uk.ac.manchester.tornado.api.types.arrays.*; import java.io.IOException; +import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; import java.nio.ByteOrder; @@ -121,6 +122,9 @@ public static FloatTensor loadTensor(GGMLTensorEntry entry) { case F32 -> new FP32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case Q4_K -> new Q4_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case Q5_K -> new Q5_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case Q6_K -> new Q6_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case F16 -> new FP16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); }; @@ -149,11 +153,69 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) { case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); case Q8_0 -> Q8_0TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); - case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet"); + case Q4_K, Q5_K, Q6_K -> dequantizeToQ8_0TornadoTensor(entry); + case Q4_0 -> throw new UnsupportedOperationException("Q4_0 format not supported for TornadoVM yet"); default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); }; } + /** + * Dequantizes a K-quant tensor (Q4_K, Q5_K, Q6_K) to Q8_0 format for TornadoVM/GPU execution. + * This is a load-time conversion that allows K-quant models to run on GPU with existing Q8_0 kernels. + */ + private static Q8_0TornadoTensor dequantizeToQ8_0TornadoTensor(GGMLTensorEntry entry) { + // The entry's memorySegment includes a TornadoVM ARRAY_HEADER prefix (16 bytes of zeros). + // Slice past it so the K-quant FloatTensor reads raw tensor data starting at byte 0. + long headerBytes = TornadoNativeArray.ARRAY_HEADER; + GGMLTensorEntry dataEntry = new GGMLTensorEntry( + entry.mappedFile(), entry.name(), entry.ggmlType(), entry.shape(), + entry.memorySegment().asSlice(headerBytes)); + FloatTensor sourceTensor = loadTensor(dataEntry); + int numElements = sourceTensor.size(); + int blockSize = 32; + int blocksNeeded = (numElements + blockSize - 1) / blockSize; + int q8BlockBytes = 34; // 2 bytes scale + 32 bytes quants + int q8BytesNeeded = blocksNeeded * q8BlockBytes; + + byte[] q8Data = new byte[q8BytesNeeded]; + + for (int b = 0; b < blocksNeeded; b++) { + int start = b * blockSize; + int end = Math.min(start + blockSize, numElements); + + // Find max absolute value for scale + float maxAbs = 0; + for (int i = start; i < end; i++) { + maxAbs = Math.max(maxAbs, Math.abs(sourceTensor.getFloat(i))); + } + float scale = maxAbs / 127.0f; + + // Write scale as fp16 (little-endian) + short scaleF16 = Float.floatToFloat16(scale); + int blockOff = b * q8BlockBytes; + q8Data[blockOff] = (byte) (scaleF16 & 0xFF); + q8Data[blockOff + 1] = (byte) ((scaleF16 >> 8) & 0xFF); + + // Quantize values + float invScale = scale != 0 ? 1.0f / scale : 0; + for (int i = start; i < end; i++) { + int qi = Math.round(sourceTensor.getFloat(i) * invScale); + qi = Math.max(-128, Math.min(127, qi)); + q8Data[blockOff + 2 + (i - start)] = (byte) qi; + } + } + + // Allocate native memory with TornadoNativeArray header, matching GGUF.loadTensorsTornado layout + MemorySegment nativeSegment = Arena.ofAuto().allocate(headerBytes + q8BytesNeeded, 4); + // Zero out the header + for (int i = 0; i < headerBytes; i++) { + nativeSegment.set(ValueLayout.JAVA_BYTE, i, (byte) 0); + } + // Copy Q8_0 data after header + MemorySegment.copy(MemorySegment.ofArray(q8Data), 0, nativeSegment, headerBytes, q8BytesNeeded); + return Q8_0TornadoTensor.fromTornadoMemorySegment(nativeSegment); + } + /** * Dispatcher method for loading a TornadoVM tensor array based on type. * Used in GPU-path. diff --git a/src/main/java/org/beehive/gpullama3/tensor/standard/Q4_KFloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/Q4_KFloatTensor.java new file mode 100644 index 00000000..d25322ad --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/Q4_KFloatTensor.java @@ -0,0 +1,127 @@ +package org.beehive.gpullama3.tensor.standard; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.Float16; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.VectorSpecies; + +import java.lang.foreign.MemorySegment; + +/** + * {@link FloatTensor} quantized in the {@link GGMLType#Q4_K} format. + * + *

Q4_K uses super-blocks of 256 elements, each containing: + *

    + *
  • 2 bytes: d (super-block scale, fp16)
  • + *
  • 2 bytes: dmin (super-block min, fp16)
  • + *
  • 12 bytes: scales/mins for 8 sub-blocks (packed 6-bit values)
  • + *
  • 128 bytes: 4-bit quantized values
  • + *
+ */ +public final class Q4_KFloatTensor extends FloatTensor { + + private static final int QK_K = 256; + private static final int BLOCK_SIZE = GGMLType.Q4_K.getTypeSize(); // 144 + + // Offsets within a block + private static final int D_OFFSET = 0; + private static final int DMIN_OFFSET = 2; + private static final int SCALES_OFFSET = 4; + private static final int QS_OFFSET = 16; // 4 + 12 + + final int size; + final MemorySegment memorySegment; + + public Q4_KFloatTensor(int size, MemorySegment memorySegment) { + this.size = size; + this.memorySegment = memorySegment; + } + + @Override + public int size() { + return size; + } + + @Override + public void setFloat(int index, float value) { + throw new UnsupportedOperationException("setFloat"); + } + + @Override + protected FloatVector getFloatVector(VectorSpecies species, int index) { + throw new UnsupportedOperationException("getFloatVector"); + } + + @Override + public GGMLType type() { + return GGMLType.Q4_K; + } + + @Override + public MemorySegment asMemorySegment() { + return memorySegment; + } + + /** + * Unpacks the 6-bit scale value for a given sub-block index. + * The 12 scale bytes encode 8 scale/min pairs in a packed format. + */ + private static int getScaleK4(int j, MemorySegment ms, long scalesOffset) { + if (j < 4) { + return Byte.toUnsignedInt(readByte(ms, scalesOffset + j)) & 63; + } else { + return (Byte.toUnsignedInt(readByte(ms, scalesOffset + j + 4)) & 0xF) + | ((Byte.toUnsignedInt(readByte(ms, scalesOffset + j - 4)) >> 6) << 4); + } + } + + /** + * Unpacks the 6-bit min value for a given sub-block index. + */ + private static int getMinK4(int j, MemorySegment ms, long scalesOffset) { + if (j < 4) { + return Byte.toUnsignedInt(readByte(ms, scalesOffset + j + 4)) & 63; + } else { + return (Byte.toUnsignedInt(readByte(ms, scalesOffset + j + 4)) >> 4) + | ((Byte.toUnsignedInt(readByte(ms, scalesOffset + j)) >> 6) << 4); + } + } + + @Override + public float getFloat(int index) { + assert 0 <= index && index < size; + int blockIndex = index / QK_K; + int withinBlock = index % QK_K; + long blockOffset = (long) blockIndex * BLOCK_SIZE; + + float d = Float.float16ToFloat(readShort(memorySegment, blockOffset + D_OFFSET)); + float dmin = Float.float16ToFloat(readShort(memorySegment, blockOffset + DMIN_OFFSET)); + long scalesOff = blockOffset + SCALES_OFFSET; + + // Each group of 64 elements uses 2 sub-blocks (low nibble / high nibble) + int pairIndex = withinBlock / 64; // 0..3 + int posInPair = withinBlock % 64; // 0..63 + + int subBlock; + int q; + if (posInPair < 32) { + subBlock = pairIndex * 2; + byte qByte = readByte(memorySegment, blockOffset + QS_OFFSET + pairIndex * 32 + posInPair); + q = Byte.toUnsignedInt(qByte) & 0xF; + } else { + subBlock = pairIndex * 2 + 1; + byte qByte = readByte(memorySegment, blockOffset + QS_OFFSET + pairIndex * 32 + (posInPair - 32)); + q = (Byte.toUnsignedInt(qByte) >> 4) & 0xF; + } + + int sc = getScaleK4(subBlock, memorySegment, scalesOff); + int m = getMinK4(subBlock, memorySegment, scalesOff); + + return d * sc * q - dmin * m; + } + + @Override + public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { + return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tensor/standard/Q5_KFloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/Q5_KFloatTensor.java new file mode 100644 index 00000000..55c250bf --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/Q5_KFloatTensor.java @@ -0,0 +1,130 @@ +package org.beehive.gpullama3.tensor.standard; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.Float16; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.VectorSpecies; + +import java.lang.foreign.MemorySegment; + +/** + * {@link FloatTensor} quantized in the {@link GGMLType#Q5_K} format. + * + *

Q5_K uses super-blocks of 256 elements, each containing: + *

    + *
  • 2 bytes: d (super-block scale, fp16)
  • + *
  • 2 bytes: dmin (super-block min, fp16)
  • + *
  • 12 bytes: scales/mins for 8 sub-blocks (packed 6-bit values, same as Q4_K)
  • + *
  • 32 bytes: qh (5th bit of each quant)
  • + *
  • 128 bytes: qs (lower 4 bits of quants)
  • + *
+ */ +public final class Q5_KFloatTensor extends FloatTensor { + + private static final int QK_K = 256; + private static final int BLOCK_SIZE = GGMLType.Q5_K.getTypeSize(); // 176 + + // Offsets within a block + private static final int D_OFFSET = 0; + private static final int DMIN_OFFSET = 2; + private static final int SCALES_OFFSET = 4; + private static final int QH_OFFSET = 16; // 32 bytes for 5th bit + private static final int QS_OFFSET = 48; // 128 bytes for lower 4 bits + + final int size; + final MemorySegment memorySegment; + + public Q5_KFloatTensor(int size, MemorySegment memorySegment) { + this.size = size; + this.memorySegment = memorySegment; + } + + @Override + public int size() { + return size; + } + + @Override + public void setFloat(int index, float value) { + throw new UnsupportedOperationException("setFloat"); + } + + @Override + protected FloatVector getFloatVector(VectorSpecies species, int index) { + throw new UnsupportedOperationException("getFloatVector"); + } + + @Override + public GGMLType type() { + return GGMLType.Q5_K; + } + + @Override + public MemorySegment asMemorySegment() { + return memorySegment; + } + + private static int getScaleK4(int j, MemorySegment ms, long scalesOffset) { + if (j < 4) { + return Byte.toUnsignedInt(readByte(ms, scalesOffset + j)) & 63; + } else { + return (Byte.toUnsignedInt(readByte(ms, scalesOffset + j + 4)) & 0xF) + | ((Byte.toUnsignedInt(readByte(ms, scalesOffset + j - 4)) >> 6) << 4); + } + } + + private static int getMinK4(int j, MemorySegment ms, long scalesOffset) { + if (j < 4) { + return Byte.toUnsignedInt(readByte(ms, scalesOffset + j + 4)) & 63; + } else { + return (Byte.toUnsignedInt(readByte(ms, scalesOffset + j + 4)) >> 4) + | ((Byte.toUnsignedInt(readByte(ms, scalesOffset + j)) >> 6) << 4); + } + } + + @Override + public float getFloat(int index) { + assert 0 <= index && index < size; + int blockIndex = index / QK_K; + int withinBlock = index % QK_K; + long blockOffset = (long) blockIndex * BLOCK_SIZE; + + float d = Float.float16ToFloat(readShort(memorySegment, blockOffset + D_OFFSET)); + float dmin = Float.float16ToFloat(readShort(memorySegment, blockOffset + DMIN_OFFSET)); + long scalesOff = blockOffset + SCALES_OFFSET; + + int pairIndex = withinBlock / 64; // 0..3 + int posInPair = withinBlock % 64; // 0..63 + + int subBlock; + int q; + int highBit; + if (posInPair < 32) { + subBlock = pairIndex * 2; + byte qsByte = readByte(memorySegment, blockOffset + QS_OFFSET + pairIndex * 32 + posInPair); + q = Byte.toUnsignedInt(qsByte) & 0xF; + // 5th bit from qh: bit position is (pairIndex * 2) for low nibble elements + byte qhByte = readByte(memorySegment, blockOffset + QH_OFFSET + posInPair); + highBit = (Byte.toUnsignedInt(qhByte) >> (pairIndex * 2)) & 1; + } else { + subBlock = pairIndex * 2 + 1; + byte qsByte = readByte(memorySegment, blockOffset + QS_OFFSET + pairIndex * 32 + (posInPair - 32)); + q = (Byte.toUnsignedInt(qsByte) >> 4) & 0xF; + // 5th bit from qh: bit position is (pairIndex * 2 + 1) for high nibble elements + byte qhByte = readByte(memorySegment, blockOffset + QH_OFFSET + (posInPair - 32)); + highBit = (Byte.toUnsignedInt(qhByte) >> (pairIndex * 2 + 1)) & 1; + } + + q += highBit * 16; // Add the 5th bit + + int sc = getScaleK4(subBlock, memorySegment, scalesOff); + int m = getMinK4(subBlock, memorySegment, scalesOff); + + return d * sc * q - dmin * m; + } + + @Override + public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { + return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tensor/standard/Q6_KFloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/Q6_KFloatTensor.java new file mode 100644 index 00000000..602dcc2d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/Q6_KFloatTensor.java @@ -0,0 +1,123 @@ +package org.beehive.gpullama3.tensor.standard; + +import org.beehive.gpullama3.tensor.GGMLType; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.VectorSpecies; + +import java.lang.foreign.MemorySegment; + +/** + * {@link FloatTensor} quantized in the {@link GGMLType#Q6_K} format. + * + *

Q6_K uses super-blocks of 256 elements, each containing: + *

    + *
  • 128 bytes: ql (lower 4 bits of 6-bit quants)
  • + *
  • 64 bytes: qh (upper 2 bits of 6-bit quants)
  • + *
  • 16 bytes: scales (signed 8-bit per 16-element sub-block)
  • + *
  • 2 bytes: d (super-block scale, fp16)
  • + *
+ */ +public final class Q6_KFloatTensor extends FloatTensor { + + private static final int QK_K = 256; + private static final int BLOCK_SIZE = GGMLType.Q6_K.getTypeSize(); // 210 + + // Offsets within a block + private static final int QL_OFFSET = 0; // 128 bytes + private static final int QH_OFFSET = 128; // 64 bytes + private static final int SCALES_OFFSET = 192; // 16 bytes + private static final int D_OFFSET = 208; // 2 bytes + + final int size; + final MemorySegment memorySegment; + + public Q6_KFloatTensor(int size, MemorySegment memorySegment) { + this.size = size; + this.memorySegment = memorySegment; + } + + @Override + public int size() { + return size; + } + + @Override + public void setFloat(int index, float value) { + throw new UnsupportedOperationException("setFloat"); + } + + @Override + protected FloatVector getFloatVector(VectorSpecies species, int index) { + throw new UnsupportedOperationException("getFloatVector"); + } + + @Override + public GGMLType type() { + return GGMLType.Q6_K; + } + + @Override + public MemorySegment asMemorySegment() { + return memorySegment; + } + + @Override + public float getFloat(int index) { + assert 0 <= index && index < size; + int blockIndex = index / QK_K; + int withinBlock = index % QK_K; + long blockOffset = (long) blockIndex * BLOCK_SIZE; + + float d = Float.float16ToFloat(readShort(memorySegment, blockOffset + D_OFFSET)); + + // The block is split into two halves of 128 elements each + int halfIndex = withinBlock / 128; // 0 or 1 + int posInHalf = withinBlock % 128; // 0..127 + + // Within each half, there are 4 groups of 32 elements + int groupInHalf = posInHalf / 32; // 0..3 + int posInGroup = posInHalf % 32; // 0..31 + + // ql/qh pointers advance by 64/32 per half + long qlBase = blockOffset + QL_OFFSET + halfIndex * 64; + long qhBase = blockOffset + QH_OFFSET + halfIndex * 32; + long scBase = blockOffset + SCALES_OFFSET + halfIndex * 8; + + // Scale index: is = posInGroup / 16 (0 or 1), then offset by group + int is = posInGroup / 16; + + int qValue; + switch (groupInHalf) { + case 0 -> { + int ql = Byte.toUnsignedInt(readByte(memorySegment, qlBase + posInGroup)); + int qh = Byte.toUnsignedInt(readByte(memorySegment, qhBase + posInGroup)); + qValue = ((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32; + return d * (byte) readByte(memorySegment, scBase + is) * qValue; + } + case 1 -> { + int ql = Byte.toUnsignedInt(readByte(memorySegment, qlBase + 32 + posInGroup)); + int qh = Byte.toUnsignedInt(readByte(memorySegment, qhBase + posInGroup)); + qValue = ((ql & 0xF) | (((qh >> 2) & 3) << 4)) - 32; + return d * (byte) readByte(memorySegment, scBase + is + 2) * qValue; + } + case 2 -> { + int ql = Byte.toUnsignedInt(readByte(memorySegment, qlBase + posInGroup)); + int qh = Byte.toUnsignedInt(readByte(memorySegment, qhBase + posInGroup)); + qValue = ((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32; + return d * (byte) readByte(memorySegment, scBase + is + 4) * qValue; + } + case 3 -> { + int ql = Byte.toUnsignedInt(readByte(memorySegment, qlBase + 32 + posInGroup)); + int qh = Byte.toUnsignedInt(readByte(memorySegment, qhBase + posInGroup)); + qValue = ((ql >> 4) | (((qh >> 6) & 3) << 4)) - 32; + return d * (byte) readByte(memorySegment, scBase + is + 6) * qValue; + } + default -> throw new AssertionError(); + } + } + + @Override + public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { + return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); + } +}