Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,12 @@ class GGUFReader(
// Divide first to avoid overflow, then multiply. For quantized tensors,
// nElems must be divisible by blockSize, so this is exact.
val numBlocks = nElems.toLong() / blockSize
val nBytes = (numBlocks * typeSize).toInt()
val nBytesLong = numBlocks * typeSize.toLong()
require(nBytesLong <= Int.MAX_VALUE) {
"Tensor '$tensorName' is $nBytesLong bytes (> 2 GB). " +
"Use StreamingGGUFReader with loadTensorStorageMapped() instead."
}
val nBytes = nBytesLong.toInt()
val dataOffs = startOffs + offsetTensor[0].toInt()

// For non-native/quantized types (including unknown), tensor payload is stored as bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ public class StreamingGGUFReader private constructor(
* @return Raw bytes for the tensor
*/
public fun loadTensorData(tensor: StreamingTensorInfo): ByteArray {
return source.readAt(tensor.absoluteDataOffset, tensor.nBytes)
require(tensor.nBytes <= Int.MAX_VALUE) {
"Tensor '${tensor.name}' is ${tensor.nBytes} bytes (> 2 GB). " +
"Use loadTensorStorageMapped() for file-backed zero-copy access instead."
}
return source.readAt(tensor.absoluteDataOffset, tensor.nBytes.toInt())
}

/**
Expand All @@ -96,7 +100,11 @@ public class StreamingGGUFReader private constructor(
* @return Number of bytes read
*/
public fun loadTensorData(tensor: StreamingTensorInfo, buffer: ByteArray, offset: Int = 0): Int {
return source.readAt(tensor.absoluteDataOffset, buffer, offset, tensor.nBytes)
require(tensor.nBytes <= Int.MAX_VALUE) {
"Tensor '${tensor.name}' is ${tensor.nBytes} bytes (> 2 GB). " +
"Use loadTensorStorageMapped() for file-backed zero-copy access instead."
}
return source.readAt(tensor.absoluteDataOffset, buffer, offset, tensor.nBytes.toInt())
}

// ========== TensorStorage Loading ==========
Expand Down Expand Up @@ -145,7 +153,7 @@ public class StreamingGGUFReader private constructor(
buffer = BufferHandle.FileBacked(
path = filePath,
fileOffset = tensor.absoluteDataOffset,
sizeInBytes = tensor.nBytes.toLong()
sizeInBytes = tensor.nBytes
),
placement = Placement.MMAP_WEIGHTS
)
Expand Down Expand Up @@ -339,24 +347,24 @@ public class StreamingGGUFReader private constructor(
for ((index, info) in parsedTensors.withIndex()) {
val nElements = if (info.dims.isEmpty()) 0L else info.dims.fold(1UL) { acc, d -> acc * d }.toLong()

val nBytes: Int = if (info.ggmlType.isUnknown) {
val nBytes: Long = if (info.ggmlType.isUnknown) {
// For unknown types, estimate size from next tensor's offset
val sortedIndex = sortedByOffset.indexOfFirst { it.name == info.name }
if (sortedIndex < sortedByOffset.size - 1) {
// Use gap to next tensor as size estimate
val nextOffset = sortedByOffset[sortedIndex + 1].relativeOffset
(nextOffset - info.relativeOffset).toInt()
nextOffset - info.relativeOffset
} else {
// Last tensor - estimate from element count assuming 1 byte per element
// This is a rough fallback; actual loading may need adjustment
nElements.toInt()
nElements
}
} else {
// Known type - calculate from quantization parameters
val (blockSize, typeSize) = GGML_QUANT_SIZES[info.ggmlType]
?: (1 to 1) // Fallback for types in enum but not in size map
val numBlocks = nElements / blockSize
(numBlocks * typeSize).toInt()
numBlocks * typeSize.toLong()
}

_tensors.add(
Expand Down Expand Up @@ -534,7 +542,7 @@ public data class StreamingTensorInfo(
/** Total number of elements */
val nElements: Long,
/** Size in bytes (estimated for unknown types) */
val nBytes: Int,
val nBytes: Long,
/** Offset relative to data section start */
val relativeOffset: Long,
/** Absolute byte offset in file (set after parsing) */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
package sk.ainet.io.gguf

import sk.ainet.io.RandomAccessSource
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue

/**
* Regression test for nBytes Int overflow with tensors > 2 GB.
*
* Bug: StreamingTensorInfo.nBytes was Int, which overflowed for tensors
* whose byte size exceeds Int.MAX_VALUE (~2 GB). This blocked loading
* Gemma 4 E4B (per_layer_token_embd.weight is ~1.4 GB in Q4_K_M but
* the Long→Int cast overflowed earlier in the arithmetic chain for
* larger tensors).
*
* Root cause: `(numBlocks * typeSize).toInt()` silently truncated a
* Long result that exceeded Int.MAX_VALUE.
*
* Fix: Changed StreamingTensorInfo.nBytes from Int to Long, and kept
* all intermediate arithmetic in Long.
*/
class LargeTensorNBytesOverflowTest {

/**
* Verify that nBytes computation stays in Long and does not overflow
* for a tensor with > 2 GB byte size.
*
* Simulates a tensor like Gemma 4 E4B's per_layer_token_embd.weight:
* shape [262144, 10752], Q4_K format.
*
* Q4_K: blockSize=256, typeSize=144
* nElements = 262144 * 10752 = 2,818,572,288
* numBlocks = 2,818,572,288 / 256 = 11,009,892
* nBytes = 11,009,892 * 144 = 1,585,424,448 (fits in Int)
*
* But for a larger tensor (e.g., shape [524288, 10752]):
* nElements = 524288 * 10752 = 5,637,144,576
* numBlocks = 5,637,144,576 / 256 = 22,019,784
* nBytes = 22,019,784 * 144 = 3,170,848,896 (> Int.MAX_VALUE!)
*/
@Test
fun nBytesComputationDoesNotOverflowForLargeTensors() {
// Create a minimal GGUF with a large tensor that would overflow Int
val largeTensorGguf = buildLargeTensorGguf(
tensorName = "large_weight",
dims = listOf(524288UL, 10752UL), // > 2GB in Q4_K
ggmlType = GGMLQuantizationType.Q4_K
)

val source = bytesAsSource(largeTensorGguf)
val reader = StreamingGGUFReader.open(source)

assertEquals(1, reader.tensors.size)
val tensor = reader.tensors[0]

// The key assertion: nBytes must be positive and correct
val expectedElements = 524288L * 10752L
val expectedBlocks = expectedElements / 256
val expectedBytes = expectedBlocks * 144

assertEquals("large_weight", tensor.name)
assertEquals(expectedElements, tensor.nElements)
assertTrue(tensor.nBytes > 0, "nBytes must be positive, got ${tensor.nBytes}")
assertEquals(expectedBytes, tensor.nBytes,
"nBytes computation should not overflow: expected $expectedBytes, got ${tensor.nBytes}")
assertTrue(tensor.nBytes > Int.MAX_VALUE,
"This tensor should exceed Int.MAX_VALUE to test the fix")
}

@Test
fun nBytesCorrectForNormalSizedTensor() {
// Normal tensor that fits in Int — should still work
val normalGguf = buildLargeTensorGguf(
tensorName = "normal_weight",
dims = listOf(256UL, 256UL), // 65536 elements, Q4_K
ggmlType = GGMLQuantizationType.Q4_K
)

val source = bytesAsSource(normalGguf)
val reader = StreamingGGUFReader.open(source)

val tensor = reader.tensors[0]
val expectedElements = 256L * 256L
val expectedBlocks = expectedElements / 256
val expectedBytes = expectedBlocks * 144

assertEquals(expectedElements, tensor.nElements)
assertEquals(expectedBytes, tensor.nBytes)
}

@Test
fun loadTensorDataRejectsLargeTensor() {
val largeTensorGguf = buildLargeTensorGguf(
tensorName = "huge_weight",
dims = listOf(524288UL, 10752UL),
ggmlType = GGMLQuantizationType.Q4_K
)

val source = bytesAsSource(largeTensorGguf)
val reader = StreamingGGUFReader.open(source)
val tensor = reader.tensors[0]

// loadTensorData should reject > 2GB tensors with a clear error
assertFailsWith<IllegalArgumentException> {
reader.loadTensorData(tensor)
}
}

@Test
fun loadTensorStorageMappedWorksForLargeTensor() {
val largeTensorGguf = buildLargeTensorGguf(
tensorName = "huge_weight",
dims = listOf(524288UL, 10752UL),
ggmlType = GGMLQuantizationType.Q4_K
)

val source = bytesAsSource(largeTensorGguf)
val reader = StreamingGGUFReader.open(source)
val tensor = reader.tensors[0]

// loadTensorStorageMapped should work — it creates a FileBacked handle
// without loading data into heap
val storage = reader.loadTensorStorageMapped(tensor, "/fake/path.gguf")
assertTrue(storage.isFileBacked)
assertEquals(tensor.nBytes, storage.buffer.sizeInBytes)
}

@Test
fun gemma4StyleTensorDoesNotOverflow() {
// Simulate the exact Gemma 4 E4B tensor that triggered the bug:
// per_layer_token_embd.weight: shape [262144, 10752], Q4_K_M
// This one fits in Int but was failing due to intermediate overflow
val gguf = buildLargeTensorGguf(
tensorName = "per_layer_token_embd.weight",
dims = listOf(262144UL, 10752UL),
ggmlType = GGMLQuantizationType.Q4_K
)

val source = bytesAsSource(gguf)
val reader = StreamingGGUFReader.open(source)
val tensor = reader.tensors[0]

val expectedElements = 262144L * 10752L
val expectedBlocks = expectedElements / 256
val expectedBytes = expectedBlocks * 144 // 1,585,424,448 bytes

assertEquals(expectedElements, tensor.nElements)
assertEquals(expectedBytes, tensor.nBytes)
assertTrue(tensor.nBytes > 0, "nBytes must not be negative (overflow)")
assertTrue(tensor.nBytes < Int.MAX_VALUE,
"Gemma 4 E4B PLE weight should fit in Int: ${tensor.nBytes}")
}

// ========== Helpers ==========

/**
* Build a minimal valid GGUF v3 file with one tensor (metadata only, no actual data).
*
* GGUF format:
* - Magic (4 bytes): 0x46554747
* - Version (4 bytes): 3
* - Tensor count (8 bytes)
* - KV count (8 bytes): 0
* - Tensor info entries
* - Data section (empty — we only test metadata parsing)
*/
private fun buildLargeTensorGguf(
tensorName: String,
dims: List<ULong>,
ggmlType: GGMLQuantizationType
): ByteArray {
val nameBytes = tensorName.encodeToByteArray()
// Calculate size:
// Header: 4 + 4 + 8 + 8 = 24
// Tensor info: 8 (name len) + name + 4 (ndims) + 8*ndims (dims) + 4 (type) + 8 (offset)
val tensorInfoSize = 8 + nameBytes.size + 4 + 8 * dims.size + 4 + 8
val totalSize = 24 + tensorInfoSize + 32 // +32 for alignment padding

val buf = ByteArray(totalSize)
var pos = 0

// Magic
writeUInt(buf, pos, GGUF_MAGIC); pos += 4
// Version
writeUInt(buf, pos, 3u); pos += 4
// Tensor count
writeULong(buf, pos, 1u); pos += 8
// KV count
writeULong(buf, pos, 0u); pos += 8

// Tensor info: name length
writeULong(buf, pos, nameBytes.size.toULong()); pos += 8
// Tensor info: name
nameBytes.copyInto(buf, pos); pos += nameBytes.size
// Tensor info: ndims
writeUInt(buf, pos, dims.size.toUInt()); pos += 4
// Tensor info: dims
for (d in dims) {
writeULong(buf, pos, d); pos += 8
}
// Tensor info: type
writeUInt(buf, pos, ggmlType.value.toUInt()); pos += 4
// Tensor info: relative offset
writeULong(buf, pos, 0u); pos += 8

return buf
}

private fun writeUInt(buf: ByteArray, pos: Int, value: UInt) {
val v = value.toInt()
buf[pos] = (v and 0xFF).toByte()
buf[pos + 1] = ((v shr 8) and 0xFF).toByte()
buf[pos + 2] = ((v shr 16) and 0xFF).toByte()
buf[pos + 3] = ((v shr 24) and 0xFF).toByte()
}

private fun writeULong(buf: ByteArray, pos: Int, value: ULong) {
val v = value.toLong()
for (i in 0 until 8) {
buf[pos + i] = ((v shr (i * 8)) and 0xFF).toByte()
}
}

private fun bytesAsSource(bytes: ByteArray): RandomAccessSource {
return object : RandomAccessSource {
override val size: Long get() = bytes.size.toLong()
override fun readAt(position: Long, length: Int): ByteArray =
bytes.copyOfRange(position.toInt(), position.toInt() + length)
override fun readAt(position: Long, buffer: ByteArray, offset: Int, length: Int): Int {
bytes.copyInto(buffer, offset, position.toInt(), position.toInt() + length)
return length
}
override fun close() {}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class QuantizedModelLoadTest {

// Verify nBytes calculation: (nElements / blockSize) * typeSize
val (blockSize, typeSize) = GGML_QUANT_SIZES[tensor.tensorType]!!
val expectedBytes = (tensor.nElements / blockSize) * typeSize
assertEquals(expectedBytes.toInt(), tensor.nBytes,
val expectedBytes = (tensor.nElements.toLong() / blockSize) * typeSize
assertEquals(expectedBytes, tensor.nBytes,
"Tensor ${tensor.name}: nBytes mismatch - expected $expectedBytes, got ${tensor.nBytes}")
}

Expand All @@ -59,7 +59,7 @@ class QuantizedModelLoadTest {
if (smallTensor != null) {
println("\nLoading smallest tensor: ${smallTensor.name} (${smallTensor.nBytes} bytes)")
val data = reader.loadTensorData(smallTensor)
assertEquals(smallTensor.nBytes, data.size,
assertEquals(smallTensor.nBytes, data.size.toLong(),
"Loaded data size should match nBytes")
println("Successfully loaded ${data.size} bytes!")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class StreamingGGUFReaderTest {
val data = reader.loadTensor(firstTensor.name)

// Verify data size matches expected
assertEquals(firstTensor.nBytes, data.size,
assertEquals(firstTensor.nBytes, data.size.toLong(),
"Loaded data size ${data.size} doesn't match expected ${firstTensor.nBytes}")

// Load same tensor again using TensorInfo directly
Expand All @@ -151,12 +151,12 @@ class StreamingGGUFReaderTest {
assertTrue(reader.tensors.isNotEmpty(), "No tensors to test")

val tensor = reader.tensors.first()
val buffer = ByteArray(tensor.nBytes + 100) // Extra space
val buffer = ByteArray(tensor.nBytes.toInt() + 100) // Extra space

// Load into buffer at offset
val bytesRead = reader.loadTensorData(tensor, buffer, 10)

assertEquals(tensor.nBytes, bytesRead, "Bytes read mismatch")
assertEquals(tensor.nBytes, bytesRead.toLong(), "Bytes read mismatch")

// First 10 bytes should be zero (untouched)
for (i in 0 until 10) {
Expand All @@ -165,7 +165,7 @@ class StreamingGGUFReaderTest {

// Compare with direct load
val directData = reader.loadTensorData(tensor)
for (i in 0 until tensor.nBytes) {
for (i in 0 until tensor.nBytes.toInt()) {
assertEquals(directData[i], buffer[10 + i],
"Data mismatch at index $i")
}
Expand All @@ -181,7 +181,7 @@ class StreamingGGUFReaderTest {
// Try to load each tensor and verify size matches expected
for (tensor in streamingReader.tensors) {
val data = streamingReader.loadTensorData(tensor)
assertEquals(tensor.nBytes, data.size,
assertEquals(tensor.nBytes, data.size.toLong(),
"Data size mismatch for tensor '${tensor.name}': expected ${tensor.nBytes}, got ${data.size}")
}
}
Expand Down
Loading