From 34a474ba7942874374226e467bf2c05e9e9e4125 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 12:11:31 +0200 Subject: [PATCH 1/2] Add TensorEncoding accessor for TensorSpec.metadata (#469 step 1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces three additive helpers under sk.ainet.lang.tensor.ops: - `TENSOR_ENCODING_METADATA_KEY` — the shared metadata key so raw map callers agree with the typed accessors. - `TensorSpec.tensorEncoding: TensorEncoding?` — extension getter that reads the encoding stored in metadata, or `null` when the producer did not populate it. A `null` return is intentionally distinct from `TensorEncoding.Dense`. - `TensorSpec.withTensorEncoding(TensorEncoding?)` — returns a copy with the encoding set (or removed for `null`), preserving all other metadata entries untouched. - `TensorData<*, *>.inferTensorEncoding()` — single source of truth mapping concrete `TensorData` subclasses to their `TensorEncoding`. Today that collapses to `PackedBlockStorage` (Q4_K, Q8_0, TernaryPacked, TurboQuant) which already exposes its own `encoding` field, so the helper is one line but centralizes the contract for future non-packed layouts. Unit tests cover unset reads, round-trips for Q8_0, Q4_K, TernaryPacked, and Dense, clearing via `null`, preservation of unrelated metadata, and overwrite semantics. No TraceToGraphBuilder / loader plumbing in this step — that is issue #469 step 2, intentionally scoped to its own commit so the data-model change lands in isolation. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../lang/tensor/ops/TensorSpecEncoding.kt | 54 ++++++++++++ .../lang/tensor/ops/TensorSpecEncodingTest.kt | 87 +++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncoding.kt create mode 100644 skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncodingTest.kt diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncoding.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncoding.kt new file mode 100644 index 00000000..9cc4e620 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncoding.kt @@ -0,0 +1,54 @@ +package sk.ainet.lang.tensor.ops + +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.tensor.storage.TensorEncoding + +/** + * Metadata key used to carry a [TensorEncoding] on a [TensorSpec]. + * + * Exposed so that callers that need to read/write the raw metadata map + * directly (for interop, serialization round-trips, etc.) use the same + * string the typed accessors below use. + */ +public const val TENSOR_ENCODING_METADATA_KEY: String = "tensorEncoding" + +/** + * Physical storage encoding carried on this spec, or `null` if the producer + * did not populate it. + * + * A `null` return means "unknown / not carried through the graph" — it is + * NOT equivalent to [TensorEncoding.Dense]. Consumers that need a concrete + * encoding should treat `null` as unknown and fall back to dtype-driven + * defaults rather than assuming dense. + */ +public val TensorSpec.tensorEncoding: TensorEncoding? + get() = metadata[TENSOR_ENCODING_METADATA_KEY] as? TensorEncoding + +/** + * Return a copy of this spec with [encoding] stored in its metadata map. + * Passing `null` removes the entry; passing a non-null value adds or + * replaces it, leaving all other metadata untouched. + */ +public fun TensorSpec.withTensorEncoding(encoding: TensorEncoding?): TensorSpec { + val newMetadata: Map = if (encoding == null) { + metadata - TENSOR_ENCODING_METADATA_KEY + } else { + metadata + (TENSOR_ENCODING_METADATA_KEY to encoding) + } + return copy(metadata = newMetadata) +} + +/** + * Infer a [TensorEncoding] from a concrete [TensorData] instance, or return + * `null` when the layout is dense / unknown. Single source of truth for the + * data-subclass → encoding mapping so trace builders and loaders agree. + * + * Any [TensorData] implementing [PackedBlockStorage] already exposes its + * own `encoding`, so this helper is one line today but centralizes the + * contract for future non-packed quantized layouts. + */ +public fun TensorData<*, *>.inferTensorEncoding(): TensorEncoding? = when (this) { + is PackedBlockStorage -> this.encoding + else -> null +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncodingTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncodingTest.kt new file mode 100644 index 00000000..bd4022d0 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncodingTest.kt @@ -0,0 +1,87 @@ +package sk.ainet.lang.tensor.ops + +import sk.ainet.lang.tensor.storage.TensorEncoding +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.assertSame +import kotlin.test.assertTrue + +class TensorSpecEncodingTest { + + @Test + fun unset_encoding_reads_as_null() { + val spec = TensorSpec(name = "x", shape = listOf(2, 3), dtype = "FP32") + assertNull(spec.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_Q8_0() { + val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32") + val annotated = spec.withTensorEncoding(TensorEncoding.Q8_0) + + assertSame(TensorEncoding.Q8_0, annotated.tensorEncoding) + // Original spec is untouched — TensorSpec is a data class and the + // helper returns a copy. + assertNull(spec.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_Q4_K() { + val spec = TensorSpec(name = "w", shape = listOf(256), dtype = "FP32") + val annotated = spec.withTensorEncoding(TensorEncoding.Q4_K) + assertSame(TensorEncoding.Q4_K, annotated.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_TernaryPacked() { + val spec = TensorSpec(name = "w", shape = listOf(128), dtype = "FP32") + val annotated = spec.withTensorEncoding(TensorEncoding.TernaryPacked) + assertSame(TensorEncoding.TernaryPacked, annotated.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_Dense() { + val spec = TensorSpec(name = "x", shape = listOf(4), dtype = "FP32") + val dense = TensorEncoding.Dense(bytesPerElement = 4) + val annotated = spec.withTensorEncoding(dense) + assertEquals(dense, annotated.tensorEncoding) + } + + @Test + fun passing_null_removes_the_encoding_entry() { + val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32") + .withTensorEncoding(TensorEncoding.Q8_0) + assertSame(TensorEncoding.Q8_0, spec.tensorEncoding) + + val cleared = spec.withTensorEncoding(null) + assertNull(cleared.tensorEncoding) + assertTrue( + !cleared.metadata.containsKey(TENSOR_ENCODING_METADATA_KEY), + "clearing should remove the metadata key entirely, not leave a null" + ) + } + + @Test + fun withTensorEncoding_preserves_other_metadata() { + val spec = TensorSpec( + name = "w", + shape = listOf(32), + dtype = "FP32", + metadata = mapOf("owner" to "attention.q_proj", "frozen" to true) + ) + val annotated = spec.withTensorEncoding(TensorEncoding.Q8_0) + + assertEquals("attention.q_proj", annotated.metadata["owner"]) + assertEquals(true, annotated.metadata["frozen"]) + assertSame(TensorEncoding.Q8_0, annotated.tensorEncoding) + } + + @Test + fun replacing_encoding_overwrites_previous_value() { + val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32") + .withTensorEncoding(TensorEncoding.Q8_0) + .withTensorEncoding(TensorEncoding.Q4_K) + assertSame(TensorEncoding.Q4_K, spec.tensorEncoding) + } +} From e9ddce67b53131f6a4f31ccd3a1444c501498c60 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 12:22:45 +0200 Subject: [PATCH 2/2] Propagate TensorEncoding through TraceToGraphBuilder.finalize (#469 step 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When finalize() resolves an unresolved tensor via the session, also derive its TensorEncoding via TensorData.inferTensorEncoding() and attach it to the produced node's output spec and to every outgoing edge's tensor spec. Applies symmetrically to both the weight-node branch (when FloatArrayTensorData is extractable) and the input- placeholder branch (when it isn't, which is exactly what used to happen to Q4_K / Q8_0 / Ternary weights). Net effect: a session-resolved weight backed by Q8_0 data now reaches downstream compile stages with `spec.tensorEncoding == TensorEncoding.Q8_0` instead of being silently downgraded to a lossy FP32 placeholder. Existing skainet-compile-dag tests stay green — the change is additive and dense / FloatArray paths see no behavior difference (inferTensorEncoding returns null, withTensorEncoding(null) is a no-op). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../ainet/lang/trace/TraceToGraphBuilder.kt | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt index b31990a1..55a5a304 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt @@ -6,6 +6,8 @@ import sk.ainet.lang.graph.GraphNode import sk.ainet.lang.tensor.ops.Operation import sk.ainet.lang.tensor.ops.TensorSpec import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.tensor.ops.inferTensorEncoding +import sk.ainet.lang.tensor.ops.withTensorEncoding /** * Shared builder to convert OpTrace streams into a ComputeGraph. @@ -258,8 +260,14 @@ public class TraceToGraphBuilder( // Try to resolve as a constant from the session val tensor = if (!forceInput && embedConstants) session?.resolve(firstRef.tensorRef) else null val constantValues = tensor?.let { extractFloatArray(it) } + // Resolved tensors that carry a concrete storage encoding (Q4_K, + // Q8_0, TernaryPacked, TurboQuant, …) propagate it onto the + // produced spec so later compile stages can preserve the + // quantization instead of silently re-materializing FP32. + val encoding = tensor?.data?.inferTensorEncoding() val syntheticNode: GraphNode + val producedSpec: TensorSpec if (constantValues != null) { // Create a constant/weight node with embedded values val weightShape = tensor!!.shape.dimensions.toList() @@ -273,15 +281,16 @@ public class TraceToGraphBuilder( "trainable" to false ) ) + producedSpec = TensorSpec( + name = tensorId, + shape = weightShape, + dtype = weightDtype + ).withTensorEncoding(encoding) syntheticNode = GraphNode( id = nodeId, operation = op, inputs = emptyList(), - outputs = listOf(TensorSpec( - name = tensorId, - shape = weightShape, - dtype = weightDtype - )) + outputs = listOf(producedSpec) ) } else { // Create an input placeholder node @@ -291,17 +300,19 @@ public class TraceToGraphBuilder( type = "input", parameters = emptyMap() ) + producedSpec = spec.withTensorEncoding(encoding) syntheticNode = GraphNode( id = nodeId, operation = op, inputs = emptyList(), - outputs = listOf(spec) + outputs = listOf(producedSpec) ) } graph.addNode(syntheticNode) - // Wire edges to all consumers + // Wire edges to all consumers, propagating the encoding on the + // edge tensor spec so every consumer sees the quantization hint. for (ref in refs) { graph.addEdge( GraphEdge( @@ -310,13 +321,13 @@ public class TraceToGraphBuilder( destination = ref.consumerNode, sourceOutputIndex = 0, destinationInputIndex = ref.inputIndex, - tensorSpec = ref.spec + tensorSpec = ref.spec.withTensorEncoding(encoding) ) ) } // Register as producer - producersByTensorId[tensorId] = Producer(syntheticNode, 0, spec) + producersByTensorId[tensorId] = Producer(syntheticNode, 0, producedSpec) } unresolvedByTensorId.clear() }