From 3c6838f0056c2b81e7d060c4296e597f17c6b48a Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sun, 19 Apr 2026 19:18:27 +0200 Subject: [PATCH] Fall back to TensorRef shape/dtype when trace attrs are missing (#530) The KSP-generated tracer wrapper only populates outputShapes/inputShapes attributes for a hard-coded allowlist (conv2d, unary/binary/scalar ops, shape ops). conv1d and conv3d fall through AttributeStrategy.DefaultMapping, which records only stride/padding/dilation/groups. Without the shape attributes, DefaultExecutionTape.recordTrace built TensorSpecs with shape=null, and the StableHLO converter emitted tensor for conv1d output (12 occurrences in the Whisper encoder), which iree-compile rejects as a type mismatch against the static bias. The TensorRef in the trace already carries the correct static Shape captured from the runtime tensor, so use it as the fallback. Same treatment for dtype. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../ainet/lang/graph/DefaultExecutionTape.kt | 14 ++++---- .../compile/grad/DefaultExecutionTapeTest.kt | 36 +++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt index 806b8ff2..3b2bbfb2 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt @@ -74,17 +74,19 @@ public open class DefaultExecutionTape( val outputDTypes = (trace.attributes["outputDTypes"] as? List<*>)?.map { it?.toString() } val inputs = List(trace.inputs.size) { i -> + val ref = trace.inputs[i] TensorSpec( - name = trace.inputs[i].id, - shape = inputShapes?.getOrNull(i), - dtype = inputDTypes?.getOrNull(i) ?: "unknown", + name = ref.id, + shape = inputShapes?.getOrNull(i) ?: ref.shape.dimensions.toList(), + dtype = inputDTypes?.getOrNull(i) ?: ref.dtype.name, ) } val outputs = List(trace.outputs.size) { i -> + val ref = trace.outputs[i] TensorSpec( - name = trace.outputs[i].id, - shape = outputShapes?.getOrNull(i), - dtype = outputDTypes?.getOrNull(i) ?: "unknown", + name = ref.id, + shape = outputShapes?.getOrNull(i) ?: ref.shape.dimensions.toList(), + dtype = outputDTypes?.getOrNull(i) ?: ref.dtype.name, ) } diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt index d2326f39..8629c3a9 100644 --- a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt @@ -14,6 +14,8 @@ import sk.ainet.lang.tensor.* import sk.ainet.lang.tensor.data.FloatArrayTensorData import sk.ainet.lang.graph.DefaultExecutionTape import sk.ainet.lang.tensor.ops.AddOperation +import sk.ainet.lang.trace.OpTrace +import kotlin.test.assertEquals class DefaultExecutionTapeTest { @@ -147,4 +149,38 @@ class DefaultExecutionTapeTest { assertTrue(prunedTape.operations[0].operation is AddOperation<*, *>) assertTrue(prunedTape.operations[1].operation is AddOperation<*, *>) } + + @Test + fun recordTrace_without_shape_attributes_falls_back_to_tensor_ref_shape() { + val trainCtx = createTrainCtx() + val input = trainCtx.fromFloatArray(Shape(1, 80, 3000), FP32::class, FloatArray(1 * 80 * 3000)) + val weight = trainCtx.fromFloatArray(Shape(384, 80, 3), FP32::class, FloatArray(384 * 80 * 3)) + val output = trainCtx.fromFloatArray(Shape(1, 384, 3000), FP32::class, FloatArray(1 * 384 * 3000)) + + val tape = DefaultExecutionTape(trainCtx.session) + tape.startRecording() + val inputRef = tape.session.refOf(input) + val weightRef = tape.session.refOf(weight) + val outputRef = tape.session.refOf(output) + + tape.recordTrace( + OpTrace( + opType = "conv1d", + inputs = listOf(inputRef, weightRef), + outputs = listOf(outputRef), + attributes = mapOf( + "stride" to 1, + "padding" to 1, + "dilation" to 1, + "groups" to 1 + ) + ) + ) + tape.stopRecording() + + val recorded = tape.operations.single { it.operation.name == "conv1d" } + assertEquals(listOf(1, 384, 3000), recorded.outputs.single().shape) + assertEquals(listOf(1, 80, 3000), recorded.inputs[0].shape) + assertEquals(listOf(384, 80, 3), recorded.inputs[1].shape) + } }