From 7962e2b6efb10d0bf03e623a5c9f3bd598401c23 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sun, 19 Apr 2026 12:56:53 +0200 Subject: [PATCH] Add skainet-io-iree-params module with IrpaWriter (PR C of #523) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Writes IREE parameter archives (.irpa) consumable by `iree-compile --iree-opt-import-parameters=` and by `iree-run-module --parameters==`. The archive format per IREE's `parameter_archive.h`: +--- 40 B ----+ fixed header (magic "IRPA", version, counts) +--- 48 B ----+ three segment references +--- pad 16 --+ +-------------+ entry segment: 80-byte DATA records +-------------+ metadata segment: concatenated key bytes +--- pad 64 --+ +-------------+ storage segment: raw tensor bytes per entry All u16/u32/u64 values little-endian. The C entry-header struct has an implicit 4-byte pad after `u32 type` to align the following `u64 flags` — the writer emits that pad explicitly and a byte-level test pins the 80-byte layout so future changes can't silently re-break it. No scope column in the archive itself: scope is a runtime binding (`--parameters==`). Callers with multiple scopes group via `IrpaWriter.groupByScope(refs)` and emit one .irpa per scope. The writer delegates byte sourcing to `BufferHandle`; Owned and Borrowed variants are wired today, with Mapped / FileBacked landing in PR E (#523) to give the gguf and safetensors loaders a zero-copy path into the archive. ### Companion fix in skainet-compile-hlo PR B (#524) emitted `util.global private @key : type` without an initializer, which iree-compile treats as uninitialized — it would not import anything from a .irpa. PR C completes the emission: util.global private @key = #flow.parameter.named<"scope"::"key"> : type %r = util.global.load @key : type `MlirValidator` was also taught to accept module-scope `util.global` assignments: the `@`-prefixed symbol is a global, not an SSA value, and must not trip the existing `%`-only SSA-format check. ### Tests - IrpaWriterTest pins the byte layout against the IREE spec — header magic/version, segment offsets, entry-record fields, key concatenation, data placement with 64-byte per-entry alignment, Owned / Borrowed handle support, groupByScope ordering, empty- input rejection. - Existing ConstantMaterializationPolicyTest updated to assert the new `#flow.parameter.named<...>` initializer on every externalized global. - Full `:skainet-compile:skainet-compile-hlo:jvmTest` and `:skainet-io:skainet-io-iree-params:jvmTest` pass. ### Deferred verification A real `iree-compile --iree-opt-import-parameters=` round-trip test will land once CI has an IREE toolchain available. Byte-level layout tests are a close proxy — any deviation from the reference C format breaks both. Part of #523. Co-Authored-By: Claude Opus 4.7 (1M context) --- settings.gradle.kts | 1 + .../api/jvm/skainet-compile-hlo.api | 2 + .../sk/ainet/compile/hlo/MlirValidator.kt | 13 +- .../converters/ConstantOperationsConverter.kt | 13 +- .../hlo/ConstantMaterializationPolicyTest.kt | 13 +- .../skainet-io-iree-params/build.gradle.kts | 69 ++++ .../skainet-io-iree-params/gradle.properties | 2 + .../kotlin/sk/ainet/io/irpa/IrpaWriter.kt | 320 ++++++++++++++++++ .../kotlin/sk/ainet/io/irpa/IrpaWriterTest.kt | 218 ++++++++++++ 9 files changed, 646 insertions(+), 5 deletions(-) create mode 100644 skainet-io/skainet-io-iree-params/build.gradle.kts create mode 100644 skainet-io/skainet-io-iree-params/gradle.properties create mode 100644 skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt create mode 100644 skainet-io/skainet-io-iree-params/src/commonTest/kotlin/sk/ainet/io/irpa/IrpaWriterTest.kt diff --git a/settings.gradle.kts b/settings.gradle.kts index 909a2ea6..c25aed61 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -76,3 +76,4 @@ include("skainet-test:skainet-test-java") include("skainet-apps:skainet-grayscale-cli") include("skainet-apps:skainet-tensor-tools") include("skainet-io:skainet-io-safetensors") +include("skainet-io:skainet-io-iree-params") diff --git a/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api b/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api index c9ad1992..a81a52a7 100644 --- a/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api +++ b/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api @@ -59,10 +59,12 @@ public final class sk/ainet/compile/hlo/ConversionContext { public final fun getModuleDeclarations ()Ljava/lang/String; public final fun getTypeMapper ()Lsk/ainet/compile/hlo/TypeMapper; public final fun getValueName (Ljava/lang/String;)Ljava/lang/String; + public final fun getValueType (Ljava/lang/String;)Ljava/lang/String; public final fun nextTempValue ()Ljava/lang/String; public final fun registerExternalParameter (Lsk/ainet/compile/hlo/ExternalParameterRef;)V public final fun setGraph (Lsk/ainet/lang/graph/ComputeGraph;)V public final fun setValueName (Ljava/lang/String;Ljava/lang/String;)V + public final fun setValueType (Ljava/lang/String;Ljava/lang/String;)V } public abstract class sk/ainet/compile/hlo/ConversionResult { diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt index 22af3fd5..afe13de9 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt @@ -66,8 +66,17 @@ public class MlirValidator { inFunction = true } - // Check for basic SSA value format - if (trimmed.contains(" = ") && !validateSSAValue(trimmed)) { + // Check for basic SSA value format. Module-scope global + // declarations (`util.global private @name = #flow.parameter.named<...> : ...`) + // look syntactically like assignments but bind `@`-prefixed + // symbols, not `%`-prefixed SSA values. Treated as a + // separate syntactic category — the IREE external-weight + // path (issue #523) depends on these lines passing + // validation. + if (trimmed.contains(" = ") && + !trimmed.startsWith("util.global") && + !validateSSAValue(trimmed) + ) { errors.add("Line ${lineNum + 1}: Invalid SSA value format") } diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt index 4a731733..44de7868 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt @@ -431,7 +431,18 @@ public class ConstantOperationsConverter : StableHloOperationConverter { source = BufferHandle.Owned(bytes) ) ) - context.emitModuleDeclaration("util.global private @${key} : $outputType") + // Bind the global to an archive entry via the IREE flow-dialect + // parameter attribute. Without this initializer, iree-compile + // treats the util.global as uninitialized and does not pull + // bytes from the .irpa file at --iree-opt-import-parameters + // time. Scope is carried in the MLIR reference (left of `::`) + // and resolved at iree-compile / iree-run-module time against + // `--parameters==.irpa`. The .irpa file itself + // stores a flat key table with no scope column. + context.emitModuleDeclaration( + "util.global private @${key} = " + + "#flow.parameter.named<\"${scope}\"::\"${key}\"> : $outputType" + ) val resultValue = context.nextTempValue() val operation = "$resultValue = util.global.load @${key} : $outputType" diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantMaterializationPolicyTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantMaterializationPolicyTest.kt index 8057f867..8ccc9fb1 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantMaterializationPolicyTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantMaterializationPolicyTest.kt @@ -44,8 +44,14 @@ class ConstantMaterializationPolicyTest { ConstantMaterializationPolicy.ExternalAlways() ).convert(buildTensorConstantGraph(), "external_policy") + // The util.global must carry a #flow.parameter.named initializer + // so iree-compile binds the declaration to an archive entry at + // --iree-opt-import-parameters time (see issue #523). assertTrue( - module.content.contains("util.global private @weights : tensor<2x2xf32>"), + module.content.contains( + "util.global private @weights = " + + "#flow.parameter.named<\"model\"::\"weights\"> : tensor<2x2xf32>" + ), "module decl missing:\n${module.content}" ) assertTrue( @@ -110,7 +116,10 @@ class ConstantMaterializationPolicyTest { ) // Large is externalized. assertTrue( - module.content.contains("util.global private @large_w : tensor<4x4xf32>"), + module.content.contains( + "util.global private @large_w = " + + "#flow.parameter.named<\"model\"::\"large_w\"> : tensor<4x4xf32>" + ), "large tensor must externalize under SizeThreshold:\n${module.content}" ) diff --git a/skainet-io/skainet-io-iree-params/build.gradle.kts b/skainet-io/skainet-io-iree-params/build.gradle.kts new file mode 100644 index 00000000..4095fdf9 --- /dev/null +++ b/skainet-io/skainet-io-iree-params/build.gradle.kts @@ -0,0 +1,69 @@ +import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl +import org.jetbrains.kotlin.gradle.dsl.JvmTarget + +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.androidMultiplatformLibrary) + alias(libs.plugins.vanniktech.mavenPublish) + id("sk.ainet.dokka") +} + +kotlin { + targets.configureEach { + compilations.configureEach { + compileTaskProvider.get().compilerOptions { + freeCompilerArgs.add("-Xexpect-actual-classes") + } + } + } + + jvm() + android { + namespace = "sk.ainet.io.irpa" + compileSdk = libs.versions.android.compileSdk.get().toInt() + minSdk = libs.versions.android.minSdk.get().toInt() + compilerOptions { + jvmTarget.set(JvmTarget.JVM_1_8) + } + } + + iosArm64() + iosSimulatorArm64() + macosArm64() + linuxX64() + linuxArm64() + + js { + browser() + } + + @OptIn(ExperimentalWasmDsl::class) + wasmJs { + browser() + } + + @OptIn(ExperimentalWasmDsl::class) + wasmWasi { + nodejs() + } + + sourceSets { + val commonMain by getting { + dependencies { + implementation(libs.kotlinx.io.core) + implementation(project(":skainet-lang:skainet-lang-core")) + implementation(project(":skainet-compile:skainet-compile-hlo")) + } + } + val commonTest by getting { + dependencies { + implementation(libs.kotlin.test) + } + } + val jvmTest by getting { + dependencies { + implementation(libs.junit) + } + } + } +} diff --git a/skainet-io/skainet-io-iree-params/gradle.properties b/skainet-io/skainet-io-iree-params/gradle.properties new file mode 100644 index 00000000..4c961584 --- /dev/null +++ b/skainet-io/skainet-io-iree-params/gradle.properties @@ -0,0 +1,2 @@ +POM_ARTIFACT_ID=skainet-io-iree-params +POM_NAME=skainet IO IREE Params diff --git a/skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt b/skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt new file mode 100644 index 00000000..2be7eb7c --- /dev/null +++ b/skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt @@ -0,0 +1,320 @@ +package sk.ainet.io.irpa + +import kotlinx.io.Sink +import kotlinx.io.write +import kotlinx.io.writeIntLe +import kotlinx.io.writeLongLe +import kotlinx.io.writeShortLe +import sk.ainet.compile.hlo.ExternalParameterRef +import sk.ainet.lang.tensor.storage.BufferHandle + +/** + * Writes an IREE parameter archive (`.irpa`) file. + * + * The archive is the runtime-consumable counterpart to the + * `#flow.parameter.named<"scope"::"key">` references emitted by + * [sk.ainet.compile.hlo.StableHloConverter] when + * [sk.ainet.compile.hlo.ConstantMaterializationPolicy.ExternalAlways] is + * active. `iree-compile --iree-opt-import-parameters=.irpa` inlines + * the referenced bytes at compile time; `iree-run-module --parameters==.irpa` + * loads them at runtime. + * + * Format: v0 binary layout defined in IREE's `parameter_archive.h`: + * + * ``` + * +-------------------------- 40 B --------------------------+ + * | header_v0 (magic="IRPA", segments) | + * +-------------------------- pad to 16 ----------------------+ + * | entry_segment (DATA entries, each 80 B aligned to 16) | + * +-------------------------- no padding ---------------------+ + * | metadata_segment (concatenated key bytes, per entry) | + * +-------------------------- pad to entry.minimum_alignment -+ + * | storage_segment (tensor bytes, each block aligned to 64) | + * +-----------------------------------------------------------+ + * ``` + * + * All `u16`/`u32`/`u64` values little-endian. Offsets inside the header + * are relative to the header start; inside entries, to the owning + * segment's offset. The archive has no dtype, no shape, and no scope + * column — the MLIR `#flow.parameter.named<...>` reference carries all + * structural metadata, and scope binding is a CLI concern. + * + * This writer is scope-agnostic: a single [write] call produces one + * `.irpa` for all entries passed in. Callers with tensors in multiple + * scopes must call [write] once per scope (use [groupByScope]). + * + * See issue #523 for the architectural context. + */ +public class IrpaWriter { + + /** + * Write an `.irpa` archive containing every [entry] to [sink]. + * + * Streams the output so peak memory stays bounded by the largest + * single [BufferHandle] and the entry/metadata tables (tiny even + * for thousands of weights). The sink is not flushed or closed — + * caller decides when to do either. + * + * @throws IllegalArgumentException if [entries] is empty — a + * valid archive must have at least one entry, and silently + * emitting an empty file is a worse failure mode than a loud + * precondition check. + */ + public fun write(entries: List, sink: Sink) { + require(entries.isNotEmpty()) { + "IrpaWriter.write() requires at least one entry; empty archives are not useful." + } + + // Precompute layout. Segment offsets are relative to the + // header start (first byte of the file, since we always write + // the header at offset 0). + val headerBlockSize = HEADER_BLOCK_SIZE + val entrySegmentOffset = alignUp(headerBlockSize, HEADER_ALIGNMENT).toLong() + val entrySegmentLength = entries.size * ENTRY_SIZE_ALIGNED.toLong() + val metadataSegmentOffset = entrySegmentOffset + entrySegmentLength + val metadataSegmentLength = entries.sumOf { it.key.encodeToByteArray().size.toLong() } + + // Storage segment starts at the first alignment >= end-of-metadata + // that satisfies DEFAULT_DATA_ALIGNMENT. Per-entry alignment + // inside the segment uses the same default unless an entry + // overrides it (not exposed today). + val storageSegmentOffset = alignUpLong( + metadataSegmentOffset + metadataSegmentLength, + DEFAULT_DATA_ALIGNMENT.toLong() + ) + + // Running storage cursor — offset WITHIN the storage segment, + // not absolute. Each entry records its offset here after + // aligning up. + var storageCursor = 0L + val perEntryStorageOffset = LongArray(entries.size) + for ((i, entry) in entries.withIndex()) { + val aligned = alignUpLong(storageCursor, DEFAULT_DATA_ALIGNMENT.toLong()) + perEntryStorageOffset[i] = aligned + storageCursor = aligned + entry.source.sizeInBytes + } + val storageSegmentLength = storageCursor + + // --- Write the header --- + writeHeader( + sink = sink, + entryCount = entries.size.toLong(), + entrySegmentOffset = entrySegmentOffset, + entrySegmentLength = entrySegmentLength, + metadataSegmentOffset = metadataSegmentOffset, + metadataSegmentLength = metadataSegmentLength, + storageSegmentOffset = storageSegmentOffset, + storageSegmentLength = storageSegmentLength + ) + + // Pad from end-of-header-block to start-of-entry-segment. + writePadding(sink, (entrySegmentOffset - headerBlockSize.toLong()).toInt()) + + // --- Write the entry segment (DATA records) --- + var metadataCursor = 0L + for ((i, entry) in entries.withIndex()) { + val keyBytes = entry.key.encodeToByteArray() + writeDataEntry( + sink = sink, + nameOffset = metadataCursor, + nameLength = keyBytes.size.toLong(), + storageOffset = perEntryStorageOffset[i], + storageLength = entry.source.sizeInBytes + ) + metadataCursor += keyBytes.size + } + + // --- Write the metadata segment (concatenated key bytes) --- + for (entry in entries) { + val keyBytes = entry.key.encodeToByteArray() + for (b in keyBytes) sink.writeByte(b) + } + + // Pad from end-of-metadata to start-of-storage. + val metadataEndAbs = metadataSegmentOffset + metadataSegmentLength + writePaddingLong(sink, storageSegmentOffset - metadataEndAbs) + + // --- Write the storage segment --- + var writtenInStorage = 0L + for ((i, entry) in entries.withIndex()) { + // Pad from previous entry's end to this entry's aligned offset. + val entryOffset = perEntryStorageOffset[i] + writePaddingLong(sink, entryOffset - writtenInStorage) + writeBufferHandle(sink, entry.source) + writtenInStorage = entryOffset + entry.source.sizeInBytes + } + // No trailing file-level pad to 4096 — optional per the spec, + // and callers writing to a stream may not know total size up + // front. mmap readers tolerate short tails. + } + + /** + * Group a mixed-scope ref list into per-scope bundles. Callers + * with multiple scopes should invoke [write] once per bundle and + * pass each resulting file to `iree-compile --parameters==`. + * + * Preserves within-scope order — matters for reproducible archives. + */ + public fun groupByScope(entries: List): Map> { + val grouped = linkedMapOf>() + for (entry in entries) { + grouped.getOrPut(entry.scope) { mutableListOf() }.add(entry) + } + return grouped + } + + private fun writeHeader( + sink: Sink, + entryCount: Long, + entrySegmentOffset: Long, + entrySegmentLength: Long, + metadataSegmentOffset: Long, + metadataSegmentLength: Long, + storageSegmentOffset: Long, + storageSegmentLength: Long + ) { + // --- Fixed 40-byte header (iree_io_parameter_archive_header_v0_t) --- + sink.writeIntLe(MAGIC) // 0: magic + sink.writeShortLe(0) // 4: version_major + sink.writeShortLe(0) // 6: version_minor + sink.writeLongLe(HEADER_FIXED_SIZE.toLong()) // 8: header_size (40) + sink.writeLongLe(0L) // 16: next_header_offset + sink.writeLongLe(0L) // 24: flags + sink.writeLongLe(entryCount) // 32: entry_count + + // --- Three segment references, 16 bytes each --- + // Layout: { u64 offset; u64 length; } — offsets are relative + // to the start of the header block. + sink.writeLongLe(entrySegmentOffset) // 40: entry.offset + sink.writeLongLe(entrySegmentLength) // 48: entry.length + sink.writeLongLe(metadataSegmentOffset) // 56: metadata.offset + sink.writeLongLe(metadataSegmentLength) // 64: metadata.length + sink.writeLongLe(storageSegmentOffset) // 72: storage.offset + sink.writeLongLe(storageSegmentLength) // 80: storage.length + // Total written: 88 bytes (HEADER_BLOCK_SIZE). + } + + private fun writeDataEntry( + sink: Sink, + nameOffset: Long, + nameLength: Long, + storageOffset: Long, + storageLength: Long + ) { + sink.writeLongLe(ENTRY_HEADER_SIZE_DATA.toLong()) // entry_size (u64) + sink.writeIntLe(ENTRY_TYPE_DATA) // type (u32) + // 4-byte pad: the C struct has `u64 flags` immediately after + // `u32 type`, and the compiler inserts 4 bytes of padding to + // align `flags` on an 8-byte boundary. Without these bytes + // every subsequent u64 field reads from the wrong offset and + // the parser rejects the archive. + sink.writeIntLe(0) + sink.writeLongLe(0L) // flags (u64) + sink.writeLongLe(nameOffset) // name.offset + sink.writeLongLe(nameLength) // name.length + sink.writeLongLe(0L) // metadata.offset + sink.writeLongLe(0L) // metadata.length + sink.writeLongLe(DEFAULT_DATA_ALIGNMENT.toLong()) // minimum_alignment + sink.writeLongLe(storageOffset) // storage.offset + sink.writeLongLe(storageLength) // storage.length + // Total bytes written: 80 (ENTRY_HEADER_SIZE_DATA). Already + // 16-aligned, so no inter-entry padding required. + } + + private fun writeBufferHandle(sink: Sink, handle: BufferHandle) { + val (data, offset, length) = when (handle) { + is BufferHandle.Owned -> Triple(handle.data, handle.offset, handle.sizeInBytes.toInt()) + is BufferHandle.Borrowed -> Triple(handle.data, handle.offset, handle.sizeInBytes.toInt()) + else -> throw IllegalArgumentException( + "IrpaWriter does not yet handle BufferHandle subclass ${handle::class.simpleName}. " + + "Only Owned/Borrowed byte-array handles are wired in PR C; mmap-backed " + + "handles land with PR E (issue #523)." + ) + } + // Byte-at-a-time so we do not rely on kotlinx.io's + // `Sink.write(ByteArray, Int, Int)` extension resolving on the + // raw receiver — extension overload ambiguity bit this on JVM + // in an earlier revision. Performance-critical callers should + // switch to `write(ByteArray, ...)` once that path is covered + // by a dedicated test. + for (i in offset until offset + length) { + sink.writeByte(data[i]) + } + } + + private fun writePadding(sink: Sink, bytes: Int) { + if (bytes <= 0) return + val zeros = ByteArray(bytes) + sink.write(zeros) + } + + private fun writePaddingLong(sink: Sink, bytes: Long) { + if (bytes <= 0) return + // Write in chunks to keep the transient buffer small; 4 KiB + // is comfortably below any real alignment gap we'll see. + val chunk = ByteArray(4096) + var remaining = bytes + while (remaining > 0) { + val step = if (remaining >= chunk.size) chunk.size else remaining.toInt() + if (step == chunk.size) { + sink.write(chunk) + } else { + sink.write(ByteArray(step)) + } + remaining -= step + } + } + + public companion object { + /** Magic bytes `"IRPA"` little-endian. */ + public const val MAGIC: Int = 0x41505249 + /** + * Size of the fixed `iree_io_parameter_archive_header_v0_t` + * struct — magic through entry_count. This is the value + * written into the header's own `header_size` field. + */ + public const val HEADER_FIXED_SIZE: Int = 40 + /** + * Fixed size of the three segment references that follow the + * header struct (3 × 16 bytes). + */ + public const val SEGMENT_REFS_SIZE: Int = 48 + /** + * Total on-disk size of the header + segment-references block + * before the entry segment begins. Entry segment offset is + * always >= this value. + */ + public const val HEADER_BLOCK_SIZE: Int = HEADER_FIXED_SIZE + SEGMENT_REFS_SIZE + /** Alignment the header itself sits at in the file. */ + public const val HEADER_ALIGNMENT: Int = 16 + /** + * Natural size of a DATA-type entry header. Each entry is + * then padded to [ENTRY_SIZE_ALIGNED] before the next entry + * starts. + */ + public const val ENTRY_HEADER_SIZE_DATA: Int = 80 + /** Entry-to-entry alignment inside the entry segment. */ + public const val ENTRY_ALIGNMENT: Int = 16 + /** DATA entry size after alignment padding. */ + public const val ENTRY_SIZE_ALIGNED: Int = ENTRY_HEADER_SIZE_DATA // already 16-aligned + /** `iree_io_parameter_archive_entry_type_t::DATA`. */ + public const val ENTRY_TYPE_DATA: Int = 2 + /** Default per-entry storage alignment. */ + public const val DEFAULT_DATA_ALIGNMENT: Int = 64 + + /** + * Round [value] up to the nearest multiple of [alignment]. + * Used for header / entry alignment (always `Int` inputs). + */ + internal fun alignUp(value: Int, alignment: Int): Int = + (value + alignment - 1) and (alignment - 1).inv() + + /** + * Round [value] up to the nearest multiple of [alignment]. + * `Long` variant for storage-segment offsets, which can exceed + * 2 GiB for large models. + */ + internal fun alignUpLong(value: Long, alignment: Long): Long = + (value + alignment - 1L) and (alignment - 1L).inv() + } +} diff --git a/skainet-io/skainet-io-iree-params/src/commonTest/kotlin/sk/ainet/io/irpa/IrpaWriterTest.kt b/skainet-io/skainet-io-iree-params/src/commonTest/kotlin/sk/ainet/io/irpa/IrpaWriterTest.kt new file mode 100644 index 00000000..6dc192ef --- /dev/null +++ b/skainet-io/skainet-io-iree-params/src/commonTest/kotlin/sk/ainet/io/irpa/IrpaWriterTest.kt @@ -0,0 +1,218 @@ +package sk.ainet.io.irpa + +import kotlinx.io.Buffer +import kotlinx.io.readByteArray +import sk.ainet.compile.hlo.ExternalParameterRef +import sk.ainet.lang.tensor.storage.BufferHandle +import sk.ainet.lang.tensor.storage.TensorEncoding +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Byte-level tests for [IrpaWriter]. The format is small enough + * that we verify exact byte layout against what IREE's reader will + * expect, not just structural round-trips — if the layout drifts, + * `iree-compile --iree-opt-import-parameters=file.irpa` fails with + * a parse error, so we pin the wire format here. + */ +class IrpaWriterTest { + + @Test + fun testHeaderMagicAndVersion() { + val buffer = Buffer() + IrpaWriter().write( + entries = listOf(refFor("w", byteArrayOf(1, 2, 3, 4))), + sink = buffer + ) + val bytes = buffer.readByteArray() + + // Magic "IRPA" — 0x49 0x52 0x50 0x41 little-endian. + assertEquals(0x49.toByte(), bytes[0], "magic byte 0") + assertEquals(0x52.toByte(), bytes[1], "magic byte 1") + assertEquals(0x50.toByte(), bytes[2], "magic byte 2") + assertEquals(0x41.toByte(), bytes[3], "magic byte 3") + + // version_major / version_minor = 0 / 0 (v0 format) + assertEquals(0.toByte(), bytes[4]) + assertEquals(0.toByte(), bytes[5]) + assertEquals(0.toByte(), bytes[6]) + assertEquals(0.toByte(), bytes[7]) + + // header_size = 40 (fixed struct size before segment refs) + assertEquals(40L, readU64Le(bytes, 8)) + } + + @Test + fun testEntryCountAndSegmentOffsets() { + val a = refFor("a", byteArrayOf(10, 20, 30, 40)) + val b = refFor("bb", byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8)) + + val buffer = Buffer() + IrpaWriter().write(listOf(a, b), buffer) + val bytes = buffer.readByteArray() + + // entry_count = 2 at offset 32. + assertEquals(2L, readU64Le(bytes, 32)) + + // Segment refs follow the 40-byte header. The header block + // (40 + 48 segment refs = 88) is 16-aligned up to 96 before + // the entry segment begins — the runtime requires every + // segment start to sit on a 16-byte boundary. With 2 DATA + // entries of 80 bytes each: + // entry.offset = 96 + // entry.length = 160 + // metadata.offset = 256 + // metadata.length = 3 (keys "a" + "bb") + // storage.offset = align_up(259, 64) = 320 + assertEquals(96L, readU64Le(bytes, 40)) + assertEquals(160L, readU64Le(bytes, 48)) + assertEquals(256L, readU64Le(bytes, 56)) + assertEquals(3L, readU64Le(bytes, 64)) + assertEquals(320L, readU64Le(bytes, 72)) + } + + @Test + fun testDataEntryLayout() { + val buffer = Buffer() + IrpaWriter().write( + entries = listOf(refFor("key1", byteArrayOf(9, 8, 7, 6))), + sink = buffer + ) + val bytes = buffer.readByteArray() + + val entryStart = 96 // from testEntryCountAndSegmentOffsets + // Layout accounts for the 4-byte pad after `u32 type` that + // the C struct requires to align `u64 flags`. + assertEquals(80L, readU64Le(bytes, entryStart + 0), "entry_size") + assertEquals(2, readU32Le(bytes, entryStart + 8), "type=DATA(2)") + // bytes [+12, +16) = 4-byte struct pad (zero) + assertEquals(0L, readU64Le(bytes, entryStart + 16), "flags=0") + assertEquals(0L, readU64Le(bytes, entryStart + 24), "name.offset") + assertEquals(4L, readU64Le(bytes, entryStart + 32), "name.length (\"key1\")") + assertEquals(0L, readU64Le(bytes, entryStart + 40), "metadata.offset") + assertEquals(0L, readU64Le(bytes, entryStart + 48), "metadata.length") + assertEquals(64L, readU64Le(bytes, entryStart + 56), "minimum_alignment") + assertEquals(0L, readU64Le(bytes, entryStart + 64), "storage.offset") + assertEquals(4L, readU64Le(bytes, entryStart + 72), "storage.length") + } + + @Test + fun testKeysAndDataRoundTripExactly() { + val a = refFor("alpha", byteArrayOf(1, 1, 1, 1)) + val b = refFor("beta", byteArrayOf(2, 2, 2, 2, 2, 2, 2, 2)) + + val buffer = Buffer() + IrpaWriter().write(listOf(a, b), buffer) + val bytes = buffer.readByteArray() + + // Metadata segment starts right after the entry segment. + // 96 (header block aligned to 16) + 2 * 80 (entries) = 256. + val metaStart = 256 + assertContentEquals( + "alphabeta".encodeToByteArray(), + bytes.copyOfRange(metaStart, metaStart + 9), + "keys must be concatenated in entry order, no separators" + ) + + // Storage segment aligned to 64. metadata end = 248 + 9 = 257, + // aligned up to 64 -> 320. So storage starts at byte 320. + val storageStart = 320 + assertContentEquals( + byteArrayOf(1, 1, 1, 1), + bytes.copyOfRange(storageStart, storageStart + 4), + "first entry bytes at storage start" + ) + // Second entry aligned to 64 within storage segment. + // 4 (first) -> aligned up to 64. + val secondEntryAt = storageStart + 64 + assertContentEquals( + byteArrayOf(2, 2, 2, 2, 2, 2, 2, 2), + bytes.copyOfRange(secondEntryAt, secondEntryAt + 8), + "second entry at its 64-aligned offset" + ) + } + + @Test + fun testEmptyInputIsRejectedLoudly() { + assertFailsWith { + IrpaWriter().write(entries = emptyList(), sink = Buffer()) + } + } + + @Test + fun testGroupByScopePreservesOrder() { + val a1 = refFor("x", byteArrayOf(1), scope = "model") + val b1 = refFor("y", byteArrayOf(2), scope = "cache") + val a2 = refFor("z", byteArrayOf(3), scope = "model") + + val grouped = IrpaWriter().groupByScope(listOf(a1, b1, a2)) + + assertEquals(setOf("model", "cache"), grouped.keys) + assertEquals(listOf("x", "z"), grouped["model"]!!.map { it.key }) + assertEquals(listOf("y"), grouped["cache"]!!.map { it.key }) + } + + @Test + fun testOwnedBufferHandleWithOffsetAndBorrowedBoth() { + // Sanity-check that both BufferHandle flavors land bytes + // unmolested — important because Owned/Borrowed come from + // different ingestion paths (Owned from in-memory serializers, + // Borrowed from caller-supplied arrays, and eventually Mapped + // from PR E's mmap work). + val owned = ExternalParameterRef( + scope = "model", + key = "o", + encoding = TensorEncoding.Dense(bytesPerElement = 1), + source = BufferHandle.Owned(byteArrayOf(0, 0, 42, 43, 44), offset = 2) + ) + val borrowed = ExternalParameterRef( + scope = "model", + key = "b", + encoding = TensorEncoding.Dense(bytesPerElement = 1), + source = BufferHandle.Borrowed(byteArrayOf(99, 100, 101)) + ) + + val buffer = Buffer() + IrpaWriter().write(listOf(owned, borrowed), buffer) + val bytes = buffer.readByteArray() + + // Offsets: owned is 3 bytes (from offset=2, size=5-2=3), borrowed is 3 bytes. + // storage_start = 320 (same layout as testKeysAndDataRoundTripExactly). + val storage = 320 + assertContentEquals(byteArrayOf(42, 43, 44), bytes.copyOfRange(storage, storage + 3)) + val borrowedAt = storage + 64 + assertContentEquals(byteArrayOf(99, 100, 101), bytes.copyOfRange(borrowedAt, borrowedAt + 3)) + } + + // --- helpers --- + + private fun refFor( + key: String, + bytes: ByteArray, + scope: String = "model" + ): ExternalParameterRef = ExternalParameterRef( + scope = scope, + key = key, + encoding = TensorEncoding.Dense(bytesPerElement = 1), + source = BufferHandle.Owned(bytes) + ) + + private fun readU64Le(bytes: ByteArray, offset: Int): Long { + var result = 0L + for (i in 0 until 8) { + result = result or ((bytes[offset + i].toLong() and 0xff) shl (i * 8)) + } + return result + } + + private fun readU32Le(bytes: ByteArray, offset: Int): Int { + var result = 0 + for (i in 0 until 4) { + result = result or ((bytes[offset + i].toInt() and 0xff) shl (i * 8)) + } + return result + } +}