Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,14 @@ public class ActivationOperationsConverter : StableHloOperationConverter {
}

/**
* Convert softmax activation using stablehlo.reduce and stablehlo.broadcast_in_dim.
* Convert softmax activation using real reductions and broadcast_in_dim.
* softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
*
* The max and sum terms are lowered to stablehlo.custom_call @reduce_max
* and @reduce_sum — matching ReductionOperationsConverter's style — then
* broadcast back to the input shape before subtract / divide. This replaces
* an earlier lowering that used `dense<0.0>` / `dense<1.0>` placeholder
* constants (see #467) and produced numerically wrong MLIR.
*/
private fun convertSoftmax(
node: GraphNode,
Expand All @@ -100,49 +106,77 @@ public class ActivationOperationsConverter : StableHloOperationConverter {
"Unsupported softmax arity for node ${node.id}"
)
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?xf32>"
val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) }
val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) }
?: "f32"

// Get axis parameter (default to last dimension)
val axis = node.operation.parameters["axis"] as? Int ?: -1
val actualAxis = if (axis < 0) {
// For simplicity, assume last dimension (1 for 2D tensor)
1

val inputShape = node.inputs.firstOrNull()?.shape ?: outputSpec?.shape ?: emptyList()
val rank = inputShape.size

// Normalize axis against rank. Default to the last dimension.
val rawAxis = node.operation.parameters["axis"] as? Int ?: -1
val axis = when {
rank == 0 -> 0
rawAxis < 0 -> rank + rawAxis
else -> rawAxis
}.coerceIn(0, (rank - 1).coerceAtLeast(0))

// Reduced tensor type: input shape with `axis` dimension removed.
val reducedShape = if (rank > 0) {
inputShape.filterIndexed { i, _ -> i != axis }
} else {
axis
emptyList()
}

// For a simplified softmax implementation, we'll use element-wise operations
// This is a basic implementation that works along the last dimension
val reducedType = if (reducedShape.isEmpty()) {
"tensor<$elementType>"
} else {
"tensor<${reducedShape.joinToString("x")}x$elementType>"
}

// Dimensions kept for broadcast_in_dim: every input dim except `axis`,
// mapped to its position in the reduced tensor.
val broadcastDims = (0 until rank).filter { it != axis }.joinToString(", ")

val maxValue = context.nextTempValue()
val maxBroadcast = context.nextTempValue()
val shiftedValue = context.nextTempValue()
val expValue = context.nextTempValue()
val sumValue = context.nextTempValue()
val sumBroadcast = context.nextTempValue()
val resultValue = context.nextTempValue()

val operations = listOf(
// Find maximum (simplified - using a constant for now)
"$maxValue = stablehlo.constant dense<0.0> : $outputType",

// Subtract max for numerical stability (simplified)
"$shiftedValue = stablehlo.subtract ${operands[0]}, $maxValue : $outputType",

// Apply exponential
// Reduce-max along the softmax axis (for numerical stability).
"$maxValue = stablehlo.custom_call @reduce_max(${operands[0]}) " +
"{dimensions = [$axis], keepdim = false} : $reducedType",

// Broadcast reduced max back to the input shape.
"$maxBroadcast = stablehlo.broadcast_in_dim $maxValue, " +
"dims = [$broadcastDims] : ($reducedType) -> $outputType",

// Subtract the max for numerical stability.
"$shiftedValue = stablehlo.subtract ${operands[0]}, $maxBroadcast : $outputType",

// Elementwise exponential.
"$expValue = stablehlo.exponential $shiftedValue : $outputType",

// Sum (simplified - using a constant sum for now)
"$sumValue = stablehlo.constant dense<1.0> : $outputType",

// Divide to get final softmax
"$resultValue = stablehlo.divide $expValue, $sumValue : $outputType"

// Reduce-sum along the softmax axis.
"$sumValue = stablehlo.custom_call @reduce_sum($expValue) " +
"{dimensions = [$axis], keepdim = false} : $reducedType",

// Broadcast the sum back to the input shape.
"$sumBroadcast = stablehlo.broadcast_in_dim $sumValue, " +
"dims = [$broadcastDims] : ($reducedType) -> $outputType",

// Normalize.
"$resultValue = stablehlo.divide $expValue, $sumBroadcast : $outputType"
)

operations.forEach { context.emitOperation(it) }

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sk.ainet.compile.hlo

import kotlin.test.Test
import kotlin.test.assertTrue
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertEquals
import sk.ainet.lang.graph.DefaultComputeGraph
Expand Down Expand Up @@ -65,6 +66,34 @@ class ActivationOperationsConverterTest {
assertTrue(module.content.contains("stablehlo.exponential"))
assertTrue(module.content.contains("stablehlo.divide"))
assertTrue(module.content.contains("tensor<2x3xf32>"))

// Regression: softmax must NOT hardcode its max/sum terms as constants
// (issue #467 — the fake-max `dense<0.0>` and fake-sum `dense<1.0>` at
// full output shape produced numerically wrong MLIR).
assertFalse(
module.content.contains("stablehlo.constant dense<0.0> : tensor<2x3xf32>"),
"softmax must not emit a fake-max constant at output shape"
)
assertFalse(
module.content.contains("stablehlo.constant dense<1.0> : tensor<2x3xf32>"),
"softmax must not emit a fake-sum constant at output shape"
)

// The corrected lowering invokes real reductions (via custom_call for
// now — matching ReductionOperationsConverter style) plus a broadcast
// back to the input shape before the subtract/divide.
assertTrue(
module.content.contains("@reduce_max"),
"softmax must lower max(x) to a real reduction"
)
assertTrue(
module.content.contains("@reduce_sum"),
"softmax must lower sum(exp(...)) to a real reduction"
)
assertTrue(
module.content.contains("stablehlo.broadcast_in_dim"),
"softmax must broadcast reduced values back to the input shape"
)
}

@Test
Expand Down
Loading