diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt index 806b8ff2..3b2bbfb2 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt @@ -74,17 +74,19 @@ public open class DefaultExecutionTape( val outputDTypes = (trace.attributes["outputDTypes"] as? List<*>)?.map { it?.toString() } val inputs = List(trace.inputs.size) { i -> + val ref = trace.inputs[i] TensorSpec( - name = trace.inputs[i].id, - shape = inputShapes?.getOrNull(i), - dtype = inputDTypes?.getOrNull(i) ?: "unknown", + name = ref.id, + shape = inputShapes?.getOrNull(i) ?: ref.shape.dimensions.toList(), + dtype = inputDTypes?.getOrNull(i) ?: ref.dtype.name, ) } val outputs = List(trace.outputs.size) { i -> + val ref = trace.outputs[i] TensorSpec( - name = trace.outputs[i].id, - shape = outputShapes?.getOrNull(i), - dtype = outputDTypes?.getOrNull(i) ?: "unknown", + name = ref.id, + shape = outputShapes?.getOrNull(i) ?: ref.shape.dimensions.toList(), + dtype = outputDTypes?.getOrNull(i) ?: ref.dtype.name, ) } diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt index d2326f39..8629c3a9 100644 --- a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/grad/DefaultExecutionTapeTest.kt @@ -14,6 +14,8 @@ import sk.ainet.lang.tensor.* import sk.ainet.lang.tensor.data.FloatArrayTensorData import sk.ainet.lang.graph.DefaultExecutionTape import sk.ainet.lang.tensor.ops.AddOperation +import sk.ainet.lang.trace.OpTrace +import kotlin.test.assertEquals class DefaultExecutionTapeTest { @@ -147,4 +149,38 @@ class DefaultExecutionTapeTest { assertTrue(prunedTape.operations[0].operation is AddOperation<*, *>) assertTrue(prunedTape.operations[1].operation is AddOperation<*, *>) } + + @Test + fun recordTrace_without_shape_attributes_falls_back_to_tensor_ref_shape() { + val trainCtx = createTrainCtx() + val input = trainCtx.fromFloatArray(Shape(1, 80, 3000), FP32::class, FloatArray(1 * 80 * 3000)) + val weight = trainCtx.fromFloatArray(Shape(384, 80, 3), FP32::class, FloatArray(384 * 80 * 3)) + val output = trainCtx.fromFloatArray(Shape(1, 384, 3000), FP32::class, FloatArray(1 * 384 * 3000)) + + val tape = DefaultExecutionTape(trainCtx.session) + tape.startRecording() + val inputRef = tape.session.refOf(input) + val weightRef = tape.session.refOf(weight) + val outputRef = tape.session.refOf(output) + + tape.recordTrace( + OpTrace( + opType = "conv1d", + inputs = listOf(inputRef, weightRef), + outputs = listOf(outputRef), + attributes = mapOf( + "stride" to 1, + "padding" to 1, + "dilation" to 1, + "groups" to 1 + ) + ) + ) + tape.stopRecording() + + val recorded = tape.operations.single { it.operation.name == "conv1d" } + assertEquals(listOf(1, 384, 3000), recorded.outputs.single().shape) + assertEquals(listOf(1, 80, 3000), recorded.inputs[0].shape) + assertEquals(listOf(384, 80, 3), recorded.inputs[1].shape) + } }