diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt index 94cbd4a5..0eedc3f9 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt @@ -17,6 +17,7 @@ import sk.ainet.lang.nn.normalization.LayerNormalization import sk.ainet.lang.nn.topology.MLP import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.ConvShapeUtils import sk.ainet.lang.tensor.ops.UpsampleMode import sk.ainet.lang.types.DType import sk.ainet.context.ExecutionContext @@ -110,6 +111,22 @@ public interface NeuralNetworkDsl : NetworkDslItem { */ public fun input(inputSize: Int, id: String = "", requiresGrad: Boolean = false) + /** + * Declares a multi-dimensional input shape (per-sample, batch dimension excluded). + * + * Required when downstream spatial layers (conv, pool, upsample) need to feed + * a `flatten()` -> `dense()` chain — without this, the DSL cannot know the + * flattened feature count ahead of time. + * + * Example: `input(intArrayOf(1, 28, 28))` declares one-channel 28x28 images. + * + * @param inputShape Per-sample shape (e.g. `[channels, height, width]`) + * @param id Optional identifier for the layer + * @param requiresGrad Whether the input requires gradients (default: false) + */ + public fun input(inputShape: IntArray, id: String = "", requiresGrad: Boolean = false): Unit = + input(inputShape.fold(1) { a, b -> a * b }, id, requiresGrad) + /** * Creates a flatten layer that reshapes multidimensional tensors into 1D. * Useful for transitioning from convolutional to dense layers. @@ -1042,6 +1059,108 @@ public class AvgPool2dImpl( // AttentionImpl moved to skainet-transformers llm-core TransformerDsl.kt // Stage implementation +/** + * Computes the flattened-feature dimension for `flatten(startDim, endDim)` given a + * known per-sample shape (batch dimension excluded). + * + * Tensor shape is conceptually `[batch] + currentShape`. `startDim`/`endDim` follow + * PyTorch semantics on that full tensor shape and may be negative. Returns `null` + * when the span doesn't include any per-sample dimension (e.g. flatten only over + * the batch axis), which leaves the dense feature dimension unchanged. + */ +/** + * Per-sample shape produced by a `MaxPool2d`/`AvgPool2d` given the layer's kernel/stride/padding. + * Returns `null` when the input shape is unknown or not rank-3 `(C, H, W)`. + */ +internal fun pool2dNextShape( + currentShape: IntArray?, + kernelSize: Pair, + stride: Pair, + padding: Pair +): IntArray? { + val shape = currentShape ?: return null + if (shape.size != 3) return null + val out = ConvShapeUtils.pool2dOutputShape( + intArrayOf(1, shape[0], shape[1], shape[2]), + kernelSize, stride, padding + ) + return intArrayOf(out[1], out[2], out[3]) +} + +/** + * Per-sample shape produced by `Upsample2d` given the scale factors. + * Returns `null` when the input shape is unknown or not rank-3 `(C, H, W)`. + */ +internal fun upsample2dNextShape(currentShape: IntArray?, scale: Pair): IntArray? { + val shape = currentShape ?: return null + if (shape.size != 3) return null + val out = ConvShapeUtils.upsample2dOutputShape( + intArrayOf(1, shape[0], shape[1], shape[2]), + scale + ) + return intArrayOf(out[1], out[2], out[3]) +} + +/** + * Per-sample shape produced by Conv1d given the layer's kernel/stride/padding/dilation. + * Returns `null` when the input shape is unknown or not rank-2 `(C_in, L)`. + */ +internal fun Conv1dImpl<*, *>.nextShapeFor(currentShape: IntArray?): IntArray? { + val shape = currentShape ?: return null + if (shape.size != 2) return null + val out = ConvShapeUtils.conv1dOutputShape( + intArrayOf(1, shape[0], shape[1]), + intArrayOf(outChannels, inChannels / groups, kernelSize), + stride, padding, dilation + ) + return intArrayOf(out[1], out[2]) +} + +/** + * Per-sample shape produced by Conv2d given the layer's kernel/stride/padding/dilation. + * Returns `null` when the input shape is unknown or not rank-3 `(C_in, H, W)`. + */ +internal fun Conv2dImpl<*, *>.nextShapeFor(currentShape: IntArray?): IntArray? { + val shape = currentShape ?: return null + if (shape.size != 3) return null + val out = ConvShapeUtils.conv2dOutputShape( + intArrayOf(1, shape[0], shape[1], shape[2]), + intArrayOf(outChannels, inChannels / groups, kernelSize.first, kernelSize.second), + stride, padding, dilation + ) + return intArrayOf(out[1], out[2], out[3]) +} + +/** + * Per-sample shape produced by Conv3d given the layer's kernel/stride/padding/dilation. + * Returns `null` when the input shape is unknown or not rank-4 `(C_in, D, H, W)`. + */ +internal fun Conv3dImpl<*, *>.nextShapeFor(currentShape: IntArray?): IntArray? { + val shape = currentShape ?: return null + if (shape.size != 4) return null + val out = ConvShapeUtils.conv3dOutputShape( + intArrayOf(1, shape[0], shape[1], shape[2], shape[3]), + intArrayOf(outChannels, inChannels / groups, kernelSize.first, kernelSize.second, kernelSize.third), + stride, padding, dilation + ) + return intArrayOf(out[1], out[2], out[3], out[4]) +} + +internal fun flattenedDimensionFor(currentShape: IntArray, startDim: Int, endDim: Int): Int? { + val tensorRank = currentShape.size + 1 + val sd = if (startDim < 0) tensorRank + startDim else startDim + val ed = if (endDim < 0) tensorRank + endDim else endDim + require(sd in 0 until tensorRank && ed in 0 until tensorRank && sd <= ed) { + "Invalid flatten range startDim=$startDim, endDim=$endDim for tensor rank $tensorRank" + } + val perSampleStart = (sd - 1).coerceAtLeast(0) + val perSampleEndInclusive = ed - 1 + if (perSampleEndInclusive < 0 || perSampleStart > perSampleEndInclusive) return null + var product = 1 + for (i in perSampleStart..perSampleEndInclusive) product *= currentShape[i] + return product +} + public class StageImpl( override val executionContext: ExecutionContext, private val id: String, @@ -1050,11 +1169,21 @@ public class StageImpl( public val modules: MutableList> = mutableListOf>() public var lastDimension: Int = 0 public var inputDimension: Int = 0 + /** Per-sample shape (batch excluded) tracked through spatial layers; `null` if unknown. */ + public var currentShape: IntArray? = null public fun create(): Module = MLP(*modules.toTypedArray(), name = id) override fun input(inputSize: Int, id: String, requiresGrad: Boolean) { lastDimension = inputSize + currentShape = intArrayOf(inputSize) + modules.add(Input(name = getDefaultName(id, "Input", modules.size), requiresGrad = requiresGrad)) + } + + override fun input(inputShape: IntArray, id: String, requiresGrad: Boolean) { + require(inputShape.isNotEmpty()) { "input(inputShape) requires at least one dimension" } + currentShape = inputShape.copyOf() + lastDimension = inputShape.fold(1) { a, b -> a * b } modules.add(Input(name = getDefaultName(id, "Input", modules.size), requiresGrad = requiresGrad)) } @@ -1065,22 +1194,20 @@ public class StageImpl( ) impl.content() modules += impl.create() - // For flatten, we need to calculate the flattened size - // This is a simple approach - assume we're flattening from start_dim=1 (keeping batch dimension) - // The lastDimension should be set based on actual tensor dimensions, but for now - // we'll use a placeholder approach that works with typical CNN architectures - // TODO: Implement proper shape inference based on actual input dimensions - if (lastDimension == 0) { - // Fallback for the MNIST CNN test case with input (1,1,28,28) - // After conv1(16ch) + pool -> conv2(32ch) + pool, we get (1,32,7,7) - // Flattening from dim 1 gives size 32*7*7 = 1568 - lastDimension = 1568 // TODO: calculate from tracked shapes + val shape = currentShape + val inferred = if (shape != null) flattenedDimensionFor(shape, impl.startDim, impl.endDim) else null + if (inferred != null) { + lastDimension = inferred + currentShape = intArrayOf(inferred) } + // If we couldn't infer (no input shape declared), leave lastDimension untouched. + // A subsequent dense() will surface the missing dimension as a clear error. } override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) { val inputDimension = lastDimension lastDimension = outputDimension + currentShape = intArrayOf(outputDimension) val impl = DenseImpl( executionContext, inputDimension = inputDimension, @@ -1105,6 +1232,7 @@ public class StageImpl( impl.content() // Update lastDimension based on the units set in the content block lastDimension = impl.outputDimension + currentShape = intArrayOf(impl.outputDimension) // dense layer consists of linear module and activation function module (2 modules) modules += impl.create() } @@ -1116,16 +1244,20 @@ public class StageImpl( override fun sequential(content: NeuralNetworkDsl.() -> Unit) { val sequentialImpl = NeuralNetworkDslImpl(executionContext, kClass) sequentialImpl.lastDimension = lastDimension + sequentialImpl.currentShape = currentShape?.copyOf() sequentialImpl.content() lastDimension = sequentialImpl.lastDimension + currentShape = sequentialImpl.currentShape?.copyOf() modules += sequentialImpl.create() } override fun stage(id: String, content: NeuralNetworkDsl.() -> Unit) { val stageImpl = StageImpl(executionContext, id, kClass) stageImpl.lastDimension = lastDimension + stageImpl.currentShape = currentShape?.copyOf() stageImpl.content() lastDimension = stageImpl.lastDimension + currentShape = stageImpl.currentShape?.copyOf() modules += stageImpl.create() } @@ -1220,7 +1352,7 @@ public class StageImpl( // Create Conv2dImpl with default inChannels=1, can be modified via DSL val conv2dImpl = Conv2dImpl( executionContext, - initialInChannels = 1, // Default value, can be overridden in content block + initialInChannels = currentShape?.takeIf { it.size == 3 }?.get(0) ?: 1, initialOutChannels = outChannels, initialKernelSize = kernelSize, initialStride = stride, @@ -1237,6 +1369,7 @@ public class StageImpl( // Create and add the Conv2d module modules.add(conv2dImpl.create()) + currentShape = conv2dImpl.nextShapeFor(currentShape) } override fun conv2d( @@ -1245,7 +1378,7 @@ public class StageImpl( ) { val conv2dImpl = Conv2dImpl( executionContext = executionContext, - initialInChannels = 1, + initialInChannels = currentShape?.takeIf { it.size == 3 }?.get(0) ?: 1, initialOutChannels = 1, initialKernelSize = 1 to 1, initialStride = 1 to 1, @@ -1258,6 +1391,7 @@ public class StageImpl( ) conv2dImpl.content() modules.add(conv2dImpl.create()) + currentShape = conv2dImpl.nextShapeFor(currentShape) } @@ -1273,6 +1407,7 @@ public class StageImpl( padding = padding, name = getDefaultName(id, "MaxPool2d", modules.size) ) + currentShape = pool2dNextShape(currentShape, kernelSize, stride, padding) } override fun maxPool2d( @@ -1288,6 +1423,7 @@ public class StageImpl( ) impl.content() modules += impl.create() + currentShape = pool2dNextShape(currentShape, impl.kernelSize, impl.stride, impl.padding) } override fun upsample2d( @@ -1302,6 +1438,7 @@ public class StageImpl( alignCorners = alignCorners, name = getDefaultName(id, "Upsample2d", modules.size) ) + currentShape = upsample2dNextShape(currentShape, scale) } override fun upsample2d(id: String, content: UPSAMPLE2D.() -> Unit) { @@ -1314,6 +1451,7 @@ public class StageImpl( ) impl.content() modules += impl.create() + currentShape = upsample2dNextShape(currentShape, impl.scale) } override fun conv1d( @@ -1329,7 +1467,7 @@ public class StageImpl( ) { val conv1dImpl = Conv1dImpl( executionContext = executionContext, - initialInChannels = 1, + initialInChannels = currentShape?.takeIf { it.size == 2 }?.get(0) ?: 1, initialOutChannels = outChannels, initialKernelSize = kernelSize, initialStride = stride, @@ -1342,6 +1480,7 @@ public class StageImpl( ) conv1dImpl.content() modules.add(conv1dImpl.create()) + currentShape = conv1dImpl.nextShapeFor(currentShape) } override fun conv3d( @@ -1357,7 +1496,7 @@ public class StageImpl( ) { val conv3dImpl = Conv3dImpl( executionContext = executionContext, - initialInChannels = 1, + initialInChannels = currentShape?.takeIf { it.size == 4 }?.get(0) ?: 1, initialOutChannels = outChannels, initialKernelSize = kernelSize, initialStride = stride, @@ -1370,6 +1509,7 @@ public class StageImpl( ) conv3dImpl.content() modules.add(conv3dImpl.create()) + currentShape = conv3dImpl.nextShapeFor(currentShape) } override fun avgPool2d( @@ -1386,6 +1526,7 @@ public class StageImpl( countIncludePad = countIncludePad, name = getDefaultName(id, "AvgPool2d", modules.size) ) + currentShape = pool2dNextShape(currentShape, kernelSize, stride, padding) } override fun softmax(dim: Int, id: String) { @@ -1403,11 +1544,21 @@ public class NeuralNetworkDslImpl( public val modules: MutableList> = mutableListOf>() public var lastDimension: Int = 0 + /** Per-sample shape (batch excluded) tracked through spatial layers; `null` if unknown. */ + public var currentShape: IntArray? = null public fun create(): Module = NetworkBuilder().add(*modules.toTypedArray()).build() override fun input(inputSize: Int, id: String, requiresGrad: Boolean) { lastDimension = inputSize + currentShape = intArrayOf(inputSize) + modules.add(Input(name = getDefaultName(id, "Input", modules.size), requiresGrad = requiresGrad)) + } + + override fun input(inputShape: IntArray, id: String, requiresGrad: Boolean) { + require(inputShape.isNotEmpty()) { "input(inputShape) requires at least one dimension" } + currentShape = inputShape.copyOf() + lastDimension = inputShape.fold(1) { a, b -> a * b } modules.add(Input(name = getDefaultName(id, "Input", modules.size), requiresGrad = requiresGrad)) } @@ -1419,22 +1570,20 @@ public class NeuralNetworkDslImpl( ) impl.content() modules += impl.create() - // For flatten, we need to calculate the flattened size - // This is a simple approach - assume we're flattening from start_dim=1 (keeping batch dimension) - // The lastDimension should be set based on actual tensor dimensions, but for now - // we'll use a placeholder approach that works with typical CNN architectures - // TODO: Implement proper shape inference based on actual input dimensions - if (lastDimension == 0) { - // Fallback for the MNIST CNN test case with input (1,1,28,28) - // After conv1(16ch) + pool -> conv2(32ch) + pool, we get (1,32,7,7) - // Flattening from dim 1 gives size 32*7*7 = 1568 - lastDimension = 1568 // TODO: calculate from tracked shapes + val shape = currentShape + val inferred = if (shape != null) flattenedDimensionFor(shape, impl.startDim, impl.endDim) else null + if (inferred != null) { + lastDimension = inferred + currentShape = intArrayOf(inferred) } + // If we couldn't infer (no input shape declared), leave lastDimension untouched. + // A subsequent dense() will surface the missing dimension as a clear error. } override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) { val inputDimension = lastDimension lastDimension = outputDimension + currentShape = intArrayOf(outputDimension) val impl = DenseImpl( executionContext = executionContext, inputDimension = inputDimension, @@ -1459,6 +1608,7 @@ public class NeuralNetworkDslImpl( impl.content() // Update lastDimension based on the units set in the content block lastDimension = impl.outputDimension + currentShape = intArrayOf(impl.outputDimension) // dense layer consists of linear module and activation function module (2 modules) modules += impl.create() } @@ -1470,16 +1620,20 @@ public class NeuralNetworkDslImpl( override fun sequential(content: NeuralNetworkDsl.() -> Unit) { val sequentialImpl = NeuralNetworkDslImpl(executionContext, kClass) sequentialImpl.lastDimension = lastDimension + sequentialImpl.currentShape = currentShape?.copyOf() sequentialImpl.content() lastDimension = sequentialImpl.lastDimension + currentShape = sequentialImpl.currentShape?.copyOf() modules += sequentialImpl.create() } override fun stage(id: String, content: NeuralNetworkDsl.() -> Unit) { val stageImpl = StageImpl(executionContext, id, kClass) stageImpl.lastDimension = lastDimension + stageImpl.currentShape = currentShape?.copyOf() stageImpl.content() lastDimension = stageImpl.lastDimension + currentShape = stageImpl.currentShape?.copyOf() modules += stageImpl.create() } @@ -1574,7 +1728,7 @@ public class NeuralNetworkDslImpl( // Create Conv2dImpl with default inChannels=1, can be modified via DSL val conv2dImpl = Conv2dImpl( executionContext = executionContext, - initialInChannels = 1, // Default value, can be overridden in content block + initialInChannels = currentShape?.takeIf { it.size == 3 }?.get(0) ?: 1, initialOutChannels = outChannels, initialKernelSize = kernelSize, initialStride = stride, @@ -1591,6 +1745,7 @@ public class NeuralNetworkDslImpl( // Create and add the Conv2d module modules.add(conv2dImpl.create()) + currentShape = conv2dImpl.nextShapeFor(currentShape) } override fun conv2d( @@ -1599,7 +1754,7 @@ public class NeuralNetworkDslImpl( ) { val conv2dImpl = Conv2dImpl( executionContext = executionContext, - initialInChannels = 1, + initialInChannels = currentShape?.takeIf { it.size == 3 }?.get(0) ?: 1, initialOutChannels = 1, initialKernelSize = 1 to 1, initialStride = 1 to 1, @@ -1612,6 +1767,7 @@ public class NeuralNetworkDslImpl( ) conv2dImpl.content() modules.add(conv2dImpl.create()) + currentShape = conv2dImpl.nextShapeFor(currentShape) } override fun maxPool2d( @@ -1626,6 +1782,7 @@ public class NeuralNetworkDslImpl( padding = padding, name = getDefaultName(id, "MaxPool2d", modules.size) ) + currentShape = pool2dNextShape(currentShape, kernelSize, stride, padding) } override fun maxPool2d( @@ -1641,6 +1798,7 @@ public class NeuralNetworkDslImpl( ) impl.content() modules.add(impl.create()) + currentShape = pool2dNextShape(currentShape, impl.kernelSize, impl.stride, impl.padding) } override fun upsample2d( @@ -1655,6 +1813,7 @@ public class NeuralNetworkDslImpl( alignCorners = alignCorners, name = getDefaultName(id, "Upsample2d", modules.size) ) + currentShape = upsample2dNextShape(currentShape, scale) } override fun upsample2d( @@ -1670,6 +1829,7 @@ public class NeuralNetworkDslImpl( ) impl.content() modules.add(impl.create()) + currentShape = upsample2dNextShape(currentShape, impl.scale) } override fun conv1d( @@ -1685,7 +1845,7 @@ public class NeuralNetworkDslImpl( ) { val conv1dImpl = Conv1dImpl( executionContext = executionContext, - initialInChannels = 1, + initialInChannels = currentShape?.takeIf { it.size == 2 }?.get(0) ?: 1, initialOutChannels = outChannels, initialKernelSize = kernelSize, initialStride = stride, @@ -1698,6 +1858,7 @@ public class NeuralNetworkDslImpl( ) conv1dImpl.content() modules.add(conv1dImpl.create()) + currentShape = conv1dImpl.nextShapeFor(currentShape) } override fun conv3d( @@ -1713,7 +1874,7 @@ public class NeuralNetworkDslImpl( ) { val conv3dImpl = Conv3dImpl( executionContext = executionContext, - initialInChannels = 1, + initialInChannels = currentShape?.takeIf { it.size == 4 }?.get(0) ?: 1, initialOutChannels = outChannels, initialKernelSize = kernelSize, initialStride = stride, @@ -1726,6 +1887,7 @@ public class NeuralNetworkDslImpl( ) conv3dImpl.content() modules.add(conv3dImpl.create()) + currentShape = conv3dImpl.nextShapeFor(currentShape) } override fun avgPool2d( @@ -1742,6 +1904,7 @@ public class NeuralNetworkDslImpl( countIncludePad = countIncludePad, name = getDefaultName(id, "AvgPool2d", modules.size) ) + currentShape = pool2dNextShape(currentShape, kernelSize, stride, padding) } override fun softmax(dim: Int, id: String) { diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt index 2b5b5cdd..3d37843f 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/ConvShapeUtils.kt @@ -95,4 +95,45 @@ public object ConvShapeUtils { val outW = (inW + 2 * padding.third - dilation.third * (kW - 1) - 1) / stride.third + 1 return intArrayOf(batch, outChannels, outD, outH, outW) } + + /** + * 2D pool (max/avg) output shape: input `(batch, channels, height, width)`, + * result `(batch, channels, out_h, out_w)`. Channels are preserved. + */ + public fun pool2dOutputShape( + inputShape: IntArray, + kernelSize: Pair, + stride: Pair, + padding: Pair + ): IntArray { + require(inputShape.size == 4) { + "Pool2d input must be rank 4 (batch, channels, height, width), got rank ${inputShape.size}" + } + val batch = inputShape[0] + val channels = inputShape[1] + val inH = inputShape[2] + val inW = inputShape[3] + val outH = (inH + 2 * padding.first - kernelSize.first) / stride.first + 1 + val outW = (inW + 2 * padding.second - kernelSize.second) / stride.second + 1 + return intArrayOf(batch, channels, outH, outW) + } + + /** + * 2D upsample output shape: input `(batch, channels, height, width)`, + * result `(batch, channels, height * scale_h, width * scale_w)`. + */ + public fun upsample2dOutputShape( + inputShape: IntArray, + scale: Pair + ): IntArray { + require(inputShape.size == 4) { + "Upsample2d input must be rank 4 (batch, channels, height, width), got rank ${inputShape.size}" + } + return intArrayOf( + inputShape[0], + inputShape[1], + inputShape[2] * scale.first, + inputShape[3] * scale.second + ) + } } diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/dsl/CnnShapeInferenceTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/dsl/CnnShapeInferenceTest.kt new file mode 100644 index 00000000..2d19f0d2 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/dsl/CnnShapeInferenceTest.kt @@ -0,0 +1,132 @@ +package sk.ainet.lang.nn.dsl + +import sk.ainet.lang.types.FP32 +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +/** + * Verifies that the DSL tracks per-sample shapes through conv/pool/upsample/flatten + * so that a downstream `dense()` receives the correct input dimension. + * + * Before #535 was fixed, `flatten()` fell back to a hardcoded `1568` (the MNIST CNN + * value), which broke any other architecture. + */ +class CnnShapeInferenceTest { + + @Test + fun input_intArray_sets_currentShape_and_flat_lastDimension() { + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.input(intArrayOf(1, 28, 28)) + assertContentEquals(intArrayOf(1, 28, 28), builder.currentShape) + assertEquals(1 * 28 * 28, builder.lastDimension) + } + + @Test + fun mnist_cnn_chain_infers_1568_after_flatten() { + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.apply { + input(intArrayOf(1, 28, 28)) + // 28 + 2*2 - (5-1) - 1 + 1 = 28 + conv2d(outChannels = 16, kernelSize = 5 to 5, stride = 1 to 1, padding = 2 to 2) + // (28 + 0 - 2)/2 + 1 = 14 + maxPool2d(kernelSize = 2 to 2, stride = 2 to 2) + conv2d(outChannels = 32, kernelSize = 5 to 5, stride = 1 to 1, padding = 2 to 2) + // 14 -> 7 + maxPool2d(kernelSize = 2 to 2, stride = 2 to 2) + } + assertContentEquals(intArrayOf(32, 7, 7), builder.currentShape) + builder.flatten() + assertEquals(32 * 7 * 7, builder.lastDimension) + } + + @Test + fun custom_64_channel_cnn_does_not_collide_with_old_1568() { + // The architecture from issue #535 that used to crash because flatten() hardcoded 1568. + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.apply { + input(intArrayOf(3, 32, 32)) + conv2d(outChannels = 32, kernelSize = 3 to 3, padding = 1 to 1) + // 32 + 0 - (2) ) / 2 + 1 = 16 + maxPool2d(kernelSize = 2 to 2) + conv2d(outChannels = 64, kernelSize = 3 to 3, padding = 1 to 1) + maxPool2d(kernelSize = 2 to 2) // 8 + } + assertContentEquals(intArrayOf(64, 8, 8), builder.currentShape) + builder.flatten() + assertEquals(64 * 8 * 8, builder.lastDimension) + } + + @Test + fun conv1d_chain_tracks_length_correctly() { + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.apply { + input(intArrayOf(80, 3000)) // (channels, length) — Whisper-style mel input + conv1d(outChannels = 384, kernelSize = 3, padding = 1) + } + // (3000 + 2 - 2 - 1)/1 + 1 = 3000 + assertContentEquals(intArrayOf(384, 3000), builder.currentShape) + } + + @Test + fun upsample2d_doubles_spatial_dims() { + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.apply { + input(intArrayOf(8, 16, 16)) + upsample2d(scale = 2 to 2) + } + assertContentEquals(intArrayOf(8, 32, 32), builder.currentShape) + } + + @Test + fun avgPool2d_reduces_spatial_dims() { + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.apply { + input(intArrayOf(16, 32, 32)) + avgPool2d(kernelSize = 4 to 4, stride = 4 to 4) + } + assertContentEquals(intArrayOf(16, 8, 8), builder.currentShape) + } + + @Test + fun stage_propagates_shape_in_and_out() { + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.apply { + input(intArrayOf(1, 28, 28)) + stage("conv1") { + conv2d(outChannels = 16, kernelSize = 5 to 5, padding = 2 to 2) + maxPool2d(kernelSize = 2 to 2) + } + } + assertContentEquals(intArrayOf(16, 14, 14), builder.currentShape) + } + + @Test + fun flatten_without_input_shape_leaves_lastDimension_untouched() { + // Backward-compat: building a sequential with bare flatten (no input) must + // not throw, because some tests use it as a runtime-only module. + val builder = NeuralNetworkDslImpl( + DefaultNetworkExecutionContext, FP32::class + ) + builder.flatten() + assertEquals(0, builder.lastDimension) + assertNotNull(builder.modules.firstOrNull()) + } +} + +private val DefaultNetworkExecutionContext = sk.ainet.lang.nn.DefaultNeuralNetworkExecutionContext() diff --git a/skainet-lang/skainet-lang-models/src/commonMain/kotlin/sk/ainet/lang/model/dnn/cnn/MNIST.kt b/skainet-lang/skainet-lang-models/src/commonMain/kotlin/sk/ainet/lang/model/dnn/cnn/MNIST.kt index 92b3032d..802a84cf 100644 --- a/skainet-lang/skainet-lang-models/src/commonMain/kotlin/sk/ainet/lang/model/dnn/cnn/MNIST.kt +++ b/skainet-lang/skainet-lang-models/src/commonMain/kotlin/sk/ainet/lang/model/dnn/cnn/MNIST.kt @@ -54,6 +54,11 @@ public class MnistCnn : Model, Tensor = definition { network(executionContext) { sequential { + // Per-sample input shape (channels, height, width). Required for the + // DSL to track shapes through conv/pool stages and infer the flatten + // size before the dense layers. + input(intArrayOf(1, 28, 28)) + // Stage: "conv1" stage("conv1") { conv2d(outChannels = 16, kernelSize = 5 to 5, stride = 1 to 1, padding = 2 to 2)