From 9a12185f53d1ba6870e9ed0cb9bfe54f86a9a799 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sun, 19 Apr 2026 21:32:22 +0200 Subject: [PATCH] Compute real conv output shapes in graph operations (#536) Conv{1,2,3}dOperation.inferOutputs previously echoed inputs[0].shape, ignoring the weight shape and stride/padding/dilation parameters. This left every stablehlo.convolution result as tensor, blocking iree-compile on the Whisper encoder. Extract the shape math into a public ConvShapeUtils object so the eager (VoidTensorOps) and graph-emission paths share one source of truth, and rewrite the three inferOutputs methods to use it. Conv2d and Conv3d accept either Pair/Triple (as written by RecordingExecution) or scalar Int parameters. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../ainet/lang/tensor/ops/ConvShapeUtils.kt | 98 ++++++++++++ .../ainet/lang/tensor/ops/TensorOperations.kt | 59 +++++++- .../sk/ainet/lang/tensor/ops/VoidTensorOps.kt | 91 +++-------- .../ops/ConvOperationInferOutputsTest.kt | 142 ++++++++++++++++++ 4 files changed, 315 insertions(+), 75 deletions(-) create mode 100644 skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt create mode 100644 skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/ConvOperationInferOutputsTest.kt diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt new file mode 100644 index 00000000..2b5b5cdd --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt @@ -0,0 +1,98 @@ +package sk.ainet.lang.tensor.ops + +/** + * Single source of truth for convolution output-shape arithmetic. + * + * Both eager execution (`VoidTensorOps`) and graph emission + * (`Conv{1,2,3}dOperation.inferOutputs`) must produce identical shapes; + * keeping the formula here prevents the two paths from diverging. + */ +public object ConvShapeUtils { + + /** + * Conv1d output shape: input `(batch, in_channels, length)`, + * weight `(out_channels, in_channels_per_group, kernel_length)`, + * result `(batch, out_channels, out_length)`. + */ + public fun conv1dOutputShape( + inputShape: IntArray, + weightShape: IntArray, + stride: Int, + padding: Int, + dilation: Int + ): IntArray { + require(inputShape.size == 3) { + "Conv1d input must be rank 3 (batch, channels, length), got rank ${inputShape.size}" + } + require(weightShape.size == 3) { + "Conv1d weight must be rank 3 (out_channels, in_channels, kernel_length), got rank ${weightShape.size}" + } + val batch = inputShape[0] + val outChannels = weightShape[0] + val inLength = inputShape[2] + val kernel = weightShape[2] + val outLength = (inLength + 2 * padding - dilation * (kernel - 1) - 1) / stride + 1 + return intArrayOf(batch, outChannels, outLength) + } + + /** + * Conv2d output shape: input `(batch, in_channels, height, width)`, + * weight `(out_channels, in_channels_per_group, kernel_h, kernel_w)`, + * result `(batch, out_channels, out_h, out_w)`. + */ + public fun conv2dOutputShape( + inputShape: IntArray, + weightShape: IntArray, + stride: Pair, + padding: Pair, + dilation: Pair + ): IntArray { + require(inputShape.size == 4) { + "Conv2d input must be rank 4 (batch, channels, height, width), got rank ${inputShape.size}" + } + require(weightShape.size == 4) { + "Conv2d weight must be rank 4 (out_channels, in_channels, kernel_h, kernel_w), got rank ${weightShape.size}" + } + val batch = inputShape[0] + val outChannels = weightShape[0] + val inH = inputShape[2] + val inW = inputShape[3] + val kH = weightShape[2] + val kW = weightShape[3] + val outH = (inH + 2 * padding.first - dilation.first * (kH - 1) - 1) / stride.first + 1 + val outW = (inW + 2 * padding.second - dilation.second * (kW - 1) - 1) / stride.second + 1 + return intArrayOf(batch, outChannels, outH, outW) + } + + /** + * Conv3d output shape: input `(batch, in_channels, depth, height, width)`, + * weight `(out_channels, in_channels_per_group, kernel_d, kernel_h, kernel_w)`, + * result `(batch, out_channels, out_d, out_h, out_w)`. + */ + public fun conv3dOutputShape( + inputShape: IntArray, + weightShape: IntArray, + stride: Triple, + padding: Triple, + dilation: Triple + ): IntArray { + require(inputShape.size == 5) { + "Conv3d input must be rank 5 (batch, channels, depth, height, width), got rank ${inputShape.size}" + } + require(weightShape.size == 5) { + "Conv3d weight must be rank 5 (out_channels, in_channels, kernel_d, kernel_h, kernel_w), got rank ${weightShape.size}" + } + val batch = inputShape[0] + val outChannels = weightShape[0] + val inD = inputShape[2] + val inH = inputShape[3] + val inW = inputShape[4] + val kD = weightShape[2] + val kH = weightShape[3] + val kW = weightShape[4] + val outD = (inD + 2 * padding.first - dilation.first * (kD - 1) - 1) / stride.first + 1 + val outH = (inH + 2 * padding.second - dilation.second * (kH - 1) - 1) / stride.second + 1 + val outW = (inW + 2 * padding.third - dilation.third * (kW - 1) - 1) / stride.third + 1 + return intArrayOf(batch, outChannels, outD, outH, outW) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt index 2057d5e7..5c39b8c5 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt @@ -438,7 +438,18 @@ public class Conv1dOperation( override fun inferOutputs(inputs: List): List { require(inputs.size >= 2) { "Conv1d operation requires at least 2 inputs" } - val outputShape = inputs[0].shape + val inShape = inputs[0].shape + val wShape = inputs[1].shape + val stride = (parameters["stride"] as? Int) ?: 1 + val padding = (parameters["padding"] as? Int) ?: 0 + val dilation = (parameters["dilation"] as? Int) ?: 1 + val outputShape = if (inShape != null && wShape != null && inShape.size == 3 && wShape.size == 3) { + ConvShapeUtils.conv1dOutputShape( + inShape.toIntArray(), wShape.toIntArray(), stride, padding, dilation + ).toList() + } else { + null + } return listOf( TensorSpec( name = "conv1d_output", @@ -470,7 +481,18 @@ public class Conv2dOperation( override fun inferOutputs(inputs: List): List { require(inputs.size >= 2) { "Conv2d operation requires at least 2 inputs" } - val outputShape = inputs[0].shape + val inShape = inputs[0].shape + val wShape = inputs[1].shape + val stride = pairParam("stride", 1) + val padding = pairParam("padding", 0) + val dilation = pairParam("dilation", 1) + val outputShape = if (inShape != null && wShape != null && inShape.size == 4 && wShape.size == 4) { + ConvShapeUtils.conv2dOutputShape( + inShape.toIntArray(), wShape.toIntArray(), stride, padding, dilation + ).toList() + } else { + null + } return listOf( TensorSpec( name = "conv2d_output", @@ -481,6 +503,16 @@ public class Conv2dOperation( ) } + private fun pairParam(name: String, default: Int): Pair { + val raw = parameters[name] ?: return default to default + @Suppress("UNCHECKED_CAST") + return when (raw) { + is Pair<*, *> -> raw as Pair + is Int -> raw to raw + else -> default to default + } + } + override fun clone(newParameters: Map): Operation = Conv2dOperation(newParameters) } @@ -502,7 +534,18 @@ public class Conv3dOperation( override fun inferOutputs(inputs: List): List { require(inputs.size >= 2) { "Conv3d operation requires at least 2 inputs" } - val outputShape = inputs[0].shape + val inShape = inputs[0].shape + val wShape = inputs[1].shape + val stride = tripleParam("stride", 1) + val padding = tripleParam("padding", 0) + val dilation = tripleParam("dilation", 1) + val outputShape = if (inShape != null && wShape != null && inShape.size == 5 && wShape.size == 5) { + ConvShapeUtils.conv3dOutputShape( + inShape.toIntArray(), wShape.toIntArray(), stride, padding, dilation + ).toList() + } else { + null + } return listOf( TensorSpec( name = "conv3d_output", @@ -513,6 +556,16 @@ public class Conv3dOperation( ) } + private fun tripleParam(name: String, default: Int): Triple { + val raw = parameters[name] ?: return Triple(default, default, default) + @Suppress("UNCHECKED_CAST") + return when (raw) { + is Triple<*, *, *> -> raw as Triple + is Int -> Triple(raw, raw, raw) + else -> Triple(default, default, default) + } + } + override fun clone(newParameters: Map): Operation = Conv3dOperation(newParameters) } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt index eef1859c..8530d449 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt @@ -750,23 +750,11 @@ public class VoidTensorOps : TensorOps { stride: Int, padding: Int, dilation: Int - ): Shape { - if (inputShape.rank != 3) { - throw IllegalArgumentException("Conv1d input must be 3D tensor (batch, channels, length)") - } - if (weightShape.rank != 3) { - throw IllegalArgumentException("Conv1d weight must be 3D tensor (out_channels, in_channels, kernel_length)") - } - - val batch = inputShape.dimensions[0] - val outChannels = weightShape.dimensions[0] - val inputLength = inputShape.dimensions[2] - val kernelLength = weightShape.dimensions[2] - - val outputLength = ((inputLength + 2 * padding - dilation * (kernelLength - 1) - 1) / stride) + 1 - - return Shape(batch, outChannels, outputLength) - } + ): Shape = Shape( + ConvShapeUtils.conv1dOutputShape( + inputShape.dimensions, weightShape.dimensions, stride, padding, dilation + ) + ) /** * Calculates the result shape for conv3d operation. @@ -780,33 +768,11 @@ public class VoidTensorOps : TensorOps { stride: Triple, padding: Triple, dilation: Triple - ): Shape { - if (inputShape.rank != 5) { - throw IllegalArgumentException("Conv3d input must be 5D tensor (batch, channels, depth, height, width)") - } - if (weightShape.rank != 5) { - throw IllegalArgumentException("Conv3d weight must be 5D tensor (out_channels, in_channels, kernel_d, kernel_h, kernel_w)") - } - - val batch = inputShape.dimensions[0] - val outChannels = weightShape.dimensions[0] - val inputDepth = inputShape.dimensions[2] - val inputHeight = inputShape.dimensions[3] - val inputWidth = inputShape.dimensions[4] - val kernelDepth = weightShape.dimensions[2] - val kernelHeight = weightShape.dimensions[3] - val kernelWidth = weightShape.dimensions[4] - - val (strideD, strideH, strideW) = stride - val (padD, padH, padW) = padding - val (dilationD, dilationH, dilationW) = dilation - - val outputDepth = ((inputDepth + 2 * padD - dilationD * (kernelDepth - 1) - 1) / strideD) + 1 - val outputHeight = ((inputHeight + 2 * padH - dilationH * (kernelHeight - 1) - 1) / strideH) + 1 - val outputWidth = ((inputWidth + 2 * padW - dilationW * (kernelWidth - 1) - 1) / strideW) + 1 - - return Shape(batch, outChannels, outputDepth, outputHeight, outputWidth) - } + ): Shape = Shape( + ConvShapeUtils.conv3dOutputShape( + inputShape.dimensions, weightShape.dimensions, stride, padding, dilation + ) + ) /** * Calculates the result shape for convTranspose1d operation. @@ -832,35 +798,16 @@ public class VoidTensorOps : TensorOps { * Output shape: (batch, out_channels, out_height, out_width) */ private fun calculateConv2dShape( - inputShape: Shape, - weightShape: Shape, - stride: Pair, - padding: Pair, + inputShape: Shape, + weightShape: Shape, + stride: Pair, + padding: Pair, dilation: Pair - ): Shape { - if (inputShape.rank != 4) { - throw IllegalArgumentException("Conv2d input must be 4D tensor (batch, channels, height, width)") - } - if (weightShape.rank != 4) { - throw IllegalArgumentException("Conv2d weight must be 4D tensor (out_channels, in_channels, kernel_h, kernel_w)") - } - - val batch = inputShape.dimensions[0] - val outChannels = weightShape.dimensions[0] - val inputHeight = inputShape.dimensions[2] - val inputWidth = inputShape.dimensions[3] - val kernelHeight = weightShape.dimensions[2] - val kernelWidth = weightShape.dimensions[3] - - val (strideH, strideW) = stride - val (padH, padW) = padding - val (dilationH, dilationW) = dilation - - val outputHeight = ((inputHeight + 2 * padH - dilationH * (kernelHeight - 1) - 1) / strideH) + 1 - val outputWidth = ((inputWidth + 2 * padW - dilationW * (kernelWidth - 1) - 1) / strideW) + 1 - - return Shape(batch, outChannels, outputHeight, outputWidth) - } + ): Shape = Shape( + ConvShapeUtils.conv2dOutputShape( + inputShape.dimensions, weightShape.dimensions, stride, padding, dilation + ) + ) /** * Calculates the result shape for maxPool2d operation. diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/ConvOperationInferOutputsTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/ConvOperationInferOutputsTest.kt new file mode 100644 index 00000000..9c6c6e22 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/ConvOperationInferOutputsTest.kt @@ -0,0 +1,142 @@ +package sk.ainet.lang.tensor.ops + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import sk.ainet.lang.types.FP32 + +class ConvOperationInferOutputsTest { + + private fun spec(name: String, shape: List?): TensorSpec = + TensorSpec(name = name, shape = shape, dtype = FP32::class.simpleName!!) + + // ----- Conv1d ----- + + @Test + fun conv1d_inferOutputs_uses_weight_shape_and_params() { + val op = Conv1dOperation( + mapOf("stride" to 1, "padding" to 1, "dilation" to 1, "groups" to 1) + ) + val out = op.inferOutputs( + listOf( + spec("input", listOf(1, 80, 3000)), + spec("weight", listOf(384, 80, 3)) + ) + ).single() + + assertEquals(listOf(1, 384, 3000), out.shape) + assertEquals("conv1d_output", out.name) + } + + @Test + fun conv1d_inferOutputs_default_params_when_missing() { + val op = Conv1dOperation() + val out = op.inferOutputs( + listOf( + spec("input", listOf(2, 3, 28)), + spec("weight", listOf(16, 3, 3)) + ) + ).single() + + assertEquals(listOf(2, 16, 26), out.shape) + } + + @Test + fun conv1d_inferOutputs_returns_null_shape_when_input_unknown() { + val op = Conv1dOperation(mapOf("stride" to 1, "padding" to 0)) + val out = op.inferOutputs( + listOf( + spec("input", null), + spec("weight", listOf(16, 3, 3)) + ) + ).single() + + assertNull(out.shape) + } + + @Test + fun conv1d_inferOutputs_returns_null_shape_when_rank_mismatch() { + val op = Conv1dOperation() + val out = op.inferOutputs( + listOf( + spec("input", listOf(28)), + spec("weight", listOf(16, 3, 3)) + ) + ).single() + + assertNull(out.shape) + } + + // ----- Conv2d ----- + + @Test + fun conv2d_inferOutputs_uses_pair_params() { + val op = Conv2dOperation( + mapOf( + "stride" to (2 to 2), + "padding" to (1 to 1), + "dilation" to (1 to 1), + "groups" to 1 + ) + ) + val out = op.inferOutputs( + listOf( + spec("input", listOf(1, 3, 32, 32)), + spec("weight", listOf(16, 3, 3, 3)) + ) + ).single() + + // (32 + 2 - 2 - 1) / 2 + 1 = 16 + assertEquals(listOf(1, 16, 16, 16), out.shape) + } + + @Test + fun conv2d_inferOutputs_accepts_int_param_as_symmetric() { + val op = Conv2dOperation( + mapOf("stride" to 1, "padding" to 0, "dilation" to 1) + ) + val out = op.inferOutputs( + listOf( + spec("input", listOf(1, 3, 28, 28)), + spec("weight", listOf(16, 3, 3, 3)) + ) + ).single() + + assertEquals(listOf(1, 16, 26, 26), out.shape) + } + + // ----- Conv3d ----- + + @Test + fun conv3d_inferOutputs_uses_triple_params() { + val op = Conv3dOperation( + mapOf( + "stride" to Triple(1, 1, 1), + "padding" to Triple(0, 0, 0), + "dilation" to Triple(1, 1, 1), + "groups" to 1 + ) + ) + val out = op.inferOutputs( + listOf( + spec("input", listOf(1, 3, 16, 16, 16)), + spec("weight", listOf(8, 3, 3, 3, 3)) + ) + ).single() + + assertEquals(listOf(1, 8, 14, 14, 14), out.shape) + } + + @Test + fun conv3d_inferOutputs_default_params_when_missing() { + val op = Conv3dOperation() + val out = op.inferOutputs( + listOf( + spec("input", listOf(1, 3, 8, 8, 8)), + spec("weight", listOf(4, 3, 3, 3, 3)) + ) + ).single() + + assertEquals(listOf(1, 4, 6, 6, 6), out.shape) + } +}