Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,23 @@ public class LinalgOperationsConverter : StableHloOperationConverter {
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?xf32>"

val lhsSpec = node.inputs.getOrNull(0)
val rhsSpec = node.inputs.getOrNull(1)
val lhsType = lhsSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?xf32>"
val rhsType = rhsSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?xf32>"

val resultValue = context.nextTempValue()
// Standard matmul: contract last dimension of left operand with
// second-to-last dimension of right operand

// Standard matmul: contract last dimension of left operand with
// second-to-last dimension of right operand.
// For 2D: A[M,K] x B[K,N] -> C[M,N]
// contracting_dims = [[1], [0]] means:
// - dimension 1 (K) of left operand
// - dimension 0 (K) of right operand
val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, contracting_dims = [[1], [0]] : $outputType"

// contracting_dims = [1] x [0] means dim 1 of lhs with dim 0 of rhs.
val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, contracting_dims = [1] x [0] : ($lhsType, $rhsType) -> $outputType"

context.emitOperation(operation)

return ConversionResult.Success(
Expand Down Expand Up @@ -107,29 +111,42 @@ public class LinalgOperationsConverter : StableHloOperationConverter {
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?x?xf32>"

// Determine batch dimensions from the operation parameters or infer from shapes
val batchDims = node.operation.parameters["batch_dims"] as? List<*>
val batchDimsStr = if (batchDims != null && batchDims.isNotEmpty()) {
val dims = batchDims.joinToString(", ")
"[[${dims}], [${dims}]]"
val lhsSpec = node.inputs.getOrNull(0)
val rhsSpec = node.inputs.getOrNull(1)
val lhsType = lhsSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?x?xf32>"
val rhsType = rhsSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?x?xf32>"

// Infer batching rank: for A[..., M, K] x B[..., K, N], batching dims
// are all leading dims except the last two. Falls back to rank 3 if
// shape is unknown (matches prior hard-coded behavior).
val rank = lhsSpec?.shape?.size ?: rhsSpec?.shape?.size ?: 3
val batchCount = (rank - 2).coerceAtLeast(0)
val explicitBatch = node.operation.parameters["batch_dims"] as? List<*>
val batchDimsList = if (explicitBatch != null && explicitBatch.isNotEmpty()) {
explicitBatch.map { it.toString() }
} else {
// Default: assume first dimension is batch
"[[0], [0]]"
(0 until batchCount).map { it.toString() }
}

val contractingLhs = rank - 1
val contractingRhs = (rank - 2).coerceAtLeast(0)

val resultValue = context.nextTempValue()

// Batch matmul: preserve batch dimensions and contract matrix dimensions
// For 3D: A[B,M,K] x B[B,K,N] -> C[B,M,N]
// batch_dims = [[0], [0]] means batch dimension 0 is preserved
// contracting_dims = [[2], [1]] means:
// - dimension 2 (K) of left operand
// - dimension 1 (K) of right operand
val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, contracting_dims = [[2], [1]], batch_dims = $batchDimsStr : $outputType"


// Batch matmul: preserve batch dimensions and contract matrix dimensions.
// For A[..., M, K] x B[..., K, N]: batching_dims cover leading dims,
// contracting_dims = [rank-1] x [rank-2].
val batchClause = if (batchDimsList.isNotEmpty()) {
val b = batchDimsList.joinToString(", ")
"batching_dims = [$b] x [$b], "
} else {
""
}
val operation = "$resultValue = stablehlo.dot_general ${operands[0]}, ${operands[1]}, ${batchClause}contracting_dims = [$contractingLhs] x [$contractingRhs] : ($lhsType, $rhsType) -> $outputType"

context.emitOperation(operation)

return ConversionResult.Success(
Expand Down Expand Up @@ -162,27 +179,29 @@ public class LinalgOperationsConverter : StableHloOperationConverter {
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?xf32>"

val inputSpec = node.inputs.firstOrNull()
val inputType = inputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?xf32>"

// Get permutation from operation parameters
val permutation = node.operation.parameters["permutation"] as? List<*>
?: node.operation.parameters["perm"] as? List<*>
?: node.operation.parameters["axes"] as? List<*>

val permutationStr = if (permutation != null && permutation.isNotEmpty()) {
// Use provided permutation
permutation.joinToString(", ")
} else {
// Default: reverse all dimensions
// For 2D: [1, 0], for 3D: [2, 1, 0], etc.
val inputSpec = context.getInputNodes(node).firstOrNull()?.outputs?.firstOrNull()
val rank = inputSpec?.shape?.size ?: 2
(rank - 1 downTo 0).joinToString(", ")
}

val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.transpose ${operands[0]}, dims = [$permutationStr] : $outputType"
val operation = "$resultValue = stablehlo.transpose ${operands[0]}, dims = [$permutationStr] : ($inputType) -> $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class LinalgOperationsConverterTest {
val module = converter.convert(graph, "test_matmul")

assertTrue(module.content.contains("stablehlo.dot_general"))
assertTrue(module.content.contains("contracting_dims = [[1], [0]]"))
assertTrue(module.content.contains("tensor<3x5xf32>"))
assertTrue(module.content.contains("contracting_dims = [1] x [0]"))
assertTrue(module.content.contains("(tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>"))
assertEquals("test_matmul", module.functionName)
}

Expand Down Expand Up @@ -96,9 +96,9 @@ class LinalgOperationsConverterTest {
val module = converter.convert(graph, "test_batch_matmul")

assertTrue(module.content.contains("stablehlo.dot_general"))
assertTrue(module.content.contains("contracting_dims = [[2], [1]]"))
assertTrue(module.content.contains("batch_dims = [[0], [0]]"))
assertTrue(module.content.contains("tensor<4x3x5xf32>"))
assertTrue(module.content.contains("batching_dims = [0] x [0]"))
assertTrue(module.content.contains("contracting_dims = [2] x [1]"))
assertTrue(module.content.contains("(tensor<4x3x4xf32>, tensor<4x4x5xf32>) -> tensor<4x3x5xf32>"))
}

@Test
Expand All @@ -109,7 +109,7 @@ class LinalgOperationsConverterTest {

assertTrue(module.content.contains("stablehlo.transpose"))
assertTrue(module.content.contains("dims = [1, 0]"))
assertTrue(module.content.contains("tensor<3x2xf32>"))
assertTrue(module.content.contains("(tensor<2x3xf32>) -> tensor<3x2xf32>"))
}

@Test
Expand Down Expand Up @@ -140,7 +140,7 @@ class LinalgOperationsConverterTest {

assertTrue(module.content.contains("stablehlo.transpose"))
assertTrue(module.content.contains("dims = [0, 2, 1]"))
assertTrue(module.content.contains("tensor<2x4x3xf32>"))
assertTrue(module.content.contains("(tensor<2x3x4xf32>) -> tensor<2x4x3xf32>"))
}

@Test
Expand Down
Loading