Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,20 @@ public class ConstantOperationsConverter : StableHloOperationConverter {
* 1D [3]: dense<[v0, v1, v2]>
* 2D [2,3]: dense<[[v0,v1,v2],[v3,v4,v5]]>
* 4D [1,3,1,1]: dense<[[[[v0],[v1],[v2]]]]>
*
* Splat collapse: when every element is the same value and the input
* list fully covers the shape, emit the single-scalar splat form
* (`dense<v>` ≡ `dense<[[v, v, ...], ...]>` for any rank). This is the
* first-pass lever against the 151 MB MLIR-text blowup described in
* #519 — uninitialized VoidTensorOps-backed weights are uniform by
* construction and compress from O(N*M) characters down to one.
*/
private fun formatTensorValues(values: List<*>, outputSpec: TensorSpec?): String {
val shape = outputSpec?.shape ?: emptyList()
val expectedSize = if (shape.isEmpty()) values.size else shape.fold(1) { acc, d -> acc * d }
if (values.isNotEmpty() && values.size >= expectedSize && values.toSet().size == 1) {
return formatConstantValue(values[0] as Number)
}

return when {
values.isEmpty() -> "0.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,76 @@ class ConstantOperationsConverterTest {
assertTrue(module.content.contains("dense<2.0>"))
}

@Test
fun testTensorConstantWithUniformValuesEmitsSplat() {
// Regression for #519: a uniform-value list (common for
// VoidTensorOps-initialized weights) must collapse to the
// dense<v> splat form rather than expanding into an N-element
// nested array literal. Without the fix, a 10x10 zero tensor
// renders as 100 floats in text; with it, one scalar.
val graph = DefaultComputeGraph()
val inputOp = InputOperation<DType, Any>()
val inputNode = GraphNode(
id = "in",
operation = inputOp,
inputs = emptyList(),
outputs = listOf(TensorSpec("x", listOf(10, 10), "FP32"))
)
val zeros = List(100) { 0.0f }
val weightOp = createConstantOperation(
"tensor_constant",
mapOf("values" to zeros)
)
val weightNode = GraphNode(
id = "w",
operation = weightOp,
inputs = emptyList(),
outputs = listOf(TensorSpec("w", listOf(10, 10), "FP32"))
)
graph.addNode(inputNode)
graph.addNode(weightNode)

val fullConverter = StableHloConverterFactory.createExtended()
val module = fullConverter.convert(graph, "test_uniform_splat")

assertTrue(
module.content.contains("dense<0.0>"),
"uniform-zero tensor must collapse to splat form, got:\n${module.content}"
)
// No spelled-out array for the uniform constant.
assertTrue(
!module.content.contains("dense<[[0.0, 0.0"),
"uniform splat must not also emit a nested array literal"
)
}

@Test
fun testTensorConstantWithNonUniformValuesKeepsNestedLiteral() {
// Opposite direction: when values differ, we must still spell
// them out — splat is only for uniform lists, not a blanket
// compression.
val graph = DefaultComputeGraph()
val weightOp = createConstantOperation(
"tensor_constant",
mapOf("values" to listOf(1.0f, 2.0f, 3.0f, 4.0f))
)
val weightNode = GraphNode(
id = "w",
operation = weightOp,
inputs = emptyList(),
outputs = listOf(TensorSpec("w", listOf(2, 2), "FP32"))
)
graph.addNode(weightNode)

val fullConverter = StableHloConverterFactory.createExtended()
val module = fullConverter.convert(graph, "test_non_uniform")

assertTrue(
module.content.contains("[[1.0, 2.0], [3.0, 4.0]]"),
"non-uniform tensor must keep nested array literal, got:\n${module.content}"
)
}

// Helper methods to create test graphs

private fun createGraphWithInputAndConstant(): DefaultComputeGraph {
Expand Down
Loading