From 753907643025b836b7e110bad63991d688491295 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sat, 18 Apr 2026 09:36:47 +0200 Subject: [PATCH] Fix stablehlo.transpose and stablehlo.dot_general MLIR emission MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit iree-compile rejected MLIR emitted by LinalgOperationsConverter because the converter used syntax that does not match the StableHLO spec: - transpose emitted `: outputType` instead of `: (inputType) -> outputType`, which iree-compile flagged as "invalid kind of type specified". - dot_general emitted `contracting_dims = [[N], [M]]` (ArrayAttr-style) instead of `contracting_dims = [N] x [M]`, omitted batching_dims entirely for batched matmuls, and dropped the `(lhsType, rhsType) ->` input-type clause — all three are required by the StableHLO parser. LinalgOperationsConverter now threads input TensorSpecs through to the emission, derives batching_dims from lhs rank for bmm (supporting 2D, 3D, and 4D+ inputs instead of the previous 3D-only assumption), and writes the `(lhs, rhs) -> result` type tuple. Tests updated accordingly; full :skainet-compile-hlo:jvmTest passes. These are the first two of four items from a skainet-whisper IREE Vulkan bring-up on branch `feature/iree-vulkan-gpu` (SKaiNET 0.18.0). The remaining two are tracked separately. Related: #518, #519 Co-Authored-By: Claude Opus 4.7 (1M context) --- .../converters/LinalgOperationsConverter.kt | 89 +++++++++++-------- .../hlo/LinalgOperationsConverterTest.kt | 14 +-- 2 files changed, 61 insertions(+), 42 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt index cb58b51d..c3dd716a 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt @@ -65,19 +65,23 @@ public class LinalgOperationsConverter : StableHloOperationConverter { } val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - + val lhsSpec = node.inputs.getOrNull(0) + val rhsSpec = node.inputs.getOrNull(1) + val lhsType = lhsSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + val rhsType = rhsSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + val resultValue = context.nextTempValue() - - // Standard matmul: contract last dimension of left operand with - // second-to-last dimension of right operand + + // Standard matmul: contract last dimension of left operand with + // second-to-last dimension of right operand. // For 2D: A[M,K] x B[K,N] -> C[M,N] - // contracting_dims = [[1], [0]] means: - // - dimension 1 (K) of left operand - // - dimension 0 (K) of right operand - val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, contracting_dims = [[1], [0]] : $outputType" - + // contracting_dims = [1] x [0] means dim 1 of lhs with dim 0 of rhs. + val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, contracting_dims = [1] x [0] : ($lhsType, $rhsType) -> $outputType" + context.emitOperation(operation) return ConversionResult.Success( @@ -107,29 +111,42 @@ public class LinalgOperationsConverter : StableHloOperationConverter { } val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - - // Determine batch dimensions from the operation parameters or infer from shapes - val batchDims = node.operation.parameters["batch_dims"] as? List<*> - val batchDimsStr = if (batchDims != null && batchDims.isNotEmpty()) { - val dims = batchDims.joinToString(", ") - "[[${dims}], [${dims}]]" + val lhsSpec = node.inputs.getOrNull(0) + val rhsSpec = node.inputs.getOrNull(1) + val lhsType = lhsSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + val rhsType = rhsSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + + // Infer batching rank: for A[..., M, K] x B[..., K, N], batching dims + // are all leading dims except the last two. Falls back to rank 3 if + // shape is unknown (matches prior hard-coded behavior). + val rank = lhsSpec?.shape?.size ?: rhsSpec?.shape?.size ?: 3 + val batchCount = (rank - 2).coerceAtLeast(0) + val explicitBatch = node.operation.parameters["batch_dims"] as? List<*> + val batchDimsList = if (explicitBatch != null && explicitBatch.isNotEmpty()) { + explicitBatch.map { it.toString() } } else { - // Default: assume first dimension is batch - "[[0], [0]]" + (0 until batchCount).map { it.toString() } } - + val contractingLhs = rank - 1 + val contractingRhs = (rank - 2).coerceAtLeast(0) + val resultValue = context.nextTempValue() - - // Batch matmul: preserve batch dimensions and contract matrix dimensions - // For 3D: A[B,M,K] x B[B,K,N] -> C[B,M,N] - // batch_dims = [[0], [0]] means batch dimension 0 is preserved - // contracting_dims = [[2], [1]] means: - // - dimension 2 (K) of left operand - // - dimension 1 (K) of right operand - val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, contracting_dims = [[2], [1]], batch_dims = $batchDimsStr : $outputType" - + + // Batch matmul: preserve batch dimensions and contract matrix dimensions. + // For A[..., M, K] x B[..., K, N]: batching_dims cover leading dims, + // contracting_dims = [rank-1] x [rank-2]. + val batchClause = if (batchDimsList.isNotEmpty()) { + val b = batchDimsList.joinToString(", ") + "batching_dims = [$b] x [$b], " + } else { + "" + } + val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, ${batchClause}contracting_dims = [$contractingLhs] x [$contractingRhs] : ($lhsType, $rhsType) -> $outputType" + context.emitOperation(operation) return ConversionResult.Success( @@ -162,27 +179,29 @@ public class LinalgOperationsConverter : StableHloOperationConverter { } val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - + val inputSpec = node.inputs.firstOrNull() + val inputType = inputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + // Get permutation from operation parameters val permutation = node.operation.parameters["permutation"] as? List<*> ?: node.operation.parameters["perm"] as? List<*> ?: node.operation.parameters["axes"] as? List<*> - + val permutationStr = if (permutation != null && permutation.isNotEmpty()) { // Use provided permutation permutation.joinToString(", ") } else { // Default: reverse all dimensions // For 2D: [1, 0], for 3D: [2, 1, 0], etc. - val inputSpec = context.getInputNodes(node).firstOrNull()?.outputs?.firstOrNull() val rank = inputSpec?.shape?.size ?: 2 (rank - 1 downTo 0).joinToString(", ") } - + val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.transpose ${operands[0]}, dims = [$permutationStr] : $outputType" + val operation = "$resultValue = stablehlo.transpose ${operands[0]}, dims = [$permutationStr] : ($inputType) -> $outputType" context.emitOperation(operation) return ConversionResult.Success( diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/LinalgOperationsConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/LinalgOperationsConverterTest.kt index 003daeb5..783b013d 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/LinalgOperationsConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/LinalgOperationsConverterTest.kt @@ -47,8 +47,8 @@ class LinalgOperationsConverterTest { val module = converter.convert(graph, "test_matmul") assertTrue(module.content.contains("stablehlo.dot_general")) - assertTrue(module.content.contains("contracting_dims = [[1], [0]]")) - assertTrue(module.content.contains("tensor<3x5xf32>")) + assertTrue(module.content.contains("contracting_dims = [1] x [0]")) + assertTrue(module.content.contains("(tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>")) assertEquals("test_matmul", module.functionName) } @@ -96,9 +96,9 @@ class LinalgOperationsConverterTest { val module = converter.convert(graph, "test_batch_matmul") assertTrue(module.content.contains("stablehlo.dot_general")) - assertTrue(module.content.contains("contracting_dims = [[2], [1]]")) - assertTrue(module.content.contains("batch_dims = [[0], [0]]")) - assertTrue(module.content.contains("tensor<4x3x5xf32>")) + assertTrue(module.content.contains("batching_dims = [0] x [0]")) + assertTrue(module.content.contains("contracting_dims = [2] x [1]")) + assertTrue(module.content.contains("(tensor<4x3x4xf32>, tensor<4x4x5xf32>) -> tensor<4x3x5xf32>")) } @Test @@ -109,7 +109,7 @@ class LinalgOperationsConverterTest { assertTrue(module.content.contains("stablehlo.transpose")) assertTrue(module.content.contains("dims = [1, 0]")) - assertTrue(module.content.contains("tensor<3x2xf32>")) + assertTrue(module.content.contains("(tensor<2x3xf32>) -> tensor<3x2xf32>")) } @Test @@ -140,7 +140,7 @@ class LinalgOperationsConverterTest { assertTrue(module.content.contains("stablehlo.transpose")) assertTrue(module.content.contains("dims = [0, 2, 1]")) - assertTrue(module.content.contains("tensor<2x4x3xf32>")) + assertTrue(module.content.contains("(tensor<2x3x4xf32>) -> tensor<2x4x3xf32>")) } @Test