diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt index 65d9452a..22af3fd5 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt @@ -36,22 +36,28 @@ public class MlirValidator { for ((lineNum, line) in lines.withIndex()) { val trimmed = line.trim() - + // Skip empty lines and comments if (trimmed.isEmpty() || trimmed.startsWith("//")) continue - + // Check brace balance braceCount += trimmed.count { it == '{' } braceCount -= trimmed.count { it == '}' } - + // Check module structure if (trimmed.startsWith("module")) { if (inModule) { errors.add("Line ${lineNum + 1}: Nested modules not allowed") } inModule = true + // Module headers may carry a `module attributes { ... } {` + // preamble whose attribute dict contains `name = "value"` + // entries. These aren't SSA assignments and must not be + // fed into validateSSAValue, so stop processing this line + // here. + continue } - + // Check function structure if (trimmed.contains("func.func")) { if (!inModule) { @@ -59,7 +65,7 @@ public class MlirValidator { } inFunction = true } - + // Check for basic SSA value format if (trimmed.contains(" = ") && !validateSSAValue(trimmed)) { errors.add("Line ${lineNum + 1}: Invalid SSA value format") @@ -90,10 +96,13 @@ public class MlirValidator { for ((lineNum, line) in lines.withIndex()) { val trimmed = line.trim() - - // Skip empty lines and comments - if (trimmed.isEmpty() || trimmed.startsWith("//")) continue - + + // Skip empty lines, comments, and module header lines (which + // may carry a `module attributes { ... }` dictionary whose + // `name = "value"` entries look like SSA assignments but are + // not). + if (trimmed.isEmpty() || trimmed.startsWith("//") || trimmed.startsWith("module")) continue + // Extract defined SSA values if (trimmed.contains(" = ")) { val parts = trimmed.split(" = ", limit = 2) @@ -162,8 +171,10 @@ public class MlirValidator { */ public fun validateModule(content: String): List { val errors = mutableListOf() - - if (!content.contains("module {")) { + + // Accept both the bare `module {` and the attributes-carrying + // `module attributes { ... } {` header forms. + if (!content.contains("module {") && !content.contains("module attributes")) { errors.add("Missing module declaration") } diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt index 4a6c3713..54bf4113 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt @@ -3,6 +3,8 @@ package sk.ainet.compile.hlo import sk.ainet.lang.graph.ComputeGraph import sk.ainet.lang.graph.GraphNode import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.tensorEncoding +import sk.ainet.lang.tensor.storage.TensorEncoding /** * Main converter class that orchestrates the conversion process from ComputeGraph to StableHLO MLIR. @@ -45,9 +47,26 @@ public class StableHloConverter( // Build function signature with proper return types val functionSignature = buildFunctionSignature(inputNodes, outputSpecs, functionName) - - // Start building MLIR content - context.emitLine("module {") + + // Collect every TensorSpec with a non-null tensorEncoding into a + // single name -> encoding map. Emitting this as a structured + // MLIR attribute on the module header lets downstream tools + // enumerate every encoded tensor via one attribute lookup + // instead of string-matching against scattered comments. + val tensorEncodings = collectTensorEncodings(topo) + + // Start building MLIR content — promote to `module attributes` + // only when we have at least one encoded tensor. Dense graphs + // keep the bare `module {` header for byte-for-byte backward + // compatibility with existing round-trip tests. + if (tensorEncodings.isNotEmpty()) { + val dictEntries = tensorEncodings.entries + .sortedBy { it.key } + .joinToString(", ") { (name, encoding) -> "$name = \"${encoding.name}\"" } + context.emitLine("module attributes {skainet.tensor_encodings = {$dictEntries}} {") + } else { + context.emitLine("module {") + } context.emitLine(" func.func $functionSignature {") // Initialize input values in context @@ -176,6 +195,27 @@ public class StableHloConverter( } } + /** + * Walk every node's input and output specs once and collect the + * `name -> encoding` map of every tensor that carries a non-null + * [TensorEncoding]. Duplicates (the same name appearing in multiple + * nodes) collapse to a single entry — first-writer-wins. + */ + private fun collectTensorEncodings(nodes: List): Map { + val result = linkedMapOf() + for (node in nodes) { + for (spec in node.outputs) { + val encoding = spec.tensorEncoding ?: continue + result.putIfAbsent(spec.name, encoding) + } + for (spec in node.inputs) { + val encoding = spec.tensorEncoding ?: continue + result.putIfAbsent(spec.name, encoding) + } + } + return result + } + /** * Determine output specifications from output nodes */ diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/TensorEncodingsModuleAttributeTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/TensorEncodingsModuleAttributeTest.kt new file mode 100644 index 00000000..634381ed --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/TensorEncodingsModuleAttributeTest.kt @@ -0,0 +1,158 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.AddOperation +import sk.ainet.lang.tensor.ops.InputOperation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.withTensorEncoding +import sk.ainet.lang.tensor.storage.TensorEncoding +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers the structured module-level attribute emission for #477: + * every TensorSpec flowing through the graph with a non-null + * tensorEncoding must appear in a single `skainet.tensor_encodings` + * dictionary on the emitted `module attributes { ... }` header, so + * downstream tools can read it with one attribute lookup instead of + * string-matching against scattered comments. + */ +class TensorEncodingsModuleAttributeTest { + + @Test + fun encoded_weights_produce_module_attributes_block() { + val graph = DefaultComputeGraph() + + val inputA = GraphNode( + id = "a", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32")) + ) + + // Two weight inputs with distinct encodings, exactly the shape + // TraceToGraphBuilder.finalize() produces post-#469 when a session + // resolves quantized weights. + val q8Spec = TensorSpec("w_q8", listOf(1, 4), "FP32") + .withTensorEncoding(TensorEncoding.Q8_0) + val q8Node = GraphNode( + id = "w_q8", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(q8Spec) + ) + + val q4Spec = TensorSpec("w_q4", listOf(1, 4), "FP32") + .withTensorEncoding(TensorEncoding.Q4_K) + val q4Node = GraphNode( + id = "w_q4", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(q4Spec) + ) + + val add1 = GraphNode( + id = "add1", + operation = AddOperation(), + inputs = listOf(TensorSpec("a", listOf(1, 4), "FP32"), q8Spec), + outputs = listOf(TensorSpec("sum1", listOf(1, 4), "FP32")) + ) + val add2 = GraphNode( + id = "add2", + operation = AddOperation(), + inputs = listOf(TensorSpec("sum1", listOf(1, 4), "FP32"), q4Spec), + outputs = listOf(TensorSpec("sum2", listOf(1, 4), "FP32")) + ) + + graph.addNode(inputA) + graph.addNode(q8Node) + graph.addNode(q4Node) + graph.addNode(add1) + graph.addNode(add2) + graph.addEdge(GraphEdge("e1", inputA, add1, 0, 0, inputA.outputs[0])) + graph.addEdge(GraphEdge("e2", q8Node, add1, 0, 1, q8Spec)) + graph.addEdge(GraphEdge("e3", add1, add2, 0, 0, add1.outputs[0])) + graph.addEdge(GraphEdge("e4", q4Node, add2, 0, 1, q4Spec)) + + val mlir = toStableHlo(graph, "quant_chain").content + println("[DEBUG_LOG] module-attribute export:\n$mlir") + + // The emitted module header must carry a structured attribute + // enumerating every encoded tensor in one place. + assertTrue( + mlir.contains("module attributes"), + "module header must be emitted with `module attributes { ... }` when encodings are present" + ) + assertTrue( + mlir.contains("skainet.tensor_encodings"), + "module attributes must include the `skainet.tensor_encodings` dictionary" + ) + + // Both encoded tensors must appear in the dictionary by name, + // each mapped to its TensorEncoding.name. + assertTrue( + mlir.contains("w_q8 = \"Q8_0\""), + "dictionary must map `w_q8` to `\"Q8_0\"`" + ) + assertTrue( + mlir.contains("w_q4 = \"Q4_K\""), + "dictionary must map `w_q4` to `\"Q4_K\"`" + ) + } + + @Test + fun dense_graph_keeps_bare_module_header() { + // A graph with no encoding metadata must emit the bare + // `module {` header with no `attributes` block. A `null` + // tensorEncoding is the unknown / not-carried state — not + // Dense — and the emitter must stay silent. + val graph = DefaultComputeGraph() + + val inputA = GraphNode( + id = "a", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32")) + ) + val inputB = GraphNode( + id = "b", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("b", listOf(1, 4), "FP32")) + ) + val add = GraphNode( + id = "add1", + operation = AddOperation(), + inputs = listOf( + TensorSpec("a", listOf(1, 4), "FP32"), + TensorSpec("b", listOf(1, 4), "FP32") + ), + outputs = listOf(TensorSpec("c", listOf(1, 4), "FP32")) + ) + + graph.addNode(inputA) + graph.addNode(inputB) + graph.addNode(add) + graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0])) + graph.addEdge(GraphEdge("e2", inputB, add, 0, 1, inputB.outputs[0])) + + val mlir = toStableHlo(graph, "dense_add").content + + assertFalse( + mlir.contains("module attributes"), + "dense graph must not emit a `module attributes` block" + ) + assertFalse( + mlir.contains("skainet.tensor_encodings"), + "dense graph must not emit the `skainet.tensor_encodings` dictionary" + ) + assertTrue( + mlir.contains("module {"), + "dense graph must keep the bare `module {` header" + ) + } +}