diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt index 61318dbb..8e3c1933 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt @@ -348,9 +348,20 @@ public class ConstantOperationsConverter : StableHloOperationConverter { * 1D [3]: dense<[v0, v1, v2]> * 2D [2,3]: dense<[[v0,v1,v2],[v3,v4,v5]]> * 4D [1,3,1,1]: dense<[[[[v0],[v1],[v2]]]]> + * + * Splat collapse: when every element is the same value and the input + * list fully covers the shape, emit the single-scalar splat form + * (`dense` ≡ `dense<[[v, v, ...], ...]>` for any rank). This is the + * first-pass lever against the 151 MB MLIR-text blowup described in + * #519 — uninitialized VoidTensorOps-backed weights are uniform by + * construction and compress from O(N*M) characters down to one. */ private fun formatTensorValues(values: List<*>, outputSpec: TensorSpec?): String { val shape = outputSpec?.shape ?: emptyList() + val expectedSize = if (shape.isEmpty()) values.size else shape.fold(1) { acc, d -> acc * d } + if (values.isNotEmpty() && values.size >= expectedSize && values.toSet().size == 1) { + return formatConstantValue(values[0] as Number) + } return when { values.isEmpty() -> "0.0" diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantOperationsConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantOperationsConverterTest.kt index 791da0d8..bb66876c 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantOperationsConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConstantOperationsConverterTest.kt @@ -59,6 +59,76 @@ class ConstantOperationsConverterTest { assertTrue(module.content.contains("dense<2.0>")) } + @Test + fun testTensorConstantWithUniformValuesEmitsSplat() { + // Regression for #519: a uniform-value list (common for + // VoidTensorOps-initialized weights) must collapse to the + // dense splat form rather than expanding into an N-element + // nested array literal. Without the fix, a 10x10 zero tensor + // renders as 100 floats in text; with it, one scalar. + val graph = DefaultComputeGraph() + val inputOp = InputOperation() + val inputNode = GraphNode( + id = "in", + operation = inputOp, + inputs = emptyList(), + outputs = listOf(TensorSpec("x", listOf(10, 10), "FP32")) + ) + val zeros = List(100) { 0.0f } + val weightOp = createConstantOperation( + "tensor_constant", + mapOf("values" to zeros) + ) + val weightNode = GraphNode( + id = "w", + operation = weightOp, + inputs = emptyList(), + outputs = listOf(TensorSpec("w", listOf(10, 10), "FP32")) + ) + graph.addNode(inputNode) + graph.addNode(weightNode) + + val fullConverter = StableHloConverterFactory.createExtended() + val module = fullConverter.convert(graph, "test_uniform_splat") + + assertTrue( + module.content.contains("dense<0.0>"), + "uniform-zero tensor must collapse to splat form, got:\n${module.content}" + ) + // No spelled-out array for the uniform constant. + assertTrue( + !module.content.contains("dense<[[0.0, 0.0"), + "uniform splat must not also emit a nested array literal" + ) + } + + @Test + fun testTensorConstantWithNonUniformValuesKeepsNestedLiteral() { + // Opposite direction: when values differ, we must still spell + // them out — splat is only for uniform lists, not a blanket + // compression. + val graph = DefaultComputeGraph() + val weightOp = createConstantOperation( + "tensor_constant", + mapOf("values" to listOf(1.0f, 2.0f, 3.0f, 4.0f)) + ) + val weightNode = GraphNode( + id = "w", + operation = weightOp, + inputs = emptyList(), + outputs = listOf(TensorSpec("w", listOf(2, 2), "FP32")) + ) + graph.addNode(weightNode) + + val fullConverter = StableHloConverterFactory.createExtended() + val module = fullConverter.convert(graph, "test_non_uniform") + + assertTrue( + module.content.contains("[[1.0, 2.0], [3.0, 4.0]]"), + "non-uniform tensor must keep nested array literal, got:\n${module.content}" + ) + } + // Helper methods to create test graphs private fun createGraphWithInputAndConstant(): DefaultComputeGraph {