Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,41 @@ public class ConversionContext(
private var graph: ComputeGraph? = null
) {
private val valueNames = mutableMapOf<String, String>()
private val valueTypes = mutableMapOf<String, String>()
private val stringBuilder = StringBuilder()
private var tempCounter = 0

/**
* Get the SSA value name for a node ID
*/
public fun getValueName(nodeId: String): String? = valueNames[nodeId]

/**
* Set the SSA value name for a node ID
*/
public fun setValueName(nodeId: String, valueName: String) {
valueNames[nodeId] = valueName
}

/**
* Record the MLIR tensor type associated with an SSA value name.
*
* Lets converters look up the *declared* type of an operand — the
* type it actually has when the op consumes it — instead of having
* to re-derive it from downstream node.inputs metadata, which can
* reflect a post-op shape rather than the operand's true shape.
* Seeded for `%argN` by StableHloConverter when the function
* signature is emitted, then populated for each op's result.
*/
public fun setValueType(valueName: String, mlirType: String) {
valueTypes[valueName] = mlirType
}

/**
* Get the MLIR tensor type for an SSA value name, or null if the
* value was produced by a converter that did not record its type.
*/
public fun getValueType(valueName: String): String? = valueTypes[valueName]

/**
* Generate the next temporary SSA value name
Expand Down Expand Up @@ -118,6 +139,7 @@ public class ConversionContext(
*/
public fun clear() {
valueNames.clear()
valueTypes.clear()
stringBuilder.clear()
tempCounter = 0
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ public class StableHloConverter(
val valueName = "%arg$idx"
context.setValueName(node.id, valueName)

// Seed the SSA type map with %argN's declared function-signature
// type so downstream ops can recover the operand type via
// context.getValueType(operands[0]) instead of re-deriving it
// (see issue #518).
node.outputs.firstOrNull()?.let { spec ->
context.setValueType(valueName, typeMapper.mapTensorType(spec))
}

// Add comment for clarity
node.outputs.firstOrNull()?.let { spec ->
context.emitComment("input ${node.id}: ${spec.name} : ${typeMapper.mapTensorType(spec)}")
Expand Down Expand Up @@ -184,6 +192,14 @@ public class StableHloConverter(
when (result) {
is ConversionResult.Success -> {
context.setValueName(node.id, result.outputValueName)
// Record the result's MLIR type so downstream operands can
// look it up (see issue #518). Uses the node's declared
// first output spec — converters that produce types
// differing from node.outputs[0] can override this by
// calling context.setValueType directly.
node.outputs.firstOrNull()?.let { spec ->
context.setValueType(result.outputValueName, typeMapper.mapTensorType(spec))
}
}
is ConversionResult.Failure -> {
context.emitComment("Conversion failed for node ${node.id}: ${result.error}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,18 @@ public class ShapeOperationsConverter : StableHloOperationConverter {
"Missing shape parameter for reshape node ${node.id}"
)
}


val inputType = resolveOperandType(operands[0], node, context)
val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType"
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = listOf(operation)
)
}

/**
* Convert flatten operation using stablehlo.reshape.
* Flattens dimensions from startDim to endDim into a single dimension.
Expand All @@ -226,9 +227,10 @@ public class ShapeOperationsConverter : StableHloOperationConverter {
val endDim = node.operation.parameters["endDim"] as? Int ?: -1

context.emitComment("Flatten from dim $startDim to $endDim")


val inputType = resolveOperandType(operands[0], node, context)
val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType"
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
Expand Down Expand Up @@ -265,9 +267,10 @@ public class ShapeOperationsConverter : StableHloOperationConverter {
} else {
context.emitComment("Squeeze all singleton dimensions")
}


val inputType = resolveOperandType(operands[0], node, context)
val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType"
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
Expand Down Expand Up @@ -304,16 +307,45 @@ public class ShapeOperationsConverter : StableHloOperationConverter {
)

context.emitComment("Unsqueeze at dimension $dim")

// For unsqueeze, we can use either reshape or broadcast_in_dim
// Using reshape is simpler for this implementation

// For unsqueeze, we can use either reshape or broadcast_in_dim.
// Using reshape is simpler for this implementation.
val inputType = resolveOperandType(operands[0], node, context)
val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($outputType) -> $outputType"
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = listOf(operation)
)
}

/**
* Look up the MLIR type of an SSA operand.
*
* Preference order:
* 1. Declared type recorded in [ConversionContext.getValueType] —
* set either by the function-arg seeder (for `%argN`) or by the
* main converter after a prior op succeeded. This is the
* operand's actual type at the point of consumption.
* 2. `node.inputs[0]` — the edge metadata the caller wired.
* 3. Dynamic fallback — last resort when neither is available.
*
* Fixes #518: previous code used `outputType` on both sides of the
* reshape cast, which produced `(outputShape) -> outputShape` and
* broke `iree-compile` on any reshape/unsqueeze that consumed a
* function argument with a different declared shape.
*/
private fun resolveOperandType(
operandName: String,
node: GraphNode,
context: ConversionContext
): String {
context.getValueType(operandName)?.let { return it }
node.inputs.firstOrNull()?.let { spec ->
return context.getTypeMapper().mapTensorType(spec)
}
return "tensor<?xf32>"
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package sk.ainet.compile.hlo

import sk.ainet.compile.hlo.converters.ShapeOperationsConverter
import sk.ainet.lang.graph.DefaultComputeGraph
import sk.ainet.lang.graph.GraphEdge
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.ops.InputOperation
import sk.ainet.lang.tensor.ops.Operation
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.ValidationResult
Expand Down Expand Up @@ -145,6 +148,40 @@ class ShapeOperationsConverterTest {
assertTrue(result.error.contains("requires exactly 1 operand"))
}

@Test
fun testReshapeOnArgUsesDeclaredArgType() {
// Regression for #518: reshape/unsqueeze consuming %arg0 must emit
// `: (declaredArgType) -> outputType`, not `(outputType) -> outputType`.
// Previous code reused outputType on both sides of the cast, which
// produced e.g. `(tensor<1x80x3000xf32>) -> tensor<1x80x3000xf32>`
// on an input that %arg0 actually had as `tensor<80x3000xf32>`,
// breaking iree-compile with a type mismatch.
val graph = DefaultComputeGraph()
val input = GraphNode(
id = "mel",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(TensorSpec("mel", listOf(80, 3000), "FP32"))
)
val unsqueeze = GraphNode(
id = "unsqueeze1",
operation = createMockOperation("unsqueeze", mapOf("dim" to 0)),
inputs = listOf(TensorSpec("mel", listOf(80, 3000), "FP32")),
outputs = listOf(TensorSpec("mel_b", listOf(1, 80, 3000), "FP32"))
)
graph.addNode(input)
graph.addNode(unsqueeze)
graph.addEdge(GraphEdge("e1", input, unsqueeze, 0, 0, input.outputs[0]))

val fullConverter = StableHloConverterFactory.createBasic()
val module = fullConverter.convert(graph, "issue_518")

assertTrue(
module.content.contains("(tensor<80x3000xf32>) -> tensor<1x80x3000xf32>"),
"reshape must emit declared %arg0 type on source side, got:\n${module.content}"
)
}

@Test
fun testUnsupportedOperation() {
val operation = createMockOperation("unknown_shape_op", emptyMap())
Expand Down
Loading