Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package sk.ainet.lang.tensor.ops

/**
* Single source of truth for convolution output-shape arithmetic.
*
* Both eager execution (`VoidTensorOps`) and graph emission
* (`Conv{1,2,3}dOperation.inferOutputs`) must produce identical shapes;
* keeping the formula here prevents the two paths from diverging.
*/
public object ConvShapeUtils {

/**
* Conv1d output shape: input `(batch, in_channels, length)`,
* weight `(out_channels, in_channels_per_group, kernel_length)`,
* result `(batch, out_channels, out_length)`.
*/
public fun conv1dOutputShape(
inputShape: IntArray,
weightShape: IntArray,
stride: Int,
padding: Int,
dilation: Int
): IntArray {
require(inputShape.size == 3) {
"Conv1d input must be rank 3 (batch, channels, length), got rank ${inputShape.size}"
}
require(weightShape.size == 3) {
"Conv1d weight must be rank 3 (out_channels, in_channels, kernel_length), got rank ${weightShape.size}"
}
val batch = inputShape[0]
val outChannels = weightShape[0]
val inLength = inputShape[2]
val kernel = weightShape[2]
val outLength = (inLength + 2 * padding - dilation * (kernel - 1) - 1) / stride + 1
return intArrayOf(batch, outChannels, outLength)
}

/**
* Conv2d output shape: input `(batch, in_channels, height, width)`,
* weight `(out_channels, in_channels_per_group, kernel_h, kernel_w)`,
* result `(batch, out_channels, out_h, out_w)`.
*/
public fun conv2dOutputShape(
inputShape: IntArray,
weightShape: IntArray,
stride: Pair<Int, Int>,
padding: Pair<Int, Int>,
dilation: Pair<Int, Int>
): IntArray {
require(inputShape.size == 4) {
"Conv2d input must be rank 4 (batch, channels, height, width), got rank ${inputShape.size}"
}
require(weightShape.size == 4) {
"Conv2d weight must be rank 4 (out_channels, in_channels, kernel_h, kernel_w), got rank ${weightShape.size}"
}
val batch = inputShape[0]
val outChannels = weightShape[0]
val inH = inputShape[2]
val inW = inputShape[3]
val kH = weightShape[2]
val kW = weightShape[3]
val outH = (inH + 2 * padding.first - dilation.first * (kH - 1) - 1) / stride.first + 1
val outW = (inW + 2 * padding.second - dilation.second * (kW - 1) - 1) / stride.second + 1
return intArrayOf(batch, outChannels, outH, outW)
}

/**
* Conv3d output shape: input `(batch, in_channels, depth, height, width)`,
* weight `(out_channels, in_channels_per_group, kernel_d, kernel_h, kernel_w)`,
* result `(batch, out_channels, out_d, out_h, out_w)`.
*/
public fun conv3dOutputShape(
inputShape: IntArray,
weightShape: IntArray,
stride: Triple<Int, Int, Int>,
padding: Triple<Int, Int, Int>,
dilation: Triple<Int, Int, Int>
): IntArray {
require(inputShape.size == 5) {
"Conv3d input must be rank 5 (batch, channels, depth, height, width), got rank ${inputShape.size}"
}
require(weightShape.size == 5) {
"Conv3d weight must be rank 5 (out_channels, in_channels, kernel_d, kernel_h, kernel_w), got rank ${weightShape.size}"
}
val batch = inputShape[0]
val outChannels = weightShape[0]
val inD = inputShape[2]
val inH = inputShape[3]
val inW = inputShape[4]
val kD = weightShape[2]
val kH = weightShape[3]
val kW = weightShape[4]
val outD = (inD + 2 * padding.first - dilation.first * (kD - 1) - 1) / stride.first + 1
val outH = (inH + 2 * padding.second - dilation.second * (kH - 1) - 1) / stride.second + 1
val outW = (inW + 2 * padding.third - dilation.third * (kW - 1) - 1) / stride.third + 1
return intArrayOf(batch, outChannels, outD, outH, outW)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,18 @@ public class Conv1dOperation<T : DType, V>(

override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
require(inputs.size >= 2) { "Conv1d operation requires at least 2 inputs" }
val outputShape = inputs[0].shape
val inShape = inputs[0].shape
val wShape = inputs[1].shape
val stride = (parameters["stride"] as? Int) ?: 1
val padding = (parameters["padding"] as? Int) ?: 0
val dilation = (parameters["dilation"] as? Int) ?: 1
val outputShape = if (inShape != null && wShape != null && inShape.size == 3 && wShape.size == 3) {
ConvShapeUtils.conv1dOutputShape(
inShape.toIntArray(), wShape.toIntArray(), stride, padding, dilation
).toList()
} else {
null
}
return listOf(
TensorSpec(
name = "conv1d_output",
Expand Down Expand Up @@ -470,7 +481,18 @@ public class Conv2dOperation<T : DType, V>(

override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
require(inputs.size >= 2) { "Conv2d operation requires at least 2 inputs" }
val outputShape = inputs[0].shape
val inShape = inputs[0].shape
val wShape = inputs[1].shape
val stride = pairParam("stride", 1)
val padding = pairParam("padding", 0)
val dilation = pairParam("dilation", 1)
val outputShape = if (inShape != null && wShape != null && inShape.size == 4 && wShape.size == 4) {
ConvShapeUtils.conv2dOutputShape(
inShape.toIntArray(), wShape.toIntArray(), stride, padding, dilation
).toList()
} else {
null
}
return listOf(
TensorSpec(
name = "conv2d_output",
Expand All @@ -481,6 +503,16 @@ public class Conv2dOperation<T : DType, V>(
)
}

private fun pairParam(name: String, default: Int): Pair<Int, Int> {
val raw = parameters[name] ?: return default to default
@Suppress("UNCHECKED_CAST")
return when (raw) {
is Pair<*, *> -> raw as Pair<Int, Int>
is Int -> raw to raw
else -> default to default
}
}

override fun clone(newParameters: Map<String, Any>): Operation = Conv2dOperation<T, V>(newParameters)
}

Expand All @@ -502,7 +534,18 @@ public class Conv3dOperation<T : DType, V>(

override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
require(inputs.size >= 2) { "Conv3d operation requires at least 2 inputs" }
val outputShape = inputs[0].shape
val inShape = inputs[0].shape
val wShape = inputs[1].shape
val stride = tripleParam("stride", 1)
val padding = tripleParam("padding", 0)
val dilation = tripleParam("dilation", 1)
val outputShape = if (inShape != null && wShape != null && inShape.size == 5 && wShape.size == 5) {
ConvShapeUtils.conv3dOutputShape(
inShape.toIntArray(), wShape.toIntArray(), stride, padding, dilation
).toList()
} else {
null
}
return listOf(
TensorSpec(
name = "conv3d_output",
Expand All @@ -513,6 +556,16 @@ public class Conv3dOperation<T : DType, V>(
)
}

private fun tripleParam(name: String, default: Int): Triple<Int, Int, Int> {
val raw = parameters[name] ?: return Triple(default, default, default)
@Suppress("UNCHECKED_CAST")
return when (raw) {
is Triple<*, *, *> -> raw as Triple<Int, Int, Int>
is Int -> Triple(raw, raw, raw)
else -> Triple(default, default, default)
}
}

override fun clone(newParameters: Map<String, Any>): Operation = Conv3dOperation<T, V>(newParameters)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -750,23 +750,11 @@ public class VoidTensorOps : TensorOps {
stride: Int,
padding: Int,
dilation: Int
): Shape {
if (inputShape.rank != 3) {
throw IllegalArgumentException("Conv1d input must be 3D tensor (batch, channels, length)")
}
if (weightShape.rank != 3) {
throw IllegalArgumentException("Conv1d weight must be 3D tensor (out_channels, in_channels, kernel_length)")
}

val batch = inputShape.dimensions[0]
val outChannels = weightShape.dimensions[0]
val inputLength = inputShape.dimensions[2]
val kernelLength = weightShape.dimensions[2]

val outputLength = ((inputLength + 2 * padding - dilation * (kernelLength - 1) - 1) / stride) + 1

return Shape(batch, outChannels, outputLength)
}
): Shape = Shape(
ConvShapeUtils.conv1dOutputShape(
inputShape.dimensions, weightShape.dimensions, stride, padding, dilation
)
)

/**
* Calculates the result shape for conv3d operation.
Expand All @@ -780,33 +768,11 @@ public class VoidTensorOps : TensorOps {
stride: Triple<Int, Int, Int>,
padding: Triple<Int, Int, Int>,
dilation: Triple<Int, Int, Int>
): Shape {
if (inputShape.rank != 5) {
throw IllegalArgumentException("Conv3d input must be 5D tensor (batch, channels, depth, height, width)")
}
if (weightShape.rank != 5) {
throw IllegalArgumentException("Conv3d weight must be 5D tensor (out_channels, in_channels, kernel_d, kernel_h, kernel_w)")
}

val batch = inputShape.dimensions[0]
val outChannels = weightShape.dimensions[0]
val inputDepth = inputShape.dimensions[2]
val inputHeight = inputShape.dimensions[3]
val inputWidth = inputShape.dimensions[4]
val kernelDepth = weightShape.dimensions[2]
val kernelHeight = weightShape.dimensions[3]
val kernelWidth = weightShape.dimensions[4]

val (strideD, strideH, strideW) = stride
val (padD, padH, padW) = padding
val (dilationD, dilationH, dilationW) = dilation

val outputDepth = ((inputDepth + 2 * padD - dilationD * (kernelDepth - 1) - 1) / strideD) + 1
val outputHeight = ((inputHeight + 2 * padH - dilationH * (kernelHeight - 1) - 1) / strideH) + 1
val outputWidth = ((inputWidth + 2 * padW - dilationW * (kernelWidth - 1) - 1) / strideW) + 1

return Shape(batch, outChannels, outputDepth, outputHeight, outputWidth)
}
): Shape = Shape(
ConvShapeUtils.conv3dOutputShape(
inputShape.dimensions, weightShape.dimensions, stride, padding, dilation
)
)

/**
* Calculates the result shape for convTranspose1d operation.
Expand All @@ -832,35 +798,16 @@ public class VoidTensorOps : TensorOps {
* Output shape: (batch, out_channels, out_height, out_width)
*/
private fun calculateConv2dShape(
inputShape: Shape,
weightShape: Shape,
stride: Pair<Int, Int>,
padding: Pair<Int, Int>,
inputShape: Shape,
weightShape: Shape,
stride: Pair<Int, Int>,
padding: Pair<Int, Int>,
dilation: Pair<Int, Int>
): Shape {
if (inputShape.rank != 4) {
throw IllegalArgumentException("Conv2d input must be 4D tensor (batch, channels, height, width)")
}
if (weightShape.rank != 4) {
throw IllegalArgumentException("Conv2d weight must be 4D tensor (out_channels, in_channels, kernel_h, kernel_w)")
}

val batch = inputShape.dimensions[0]
val outChannels = weightShape.dimensions[0]
val inputHeight = inputShape.dimensions[2]
val inputWidth = inputShape.dimensions[3]
val kernelHeight = weightShape.dimensions[2]
val kernelWidth = weightShape.dimensions[3]

val (strideH, strideW) = stride
val (padH, padW) = padding
val (dilationH, dilationW) = dilation

val outputHeight = ((inputHeight + 2 * padH - dilationH * (kernelHeight - 1) - 1) / strideH) + 1
val outputWidth = ((inputWidth + 2 * padW - dilationW * (kernelWidth - 1) - 1) / strideW) + 1

return Shape(batch, outChannels, outputHeight, outputWidth)
}
): Shape = Shape(
ConvShapeUtils.conv2dOutputShape(
inputShape.dimensions, weightShape.dimensions, stride, padding, dilation
)
)

/**
* Calculates the result shape for maxPool2d operation.
Expand Down
Loading
Loading