Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 92 additions & 11 deletions skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,63 @@ public final class sk/ainet/compile/hlo/ConstantFoldingPass : sk/ainet/compile/h
public fun getName ()Ljava/lang/String;
}

public abstract interface class sk/ainet/compile/hlo/ConstantMaterializationPolicy {
}

public final class sk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways : sk/ainet/compile/hlo/ConstantMaterializationPolicy {
public fun <init> ()V
public fun <init> (Ljava/lang/String;)V
public synthetic fun <init> (Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()Ljava/lang/String;
public final fun copy (Ljava/lang/String;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways;
public static synthetic fun copy$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways;
public fun equals (Ljava/lang/Object;)Z
public final fun getScope ()Ljava/lang/String;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final class sk/ainet/compile/hlo/ConstantMaterializationPolicy$InlineAlways : sk/ainet/compile/hlo/ConstantMaterializationPolicy {
public static final field INSTANCE Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$InlineAlways;
public fun equals (Ljava/lang/Object;)Z
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final class sk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold : sk/ainet/compile/hlo/ConstantMaterializationPolicy {
public fun <init> (JLjava/lang/String;)V
public synthetic fun <init> (JLjava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()J
public final fun component2 ()Ljava/lang/String;
public final fun copy (JLjava/lang/String;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold;
public static synthetic fun copy$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold;JLjava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold;
public fun equals (Ljava/lang/Object;)Z
public final fun getBytes ()J
public final fun getScope ()Ljava/lang/String;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final class sk/ainet/compile/hlo/ConversionContext {
public fun <init> (Lsk/ainet/compile/hlo/TypeMapper;)V
public fun <init> (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;)V
public synthetic fun <init> (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)V
public synthetic fun <init> (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun clear ()V
public final fun emitComment (Ljava/lang/String;)V
public final fun emitEncodingAnnotation (Ljava/lang/String;ILsk/ainet/lang/tensor/ops/TensorSpec;)V
public final fun emitLine (Ljava/lang/String;)V
public final fun emitModuleDeclaration (Ljava/lang/String;)V
public final fun emitOperation (Ljava/lang/String;)V
public final fun getContent ()Ljava/lang/String;
public final fun getExternalParameters ()Ljava/util/List;
public final fun getInputNodes (Lsk/ainet/lang/graph/GraphNode;)Ljava/util/List;
public final fun getMaterializationPolicy ()Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;
public final fun getModuleDeclarations ()Ljava/lang/String;
public final fun getTypeMapper ()Lsk/ainet/compile/hlo/TypeMapper;
public final fun getValueName (Ljava/lang/String;)Ljava/lang/String;
public final fun nextTempValue ()Ljava/lang/String;
public final fun registerExternalParameter (Lsk/ainet/compile/hlo/ExternalParameterRef;)V
public final fun setGraph (Lsk/ainet/lang/graph/ComputeGraph;)V
public final fun setValueName (Ljava/lang/String;Ljava/lang/String;)V
}
Expand Down Expand Up @@ -77,6 +122,23 @@ public final class sk/ainet/compile/hlo/DeadCodeEliminationPass : sk/ainet/compi
public fun getName ()Ljava/lang/String;
}

public final class sk/ainet/compile/hlo/ExternalParameterRef {
public fun <init> (Ljava/lang/String;Ljava/lang/String;Lsk/ainet/lang/tensor/storage/TensorEncoding;Lsk/ainet/lang/tensor/storage/BufferHandle;)V
public final fun component1 ()Ljava/lang/String;
public final fun component2 ()Ljava/lang/String;
public final fun component3 ()Lsk/ainet/lang/tensor/storage/TensorEncoding;
public final fun component4 ()Lsk/ainet/lang/tensor/storage/BufferHandle;
public final fun copy (Ljava/lang/String;Ljava/lang/String;Lsk/ainet/lang/tensor/storage/TensorEncoding;Lsk/ainet/lang/tensor/storage/BufferHandle;)Lsk/ainet/compile/hlo/ExternalParameterRef;
public static synthetic fun copy$default (Lsk/ainet/compile/hlo/ExternalParameterRef;Ljava/lang/String;Ljava/lang/String;Lsk/ainet/lang/tensor/storage/TensorEncoding;Lsk/ainet/lang/tensor/storage/BufferHandle;ILjava/lang/Object;)Lsk/ainet/compile/hlo/ExternalParameterRef;
public fun equals (Ljava/lang/Object;)Z
public final fun getEncoding ()Lsk/ainet/lang/tensor/storage/TensorEncoding;
public final fun getKey ()Ljava/lang/String;
public final fun getScope ()Ljava/lang/String;
public final fun getSource ()Lsk/ainet/lang/tensor/storage/BufferHandle;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final class sk/ainet/compile/hlo/MlirValidator {
public fun <init> ()V
public final fun validate (Ljava/lang/String;)Ljava/util/List;
Expand Down Expand Up @@ -108,35 +170,48 @@ public final class sk/ainet/compile/hlo/RegistryStats {
}

public final class sk/ainet/compile/hlo/StableHloConverter {
public fun <init> (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;)V
public fun <init> (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;)V
public synthetic fun <init> (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)V
public synthetic fun <init> (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun convert (Lsk/ainet/lang/graph/ComputeGraph;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule;
public static synthetic fun convert$default (Lsk/ainet/compile/hlo/StableHloConverter;Lsk/ainet/lang/graph/ComputeGraph;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule;
public final fun convertWithOptimization (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/hlo/StableHloOptimizer;)Lsk/ainet/compile/hlo/StableHloModule;
}

public final class sk/ainet/compile/hlo/StableHloConverterFactory {
public static final field INSTANCE Lsk/ainet/compile/hlo/StableHloConverterFactory;
public final fun createBasic ()Lsk/ainet/compile/hlo/StableHloConverter;
public final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;)Lsk/ainet/compile/hlo/StableHloConverter;
public static synthetic fun createCustom$default (Lsk/ainet/compile/hlo/StableHloConverterFactory;Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter;
public final fun createExtended ()Lsk/ainet/compile/hlo/StableHloConverter;
public final fun createFast ()Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createBasic ()Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createBasic (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter;
public static synthetic fun createBasic$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;)Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;)Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;)Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter;
public static synthetic fun createCustom$default (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createExtended ()Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createExtended (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter;
public static synthetic fun createExtended$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createFast ()Lsk/ainet/compile/hlo/StableHloConverter;
public static final fun createFast (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter;
public static synthetic fun createFast$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter;
}

public final class sk/ainet/compile/hlo/StableHloModule {
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()Ljava/lang/String;
public final fun component2 ()Ljava/lang/String;
public final fun component3 ()Ljava/util/List;
public final fun component4 ()Ljava/util/List;
public final fun component5 ()Ljava/util/Map;
public final fun component6 ()Ljava/util/List;
public final fun contentLines ()Lkotlin/sequences/Sequence;
public final fun copy (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;)Lsk/ainet/compile/hlo/StableHloModule;
public static synthetic fun copy$default (Lsk/ainet/compile/hlo/StableHloModule;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule;
public final fun copy (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;)Lsk/ainet/compile/hlo/StableHloModule;
public static synthetic fun copy$default (Lsk/ainet/compile/hlo/StableHloModule;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule;
public fun equals (Ljava/lang/Object;)Z
public final fun getContent ()Ljava/lang/String;
public final fun getExternalParameters ()Ljava/util/List;
public final fun getFunctionName ()Ljava/lang/String;
public final fun getInputSpecs ()Ljava/util/List;
public final fun getMetadata ()Ljava/util/Map;
Expand Down Expand Up @@ -210,6 +285,12 @@ public final class sk/ainet/compile/hlo/converters/ConstantOperationsConverter :
public fun getSupportedOperations ()Ljava/util/Set;
}

public final class sk/ainet/compile/hlo/converters/GatherOperationsConverter : sk/ainet/compile/hlo/StableHloOperationConverter {
public fun <init> ()V
public fun convert (Lsk/ainet/lang/graph/GraphNode;Ljava/util/List;Lsk/ainet/compile/hlo/ConversionContext;)Lsk/ainet/compile/hlo/ConversionResult;
public fun getSupportedOperations ()Ljava/util/Set;
}

public final class sk/ainet/compile/hlo/converters/LegacyOperationsConverter : sk/ainet/compile/hlo/StableHloOperationConverter {
public fun <init> ()V
public fun convert (Lsk/ainet/lang/graph/GraphNode;Ljava/util/List;Lsk/ainet/compile/hlo/ConversionContext;)Lsk/ainet/compile/hlo/ConversionResult;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package sk.ainet.compile.hlo

/**
* Serialize a list of numeric values into a little-endian [ByteArray]
* matching the given SKaiNET dtype string.
*
* Used by the converter to materialize `node.operation.parameters["values"]`
* / `["initial_value"]` lists into bytes when policy is
* [ConstantMaterializationPolicy.ExternalAlways]. Pads with zeros when
* [values] is shorter than [expectedElements] so under-filled
* initializations emit a well-formed buffer of the declared shape.
*
* Kept in `commonMain` — uses [Float.toRawBits] / [Double.toRawBits]
* rather than JVM-only streams so the same code runs on all KMP
* targets the HLO module currently supports.
*
* Throws for dtypes the converter does not yet externalize. Callers
* should catch and fall back to inline emission with a diagnostic.
*/
internal fun numberListToLittleEndianBytes(
values: List<*>,
dtype: String,
expectedElements: Int
): ByteArray {
val count = expectedElements.coerceAtLeast(values.size)
val normalized = dtype.uppercase()

return when (normalized) {
"FP32", "F32", "FLOAT32" -> {
val bytes = ByteArray(count * 4)
for (i in 0 until count) {
val v = (values.getOrNull(i) as? Number)?.toFloat() ?: 0.0f
val bits = v.toRawBits()
bytes[i * 4] = (bits and 0xff).toByte()
bytes[i * 4 + 1] = (bits ushr 8 and 0xff).toByte()
bytes[i * 4 + 2] = (bits ushr 16 and 0xff).toByte()
bytes[i * 4 + 3] = (bits ushr 24 and 0xff).toByte()
}
bytes
}
"FP64", "F64", "FLOAT64" -> {
val bytes = ByteArray(count * 8)
for (i in 0 until count) {
val v = (values.getOrNull(i) as? Number)?.toDouble() ?: 0.0
val bits = v.toRawBits()
for (b in 0 until 8) {
bytes[i * 8 + b] = ((bits ushr (b * 8)) and 0xff).toByte()
}
}
bytes
}
"I32", "INT32" -> {
val bytes = ByteArray(count * 4)
for (i in 0 until count) {
val v = (values.getOrNull(i) as? Number)?.toInt() ?: 0
bytes[i * 4] = (v and 0xff).toByte()
bytes[i * 4 + 1] = (v ushr 8 and 0xff).toByte()
bytes[i * 4 + 2] = (v ushr 16 and 0xff).toByte()
bytes[i * 4 + 3] = (v ushr 24 and 0xff).toByte()
}
bytes
}
"I64", "INT64" -> {
val bytes = ByteArray(count * 8)
for (i in 0 until count) {
val v = (values.getOrNull(i) as? Number)?.toLong() ?: 0L
for (b in 0 until 8) {
bytes[i * 8 + b] = ((v ushr (b * 8)) and 0xff).toByte()
}
}
bytes
}
else -> throw IllegalArgumentException(
"External parameter materialization not yet implemented for dtype=$dtype. " +
"Supported: FP32, FP64, I32, I64."
)
}
}

/**
* Expected element count for a (possibly empty) shape. Empty shape
* (scalar) means one element; `null` / absent dims degrade to 0 so the
* caller can detect "no declared shape".
*/
internal fun elementCountFromShape(shape: List<Int>?): Int {
if (shape == null) return 0
if (shape.isEmpty()) return 1
return shape.fold(1) { acc, d -> acc * d }
}
Loading
Loading