Skip to content

Compute real conv output shapes in graph operations (#536)#537

Merged
michalharakal merged 1 commit intodevelopfrom
fix/536-conv-inferoutputs-shape
Apr 19, 2026
Merged

Compute real conv output shapes in graph operations (#536)#537
michalharakal merged 1 commit intodevelopfrom
fix/536-conv-inferoutputs-shape

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #536.

Summary

  • Add ConvShapeUtils as a public, single-source-of-truth helper for conv1d/2d/3d output-shape math.
  • Make VoidTensorOps.calculateConv{1,2,3}dShape thin delegates to the helper.
  • Fix Conv{1,2,3}dOperation.inferOutputs (tensor/ops/TensorOperations.kt) to compute the real output shape from inputs[0], inputs[1], and parameters["stride"|"padding"|"dilation"] instead of echoing inputs[0].shape. Falls back to a null shape only when input/weight shape is unknown or wrong rank. Conv2d/Conv3d accept either the Pair/Triple parameters written by RecordingExecution or scalar Int (treated symmetrically).

This unblocks the StableHLO export path: every stablehlo.convolution previously emitted as tensor<?xf32> because the graph operation never published a real output shape. PRs #530 and #532 fixed adjacent pieces (TensorRef fallback, recording decorator wiring) but not inferOutputs itself.

Test plan

  • ./gradlew :skainet-lang:skainet-lang-core:jvmTest (incl. new ConvOperationInferOutputsTest, 8 cases)
  • ./gradlew :skainet-compile:skainet-compile-core:jvmTest
  • ./gradlew :skainet-compile:skainet-compile-hlo:jvmTest
  • Re-run the Whisper encoder StableHLO export and confirm zero tensor<? occurrences in encoder_skainet.mlir
  • Confirm iree-compile accepts the module without tools/fix_stablehlo_mlir.py

🤖 Generated with Claude Code

Conv{1,2,3}dOperation.inferOutputs previously echoed inputs[0].shape,
ignoring the weight shape and stride/padding/dilation parameters. This
left every stablehlo.convolution result as tensor<?xf32>, blocking
iree-compile on the Whisper encoder.

Extract the shape math into a public ConvShapeUtils object so the
eager (VoidTensorOps) and graph-emission paths share one source of
truth, and rewrite the three inferOutputs methods to use it. Conv2d
and Conv3d accept either Pair/Triple (as written by RecordingExecution)
or scalar Int parameters.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

📖 Documentation Preview

The documentation has been built successfully for this PR.

Generated Files:

  • Operator documentation: docs/modules/operators/_generated_/
  • JSON schema output: operators.json

Artifacts:

  • Download the documentation-preview-537 artifact to view the complete documentation locally.

This comment will be updated automatically when the PR is updated.

michalharakal added a commit that referenced this pull request Apr 19, 2026
The DSL's flatten() previously fell back to a hardcoded lastDimension =
1568 (the value that happens to fit the MNIST CNN reference model).
Any other architecture - e.g. a 64-channel CNN over 32x32 inputs - hit
ArrayIndexOutOfBounds in the following dense layer.

Add a per-sample shape tracker (currentShape: IntArray?) to StageImpl
and NeuralNetworkDslImpl, plus a new input(intArrayOf(...)) overload
that seeds it. conv1d/2d/3d, maxPool2d, avgPool2d, and upsample2d now
update currentShape using the same arithmetic as VoidTensorOps via
ConvShapeUtils (extended with pool2d and upsample2d helpers, building
on the helper introduced in #537). flatten() reads currentShape and
honors startDim / endDim instead of guessing 1568. When no input shape
is declared we leave lastDimension untouched so existing flatten-only
runtime tests keep working - dense() will surface the gap with a clear
error if it actually matters.

Update MnistCnn to declare input(intArrayOf(1, 28, 28)) so it works
under the new shape inference instead of relying on the magic constant.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit c04257c into develop Apr 19, 2026
9 checks passed
@michalharakal michalharakal deleted the fix/536-conv-inferoutputs-shape branch April 19, 2026 20:14
michalharakal added a commit that referenced this pull request Apr 21, 2026
Bump VERSION_NAME to 0.19.0 in the root gradle.properties, expand the
CHANGELOG [0.19.0] - 2026-04-20 section to cover the full 130 commits
since 0.18.0 — not just the tokenizer work but the StableHLO → IREE
lowering pipeline (softmax/layerNorm/rmsnorm real lowerings, gather/
embedding/concat/slice/cast converters, ConstantMaterializationPolicy,
dense<v> splat folding, SSA type tracking), the new skainet-io-iree-
params IrpaWriter, skainet-backend-api module, Antora docs migration
with Diátaxis layout, Java API polish (#400), androidNativeArm32
target, and the graph/DSL shape-inference fixes (#535, #536, #537,
#538) that unblock non-MNIST CNN architectures and Whisper-encoder
HLO compilation. Refresh the README install snippet and "What's New"
section to reflect the 0.19.0 highlights, and note the tokenizer
milestone on the Q2 2026 roadmap line. Ops docs regenerated so the
stamped version matches the new VERSION_NAME.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal mentioned this pull request Apr 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Conv{1,2,3}dOperation.inferOutputs echoes input shape instead of computing output shape

1 participant