From 3cf2fdbdf1034be3f63b338af5386aa92e08637a Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 12:55:29 +0200 Subject: [PATCH 1/2] Pin missing skainet.tensor_encodings module attribute (#477) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds TensorEncodingsModuleAttributeTest with two cases: 1. encoded_weights_produce_module_attributes_block — builds a graph with two weight inputs carrying distinct encodings (TensorEncoding.Q8_0 and TensorEncoding.Q4_K) via withTensorEncoding, runs the converter, and asserts the emitted module header is `module attributes { ... }` with a single `skainet.tensor_encodings = { w_q4 = "Q4_K", w_q8 = "Q8_0" }` dictionary enumerating every encoded tensor in one place. Red against StableHloConverter today — the metadata is only exposed as scattered per-op comments. 2. dense_graph_keeps_bare_module_header — a dense FP32 graph with no encoding metadata must preserve the bare `module {` header and must not introduce a spurious `attributes` block. A `null` tensorEncoding is the unknown / not-carried state, intentionally distinct from TensorEncoding.Dense, and the emitter must stay silent. Green baseline. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../hlo/TensorEncodingsModuleAttributeTest.kt | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/TensorEncodingsModuleAttributeTest.kt diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/TensorEncodingsModuleAttributeTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/TensorEncodingsModuleAttributeTest.kt new file mode 100644 index 00000000..634381ed --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/TensorEncodingsModuleAttributeTest.kt @@ -0,0 +1,158 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.AddOperation +import sk.ainet.lang.tensor.ops.InputOperation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.withTensorEncoding +import sk.ainet.lang.tensor.storage.TensorEncoding +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers the structured module-level attribute emission for #477: + * every TensorSpec flowing through the graph with a non-null + * tensorEncoding must appear in a single `skainet.tensor_encodings` + * dictionary on the emitted `module attributes { ... }` header, so + * downstream tools can read it with one attribute lookup instead of + * string-matching against scattered comments. + */ +class TensorEncodingsModuleAttributeTest { + + @Test + fun encoded_weights_produce_module_attributes_block() { + val graph = DefaultComputeGraph() + + val inputA = GraphNode( + id = "a", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32")) + ) + + // Two weight inputs with distinct encodings, exactly the shape + // TraceToGraphBuilder.finalize() produces post-#469 when a session + // resolves quantized weights. + val q8Spec = TensorSpec("w_q8", listOf(1, 4), "FP32") + .withTensorEncoding(TensorEncoding.Q8_0) + val q8Node = GraphNode( + id = "w_q8", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(q8Spec) + ) + + val q4Spec = TensorSpec("w_q4", listOf(1, 4), "FP32") + .withTensorEncoding(TensorEncoding.Q4_K) + val q4Node = GraphNode( + id = "w_q4", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(q4Spec) + ) + + val add1 = GraphNode( + id = "add1", + operation = AddOperation(), + inputs = listOf(TensorSpec("a", listOf(1, 4), "FP32"), q8Spec), + outputs = listOf(TensorSpec("sum1", listOf(1, 4), "FP32")) + ) + val add2 = GraphNode( + id = "add2", + operation = AddOperation(), + inputs = listOf(TensorSpec("sum1", listOf(1, 4), "FP32"), q4Spec), + outputs = listOf(TensorSpec("sum2", listOf(1, 4), "FP32")) + ) + + graph.addNode(inputA) + graph.addNode(q8Node) + graph.addNode(q4Node) + graph.addNode(add1) + graph.addNode(add2) + graph.addEdge(GraphEdge("e1", inputA, add1, 0, 0, inputA.outputs[0])) + graph.addEdge(GraphEdge("e2", q8Node, add1, 0, 1, q8Spec)) + graph.addEdge(GraphEdge("e3", add1, add2, 0, 0, add1.outputs[0])) + graph.addEdge(GraphEdge("e4", q4Node, add2, 0, 1, q4Spec)) + + val mlir = toStableHlo(graph, "quant_chain").content + println("[DEBUG_LOG] module-attribute export:\n$mlir") + + // The emitted module header must carry a structured attribute + // enumerating every encoded tensor in one place. + assertTrue( + mlir.contains("module attributes"), + "module header must be emitted with `module attributes { ... }` when encodings are present" + ) + assertTrue( + mlir.contains("skainet.tensor_encodings"), + "module attributes must include the `skainet.tensor_encodings` dictionary" + ) + + // Both encoded tensors must appear in the dictionary by name, + // each mapped to its TensorEncoding.name. + assertTrue( + mlir.contains("w_q8 = \"Q8_0\""), + "dictionary must map `w_q8` to `\"Q8_0\"`" + ) + assertTrue( + mlir.contains("w_q4 = \"Q4_K\""), + "dictionary must map `w_q4` to `\"Q4_K\"`" + ) + } + + @Test + fun dense_graph_keeps_bare_module_header() { + // A graph with no encoding metadata must emit the bare + // `module {` header with no `attributes` block. A `null` + // tensorEncoding is the unknown / not-carried state — not + // Dense — and the emitter must stay silent. + val graph = DefaultComputeGraph() + + val inputA = GraphNode( + id = "a", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32")) + ) + val inputB = GraphNode( + id = "b", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("b", listOf(1, 4), "FP32")) + ) + val add = GraphNode( + id = "add1", + operation = AddOperation(), + inputs = listOf( + TensorSpec("a", listOf(1, 4), "FP32"), + TensorSpec("b", listOf(1, 4), "FP32") + ), + outputs = listOf(TensorSpec("c", listOf(1, 4), "FP32")) + ) + + graph.addNode(inputA) + graph.addNode(inputB) + graph.addNode(add) + graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0])) + graph.addEdge(GraphEdge("e2", inputB, add, 0, 1, inputB.outputs[0])) + + val mlir = toStableHlo(graph, "dense_add").content + + assertFalse( + mlir.contains("module attributes"), + "dense graph must not emit a `module attributes` block" + ) + assertFalse( + mlir.contains("skainet.tensor_encodings"), + "dense graph must not emit the `skainet.tensor_encodings` dictionary" + ) + assertTrue( + mlir.contains("module {"), + "dense graph must keep the bare `module {` header" + ) + } +} From ee9f23f798b6350068590eebfecc258bdc2b6085 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 12:55:29 +0200 Subject: [PATCH 2/2] Emit skainet.tensor_encodings module attribute (#477) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Teaches StableHloConverter to collect every TensorSpec with a non-null tensorEncoding from the graph in a single pre-emit walk and surface them as a structured MLIR attribute on the module header: module attributes {skainet.tensor_encodings = {w_q4 = "Q4_K", w_q8 = "Q8_0"}} { func.func @quant_chain(...) -> (...) { ... } } Downstream tools (IREE, any MLIR pass reading SKaiNET output) can now enumerate every encoded tensor via one attribute lookup instead of string-matching against the scattered `tensor_encoding` comments introduced in #473. Dense graphs with no encoding metadata continue to emit the bare `module {` header byte-for- byte-identically to today — the promotion only fires when the collect phase finds at least one non-null encoding. The existing per-op comments are left in place; a follow-up can remove them if we decide the module attribute alone is sufficient. Collect phase is first-writer-wins: if the same tensor name surfaces in multiple nodes it collapses to a single map entry. Walks both inputs and outputs of every node so weight-operand references catch encodings that didn't originate as node outputs. Also tweaks MlirValidator to accept the new header shape: - validateSyntax treats `module` lines as module preamble and stops processing them at the header check, so the dict's `name = "value"` entries are no longer fed into the SSA assignment validator (which would fire on the absence of a `%` prefix). - validateSemantics skips `module`-prefixed lines for the same reason in its defined-values extraction. - validateModule accepts both `module {` and `module attributes` as valid module declaration forms. Full :skainet-compile:skainet-compile-hlo:jvmTest suite stays green — no existing test was asserting on `module {` in a way that excludes the attributes form. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../sk/ainet/compile/hlo/MlirValidator.kt | 33 ++++++++----- .../ainet/compile/hlo/StableHloConverter.kt | 46 +++++++++++++++++-- 2 files changed, 65 insertions(+), 14 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt index 65d9452a..22af3fd5 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt @@ -36,22 +36,28 @@ public class MlirValidator { for ((lineNum, line) in lines.withIndex()) { val trimmed = line.trim() - + // Skip empty lines and comments if (trimmed.isEmpty() || trimmed.startsWith("//")) continue - + // Check brace balance braceCount += trimmed.count { it == '{' } braceCount -= trimmed.count { it == '}' } - + // Check module structure if (trimmed.startsWith("module")) { if (inModule) { errors.add("Line ${lineNum + 1}: Nested modules not allowed") } inModule = true + // Module headers may carry a `module attributes { ... } {` + // preamble whose attribute dict contains `name = "value"` + // entries. These aren't SSA assignments and must not be + // fed into validateSSAValue, so stop processing this line + // here. + continue } - + // Check function structure if (trimmed.contains("func.func")) { if (!inModule) { @@ -59,7 +65,7 @@ public class MlirValidator { } inFunction = true } - + // Check for basic SSA value format if (trimmed.contains(" = ") && !validateSSAValue(trimmed)) { errors.add("Line ${lineNum + 1}: Invalid SSA value format") @@ -90,10 +96,13 @@ public class MlirValidator { for ((lineNum, line) in lines.withIndex()) { val trimmed = line.trim() - - // Skip empty lines and comments - if (trimmed.isEmpty() || trimmed.startsWith("//")) continue - + + // Skip empty lines, comments, and module header lines (which + // may carry a `module attributes { ... }` dictionary whose + // `name = "value"` entries look like SSA assignments but are + // not). + if (trimmed.isEmpty() || trimmed.startsWith("//") || trimmed.startsWith("module")) continue + // Extract defined SSA values if (trimmed.contains(" = ")) { val parts = trimmed.split(" = ", limit = 2) @@ -162,8 +171,10 @@ public class MlirValidator { */ public fun validateModule(content: String): List { val errors = mutableListOf() - - if (!content.contains("module {")) { + + // Accept both the bare `module {` and the attributes-carrying + // `module attributes { ... } {` header forms. + if (!content.contains("module {") && !content.contains("module attributes")) { errors.add("Missing module declaration") } diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt index 4a6c3713..54bf4113 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt @@ -3,6 +3,8 @@ package sk.ainet.compile.hlo import sk.ainet.lang.graph.ComputeGraph import sk.ainet.lang.graph.GraphNode import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.tensorEncoding +import sk.ainet.lang.tensor.storage.TensorEncoding /** * Main converter class that orchestrates the conversion process from ComputeGraph to StableHLO MLIR. @@ -45,9 +47,26 @@ public class StableHloConverter( // Build function signature with proper return types val functionSignature = buildFunctionSignature(inputNodes, outputSpecs, functionName) - - // Start building MLIR content - context.emitLine("module {") + + // Collect every TensorSpec with a non-null tensorEncoding into a + // single name -> encoding map. Emitting this as a structured + // MLIR attribute on the module header lets downstream tools + // enumerate every encoded tensor via one attribute lookup + // instead of string-matching against scattered comments. + val tensorEncodings = collectTensorEncodings(topo) + + // Start building MLIR content — promote to `module attributes` + // only when we have at least one encoded tensor. Dense graphs + // keep the bare `module {` header for byte-for-byte backward + // compatibility with existing round-trip tests. + if (tensorEncodings.isNotEmpty()) { + val dictEntries = tensorEncodings.entries + .sortedBy { it.key } + .joinToString(", ") { (name, encoding) -> "$name = \"${encoding.name}\"" } + context.emitLine("module attributes {skainet.tensor_encodings = {$dictEntries}} {") + } else { + context.emitLine("module {") + } context.emitLine(" func.func $functionSignature {") // Initialize input values in context @@ -176,6 +195,27 @@ public class StableHloConverter( } } + /** + * Walk every node's input and output specs once and collect the + * `name -> encoding` map of every tensor that carries a non-null + * [TensorEncoding]. Duplicates (the same name appearing in multiple + * nodes) collapse to a single entry — first-writer-wins. + */ + private fun collectTensorEncodings(nodes: List): Map { + val result = linkedMapOf() + for (node in nodes) { + for (spec in node.outputs) { + val encoding = spec.tensorEncoding ?: continue + result.putIfAbsent(spec.name, encoding) + } + for (spec in node.inputs) { + val encoding = spec.tensorEncoding ?: continue + result.putIfAbsent(spec.name, encoding) + } + } + return result + } + /** * Determine output specifications from output nodes */