diff --git a/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api b/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api index 9e45c197..c9ad1992 100644 --- a/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api +++ b/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api @@ -4,18 +4,63 @@ public final class sk/ainet/compile/hlo/ConstantFoldingPass : sk/ainet/compile/h public fun getName ()Ljava/lang/String; } +public abstract interface class sk/ainet/compile/hlo/ConstantMaterializationPolicy { +} + +public final class sk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways : sk/ainet/compile/hlo/ConstantMaterializationPolicy { + public fun ()V + public fun (Ljava/lang/String;)V + public synthetic fun (Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun copy (Ljava/lang/String;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways; + public static synthetic fun copy$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$ExternalAlways; + public fun equals (Ljava/lang/Object;)Z + public final fun getScope ()Ljava/lang/String; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/compile/hlo/ConstantMaterializationPolicy$InlineAlways : sk/ainet/compile/hlo/ConstantMaterializationPolicy { + public static final field INSTANCE Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$InlineAlways; + public fun equals (Ljava/lang/Object;)Z + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold : sk/ainet/compile/hlo/ConstantMaterializationPolicy { + public fun (JLjava/lang/String;)V + public synthetic fun (JLjava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()J + public final fun component2 ()Ljava/lang/String; + public final fun copy (JLjava/lang/String;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold; + public static synthetic fun copy$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold;JLjava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/ConstantMaterializationPolicy$SizeThreshold; + public fun equals (Ljava/lang/Object;)Z + public final fun getBytes ()J + public final fun getScope ()Ljava/lang/String; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + public final class sk/ainet/compile/hlo/ConversionContext { + public fun (Lsk/ainet/compile/hlo/TypeMapper;)V public fun (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;)V - public synthetic fun (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)V + public synthetic fun (Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun clear ()V public final fun emitComment (Ljava/lang/String;)V + public final fun emitEncodingAnnotation (Ljava/lang/String;ILsk/ainet/lang/tensor/ops/TensorSpec;)V public final fun emitLine (Ljava/lang/String;)V + public final fun emitModuleDeclaration (Ljava/lang/String;)V public final fun emitOperation (Ljava/lang/String;)V public final fun getContent ()Ljava/lang/String; + public final fun getExternalParameters ()Ljava/util/List; public final fun getInputNodes (Lsk/ainet/lang/graph/GraphNode;)Ljava/util/List; + public final fun getMaterializationPolicy ()Lsk/ainet/compile/hlo/ConstantMaterializationPolicy; + public final fun getModuleDeclarations ()Ljava/lang/String; public final fun getTypeMapper ()Lsk/ainet/compile/hlo/TypeMapper; public final fun getValueName (Ljava/lang/String;)Ljava/lang/String; public final fun nextTempValue ()Ljava/lang/String; + public final fun registerExternalParameter (Lsk/ainet/compile/hlo/ExternalParameterRef;)V public final fun setGraph (Lsk/ainet/lang/graph/ComputeGraph;)V public final fun setValueName (Ljava/lang/String;Ljava/lang/String;)V } @@ -77,6 +122,23 @@ public final class sk/ainet/compile/hlo/DeadCodeEliminationPass : sk/ainet/compi public fun getName ()Ljava/lang/String; } +public final class sk/ainet/compile/hlo/ExternalParameterRef { + public fun (Ljava/lang/String;Ljava/lang/String;Lsk/ainet/lang/tensor/storage/TensorEncoding;Lsk/ainet/lang/tensor/storage/BufferHandle;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Ljava/lang/String; + public final fun component3 ()Lsk/ainet/lang/tensor/storage/TensorEncoding; + public final fun component4 ()Lsk/ainet/lang/tensor/storage/BufferHandle; + public final fun copy (Ljava/lang/String;Ljava/lang/String;Lsk/ainet/lang/tensor/storage/TensorEncoding;Lsk/ainet/lang/tensor/storage/BufferHandle;)Lsk/ainet/compile/hlo/ExternalParameterRef; + public static synthetic fun copy$default (Lsk/ainet/compile/hlo/ExternalParameterRef;Ljava/lang/String;Ljava/lang/String;Lsk/ainet/lang/tensor/storage/TensorEncoding;Lsk/ainet/lang/tensor/storage/BufferHandle;ILjava/lang/Object;)Lsk/ainet/compile/hlo/ExternalParameterRef; + public fun equals (Ljava/lang/Object;)Z + public final fun getEncoding ()Lsk/ainet/lang/tensor/storage/TensorEncoding; + public final fun getKey ()Ljava/lang/String; + public final fun getScope ()Ljava/lang/String; + public final fun getSource ()Lsk/ainet/lang/tensor/storage/BufferHandle; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + public final class sk/ainet/compile/hlo/MlirValidator { public fun ()V public final fun validate (Ljava/lang/String;)Ljava/util/List; @@ -108,8 +170,10 @@ public final class sk/ainet/compile/hlo/RegistryStats { } public final class sk/ainet/compile/hlo/StableHloConverter { + public fun (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;)V public fun (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;)V - public synthetic fun (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)V + public synthetic fun (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun convert (Lsk/ainet/lang/graph/ComputeGraph;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule; public static synthetic fun convert$default (Lsk/ainet/compile/hlo/StableHloConverter;Lsk/ainet/lang/graph/ComputeGraph;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule; public final fun convertWithOptimization (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/hlo/StableHloOptimizer;)Lsk/ainet/compile/hlo/StableHloModule; @@ -117,26 +181,37 @@ public final class sk/ainet/compile/hlo/StableHloConverter { public final class sk/ainet/compile/hlo/StableHloConverterFactory { public static final field INSTANCE Lsk/ainet/compile/hlo/StableHloConverterFactory; - public final fun createBasic ()Lsk/ainet/compile/hlo/StableHloConverter; - public final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;)Lsk/ainet/compile/hlo/StableHloConverter; - public static synthetic fun createCustom$default (Lsk/ainet/compile/hlo/StableHloConverterFactory;Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter; - public final fun createExtended ()Lsk/ainet/compile/hlo/StableHloConverter; - public final fun createFast ()Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createBasic ()Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createBasic (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter; + public static synthetic fun createBasic$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;)Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;)Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;)Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createCustom (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter; + public static synthetic fun createCustom$default (Lsk/ainet/compile/hlo/StableHloOperationRegistry;Lsk/ainet/compile/hlo/TypeMapper;Lsk/ainet/compile/hlo/MlirValidator;Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createExtended ()Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createExtended (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter; + public static synthetic fun createExtended$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createFast ()Lsk/ainet/compile/hlo/StableHloConverter; + public static final fun createFast (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;)Lsk/ainet/compile/hlo/StableHloConverter; + public static synthetic fun createFast$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter; } public final class sk/ainet/compile/hlo/StableHloModule { - public fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;)V - public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; public final fun component2 ()Ljava/lang/String; public final fun component3 ()Ljava/util/List; public final fun component4 ()Ljava/util/List; public final fun component5 ()Ljava/util/Map; + public final fun component6 ()Ljava/util/List; public final fun contentLines ()Lkotlin/sequences/Sequence; - public final fun copy (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;)Lsk/ainet/compile/hlo/StableHloModule; - public static synthetic fun copy$default (Lsk/ainet/compile/hlo/StableHloModule;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule; + public final fun copy (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;)Lsk/ainet/compile/hlo/StableHloModule; + public static synthetic fun copy$default (Lsk/ainet/compile/hlo/StableHloModule;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule; public fun equals (Ljava/lang/Object;)Z public final fun getContent ()Ljava/lang/String; + public final fun getExternalParameters ()Ljava/util/List; public final fun getFunctionName ()Ljava/lang/String; public final fun getInputSpecs ()Ljava/util/List; public final fun getMetadata ()Ljava/util/Map; @@ -210,6 +285,12 @@ public final class sk/ainet/compile/hlo/converters/ConstantOperationsConverter : public fun getSupportedOperations ()Ljava/util/Set; } +public final class sk/ainet/compile/hlo/converters/GatherOperationsConverter : sk/ainet/compile/hlo/StableHloOperationConverter { + public fun ()V + public fun convert (Lsk/ainet/lang/graph/GraphNode;Ljava/util/List;Lsk/ainet/compile/hlo/ConversionContext;)Lsk/ainet/compile/hlo/ConversionResult; + public fun getSupportedOperations ()Ljava/util/Set; +} + public final class sk/ainet/compile/hlo/converters/LegacyOperationsConverter : sk/ainet/compile/hlo/StableHloOperationConverter { public fun ()V public fun convert (Lsk/ainet/lang/graph/GraphNode;Ljava/util/List;Lsk/ainet/compile/hlo/ConversionContext;)Lsk/ainet/compile/hlo/ConversionResult; diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantByteSerializer.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantByteSerializer.kt new file mode 100644 index 00000000..d21eb646 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantByteSerializer.kt @@ -0,0 +1,89 @@ +package sk.ainet.compile.hlo + +/** + * Serialize a list of numeric values into a little-endian [ByteArray] + * matching the given SKaiNET dtype string. + * + * Used by the converter to materialize `node.operation.parameters["values"]` + * / `["initial_value"]` lists into bytes when policy is + * [ConstantMaterializationPolicy.ExternalAlways]. Pads with zeros when + * [values] is shorter than [expectedElements] so under-filled + * initializations emit a well-formed buffer of the declared shape. + * + * Kept in `commonMain` — uses [Float.toRawBits] / [Double.toRawBits] + * rather than JVM-only streams so the same code runs on all KMP + * targets the HLO module currently supports. + * + * Throws for dtypes the converter does not yet externalize. Callers + * should catch and fall back to inline emission with a diagnostic. + */ +internal fun numberListToLittleEndianBytes( + values: List<*>, + dtype: String, + expectedElements: Int +): ByteArray { + val count = expectedElements.coerceAtLeast(values.size) + val normalized = dtype.uppercase() + + return when (normalized) { + "FP32", "F32", "FLOAT32" -> { + val bytes = ByteArray(count * 4) + for (i in 0 until count) { + val v = (values.getOrNull(i) as? Number)?.toFloat() ?: 0.0f + val bits = v.toRawBits() + bytes[i * 4] = (bits and 0xff).toByte() + bytes[i * 4 + 1] = (bits ushr 8 and 0xff).toByte() + bytes[i * 4 + 2] = (bits ushr 16 and 0xff).toByte() + bytes[i * 4 + 3] = (bits ushr 24 and 0xff).toByte() + } + bytes + } + "FP64", "F64", "FLOAT64" -> { + val bytes = ByteArray(count * 8) + for (i in 0 until count) { + val v = (values.getOrNull(i) as? Number)?.toDouble() ?: 0.0 + val bits = v.toRawBits() + for (b in 0 until 8) { + bytes[i * 8 + b] = ((bits ushr (b * 8)) and 0xff).toByte() + } + } + bytes + } + "I32", "INT32" -> { + val bytes = ByteArray(count * 4) + for (i in 0 until count) { + val v = (values.getOrNull(i) as? Number)?.toInt() ?: 0 + bytes[i * 4] = (v and 0xff).toByte() + bytes[i * 4 + 1] = (v ushr 8 and 0xff).toByte() + bytes[i * 4 + 2] = (v ushr 16 and 0xff).toByte() + bytes[i * 4 + 3] = (v ushr 24 and 0xff).toByte() + } + bytes + } + "I64", "INT64" -> { + val bytes = ByteArray(count * 8) + for (i in 0 until count) { + val v = (values.getOrNull(i) as? Number)?.toLong() ?: 0L + for (b in 0 until 8) { + bytes[i * 8 + b] = ((v ushr (b * 8)) and 0xff).toByte() + } + } + bytes + } + else -> throw IllegalArgumentException( + "External parameter materialization not yet implemented for dtype=$dtype. " + + "Supported: FP32, FP64, I32, I64." + ) + } +} + +/** + * Expected element count for a (possibly empty) shape. Empty shape + * (scalar) means one element; `null` / absent dims degrade to 0 so the + * caller can detect "no declared shape". + */ +internal fun elementCountFromShape(shape: List?): Int { + if (shape == null) return 0 + if (shape.isEmpty()) return 1 + return shape.fold(1) { acc, d -> acc * d } +} diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantMaterialization.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantMaterialization.kt new file mode 100644 index 00000000..7d86e8a9 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantMaterialization.kt @@ -0,0 +1,96 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.tensor.storage.BufferHandle +import sk.ainet.lang.tensor.storage.TensorEncoding + +/** + * Policy governing how a constant tensor (a weight, a bias, a frozen + * parameter) is materialized into the emitted StableHLO text. + * + * The intent is to decouple "what the converter sees" (a graph node + * with values) from "how those values reach the deployed runtime" + * (inline bytes vs external parameter archive). The seam exists so a + * caller can flip between modes without changing the graph, and so the + * converter does not grow a private weight-format. + * + * Introduced as the load-bearing decision point of the architecture + * tracked in issue #523. + */ +public sealed interface ConstantMaterializationPolicy { + + /** + * Every constant is written into the emitted text as + * `stablehlo.constant dense<...>`. Matches the historical behavior + * and is the default so existing callers are unaffected. + */ + public data object InlineAlways : ConstantMaterializationPolicy + + /** + * Every candidate constant — currently `tensor_constant`, + * `dense_constant`, `parameter`, `param`, `weight`, `bias` — is + * lifted out of the IR. The converter emits a `util.global` module + * declaration and a `util.global.load` reference, and records an + * [ExternalParameterRef] on the resulting module. A downstream + * packager (e.g. the `skainet-io-iree-params` module planned in PR + * C) turns the refs into a `.irpa` sidecar that + * `iree-compile --iree-opt-import-parameters=` resolves. + * + * If a candidate node has no accompanying byte source (no `values` + * / `initial_value` list and no external handle), the converter + * falls back to inline with a diagnostic comment — better to emit + * working IR than to reference bytes that do not exist. + * + * @property scope Namespace written into the emitted + * `util.global.load` reference (`@::@`) and into + * the [ExternalParameterRef]. "model" is the conventional + * default; callers may override per-module. + */ + public data class ExternalAlways(val scope: String = "model") : ConstantMaterializationPolicy + + /** + * Hybrid policy: small constants stay inline, large ones go + * external. The threshold is measured in **logical bytes** — + * `elementCount * bytesPerElement` computed from the output + * [sk.ainet.lang.tensor.ops.TensorSpec] — not the MLIR text size. + * This keeps the decision independent of downstream splat / dense + * formatting. + * + * @property bytes Minimum logical size (inclusive) at which a + * constant is externalized. + * @property scope Namespace for externalized constants (see + * [ExternalAlways.scope]). + */ + public data class SizeThreshold( + val bytes: Long, + val scope: String = "model" + ) : ConstantMaterializationPolicy +} + +/** + * Reference to a weight tensor that has been lifted out of the emitted + * StableHLO text and moved behind an `util.global.load` reference. The + * converter produces these; a downstream packager consumes them to + * write an IREE parameter archive (`.irpa`). + * + * The converter does not copy bytes — it passes the [source] handle + * through unchanged. Callers that back a handle with `mmap` (planned + * in PR E for skainet-io-gguf and skainet-io-safetensors) get a + * true zero-copy path all the way from the source file to the `.irpa`. + * + * @property scope Parameter-archive scope (`@::@` in the + * emitted MLIR, and the scope name inside the `.irpa` container). + * @property key Symbolic name; matches `TensorSpec.name` by convention + * so the archive is addressable by tensor identity. + * @property encoding Physical [TensorEncoding] (Dense / Q4_K / Q8_0 / + * TurboQuant / TernaryPacked / Opaque). Preserved so the packager + * can blit quantized blocks verbatim instead of re-quantizing. + * @property source [BufferHandle] backing the tensor bytes. May be + * an in-memory copy today; PR E replaces these with mmap windows + * into the source GGUF / safetensors file. + */ +public data class ExternalParameterRef( + val scope: String, + val key: String, + val encoding: TensorEncoding, + val source: BufferHandle +) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt index 925b3319..cafce501 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt @@ -11,12 +11,24 @@ import sk.ainet.lang.tensor.ops.tensorEncoding * This class manages SSA value names, type mapping, and MLIR code generation * during the conversion process from ComputeGraph to StableHLO. */ -public class ConversionContext( +public class ConversionContext @kotlin.jvm.JvmOverloads constructor( private val typeMapper: TypeMapper, - private var graph: ComputeGraph? = null + private var graph: ComputeGraph? = null, + /** + * Governs whether constant tensors are inlined as `dense<...>` or + * lifted into `util.global` module declarations. Default + * [ConstantMaterializationPolicy.InlineAlways] preserves historical + * behavior for every caller that constructs a context without + * naming a policy — the external path is strictly opt-in. + * See issue #523 for the architecture context. + */ + public val materializationPolicy: ConstantMaterializationPolicy = + ConstantMaterializationPolicy.InlineAlways ) { private val valueNames = mutableMapOf() private val stringBuilder = StringBuilder() + private val moduleDeclarationsBuilder = StringBuilder() + private val externalParams = mutableListOf() private var tempCounter = 0 /** @@ -57,6 +69,39 @@ public class ConversionContext( stringBuilder.appendLine(" // $comment") } + /** + * Emit a module-scope declaration (e.g. `util.global private @w : ...`). + * + * Module-scope lines sit between `module {` and the enclosing + * `func.func` in the final MLIR output. [StableHloConverter] + * buffers them separately so callers can emit them at any point + * during node processing without disturbing the function body. + */ + public fun emitModuleDeclaration(line: String) { + moduleDeclarationsBuilder.appendLine(" $line") + } + + /** + * Return every module-scope declaration emitted so far. Used by + * [StableHloConverter] when assembling the final content. + */ + public fun getModuleDeclarations(): String = moduleDeclarationsBuilder.toString() + + /** + * Register an externalized constant tensor. The converter records + * these alongside MLIR emission so a downstream packager (see PR C + * in issue #523) can write them into an IREE `.irpa` archive. + */ + public fun registerExternalParameter(ref: ExternalParameterRef) { + externalParams += ref + } + + /** + * Snapshot of every externalized constant registered during this + * conversion. Surfaced on [StableHloModule.externalParameters]. + */ + public fun getExternalParameters(): List = externalParams.toList() + /** * Emit a `tensor_encoding` diagnostic comment when [spec] carries a * non-null `tensorEncoding` (set via [sk.ainet.lang.tensor.ops.withTensorEncoding]). @@ -119,6 +164,8 @@ public class ConversionContext( public fun clear() { valueNames.clear() stringBuilder.clear() + moduleDeclarationsBuilder.clear() + externalParams.clear() tempCounter = 0 } } \ No newline at end of file diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt index 08c1e690..77c57c42 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt @@ -12,17 +12,25 @@ import sk.ainet.lang.tensor.storage.TensorEncoding * This class provides a modular architecture for converting computational graphs to StableHLO format, * using a registry-based system for operation mapping and a conversion context for state management. */ -public class StableHloConverter( +public class StableHloConverter @kotlin.jvm.JvmOverloads constructor( private val registry: StableHloOperationRegistry, private val typeMapper: TypeMapper, - private val validator: MlirValidator? = null + private val validator: MlirValidator? = null, + /** + * Governs how constant tensors are materialized into the emitted + * MLIR — inline `dense<...>` (default, historical) or lifted out + * behind `util.global.load` references (see issue #523). Handed + * through to every [ConversionContext] this converter creates. + */ + private val materializationPolicy: ConstantMaterializationPolicy = + ConstantMaterializationPolicy.InlineAlways ) { - + /** * Convert a ComputeGraph to StableHLO MLIR format */ public fun convert(graph: ComputeGraph, functionName: String = "main"): StableHloModule { - val context = ConversionContext(typeMapper, graph) + val context = ConversionContext(typeMapper, graph, materializationPolicy) // Pre-conversion validation (allow orphaned nodes for backward compatibility) val validationResult = graph.validate() @@ -55,34 +63,38 @@ public class StableHloConverter( // instead of string-matching against scattered comments. val tensorEncodings = collectTensorEncodings(topo) - // Start building MLIR content — promote to `module attributes` - // only when we have at least one encoded tensor. Dense graphs - // keep the bare `module {` header for byte-for-byte backward - // compatibility with existing round-trip tests. - if (tensorEncodings.isNotEmpty()) { + // Process nodes first, then assemble the final content. + // Converters populate two buffers on the context — op emissions + // for the function body and any module-scope declarations + // (e.g. `util.global` decls under an external-materialization + // policy). Deferring assembly lets us inject module-scope + // lines between `module {` and `func.func` without string + // surgery. + initializeInputValues(inputNodes, context) + processNodes(topo, context) + generateReturnStatement(outputNodes, context) + + val moduleHeader = if (tensorEncodings.isNotEmpty()) { val dictEntries = tensorEncodings.entries .sortedBy { it.key } .joinToString(", ") { (name, encoding) -> "$name = \"${encoding.name}\"" } - context.emitLine("module attributes {skainet.tensor_encodings = {$dictEntries}} {") + "module attributes {skainet.tensor_encodings = {$dictEntries}} {" } else { - context.emitLine("module {") + "module {" } - context.emitLine(" func.func $functionSignature {") - - // Initialize input values in context - initializeInputValues(inputNodes, context) - - // Process nodes in topological order - processNodes(topo, context) - - // Generate return statement with output values - generateReturnStatement(outputNodes, context) - - // Close function and module - context.emitLine(" }") - context.emitLine("}") - - val content = context.getContent() + + val assembled = StringBuilder() + assembled.appendLine(moduleHeader) + val moduleDecls = context.getModuleDeclarations() + if (moduleDecls.isNotEmpty()) { + assembled.append(moduleDecls) + } + assembled.appendLine(" func.func $functionSignature {") + assembled.append(context.getContent()) + assembled.appendLine(" }") + assembled.appendLine("}") + + val content = assembled.toString() // Optional validation of generated MLIR validator?.validate(content)?.let { errors -> @@ -95,7 +107,8 @@ public class StableHloConverter( content = content, functionName = functionName, inputSpecs = inputNodes.mapNotNull { it.outputs.firstOrNull() }, - outputSpecs = outputSpecs + outputSpecs = outputSpecs, + externalParameters = context.getExternalParameters() ) } diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt index cf6bd46f..00fb0eea 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt @@ -21,9 +21,16 @@ public object StableHloConverterFactory { /** * Create a converter with basic operations support (add, matmul, relu) + * + * @param policy Controls inline vs external constant materialization. + * Defaults to [ConstantMaterializationPolicy.InlineAlways] for + * backward compatibility; see issue #523. */ @JvmStatic - public fun createBasic(): StableHloConverter { + @kotlin.jvm.JvmOverloads + public fun createBasic( + policy: ConstantMaterializationPolicy = ConstantMaterializationPolicy.InlineAlways + ): StableHloConverter { val registry = StableHloOperationRegistry() val typeMapper = TypeMapper() val validator = MlirValidator() @@ -53,14 +60,19 @@ public object StableHloConverterFactory { // LLM front-door op for token-id \u2192 embedding lookups. registry.register(GatherOperationsConverter()) - return StableHloConverter(registry, typeMapper, validator) + return StableHloConverter(registry, typeMapper, validator, policy) } /** * Create a converter with extended operations support + * + * @param policy See [createBasic]. */ @JvmStatic - public fun createExtended(): StableHloConverter { + @kotlin.jvm.JvmOverloads + public fun createExtended( + policy: ConstantMaterializationPolicy = ConstantMaterializationPolicy.InlineAlways + ): StableHloConverter { val registry = StableHloOperationRegistry() val typeMapper = TypeMapper() val validator = MlirValidator() @@ -93,17 +105,22 @@ public object StableHloConverterFactory { // LLM front-door op for token-id \u2192 embedding lookups. registry.register(GatherOperationsConverter()) - return StableHloConverter(registry, typeMapper, validator) + return StableHloConverter(registry, typeMapper, validator, policy) } /** * Create a converter without validation (for performance) + * + * @param policy See [createBasic]. */ @JvmStatic - public fun createFast(): StableHloConverter { + @kotlin.jvm.JvmOverloads + public fun createFast( + policy: ConstantMaterializationPolicy = ConstantMaterializationPolicy.InlineAlways + ): StableHloConverter { val registry = StableHloOperationRegistry() val typeMapper = TypeMapper() - + registry.register(LegacyOperationsConverter()) registry.register(MathOperationsConverter()) registry.register(LinalgOperationsConverter()) @@ -113,19 +130,22 @@ public object StableHloConverterFactory { registry.register(ReductionOperationsConverter()) registry.register(ConstantOperationsConverter()) - return StableHloConverter(registry, typeMapper, null) + return StableHloConverter(registry, typeMapper, null, policy) } - + /** * Create a custom converter with the provided components + * + * @param policy See [createBasic]. */ @JvmStatic @kotlin.jvm.JvmOverloads public fun createCustom( registry: StableHloOperationRegistry, typeMapper: TypeMapper = TypeMapper(), - validator: MlirValidator? = MlirValidator() + validator: MlirValidator? = MlirValidator(), + policy: ConstantMaterializationPolicy = ConstantMaterializationPolicy.InlineAlways ): StableHloConverter { - return StableHloConverter(registry, typeMapper, validator) + return StableHloConverter(registry, typeMapper, validator, policy) } } \ No newline at end of file diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt index 61318dbb..9c17441c 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt @@ -1,10 +1,17 @@ package sk.ainet.compile.hlo.converters +import sk.ainet.compile.hlo.ConstantMaterializationPolicy import sk.ainet.compile.hlo.ConversionContext import sk.ainet.compile.hlo.ConversionResult +import sk.ainet.compile.hlo.ExternalParameterRef import sk.ainet.compile.hlo.StableHloOperationConverter +import sk.ainet.compile.hlo.elementCountFromShape +import sk.ainet.compile.hlo.numberListToLittleEndianBytes import sk.ainet.lang.graph.GraphNode import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.tensorEncoding +import sk.ainet.lang.tensor.storage.BufferHandle +import sk.ainet.lang.tensor.storage.TensorEncoding /** * Converter for constant value operations in StableHLO. @@ -142,22 +149,28 @@ public class ConstantOperationsConverter : StableHloOperationConverter { "Missing 'values' or 'data' parameter for tensor constant", "No data specified for tensor constant ${node.id}" ) - + val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - + + // Policy seam (issue #523): if caller has opted into external + // materialization, lift this constant into a util.global behind + // a util.global.load reference and record an + // ExternalParameterRef for the downstream .irpa packager. + tryMaterializeExternal(node, outputSpec, outputType, values, context)?.let { return it } + val resultValue = context.nextTempValue() val formattedValues = formatTensorValues(values, outputSpec) val operation = "$resultValue = stablehlo.constant dense<$formattedValues> : $outputType" context.emitOperation(operation) - + return ConversionResult.Success( outputValueName = resultValue, emittedOperations = listOf(operation) ) } - + /** * Convert a splat constant (single value broadcasted to tensor shape) */ @@ -207,15 +220,23 @@ public class ConstantOperationsConverter : StableHloOperationConverter { val isTrainable = node.operation.parameters["trainable"] as? Boolean ?: true val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - - val resultValue = context.nextTempValue() - + // Add comment about parameter nature val paramType = if (isTrainable) "trainable parameter" else "frozen parameter" context.emitComment("${node.operation.name} ${node.id}: $paramType") - + + // Policy seam — same shape as convertTensorConstant. Only a + // List-valued initial_value can be externalized with bytes + // available today; Number (splat) and missing cases fall + // through to the inline path intentionally. + if (initialValue is List<*>) { + tryMaterializeExternal(node, outputSpec, outputType, initialValue, context)?.let { return it } + } + + val resultValue = context.nextTempValue() + val operation = when { initialValue is Number -> { "$resultValue = stablehlo.constant dense<${formatConstantValue(initialValue)}> : $outputType" @@ -230,7 +251,7 @@ public class ConstantOperationsConverter : StableHloOperationConverter { "$resultValue = stablehlo.constant dense<0.0> : $outputType" } } - + context.emitOperation(operation) return ConversionResult.Success( @@ -341,6 +362,102 @@ public class ConstantOperationsConverter : StableHloOperationConverter { } } + /** + * Policy-driven externalization seam (issue #523). + * + * Returns a [ConversionResult.Success] when the caller's + * [ConstantMaterializationPolicy] dictates that this constant + * should live in an IREE parameter archive instead of inline in + * the emitted MLIR. Returns `null` for the inline path — caller + * falls through to its existing `dense<...>` emission. + * + * When externalized, three things happen: + * 1. Bytes are serialized from [values] into a [BufferHandle.Owned]. + * 2. An [ExternalParameterRef] is recorded on the context so a + * downstream packager can write the `.irpa`. + * 3. `util.global private @ : ` is emitted at module + * scope and `%r = util.global.load @ : ` is emitted + * in the function body. Both use the tensor's declared name + * (or a synthetic fallback if absent) as the key. + * + * Falls back to inline (returns `null`) if the dtype is not yet + * serializable or the spec lacks a shape — better to emit working + * IR than to emit an external reference pointing at no bytes. + */ + private fun tryMaterializeExternal( + node: GraphNode, + outputSpec: TensorSpec?, + outputType: String, + values: List<*>, + context: ConversionContext + ): ConversionResult? { + val policy = context.materializationPolicy + if (policy is ConstantMaterializationPolicy.InlineAlways) return null + if (outputSpec == null) return null + + val encoding = outputSpec.tensorEncoding + ?: TensorEncoding.Dense(bytesPerElement = bytesPerElement(outputSpec.dtype)) + val elementCount = elementCountFromShape(outputSpec.shape) + if (elementCount <= 0) return null + + val logicalBytes = encoding.physicalBytes(elementCount.toLong()) ?: return null + val scope = when (policy) { + is ConstantMaterializationPolicy.InlineAlways -> return null + is ConstantMaterializationPolicy.ExternalAlways -> policy.scope + is ConstantMaterializationPolicy.SizeThreshold -> { + if (logicalBytes < policy.bytes) return null + policy.scope + } + } + + // Serialize now. Fall back to inline on unsupported dtype — + // a loud exception here would defeat the "default path is + // safe" invariant of the seam. + val bytes = try { + numberListToLittleEndianBytes(values, outputSpec.dtype, elementCount) + } catch (e: IllegalArgumentException) { + context.emitComment( + "external materialization fell back to inline for ${node.id}: ${e.message}" + ) + return null + } + + val key = outputSpec.name.ifEmpty { node.id } + context.registerExternalParameter( + ExternalParameterRef( + scope = scope, + key = key, + encoding = encoding, + source = BufferHandle.Owned(bytes) + ) + ) + context.emitModuleDeclaration("util.global private @${key} : $outputType") + + val resultValue = context.nextTempValue() + val operation = "$resultValue = util.global.load @${key} : $outputType" + context.emitOperation(operation) + + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = listOf(operation) + ) + } + + /** + * Rough bytes-per-element for the default [TensorEncoding.Dense] + * fallback when a spec does not carry an explicit encoding. + * Narrowly scoped — intentionally does not handle the packed + * quantization encodings, which must be carried on the spec + * itself. + */ + private fun bytesPerElement(dtype: String): Int = when (dtype.uppercase()) { + "FP32", "F32", "FLOAT32", "I32", "INT32", "UI32", "UINT32" -> 4 + "FP64", "F64", "FLOAT64", "I64", "INT64", "UI64", "UINT64" -> 8 + "FP16", "F16", "FLOAT16", "BF16", "BFLOAT16", "I16", "INT16", "UI16", "UINT16" -> 2 + "I8", "INT8", "UI8", "UINT8", "BOOL", "BOOLEAN" -> 1 + else -> 4 + } + /** * Format tensor values for MLIR dense constant. * MLIR dense<> syntax requires nested brackets matching the tensor rank: diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt index c9b313bf..406165e2 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt @@ -11,7 +11,14 @@ public data class StableHloModule( val functionName: String = "main", val inputSpecs: List = emptyList(), val outputSpecs: List = emptyList(), - val metadata: Map = emptyMap() + val metadata: Map = emptyMap(), + /** + * Constants that were lifted out of [content] as `util.global` + * references under [ConstantMaterializationPolicy.ExternalAlways]. + * A downstream packager writes these into an IREE `.irpa`; see + * issue #523. Empty under the default inline policy. + */ + val externalParameters: List = emptyList() ) { /** * Validate this module using the provided validator diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantMaterializationPolicyTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantMaterializationPolicyTest.kt new file mode 100644 index 00000000..8057f867 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantMaterializationPolicyTest.kt @@ -0,0 +1,198 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.InputOperation +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * End-to-end tests for the [ConstantMaterializationPolicy] seam + * introduced for issue #523. Covers the three policies through the + * full [StableHloConverterFactory.createBasic] pipeline so that every + * handoff (converter → context → module → output) is exercised. + */ +class ConstantMaterializationPolicyTest { + + @Test + fun testDefaultPolicyIsInlineAlways() { + // Callers that construct a converter without naming a policy + // must see the historical inline emission — the seam only + // activates when explicitly opted into. + val module = StableHloConverterFactory.createBasic() + .convert(buildTensorConstantGraph(), "default_policy") + + assertTrue(module.content.contains("stablehlo.constant")) + assertTrue(module.externalParameters.isEmpty()) + assertFalse(module.content.contains("util.global")) + } + + @Test + fun testExternalAlwaysEmitsUtilGlobalAndRegistersRef() { + // With ExternalAlways the converter must: + // - emit `util.global private @ : ` at module scope, + // - emit `util.global.load @` instead of inline dense<>, + // - register a matching ExternalParameterRef on the module. + val module = StableHloConverterFactory.createBasic( + ConstantMaterializationPolicy.ExternalAlways() + ).convert(buildTensorConstantGraph(), "external_policy") + + assertTrue( + module.content.contains("util.global private @weights : tensor<2x2xf32>"), + "module decl missing:\n${module.content}" + ) + assertTrue( + module.content.contains("util.global.load @weights : tensor<2x2xf32>"), + "util.global.load missing:\n${module.content}" + ) + // No inline dense<> for the externalized constant — only the + // `module attributes` clause on the header is permitted to + // mention the tensor name. Nothing in the function body + // should spell values out. + assertFalse( + module.content.contains("stablehlo.constant dense<"), + "externalized tensor leaked an inline stablehlo.constant:\n${module.content}" + ) + + assertEquals(1, module.externalParameters.size) + val ref = module.externalParameters.single() + assertEquals("model", ref.scope) + assertEquals("weights", ref.key) + // 4 f32 elements = 16 bytes + assertEquals(16L, ref.source.sizeInBytes) + } + + @Test + fun testSizeThresholdSplitsBySize() { + // A 2x2 f32 tensor is 16 bytes; threshold of 32 bytes must + // keep it inline. Raise the stakes by also including a 4x4 + // f32 tensor (64 bytes) which should be externalized. + val graph = DefaultComputeGraph() + graph.addNode( + GraphNode( + id = "small", + operation = mockConstantOp( + "tensor_constant", + mapOf("values" to listOf(1.0f, 2.0f, 3.0f, 4.0f)) + ), + inputs = emptyList(), + outputs = listOf(TensorSpec("small_w", listOf(2, 2), "FP32")) + ) + ) + graph.addNode( + GraphNode( + id = "large", + operation = mockConstantOp( + "tensor_constant", + mapOf("values" to List(16) { 0.0f }) + ), + inputs = emptyList(), + outputs = listOf(TensorSpec("large_w", listOf(4, 4), "FP32")) + ) + ) + + val module = StableHloConverterFactory.createBasic( + ConstantMaterializationPolicy.SizeThreshold(bytes = 32L) + ).convert(graph, "threshold") + + // Small stays inline — `stablehlo.constant dense<...>` shows up + // for the 2x2 tensor. + assertTrue( + module.content.contains("stablehlo.constant dense<"), + "small tensor must remain inline under SizeThreshold:\n${module.content}" + ) + // Large is externalized. + assertTrue( + module.content.contains("util.global private @large_w : tensor<4x4xf32>"), + "large tensor must externalize under SizeThreshold:\n${module.content}" + ) + + // Exactly one ref — the large one. + assertEquals(1, module.externalParameters.size) + assertEquals("large_w", module.externalParameters.single().key) + } + + @Test + fun testModuleAttrsHeaderStillEmittedAboveUtilGlobal() { + // When a tensor carries a tensorEncoding we already emit a + // `module attributes { skainet.tensor_encodings = {...} } {` + // header. The new util.global decls must slot in AFTER that + // header, before func.func — otherwise IREE's parser chokes. + // This test is aspirational for now: we only assert both + // lines exist in the expected relative order. + val graph = DefaultComputeGraph() + graph.addNode( + GraphNode( + id = "w", + operation = mockConstantOp( + "tensor_constant", + mapOf("values" to listOf(1.0f, 2.0f, 3.0f, 4.0f)) + ), + inputs = emptyList(), + outputs = listOf(TensorSpec("wkey", listOf(2, 2), "FP32")) + ) + ) + val module = StableHloConverterFactory.createBasic( + ConstantMaterializationPolicy.ExternalAlways() + ).convert(graph, "ordering") + + val headerIdx = module.content.indexOf("module {") + val utilIdx = module.content.indexOf("util.global private") + val funcIdx = module.content.indexOf("func.func") + assertTrue(headerIdx >= 0, "module { missing") + assertTrue(utilIdx > headerIdx, "util.global must follow module header") + assertTrue(funcIdx > utilIdx, "func.func must follow util.global decls") + } + + private fun buildTensorConstantGraph(): DefaultComputeGraph { + val graph = DefaultComputeGraph() + graph.addNode( + GraphNode( + id = "input", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("x", listOf(2, 2), "FP32")) + ) + ) + graph.addNode( + GraphNode( + id = "w", + operation = mockConstantOp( + "tensor_constant", + mapOf("values" to listOf(1.0f, 2.0f, 3.0f, 4.0f)) + ), + inputs = emptyList(), + outputs = listOf(TensorSpec("weights", listOf(2, 2), "FP32")) + ) + ) + return graph + } + + private fun mockConstantOp(name: String, parameters: Map): Operation = + object : Operation { + override val name: String = name + override val type: String = "constant" + override val parameters: Map = parameters + + override fun execute(inputs: List>): List> { + throw UnsupportedOperationException("mock") + } + + override fun validateInputs(inputs: List): ValidationResult = + ValidationResult.Valid + + override fun inferOutputs(inputs: List): List = emptyList() + + override fun clone(newParameters: Map): Operation = this + + override fun serialize(): Map = + mapOf("name" to name, "type" to type, "parameters" to parameters) + } +}