Skip to content

Track per-sample shape through DSL spatial layers (#535)#538

Merged
michalharakal merged 1 commit intodevelopfrom
fix/535-flatten-shape-inference
Apr 21, 2026
Merged

Track per-sample shape through DSL spatial layers (#535)#538
michalharakal merged 1 commit intodevelopfrom
fix/535-flatten-shape-inference

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #535.

Stacked on #537. Once #537 lands on develop, the base will retarget
automatically and the diff will collapse to just this PR's changes.

Summary

  • Add a per-sample currentShape: IntArray? tracker to StageImpl and NeuralNetworkDslImpl so flatten() knows the real spatial shape instead of guessing.
  • Add an input(inputShape: IntArray, ...) overload — required to seed the tracker for any CNN architecture.
  • Wire conv1d/2d/3d, maxPool2d, avgPool2d, and upsample2d to update currentShape using ConvShapeUtils (same single-source-of-truth helper introduced in Compute real conv output shapes in graph operations (#536) #537, extended here with pool2dOutputShape and upsample2dOutputShape).
  • Make flatten() compute lastDimension from currentShape honoring startDim/endDim. The hardcoded 1568 fallback is gone. When no input shape is declared, lastDimension is left untouched so existing flatten-only runtime tests keep building; a downstream dense() will surface the gap.
  • Conv* layers now seed inChannels from the tracked input shape (when rank matches), so the user no longer has to repeat inChannels = N after declaring a multi-dim input.
  • Update MnistCnn to declare input(intArrayOf(1, 28, 28)) since the magic constant is gone.

Why now

This builds on the conv shape-inference work from #530, #532, and #537. With those three landing the eager + graph paths for conv shape math, flatten() was the last shape-inference gap visible from the DSL. Issue #535 reported it crashing custom architectures with ArrayIndexOutOfBoundsException.

Test plan

  • ./gradlew :skainet-lang:skainet-lang-core:jvmTest (incl. new CnnShapeInferenceTest, 8 cases — MNIST, custom 64-channel CNN, conv1d Whisper-style, upsample, avgPool, stage propagation, backward-compat for bare flatten)
  • ./gradlew :skainet-lang:skainet-lang-models:jvmTest (MNIST CNN model now uses input(intArrayOf(1, 28, 28)))
  • ./gradlew :skainet-compile:skainet-compile-core:jvmTest
  • ./gradlew :skainet-compile:skainet-compile-hlo:jvmTest
  • Verify the 64-channel CNN reported in Fix: Implement dynamic shape for Flatten layer to replace hardcoded fallback #535 builds and runs end-to-end without ArrayIndexOutOfBoundsException

🤖 Generated with Claude Code

The DSL's flatten() previously fell back to a hardcoded lastDimension =
1568 (the value that happens to fit the MNIST CNN reference model).
Any other architecture - e.g. a 64-channel CNN over 32x32 inputs - hit
ArrayIndexOutOfBounds in the following dense layer.

Add a per-sample shape tracker (currentShape: IntArray?) to StageImpl
and NeuralNetworkDslImpl, plus a new input(intArrayOf(...)) overload
that seeds it. conv1d/2d/3d, maxPool2d, avgPool2d, and upsample2d now
update currentShape using the same arithmetic as VoidTensorOps via
ConvShapeUtils (extended with pool2d and upsample2d helpers, building
on the helper introduced in #537). flatten() reads currentShape and
honors startDim / endDim instead of guessing 1568. When no input shape
is declared we leave lastDimension untouched so existing flatten-only
runtime tests keep working - dense() will surface the gap with a clear
error if it actually matters.

Update MnistCnn to declare input(intArrayOf(1, 28, 28)) so it works
under the new shape inference instead of relying on the magic constant.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Base automatically changed from fix/536-conv-inferoutputs-shape to develop April 19, 2026 20:14
@michalharakal michalharakal merged commit 0aa9ba3 into develop Apr 21, 2026
5 checks passed
michalharakal added a commit that referenced this pull request Apr 21, 2026
Bump VERSION_NAME to 0.19.0 in the root gradle.properties, expand the
CHANGELOG [0.19.0] - 2026-04-20 section to cover the full 130 commits
since 0.18.0 — not just the tokenizer work but the StableHLO → IREE
lowering pipeline (softmax/layerNorm/rmsnorm real lowerings, gather/
embedding/concat/slice/cast converters, ConstantMaterializationPolicy,
dense<v> splat folding, SSA type tracking), the new skainet-io-iree-
params IrpaWriter, skainet-backend-api module, Antora docs migration
with Diátaxis layout, Java API polish (#400), androidNativeArm32
target, and the graph/DSL shape-inference fixes (#535, #536, #537,
#538) that unblock non-MNIST CNN architectures and Whisper-encoder
HLO compilation. Refresh the README install snippet and "What's New"
section to reflect the 0.19.0 highlights, and note the tokenizer
milestone on the Q2 2026 roadmap line. Ops docs regenerated so the
stamped version matches the new VERSION_NAME.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal mentioned this pull request Apr 21, 2026
@michalharakal michalharakal deleted the fix/535-flatten-shape-inference branch April 21, 2026 13:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix: Implement dynamic shape for Flatten layer to replace hardcoded fallback

1 participant