diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/MathOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/MathOperationsConverter.kt index c981521e..53c46315 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/MathOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/MathOperationsConverter.kt @@ -30,31 +30,73 @@ public class MathOperationsConverter : StableHloOperationConverter { // Additional mathematical operations "pow", "mod", "remainder", // Element-wise operations - "element_add", "element_sub", "element_mul", "element_div" + "element_add", "element_sub", "element_mul", "element_div", + // Element-wise type conversion. Not strictly "math", but + // MathOperationsConverter already owns the elementwise-op + // family and cast is an elementwise primitive. + "cast", "convert", "to" ) - + override fun convert( - node: GraphNode, - operands: List, + node: GraphNode, + operands: List, context: ConversionContext ): ConversionResult { // Delegate basic math operations to BasicMathConverter if (basicMathConverter.supportedOperations.contains(node.operation.name.lowercase())) { return basicMathConverter.convert(node, operands, context) } - + // Handle additional mathematical operations return when (node.operation.name.lowercase()) { "pow" -> convertPower(node, operands, context) "mod", "remainder" -> convertRemainder(node, operands, context) - "element_add", "element_sub", "element_mul", "element_div" -> + "element_add", "element_sub", "element_mul", "element_div" -> convertElementWise(node, operands, context) + "cast", "convert", "to" -> convertCast(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by MathOperationsConverter" ) } } + + /** + * Convert cast / convert / to to stablehlo.convert. + * + * Reads the target dtype from `to`, `to_dtype`, or `dtype` + * parameter — or, when absent, from the output spec's dtype, + * which is the normal tracing path. Emits the MLIR type- + * transition signature `() -> `. + */ + private fun convertCast( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size != 1) { + return ConversionResult.Failure( + "Cast operation requires exactly 1 operand, got ${operands.size}", + "Unsupported cast arity for node ${node.id}" + ) + } + + val typeMapper = context.getTypeMapper() + val inputSpec = node.inputs.firstOrNull() + val outputSpec = node.outputs.firstOrNull() + + val inputType = inputSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor" + val outputType = outputSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor" + + val resultValue = context.nextTempValue() + val operation = "$resultValue = stablehlo.convert ${operands[0]} : ($inputType) -> $outputType" + context.emitOperation(operation) + + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = listOf(operation) + ) + } private fun convertPower( node: GraphNode, diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt index b9ed232f..f591f32f 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt @@ -22,12 +22,17 @@ import sk.ainet.lang.graph.GraphNode public class ShapeOperationsConverter : StableHloOperationConverter { override val supportedOperations: Set = setOf( - "reshape", "flatten", "squeeze", "unsqueeze" + "reshape", "flatten", "squeeze", "unsqueeze", + // Structural tensor ops — generic companions to reshape / + // flatten / squeeze. concat glues tensors along an axis, + // slice extracts a static window of a tensor. + "concat", "concatenate", "cat", "stack", + "slice" ) - + override fun convert( - node: GraphNode, - operands: List, + node: GraphNode, + operands: List, context: ConversionContext ): ConversionResult { return when (node.operation.name.lowercase()) { @@ -35,12 +40,115 @@ public class ShapeOperationsConverter : StableHloOperationConverter { "flatten" -> convertFlatten(node, operands, context) "squeeze" -> convertSqueeze(node, operands, context) "unsqueeze" -> convertUnsqueeze(node, operands, context) + "concat", "concatenate", "cat", "stack" -> convertConcat(node, operands, context) + "slice" -> convertSlice(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by ShapeOperationsConverter" ) } } + + /** + * Convert concat / concatenate / cat / stack to stablehlo.concatenate. + * + * Reads the join axis from `axis` or `dim` parameter (default 0) + * and emits: + * + * %out = stablehlo.concatenate %a, %b, ..., dim = : + */ + private fun convertConcat( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.isEmpty()) { + return ConversionResult.Failure( + "Concat operation requires at least 1 operand, got 0", + "Unsupported concat arity for node ${node.id}" + ) + } + + val outputSpec = node.outputs.firstOrNull() + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + + val rank = node.inputs.firstOrNull()?.shape?.size + ?: outputSpec?.shape?.size ?: 0 + val rawAxis = node.operation.parameters["axis"] as? Int + ?: node.operation.parameters["dim"] as? Int + ?: 0 + val axis = if (rawAxis < 0 && rank > 0) rank + rawAxis else rawAxis + + val resultValue = context.nextTempValue() + val operandList = operands.joinToString(", ") + val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : $outputType" + context.emitOperation(operation) + + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = listOf(operation) + ) + } + + /** + * Convert slice to stablehlo.slice. + * + * Reads per-dim `start_indices`, `limit_indices`, and `strides` + * from parameters and emits a static slice: + * + * %out = stablehlo.slice %x [s0:l0:d0, s1:l1:d1, ...] : + * + * Strides default to 1 per dim when not supplied. Dynamic slice + * (runtime bounds) is explicitly out of scope for this first pass. + */ + private fun convertSlice( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size != 1) { + return ConversionResult.Failure( + "Slice operation requires exactly 1 operand, got ${operands.size}", + "Unsupported slice arity for node ${node.id}" + ) + } + + val outputSpec = node.outputs.firstOrNull() + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + + val inputShape = node.inputs.firstOrNull()?.shape ?: emptyList() + val rank = inputShape.size + + @Suppress("UNCHECKED_CAST") + val starts = (node.operation.parameters["start_indices"] as? List) + ?: (node.operation.parameters["starts"] as? List) + ?: List(rank) { 0 } + @Suppress("UNCHECKED_CAST") + val limits = (node.operation.parameters["limit_indices"] as? List) + ?: (node.operation.parameters["limits"] as? List) + ?: inputShape + @Suppress("UNCHECKED_CAST") + val strides = (node.operation.parameters["strides"] as? List) + ?: List(rank) { 1 } + + val startsAttr = starts.joinToString(", ") + val limitsAttr = limits.joinToString(", ") + val stridesAttr = strides.joinToString(", ") + + val resultValue = context.nextTempValue() + val operation = "$resultValue = stablehlo.slice ${operands[0]} " + + "{start_indices = [$startsAttr], " + + "limit_indices = [$limitsAttr], " + + "strides = [$stridesAttr]} : $outputType" + context.emitOperation(operation) + + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = listOf(operation) + ) + } /** * Convert reshape operation using stablehlo.reshape. diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConcatSliceCastConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConcatSliceCastConverterTest.kt new file mode 100644 index 00000000..9233a66c --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConcatSliceCastConverterTest.kt @@ -0,0 +1,258 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers #489: concat / slice (in ShapeOperationsConverter) and + * cast (in MathOperationsConverter). All three are generic + * structural / type primitives — concat glues tensors along an + * axis, slice extracts a static window, cast reinterprets the + * element type. None of them are LLM-specific; they're the + * standard companions to reshape / flatten / squeeze that were + * already covered. + */ +class ConcatSliceCastConverterTest { + + // ----- concat ------------------------------------------------------------ + + @Test + fun concat_and_aliases_are_supported() { + for (opName in listOf("concat", "concatenate", "cat", "stack")) { + val module = buildConcatModule(opName) + assertFalse( + module.content.contains("Unsupported operation"), + "$opName must be claimed by a converter" + ) + assertTrue( + module.content.contains("stablehlo.concatenate"), + "$opName must lower to stablehlo.concatenate" + ) + } + } + + @Test + fun concat_emits_dim_attribute_matching_axis_parameter() { + val module = buildConcatModule("concat", axis = 1) + assertTrue( + module.content.contains("dim = 1"), + "concat must emit `dim = ` on stablehlo.concatenate" + ) + } + + private fun buildConcatModule(opName: String, axis: Int = 0): StableHloModule { + val graph = DefaultComputeGraph() + val shape = listOf(2, 3) + + val a = GraphNode( + id = "a", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("a", shape, "FP32")) + ) + val b = GraphNode( + id = "b", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("b", shape, "FP32")) + ) + val outShape = shape.mapIndexed { i, d -> if (i == axis) d * 2 else d } + val concat = GraphNode( + id = "cat1", + operation = concatOp(opName, axis), + inputs = listOf( + TensorSpec("a", shape, "FP32"), + TensorSpec("b", shape, "FP32") + ), + outputs = listOf(TensorSpec("y", outShape, "FP32")) + ) + + graph.addNode(a) + graph.addNode(b) + graph.addNode(concat) + graph.addEdge(GraphEdge("e1", a, concat, 0, 0, a.outputs[0])) + graph.addEdge(GraphEdge("e2", b, concat, 0, 1, b.outputs[0])) + + return StableHloConverterFactory.createExtended().convert(graph, "test_$opName") + } + + // ----- slice ------------------------------------------------------------- + + @Test + fun slice_is_supported_and_emits_stablehlo_slice() { + val module = buildSliceModule() + assertFalse( + module.content.contains("Unsupported operation slice"), + "slice must be claimed by a converter" + ) + assertTrue( + module.content.contains("stablehlo.slice"), + "slice must lower to stablehlo.slice" + ) + } + + @Test + fun slice_carries_start_limit_stride_attributes() { + val module = buildSliceModule() + println("[DEBUG_LOG] slice export:\n${module.content}") + assertTrue( + module.content.contains("start_indices"), + "slice must emit start_indices" + ) + assertTrue( + module.content.contains("limit_indices"), + "slice must emit limit_indices" + ) + assertTrue( + module.content.contains("strides"), + "slice must emit strides" + ) + } + + private fun buildSliceModule(): StableHloModule { + val graph = DefaultComputeGraph() + val shape = listOf(8, 16) + + val x = GraphNode( + id = "x", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("x", shape, "FP32")) + ) + val slice = GraphNode( + id = "slice1", + operation = sliceOp( + starts = listOf(0, 0), + limits = listOf(4, 8), + strides = listOf(1, 1) + ), + inputs = listOf(TensorSpec("x", shape, "FP32")), + outputs = listOf(TensorSpec("y", listOf(4, 8), "FP32")) + ) + graph.addNode(x) + graph.addNode(slice) + graph.addEdge(GraphEdge("e1", x, slice, 0, 0, x.outputs[0])) + + return StableHloConverterFactory.createExtended().convert(graph, "test_slice") + } + + // ----- cast -------------------------------------------------------------- + + @Test + fun cast_and_aliases_are_supported() { + for (opName in listOf("cast", "convert", "to")) { + val module = buildCastModule(opName, toDtype = "FP16") + assertFalse( + module.content.contains("Unsupported operation"), + "$opName must be claimed by a converter" + ) + assertTrue( + module.content.contains("stablehlo.convert"), + "$opName must lower to stablehlo.convert" + ) + } + } + + @Test + fun cast_emits_dtype_transition_in_type_signature() { + val module = buildCastModule("cast", toDtype = "FP16") + // The emitted op must carry a type signature that shows the + // source and destination element types. Exact formatting + // comes from TypeMapper; we check for the target dtype's + // MLIR-style name appearing on the RHS of a `->`. + assertTrue( + module.content.contains("->"), + "cast must emit a type-transition arrow in its signature" + ) + assertTrue( + module.content.contains("f16"), + "cast to FP16 must mention the target element type f16" + ) + } + + private fun buildCastModule(opName: String, toDtype: String): StableHloModule { + val graph = DefaultComputeGraph() + val shape = listOf(2, 3) + + val x = GraphNode( + id = "x", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("x", shape, "FP32")) + ) + val cast = GraphNode( + id = "cast1", + operation = castOp(opName, toDtype), + inputs = listOf(TensorSpec("x", shape, "FP32")), + outputs = listOf(TensorSpec("y", shape, toDtype)) + ) + graph.addNode(x) + graph.addNode(cast) + graph.addEdge(GraphEdge("e1", x, cast, 0, 0, x.outputs[0])) + + return StableHloConverterFactory.createExtended().convert(graph, "test_$opName") + } + + // ----- fixtures ---------------------------------------------------------- + + private fun markerInputOp(): Operation = object : Operation { + override val name: String = "input" + override val type: String = "input" + override val parameters: Map = emptyMap() + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = emptyList() + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type) + } + + private fun concatOp(name: String, axis: Int): Operation = object : Operation { + override val name: String = name + override val type: String = "shape" + override val parameters: Map = mapOf("axis" to axis) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type, "parameters" to parameters) + } + + private fun sliceOp(starts: List, limits: List, strides: List): Operation = object : Operation { + override val name: String = "slice" + override val type: String = "shape" + override val parameters: Map = mapOf( + "start_indices" to starts, + "limit_indices" to limits, + "strides" to strides + ) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type, "parameters" to parameters) + } + + private fun castOp(name: String, toDtype: String): Operation = object : Operation { + override val name: String = name + override val type: String = "math" + override val parameters: Map = mapOf("to" to toDtype) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type, "parameters" to parameters) + } +} diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt index b85f161a..54c0e5bf 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ShapeOperationsConverterTest.kt @@ -21,20 +21,36 @@ class ShapeOperationsConverterTest { @Test fun testSupportedOperations() { - val expectedOperations = setOf("reshape", "flatten", "squeeze", "unsqueeze") - assertEquals(expectedOperations, converter.supportedOperations) + // Core reshape family — the originals. + val coreOperations = setOf("reshape", "flatten", "squeeze", "unsqueeze") + assertTrue( + converter.supportedOperations.containsAll(coreOperations), + "converter must still cover the core reshape family" + ) + // Structural companions added in #489. + val structuralOperations = setOf("concat", "concatenate", "cat", "stack", "slice") + assertTrue( + converter.supportedOperations.containsAll(structuralOperations), + "converter must cover the structural companions (concat/slice) added in #489" + ) } - + @Test fun testRegistryIntegration() { // Test that shape operations are supported val registry = StableHloOperationRegistry() registry.register(ShapeOperationsConverter()) - + assertTrue(registry.isSupported("reshape")) assertTrue(registry.isSupported("flatten")) assertTrue(registry.isSupported("squeeze")) assertTrue(registry.isSupported("unsqueeze")) + // Structural companions added in #489. + assertTrue(registry.isSupported("concat")) + assertTrue(registry.isSupported("concatenate")) + assertTrue(registry.isSupported("cat")) + assertTrue(registry.isSupported("stack")) + assertTrue(registry.isSupported("slice")) } @Test