diff --git a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt index 4612c837..bcfbc8a6 100644 --- a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt +++ b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt @@ -118,7 +118,7 @@ public class SimpleExecutionTape : ExecutionTape { // Stable naming helpers to align with Exporters.kt expectations private fun stableInputName(op: Operation, index: Int, total: Int): String = when (op) { - is Conv2dOperation<*, *> -> when (index) { + is Conv1dOperation<*, *>, is Conv2dOperation<*, *>, is Conv3dOperation<*, *> -> when (index) { 0 -> "input" 1 -> "weight" 2 -> "bias" @@ -243,7 +243,22 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor padding: Int, dilation: Int, groups: Int - ): Tensor = base.conv1d(input, weight, bias, stride, padding, dilation, groups) + ): Tensor { + val out = base.conv1d(input, weight, bias, stride, padding, dilation, groups) + val params = mapOf( + "stride" to stride, + "padding" to padding, + "dilation" to dilation, + "groups" to groups + ) + @Suppress("UNCHECKED_CAST") + record( + Conv1dOperation(params), + listOf(input, weight) + listOfNotNull(bias) as List>, + listOf(out) + ) + return out + } override fun conv2d( input: Tensor, @@ -278,7 +293,22 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor padding: Triple, dilation: Triple, groups: Int - ): Tensor = base.conv3d(input, weight, bias, stride, padding, dilation, groups) + ): Tensor { + val out = base.conv3d(input, weight, bias, stride, padding, dilation, groups) + val params = mapOf( + "stride" to stride, + "padding" to padding, + "dilation" to dilation, + "groups" to groups + ) + @Suppress("UNCHECKED_CAST") + record( + Conv3dOperation(params), + listOf(input, weight) + listOfNotNull(bias) as List>, + listOf(out) + ) + return out + } override fun maxPool2d( input: Tensor, diff --git a/skainet-compile/skainet-compile-core/src/commonTest/kotlin/sk/ainet/tape/SimpleExecutionTapeTest.kt b/skainet-compile/skainet-compile-core/src/commonTest/kotlin/sk/ainet/tape/SimpleExecutionTapeTest.kt index 128c4a01..dbc80c5b 100644 --- a/skainet-compile/skainet-compile-core/src/commonTest/kotlin/sk/ainet/tape/SimpleExecutionTapeTest.kt +++ b/skainet-compile/skainet-compile-core/src/commonTest/kotlin/sk/ainet/tape/SimpleExecutionTapeTest.kt @@ -10,6 +10,8 @@ import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.VoidOpsTensor import sk.ainet.lang.tensor.data.DenseTensorDataFactory import sk.ainet.lang.tensor.ops.AddOperation +import sk.ainet.lang.tensor.ops.Conv1dOperation +import sk.ainet.lang.tensor.ops.Conv3dOperation import sk.ainet.lang.tensor.ops.ReluOperation import sk.ainet.lang.tensor.ops.TensorOps import sk.ainet.lang.tensor.ops.VoidTensorOps @@ -83,4 +85,58 @@ class SimpleExecutionTapeTest { assertFalse(copy.isRecording, "Copies keep recording state but not recording by default") assertNotNull(copy.operations.first(), "Copy retains recorded op") } + + @Test + fun records_conv1d_through_recording_decorator() { + val input = VoidOpsTensor( + dataFactory.zeros(Shape(1, 80, 3000), FP32::class), FP32::class + ) + val weight = VoidOpsTensor( + dataFactory.zeros(Shape(384, 80, 3), FP32::class), FP32::class + ) + val bias = VoidOpsTensor( + dataFactory.zeros(Shape(384), FP32::class), FP32::class + ) + + val tape = Execution.withTape { + val ops = Execution.recordingOps(VoidTensorOps()) + ops.conv1d(input, weight, bias, stride = 1, padding = 1, dilation = 1, groups = 1) + } + + val recorded = tape.operations.single { it.operation is Conv1dOperation<*, *> } + assertEquals(listOf("input", "weight", "bias"), recorded.inputs.map { it.name }) + assertEquals(listOf(1, 80, 3000), recorded.inputs[0].shape) + assertEquals(listOf(384, 80, 3), recorded.inputs[1].shape) + assertEquals(listOf(384), recorded.inputs[2].shape) + assertEquals(listOf(1, 384, 3000), recorded.outputs.single().shape) + assertEquals(1, recorded.operation.parameters["stride"]) + assertEquals(1, recorded.operation.parameters["padding"]) + } + + @Test + fun records_conv3d_through_recording_decorator() { + val input = VoidOpsTensor( + dataFactory.zeros(Shape(1, 3, 16, 16, 16), FP32::class), FP32::class + ) + val weight = VoidOpsTensor( + dataFactory.zeros(Shape(8, 3, 3, 3, 3), FP32::class), FP32::class + ) + + val tape = Execution.withTape { + val ops = Execution.recordingOps(VoidTensorOps()) + ops.conv3d( + input, weight, bias = null, + stride = Triple(1, 1, 1), + padding = Triple(0, 0, 0), + dilation = Triple(1, 1, 1), + groups = 1 + ) + } + + val recorded = tape.operations.single { it.operation is Conv3dOperation<*, *> } + assertEquals(listOf("input", "weight"), recorded.inputs.map { it.name }) + assertEquals(listOf(1, 3, 16, 16, 16), recorded.inputs[0].shape) + assertEquals(listOf(8, 3, 3, 3, 3), recorded.inputs[1].shape) + assertEquals(listOf(1, 8, 14, 14, 14), recorded.outputs.single().shape) + } } diff --git a/skainet-lang/skainet-lang-core/api/android/skainet-lang-core.api b/skainet-lang/skainet-lang-core/api/android/skainet-lang-core.api index b41148bc..c9c152e6 100644 --- a/skainet-lang/skainet-lang-core/api/android/skainet-lang-core.api +++ b/skainet-lang/skainet-lang-core/api/android/skainet-lang-core.api @@ -3224,6 +3224,16 @@ public abstract class sk/ainet/lang/tensor/ops/BaseOperation : sk/ainet/lang/ten public fun toString ()Ljava/lang/String; } +public final class sk/ainet/lang/tensor/ops/Conv1dOperation : sk/ainet/lang/tensor/ops/BaseOperation { + public fun ()V + public fun (Ljava/util/Map;)V + public synthetic fun (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation; + public fun execute (Ljava/util/List;)Ljava/util/List; + public fun inferOutputs (Ljava/util/List;)Ljava/util/List; + public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult; +} + public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tensor/ops/BaseOperation { public fun ()V public fun (Ljava/util/Map;)V @@ -3234,6 +3244,16 @@ public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tens public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult; } +public final class sk/ainet/lang/tensor/ops/Conv3dOperation : sk/ainet/lang/tensor/ops/BaseOperation { + public fun ()V + public fun (Ljava/util/Map;)V + public synthetic fun (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation; + public fun execute (Ljava/util/List;)Ljava/util/List; + public fun inferOutputs (Ljava/util/List;)Ljava/util/List; + public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult; +} + public abstract interface class sk/ainet/lang/tensor/ops/DifferentiableTensorOps { public abstract fun absBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List; public abstract fun addBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List; diff --git a/skainet-lang/skainet-lang-core/api/jvm/skainet-lang-core.api b/skainet-lang/skainet-lang-core/api/jvm/skainet-lang-core.api index 8bd07b3e..dab99d7e 100644 --- a/skainet-lang/skainet-lang-core/api/jvm/skainet-lang-core.api +++ b/skainet-lang/skainet-lang-core/api/jvm/skainet-lang-core.api @@ -3366,6 +3366,16 @@ public abstract class sk/ainet/lang/tensor/ops/BaseOperation : sk/ainet/lang/ten public fun toString ()Ljava/lang/String; } +public final class sk/ainet/lang/tensor/ops/Conv1dOperation : sk/ainet/lang/tensor/ops/BaseOperation { + public fun ()V + public fun (Ljava/util/Map;)V + public synthetic fun (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation; + public fun execute (Ljava/util/List;)Ljava/util/List; + public fun inferOutputs (Ljava/util/List;)Ljava/util/List; + public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult; +} + public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tensor/ops/BaseOperation { public fun ()V public fun (Ljava/util/Map;)V @@ -3376,6 +3386,16 @@ public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tens public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult; } +public final class sk/ainet/lang/tensor/ops/Conv3dOperation : sk/ainet/lang/tensor/ops/BaseOperation { + public fun ()V + public fun (Ljava/util/Map;)V + public synthetic fun (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation; + public fun execute (Ljava/util/List;)Ljava/util/List; + public fun inferOutputs (Ljava/util/List;)Ljava/util/List; + public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult; +} + public abstract interface class sk/ainet/lang/tensor/ops/DifferentiableTensorOps { public abstract fun absBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List; public abstract fun addBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List; diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt index e4de7315..2057d5e7 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt @@ -420,22 +420,54 @@ public class TransposeOperation( /** * Convolutional operations */ +public class Conv1dOperation( + parameters: Map = emptyMap() +) : BaseOperation("conv1d", "nn", parameters) { + + override fun execute(inputs: List>): List> { + require(inputs.size >= 2) { "Conv1d operation requires at least 2 inputs" } + throw UnsupportedOperationException("Direct execution not supported in graph mode") + } + + override fun validateInputs(inputs: List): ValidationResult { + if (inputs.size < 2 || inputs.size > 3) { + return ValidationResult.Invalid(listOf("Conv1d operation requires 2-3 inputs, got ${inputs.size}")) + } + return ValidationResult.Valid + } + + override fun inferOutputs(inputs: List): List { + require(inputs.size >= 2) { "Conv1d operation requires at least 2 inputs" } + val outputShape = inputs[0].shape + return listOf( + TensorSpec( + name = "conv1d_output", + shape = outputShape, + dtype = inputs[0].dtype, + requiresGrad = inputs.any { it.requiresGrad } + ) + ) + } + + override fun clone(newParameters: Map): Operation = Conv1dOperation(newParameters) +} + public class Conv2dOperation( parameters: Map = emptyMap() ) : BaseOperation("conv2d", "nn", parameters) { - + override fun execute(inputs: List>): List> { require(inputs.size >= 2) { "Conv2d operation requires at least 2 inputs" } throw UnsupportedOperationException("Direct execution not supported in graph mode") } - + override fun validateInputs(inputs: List): ValidationResult { if (inputs.size < 2 || inputs.size > 3) { return ValidationResult.Invalid(listOf("Conv2d operation requires 2-3 inputs, got ${inputs.size}")) } return ValidationResult.Valid } - + override fun inferOutputs(inputs: List): List { require(inputs.size >= 2) { "Conv2d operation requires at least 2 inputs" } val outputShape = inputs[0].shape @@ -448,10 +480,42 @@ public class Conv2dOperation( ) ) } - + override fun clone(newParameters: Map): Operation = Conv2dOperation(newParameters) } +public class Conv3dOperation( + parameters: Map = emptyMap() +) : BaseOperation("conv3d", "nn", parameters) { + + override fun execute(inputs: List>): List> { + require(inputs.size >= 2) { "Conv3d operation requires at least 2 inputs" } + throw UnsupportedOperationException("Direct execution not supported in graph mode") + } + + override fun validateInputs(inputs: List): ValidationResult { + if (inputs.size < 2 || inputs.size > 3) { + return ValidationResult.Invalid(listOf("Conv3d operation requires 2-3 inputs, got ${inputs.size}")) + } + return ValidationResult.Valid + } + + override fun inferOutputs(inputs: List): List { + require(inputs.size >= 2) { "Conv3d operation requires at least 2 inputs" } + val outputShape = inputs[0].shape + return listOf( + TensorSpec( + name = "conv3d_output", + shape = outputShape, + dtype = inputs[0].dtype, + requiresGrad = inputs.any { it.requiresGrad } + ) + ) + } + + override fun clone(newParameters: Map): Operation = Conv3dOperation(newParameters) +} + /** * Pooling operations */