Skip to content

Fix stablehlo.reshape source type for %arg0 consumers (closes #518)#521

Merged
michalharakal merged 1 commit intodevelopfrom
feature/hlo-input-type-tracking
Apr 18, 2026
Merged

Fix stablehlo.reshape source type for %arg0 consumers (closes #518)#521
michalharakal merged 1 commit intodevelopfrom
feature/hlo-input-type-tracking

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Summary

  • Thread an SSA-value-name → MLIR-type map through ConversionContext; seed it with %argN declared types at function-signature emission and populate it after each successful op conversion.
  • Fix ShapeOperationsConverter to look up the operand's real type via the map (with node.inputs[0] fallback) instead of reusing outputType on the source side of every stablehlo.reshape cast.

Background

iree-compile rejected SKaiNET-emitted MLIR on every reshape / unsqueeze consuming a function argument. The converter emitted : (outputType) -> outputType, producing e.g. stablehlo.reshape %arg0 : (tensor<1x80x3000xf32>) -> tensor<1x80x3000xf32> when %arg0 was actually tensor<80x3000xf32>. Third of four items from the skainet-whisper IREE Vulkan bring-up; first two landed in #520.

Test plan

Closes #518

🤖 Generated with Claude Code

ShapeOperationsConverter emitted `: ($outputType) -> $outputType` on
every reshape / flatten / squeeze / unsqueeze, reusing the output type
on the source side of the cast. When the operand was a function
argument (`%arg0: tensor<80x3000xf32>`), the result was syntactically
valid but wrong — e.g. `stablehlo.reshape %arg0 : (tensor<1x80x3000xf32>)
-> tensor<1x80x3000xf32>` — and iree-compile rejected it with a type
mismatch on every native SKaiNET DSL export that reached the IREE path.

Fix threads an SSA-value-name -> MLIR-type map through ConversionContext:

- StableHloConverter seeds `%argN -> mapTensorType(inputSpec)` when it
  writes the function signature, and records `resultValue -> outputType`
  after each successful op conversion.
- ShapeOperationsConverter looks up the operand's real type via a new
  resolveOperandType helper (context.getValueType first, then
  node.inputs[0], then a dynamic fallback) and uses that on the source
  side of every reshape cast.

The type map is consumed opt-in by converters that need it, so other
converters (concat, slice, dot_general, transpose) are unaffected.

Regression test `testReshapeOnArgUsesDeclaredArgType` builds an
input -> unsqueeze graph and asserts the emitted MLIR contains
`(tensor<80x3000xf32>) -> tensor<1x80x3000xf32>`. Full
:skainet-compile-hlo:jvmTest passes.

Closes #518

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit 9eafa11 into develop Apr 18, 2026
4 checks passed
@michalharakal michalharakal deleted the feature/hlo-input-type-tracking branch April 18, 2026 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

HLO: function input vs weight constant disambiguation in StableHloConverter

1 participant