From ca611a57df5b7b43888e4a0e4f0ce4da6749766b Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sat, 18 Apr 2026 09:44:23 +0200 Subject: [PATCH] Track SSA value types so reshape emits the operand's declared type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ShapeOperationsConverter emitted `: ($outputType) -> $outputType` on every reshape / flatten / squeeze / unsqueeze, reusing the output type on the source side of the cast. When the operand was a function argument (`%arg0: tensor<80x3000xf32>`), the result was syntactically valid but wrong — e.g. `stablehlo.reshape %arg0 : (tensor<1x80x3000xf32>) -> tensor<1x80x3000xf32>` — and iree-compile rejected it with a type mismatch on every native SKaiNET DSL export that reached the IREE path. Fix threads an SSA-value-name -> MLIR-type map through ConversionContext: - StableHloConverter seeds `%argN -> mapTensorType(inputSpec)` when it writes the function signature, and records `resultValue -> outputType` after each successful op conversion. - ShapeOperationsConverter looks up the operand's real type via a new resolveOperandType helper (context.getValueType first, then node.inputs[0], then a dynamic fallback) and uses that on the source side of every reshape cast. The type map is consumed opt-in by converters that need it, so other converters (concat, slice, dot_general, transpose) are unaffected. Regression test `testReshapeOnArgUsesDeclaredArgType` builds an input -> unsqueeze graph and asserts the emitted MLIR contains `(tensor<80x3000xf32>) -> tensor<1x80x3000xf32>`. Full :skainet-compile-hlo:jvmTest passes. Closes #518 Co-Authored-By: Claude Opus 4.7 (1M context) --- .../sk/ainet/compile/hlo/ConversionContext.kt | 26 ++++++++- .../ainet/compile/hlo/StableHloConverter.kt | 16 +++++ .../converters/ShapeOperationsConverter.kt | 58 ++++++++++++++----- .../hlo/ShapeOperationsConverterTest.kt | 37 ++++++++++++ 4 files changed, 122 insertions(+), 15 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt index 925b3319..d799c03e 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt @@ -16,20 +16,41 @@ public class ConversionContext( private var graph: ComputeGraph? = null ) { private val valueNames = mutableMapOf() + private val valueTypes = mutableMapOf() private val stringBuilder = StringBuilder() private var tempCounter = 0 - + /** * Get the SSA value name for a node ID */ public fun getValueName(nodeId: String): String? = valueNames[nodeId] - + /** * Set the SSA value name for a node ID */ public fun setValueName(nodeId: String, valueName: String) { valueNames[nodeId] = valueName } + + /** + * Record the MLIR tensor type associated with an SSA value name. + * + * Lets converters look up the *declared* type of an operand — the + * type it actually has when the op consumes it — instead of having + * to re-derive it from downstream node.inputs metadata, which can + * reflect a post-op shape rather than the operand's true shape. + * Seeded for `%argN` by StableHloConverter when the function + * signature is emitted, then populated for each op's result. + */ + public fun setValueType(valueName: String, mlirType: String) { + valueTypes[valueName] = mlirType + } + + /** + * Get the MLIR tensor type for an SSA value name, or null if the + * value was produced by a converter that did not record its type. + */ + public fun getValueType(valueName: String): String? = valueTypes[valueName] /** * Generate the next temporary SSA value name @@ -118,6 +139,7 @@ public class ConversionContext( */ public fun clear() { valueNames.clear() + valueTypes.clear() stringBuilder.clear() tempCounter = 0 } diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt index 08c1e690..78ad1780 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt @@ -131,6 +131,14 @@ public class StableHloConverter( val valueName = "%arg$idx" context.setValueName(node.id, valueName) + // Seed the SSA type map with %argN's declared function-signature + // type so downstream ops can recover the operand type via + // context.getValueType(operands[0]) instead of re-deriving it + // (see issue #518). + node.outputs.firstOrNull()?.let { spec -> + context.setValueType(valueName, typeMapper.mapTensorType(spec)) + } + // Add comment for clarity node.outputs.firstOrNull()?.let { spec -> context.emitComment("input ${node.id}: ${spec.name} : ${typeMapper.mapTensorType(spec)}") @@ -184,6 +192,14 @@ public class StableHloConverter( when (result) { is ConversionResult.Success -> { context.setValueName(node.id, result.outputValueName) + // Record the result's MLIR type so downstream operands can + // look it up (see issue #518). Uses the node's declared + // first output spec — converters that produce types + // differing from node.outputs[0] can override this by + // calling context.setValueType directly. + node.outputs.firstOrNull()?.let { spec -> + context.setValueType(result.outputValueName, typeMapper.mapTensorType(spec)) + } } is ConversionResult.Failure -> { context.emitComment("Conversion failed for node ${node.id}: ${result.error}") diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt index f591f32f..86608118 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt @@ -190,17 +190,18 @@ public class ShapeOperationsConverter : StableHloOperationConverter { "Missing shape parameter for reshape node ${node.id}" ) } - + + val inputType = resolveOperandType(operands[0], node, context) val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType" + val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType" context.emitOperation(operation) - + return ConversionResult.Success( outputValueName = resultValue, emittedOperations = listOf(operation) ) } - + /** * Convert flatten operation using stablehlo.reshape. * Flattens dimensions from startDim to endDim into a single dimension. @@ -226,9 +227,10 @@ public class ShapeOperationsConverter : StableHloOperationConverter { val endDim = node.operation.parameters["endDim"] as? Int ?: -1 context.emitComment("Flatten from dim $startDim to $endDim") - + + val inputType = resolveOperandType(operands[0], node, context) val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType" + val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType" context.emitOperation(operation) return ConversionResult.Success( @@ -265,9 +267,10 @@ public class ShapeOperationsConverter : StableHloOperationConverter { } else { context.emitComment("Squeeze all singleton dimensions") } - + + val inputType = resolveOperandType(operands[0], node, context) val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType" + val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType" context.emitOperation(operation) return ConversionResult.Success( @@ -304,16 +307,45 @@ public class ShapeOperationsConverter : StableHloOperationConverter { ) context.emitComment("Unsqueeze at dimension $dim") - - // For unsqueeze, we can use either reshape or broadcast_in_dim - // Using reshape is simpler for this implementation + + // For unsqueeze, we can use either reshape or broadcast_in_dim. + // Using reshape is simpler for this implementation. + val inputType = resolveOperandType(operands[0], node, context) val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType" + val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType" context.emitOperation(operation) - + return ConversionResult.Success( outputValueName = resultValue, emittedOperations = listOf(operation) ) } + + /** + * Look up the MLIR type of an SSA operand. + * + * Preference order: + * 1. Declared type recorded in [ConversionContext.getValueType] — + * set either by the function-arg seeder (for `%argN`) or by the + * main converter after a prior op succeeded. This is the + * operand's actual type at the point of consumption. + * 2. `node.inputs[0]` — the edge metadata the caller wired. + * 3. Dynamic fallback — last resort when neither is available. + * + * Fixes #518: previous code used `outputType` on both sides of the + * reshape cast, which produced `(outputShape) -> outputShape` and + * broke `iree-compile` on any reshape/unsqueeze that consumed a + * function argument with a different declared shape. + */ + private fun resolveOperandType( + operandName: String, + node: GraphNode, + context: ConversionContext + ): String { + context.getValueType(operandName)?.let { return it } + node.inputs.firstOrNull()?.let { spec -> + return context.getTypeMapper().mapTensorType(spec) + } + return "tensor" + } } \ No newline at end of file diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt index 54c0e5bf..9495e8b3 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt @@ -1,8 +1,11 @@ package sk.ainet.compile.hlo import sk.ainet.compile.hlo.converters.ShapeOperationsConverter +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge import sk.ainet.lang.graph.GraphNode import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.ops.InputOperation import sk.ainet.lang.tensor.ops.Operation import sk.ainet.lang.tensor.ops.TensorSpec import sk.ainet.lang.tensor.ops.ValidationResult @@ -145,6 +148,40 @@ class ShapeOperationsConverterTest { assertTrue(result.error.contains("requires exactly 1 operand")) } + @Test + fun testReshapeOnArgUsesDeclaredArgType() { + // Regression for #518: reshape/unsqueeze consuming %arg0 must emit + // `: (declaredArgType) -> outputType`, not `(outputType) -> outputType`. + // Previous code reused outputType on both sides of the cast, which + // produced e.g. `(tensor<1x80x3000xf32>) -> tensor<1x80x3000xf32>` + // on an input that %arg0 actually had as `tensor<80x3000xf32>`, + // breaking iree-compile with a type mismatch. + val graph = DefaultComputeGraph() + val input = GraphNode( + id = "mel", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("mel", listOf(80, 3000), "FP32")) + ) + val unsqueeze = GraphNode( + id = "unsqueeze1", + operation = createMockOperation("unsqueeze", mapOf("dim" to 0)), + inputs = listOf(TensorSpec("mel", listOf(80, 3000), "FP32")), + outputs = listOf(TensorSpec("mel_b", listOf(1, 80, 3000), "FP32")) + ) + graph.addNode(input) + graph.addNode(unsqueeze) + graph.addEdge(GraphEdge("e1", input, unsqueeze, 0, 0, input.outputs[0])) + + val fullConverter = StableHloConverterFactory.createBasic() + val module = fullConverter.convert(graph, "issue_518") + + assertTrue( + module.content.contains("(tensor<80x3000xf32>) -> tensor<1x80x3000xf32>"), + "reshape must emit declared %arg0 type on source side, got:\n${module.content}" + ) + } + @Test fun testUnsupportedOperation() { val operation = createMockOperation("unknown_shape_op", emptyMap())