diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt index e2f233d5..81c42dc7 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt @@ -2,6 +2,7 @@ package sk.ainet.compile.hlo import sk.ainet.compile.hlo.converters.ActivationOperationsConverter import sk.ainet.compile.hlo.converters.ConstantOperationsConverter +import sk.ainet.compile.hlo.converters.GatherOperationsConverter import sk.ainet.compile.hlo.converters.LegacyOperationsConverter import sk.ainet.compile.hlo.converters.LinalgOperationsConverter import sk.ainet.compile.hlo.converters.MathOperationsConverter @@ -46,6 +47,10 @@ public object StableHloConverterFactory { // Register constant operations converter registry.register(ConstantOperationsConverter()) + // Register gather / embedding / index_select converter — the + // LLM front-door op for token-id \u2192 embedding lookups. + registry.register(GatherOperationsConverter()) + return StableHloConverter(registry, typeMapper, validator) } @@ -81,6 +86,10 @@ public object StableHloConverterFactory { // Register constant operations converter registry.register(ConstantOperationsConverter()) + // Register gather / embedding / index_select converter — the + // LLM front-door op for token-id \u2192 embedding lookups. + registry.register(GatherOperationsConverter()) + return StableHloConverter(registry, typeMapper, validator) } diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/GatherOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/GatherOperationsConverter.kt new file mode 100644 index 00000000..98b8eff3 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/GatherOperationsConverter.kt @@ -0,0 +1,153 @@ +package sk.ainet.compile.hlo.converters + +import sk.ainet.compile.hlo.ConversionContext +import sk.ainet.compile.hlo.ConversionResult +import sk.ainet.compile.hlo.StableHloOperationConverter +import sk.ainet.lang.graph.GraphNode + +/** + * Converter for memory-access / indexing operations. + * + * Today that's just `gather` and its framework aliases — the + * critical path for LLM exports, where every transformer forward + * pass begins with a token-id \u2192 embedding lookup. Without a + * converter for `gather` / `embedding` / `index_select`, a traced + * Llama / Mistral / Qwen / Gemma model fails at the very first + * operation and never reaches the norms, activations, or attention + * that the other P1 converters cover. + * + * The target lowering is the canonical `embedding(input_ids)` + * shape: a 1-D index tensor indexing the leading dimension of a + * 2-D embedding weight. Higher-rank gathers (attention-side index + * gathers, multi-dim scatter/gather) can be added in follow-up PRs + * once a traced model surfaces them; scoping this converter to the + * LLM front-door case keeps review tight. + * + * Emitted shape: + * + * ```mlir + * %out = stablehlo.gather(%weights, %indices) + * { dimension_numbers = #stablehlo.gather< + * offset_dims = [1], + * collapsed_slice_dims = [0], + * start_index_map = [0], + * index_vector_dim = 1>, + * slice_sizes = array, + * indices_are_sorted = false } + * : (tensor, tensor) + * -> tensor + * ``` + * + * The `slice_sizes` vector is derived from the weight shape: a 1 + * along the gathered axis and the full extent of every other + * dimension. `offset_dims`, `collapsed_slice_dims`, and + * `start_index_map` are computed from the single gather axis. + */ +public class GatherOperationsConverter : StableHloOperationConverter { + + override val supportedOperations: Set = setOf( + "gather", "embedding", "Embedding", "index_select" + ) + + override fun convert( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + return when (node.operation.name.lowercase()) { + "gather", "embedding", "index_select" -> convertGather(node, operands, context) + else -> ConversionResult.Unsupported( + node.operation.name, + "Operation not supported by GatherOperationsConverter" + ) + } + } + + private fun convertGather( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size < 2) { + return ConversionResult.Failure( + "Gather operation requires 2 operands (weights, indices), got ${operands.size}", + "Unsupported gather arity for node ${node.id}" + ) + } + + val weightSpec = node.inputs.getOrNull(0) + val indicesSpec = node.inputs.getOrNull(1) + val outputSpec = node.outputs.firstOrNull() + + val typeMapper = context.getTypeMapper() + val weightType = weightSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor" + val indicesType = indicesSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor" + val outputType = outputSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor" + + val weightShape = weightSpec?.shape ?: emptyList() + val weightRank = weightShape.size + val indicesRank = indicesSpec?.shape?.size ?: 1 + + // Gather axis. Default to 0 (the conventional embedding-lookup + // shape) and normalize negative axes against the weight rank. + val rawAxis = node.operation.parameters["axis"] as? Int + ?: node.operation.parameters["dim"] as? Int + ?: 0 + val axis = when { + weightRank == 0 -> 0 + rawAxis < 0 -> weightRank + rawAxis + else -> rawAxis + }.coerceIn(0, (weightRank - 1).coerceAtLeast(0)) + + // offset_dims: the axes of the output that carry "the rest of + // the row" — every weight axis except the gathered one, offset + // by the indices rank (which sits at the beginning of the + // output shape for a canonical gather). + val offsetDims = (0 until weightRank) + .filter { it != axis } + .mapIndexed { i, _ -> indicesRank + i } + .joinToString(", ") + + // collapsed_slice_dims: the axes of the weight that are + // "picked" by the indices — just the gathered axis for this + // single-axis case. + val collapsedSliceDims = "$axis" + + // start_index_map: index `i` in the indices tensor maps to + // start coordinate along the weight's gathered axis. + val startIndexMap = "$axis" + + // index_vector_dim: the axis of the indices tensor that holds + // the multi-dim coordinate. For a 1-D index tensor indexing a + // single axis, this is the rank (i.e. one past the last dim), + // following StableHLO convention that a trailing scalar + // "implicit index vector" is allowed. + val indexVectorDim = indicesRank + + // slice_sizes: a 1 along the gathered axis, the full extent + // along every other axis. + val sliceSizes = weightShape.mapIndexed { i, extent -> + if (i == axis) 1 else extent + }.joinToString(", ") + + val weightOperand = operands[0] + val indicesOperand = operands[1] + val resultValue = context.nextTempValue() + val gatherOp = "$resultValue = stablehlo.gather($weightOperand, $indicesOperand) " + + "{ dimension_numbers = #stablehlo.gather<" + + "offset_dims = [$offsetDims], " + + "collapsed_slice_dims = [$collapsedSliceDims], " + + "start_index_map = [$startIndexMap], " + + "index_vector_dim = $indexVectorDim>, " + + "slice_sizes = array, " + + "indices_are_sorted = false } " + + ": ($weightType, $indicesType) -> $outputType" + + context.emitOperation(gatherOp) + + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = listOf(gatherOp) + ) + } +} diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/GatherConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/GatherConverterTest.kt new file mode 100644 index 00000000..c875ec63 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/GatherConverterTest.kt @@ -0,0 +1,165 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers the gather / embedding converter for #483. Every LLM export + * begins with a token-id \u2192 embedding lookup and the StableHLO + * emitter had no converter for `gather` / `embedding` today — a + * traced Llama / Mistral / Qwen / Gemma forward pass therefore failed + * at the very first operation. + * + * Target is the canonical `embedding(input_ids)` shape: 1-D index + * tensor indexing the leading dimension of a 2-D embedding weight. + * The lowering follows the StableHLO gather custom assembly that + * downstream MLIR tools (IREE in particular) expect. + */ +class GatherConverterTest { + + @Test + fun gather_and_embedding_aliases_are_supported() { + val module = buildEmbeddingModule(opName = "gather") + assertTrue(module.content.contains("stablehlo.gather")) + assertFalse( + module.content.contains("Unsupported operation gather"), + "`gather` must be claimed by a converter, not dropped as unsupported" + ) + assertFalse( + module.content.contains("No converter found"), + "`gather` must be claimed by a converter, not left without a handler" + ) + } + + @Test + fun embedding_alias_routes_to_same_lowering() { + val module = buildEmbeddingModule(opName = "embedding") + assertTrue(module.content.contains("stablehlo.gather")) + assertFalse(module.content.contains("Unsupported operation")) + } + + @Test + fun index_select_alias_routes_to_same_lowering() { + val module = buildEmbeddingModule(opName = "index_select") + assertTrue(module.content.contains("stablehlo.gather")) + assertFalse(module.content.contains("Unsupported operation")) + } + + @Test + fun embedding_lowering_carries_canonical_dim_numbers_and_slice_sizes() { + val module = buildEmbeddingModule(opName = "embedding") + println("[DEBUG_LOG] gather/embedding export:\n${module.content}") + + // The emitted op must carry the dim_numbers / slice_sizes + // custom assembly that downstream MLIR tools expect for a + // 1-D index tensor gathering rows from a 2-D weight. + assertTrue( + module.content.contains("dimension_numbers"), + "gather must emit a dimension_numbers attribute" + ) + assertTrue( + module.content.contains("offset_dims = [1]"), + "gather must declare offset_dims = [1] for an axis-0 row gather on a 2-D weight" + ) + assertTrue( + module.content.contains("collapsed_slice_dims = [0]"), + "gather must declare collapsed_slice_dims = [0] for the gathered axis" + ) + assertTrue( + module.content.contains("start_index_map = [0]"), + "gather must declare start_index_map = [0]" + ) + assertTrue( + module.content.contains("slice_sizes = array"), + "gather must declare slice_sizes = [1, hidden_size=4] matching the weight row shape" + ) + + // Tight regression check: the gather operands must be the + // actual SSA value names, not a bracketed list expression. + // (Earlier draft accidentally emitted + // `stablehlo.gather([%arg0, %arg1][0], [%arg0, %arg1][1])` + // because of a `$operands[0]` Kotlin string-template pitfall.) + assertTrue( + module.content.contains("stablehlo.gather(%arg0, %arg1)"), + "gather must reference operands as bare SSA values, not `[%arg0, %arg1][0]`" + ) + assertFalse( + module.content.contains("stablehlo.gather([%"), + "gather must not emit operand lists as Kotlin-string `[..., ...][0]` junk" + ) + } + + private fun buildEmbeddingModule(opName: String): StableHloModule { + val graph = DefaultComputeGraph() + + val vocabSize = 8 + val hiddenSize = 4 + val seqLen = 3 + + val weightNode = GraphNode( + id = "W", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("W", listOf(vocabSize, hiddenSize), "FP32")) + ) + val indicesNode = GraphNode( + id = "ids", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("ids", listOf(seqLen), "INT32")) + ) + val gatherNode = GraphNode( + id = "embed1", + operation = gatherOp(opName, axis = 0), + inputs = listOf( + TensorSpec("W", listOf(vocabSize, hiddenSize), "FP32"), + TensorSpec("ids", listOf(seqLen), "INT32") + ), + outputs = listOf(TensorSpec("y", listOf(seqLen, hiddenSize), "FP32")) + ) + + graph.addNode(weightNode) + graph.addNode(indicesNode) + graph.addNode(gatherNode) + graph.addEdge(GraphEdge("e1", weightNode, gatherNode, 0, 0, weightNode.outputs[0])) + graph.addEdge(GraphEdge("e2", indicesNode, gatherNode, 0, 1, indicesNode.outputs[0])) + + val converter = StableHloConverterFactory.createExtended() + return converter.convert(graph, "test_$opName") + } + + private fun markerInputOp(): Operation = object : Operation { + override val name: String = "input" + override val type: String = "input" + override val parameters: Map = emptyMap() + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = emptyList() + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type) + } + + private fun gatherOp(name: String, axis: Int): Operation = object : Operation { + override val name: String = name + override val type: String = "indexing" + override val parameters: Map = mapOf("axis" to axis) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf( + "name" to name, "type" to type, "parameters" to parameters + ) + } +}