Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public class SimpleExecutionTape : ExecutionTape {

// Stable naming helpers to align with Exporters.kt expectations
private fun stableInputName(op: Operation, index: Int, total: Int): String = when (op) {
is Conv2dOperation<*, *> -> when (index) {
is Conv1dOperation<*, *>, is Conv2dOperation<*, *>, is Conv3dOperation<*, *> -> when (index) {
0 -> "input"
1 -> "weight"
2 -> "bias"
Expand Down Expand Up @@ -243,7 +243,22 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
padding: Int,
dilation: Int,
groups: Int
): Tensor<T, V> = base.conv1d(input, weight, bias, stride, padding, dilation, groups)
): Tensor<T, V> {
val out = base.conv1d(input, weight, bias, stride, padding, dilation, groups)
val params = mapOf(
"stride" to stride,
"padding" to padding,
"dilation" to dilation,
"groups" to groups
)
@Suppress("UNCHECKED_CAST")
record(
Conv1dOperation<T, V>(params),
listOf(input, weight) + listOfNotNull(bias) as List<Tensor<T, V>>,
listOf(out)
)
return out
}

override fun <T : DType, V> conv2d(
input: Tensor<T, V>,
Expand Down Expand Up @@ -278,7 +293,22 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
padding: Triple<Int, Int, Int>,
dilation: Triple<Int, Int, Int>,
groups: Int
): Tensor<T, V> = base.conv3d(input, weight, bias, stride, padding, dilation, groups)
): Tensor<T, V> {
val out = base.conv3d(input, weight, bias, stride, padding, dilation, groups)
val params = mapOf(
"stride" to stride,
"padding" to padding,
"dilation" to dilation,
"groups" to groups
)
@Suppress("UNCHECKED_CAST")
record(
Conv3dOperation<T, V>(params),
listOf(input, weight) + listOfNotNull(bias) as List<Tensor<T, V>>,
listOf(out)
)
return out
}

override fun <T : DType, V> maxPool2d(
input: Tensor<T, V>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.VoidOpsTensor
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
import sk.ainet.lang.tensor.ops.AddOperation
import sk.ainet.lang.tensor.ops.Conv1dOperation
import sk.ainet.lang.tensor.ops.Conv3dOperation
import sk.ainet.lang.tensor.ops.ReluOperation
import sk.ainet.lang.tensor.ops.TensorOps
import sk.ainet.lang.tensor.ops.VoidTensorOps
Expand Down Expand Up @@ -83,4 +85,58 @@ class SimpleExecutionTapeTest {
assertFalse(copy.isRecording, "Copies keep recording state but not recording by default")
assertNotNull(copy.operations.first(), "Copy retains recorded op")
}

@Test
fun records_conv1d_through_recording_decorator() {
val input = VoidOpsTensor<FP32, Float>(
dataFactory.zeros(Shape(1, 80, 3000), FP32::class), FP32::class
)
val weight = VoidOpsTensor<FP32, Float>(
dataFactory.zeros(Shape(384, 80, 3), FP32::class), FP32::class
)
val bias = VoidOpsTensor<FP32, Float>(
dataFactory.zeros(Shape(384), FP32::class), FP32::class
)

val tape = Execution.withTape {
val ops = Execution.recordingOps(VoidTensorOps())
ops.conv1d<FP32, Float>(input, weight, bias, stride = 1, padding = 1, dilation = 1, groups = 1)
}

val recorded = tape.operations.single { it.operation is Conv1dOperation<*, *> }
assertEquals(listOf("input", "weight", "bias"), recorded.inputs.map { it.name })
assertEquals(listOf(1, 80, 3000), recorded.inputs[0].shape)
assertEquals(listOf(384, 80, 3), recorded.inputs[1].shape)
assertEquals(listOf(384), recorded.inputs[2].shape)
assertEquals(listOf(1, 384, 3000), recorded.outputs.single().shape)
assertEquals(1, recorded.operation.parameters["stride"])
assertEquals(1, recorded.operation.parameters["padding"])
}

@Test
fun records_conv3d_through_recording_decorator() {
val input = VoidOpsTensor<FP32, Float>(
dataFactory.zeros(Shape(1, 3, 16, 16, 16), FP32::class), FP32::class
)
val weight = VoidOpsTensor<FP32, Float>(
dataFactory.zeros(Shape(8, 3, 3, 3, 3), FP32::class), FP32::class
)

val tape = Execution.withTape {
val ops = Execution.recordingOps(VoidTensorOps())
ops.conv3d<FP32, Float>(
input, weight, bias = null,
stride = Triple(1, 1, 1),
padding = Triple(0, 0, 0),
dilation = Triple(1, 1, 1),
groups = 1
)
}

val recorded = tape.operations.single { it.operation is Conv3dOperation<*, *> }
assertEquals(listOf("input", "weight"), recorded.inputs.map { it.name })
assertEquals(listOf(1, 3, 16, 16, 16), recorded.inputs[0].shape)
assertEquals(listOf(8, 3, 3, 3, 3), recorded.inputs[1].shape)
assertEquals(listOf(1, 8, 14, 14, 14), recorded.outputs.single().shape)
}
}
20 changes: 20 additions & 0 deletions skainet-lang/skainet-lang-core/api/android/skainet-lang-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -3224,6 +3224,16 @@ public abstract class sk/ainet/lang/tensor/ops/BaseOperation : sk/ainet/lang/ten
public fun toString ()Ljava/lang/String;
}

public final class sk/ainet/lang/tensor/ops/Conv1dOperation : sk/ainet/lang/tensor/ops/BaseOperation {
public fun <init> ()V
public fun <init> (Ljava/util/Map;)V
public synthetic fun <init> (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation;
public fun execute (Ljava/util/List;)Ljava/util/List;
public fun inferOutputs (Ljava/util/List;)Ljava/util/List;
public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult;
}

public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tensor/ops/BaseOperation {
public fun <init> ()V
public fun <init> (Ljava/util/Map;)V
Expand All @@ -3234,6 +3244,16 @@ public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tens
public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult;
}

public final class sk/ainet/lang/tensor/ops/Conv3dOperation : sk/ainet/lang/tensor/ops/BaseOperation {
public fun <init> ()V
public fun <init> (Ljava/util/Map;)V
public synthetic fun <init> (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation;
public fun execute (Ljava/util/List;)Ljava/util/List;
public fun inferOutputs (Ljava/util/List;)Ljava/util/List;
public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult;
}

public abstract interface class sk/ainet/lang/tensor/ops/DifferentiableTensorOps {
public abstract fun absBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List;
public abstract fun addBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List;
Expand Down
20 changes: 20 additions & 0 deletions skainet-lang/skainet-lang-core/api/jvm/skainet-lang-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -3366,6 +3366,16 @@ public abstract class sk/ainet/lang/tensor/ops/BaseOperation : sk/ainet/lang/ten
public fun toString ()Ljava/lang/String;
}

public final class sk/ainet/lang/tensor/ops/Conv1dOperation : sk/ainet/lang/tensor/ops/BaseOperation {
public fun <init> ()V
public fun <init> (Ljava/util/Map;)V
public synthetic fun <init> (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation;
public fun execute (Ljava/util/List;)Ljava/util/List;
public fun inferOutputs (Ljava/util/List;)Ljava/util/List;
public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult;
}

public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tensor/ops/BaseOperation {
public fun <init> ()V
public fun <init> (Ljava/util/Map;)V
Expand All @@ -3376,6 +3386,16 @@ public final class sk/ainet/lang/tensor/ops/Conv2dOperation : sk/ainet/lang/tens
public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult;
}

public final class sk/ainet/lang/tensor/ops/Conv3dOperation : sk/ainet/lang/tensor/ops/BaseOperation {
public fun <init> ()V
public fun <init> (Ljava/util/Map;)V
public synthetic fun <init> (Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun clone (Ljava/util/Map;)Lsk/ainet/lang/tensor/ops/Operation;
public fun execute (Ljava/util/List;)Ljava/util/List;
public fun inferOutputs (Ljava/util/List;)Ljava/util/List;
public fun validateInputs (Ljava/util/List;)Lsk/ainet/lang/tensor/ops/ValidationResult;
}

public abstract interface class sk/ainet/lang/tensor/ops/DifferentiableTensorOps {
public abstract fun absBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List;
public abstract fun addBackward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Ljava/util/Map;)Ljava/util/List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,22 +420,54 @@ public class TransposeOperation<T : DType, V>(
/**
* Convolutional operations
*/
public class Conv1dOperation<T : DType, V>(
parameters: Map<String, Any> = emptyMap()
) : BaseOperation("conv1d", "nn", parameters) {

override fun <T2 : DType, V2> execute(inputs: List<Tensor<T2, V2>>): List<Tensor<T2, V2>> {
require(inputs.size >= 2) { "Conv1d operation requires at least 2 inputs" }
throw UnsupportedOperationException("Direct execution not supported in graph mode")
}

override fun validateInputs(inputs: List<TensorSpec>): ValidationResult {
if (inputs.size < 2 || inputs.size > 3) {
return ValidationResult.Invalid(listOf("Conv1d operation requires 2-3 inputs, got ${inputs.size}"))
}
return ValidationResult.Valid
}

override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
require(inputs.size >= 2) { "Conv1d operation requires at least 2 inputs" }
val outputShape = inputs[0].shape
return listOf(
TensorSpec(
name = "conv1d_output",
shape = outputShape,
dtype = inputs[0].dtype,
requiresGrad = inputs.any { it.requiresGrad }
)
)
}

override fun clone(newParameters: Map<String, Any>): Operation = Conv1dOperation<T, V>(newParameters)
}

public class Conv2dOperation<T : DType, V>(
parameters: Map<String, Any> = emptyMap()
) : BaseOperation("conv2d", "nn", parameters) {

override fun <T2 : DType, V2> execute(inputs: List<Tensor<T2, V2>>): List<Tensor<T2, V2>> {
require(inputs.size >= 2) { "Conv2d operation requires at least 2 inputs" }
throw UnsupportedOperationException("Direct execution not supported in graph mode")
}

override fun validateInputs(inputs: List<TensorSpec>): ValidationResult {
if (inputs.size < 2 || inputs.size > 3) {
return ValidationResult.Invalid(listOf("Conv2d operation requires 2-3 inputs, got ${inputs.size}"))
}
return ValidationResult.Valid
}

override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
require(inputs.size >= 2) { "Conv2d operation requires at least 2 inputs" }
val outputShape = inputs[0].shape
Expand All @@ -448,10 +480,42 @@ public class Conv2dOperation<T : DType, V>(
)
)
}

override fun clone(newParameters: Map<String, Any>): Operation = Conv2dOperation<T, V>(newParameters)
}

public class Conv3dOperation<T : DType, V>(
parameters: Map<String, Any> = emptyMap()
) : BaseOperation("conv3d", "nn", parameters) {

override fun <T2 : DType, V2> execute(inputs: List<Tensor<T2, V2>>): List<Tensor<T2, V2>> {
require(inputs.size >= 2) { "Conv3d operation requires at least 2 inputs" }
throw UnsupportedOperationException("Direct execution not supported in graph mode")
}

override fun validateInputs(inputs: List<TensorSpec>): ValidationResult {
if (inputs.size < 2 || inputs.size > 3) {
return ValidationResult.Invalid(listOf("Conv3d operation requires 2-3 inputs, got ${inputs.size}"))
}
return ValidationResult.Valid
}

override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
require(inputs.size >= 2) { "Conv3d operation requires at least 2 inputs" }
val outputShape = inputs[0].shape
return listOf(
TensorSpec(
name = "conv3d_output",
shape = outputShape,
dtype = inputs[0].dtype,
requiresGrad = inputs.any { it.requiresGrad }
)
)
}

override fun clone(newParameters: Map<String, Any>): Operation = Conv3dOperation<T, V>(newParameters)
}

/**
* Pooling operations
*/
Expand Down
Loading