Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,26 @@ import sk.ainet.lang.graph.GraphNode
public class ActivationOperationsConverter : StableHloOperationConverter {

override val supportedOperations: Set<String> = setOf(
"sigmoid", "softmax", "tanh", "gelu", "swish"
"sigmoid", "softmax", "tanh", "gelu", "swish",
// SiLU (Sigmoid Linear Unit) is the name every Llama / Mistral /
// Qwen / Gemma family model uses for the same x * sigmoid(x)
// activation that PyTorch historically called swish. Register
// the alias so traced LLM graphs don't fall through to the
// "no converter found" path.
"silu", "SiLU"
)

override fun convert(
node: GraphNode,
operands: List<String>,
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
return when (node.operation.name.lowercase()) {
"sigmoid" -> convertSigmoid(node, operands, context)
"softmax" -> convertSoftmax(node, operands, context)
"tanh" -> convertTanh(node, operands, context)
"gelu" -> convertGelu(node, operands, context)
"swish" -> convertSwish(node, operands, context)
"swish", "silu" -> convertSwish(node, operands, context)
else -> ConversionResult.Unsupported(
node.operation.name,
"Operation not supported by ActivationOperationsConverter"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ class ActivationOperationsConverterTest {
assertTrue(registry.isSupported("tanh"))
assertTrue(registry.isSupported("gelu"))
assertTrue(registry.isSupported("swish"))
// silu / SiLU is the Llama-family alias for the same
// x * sigmoid(x) lowering that swish uses.
assertTrue(registry.isSupported("silu"))
assertTrue(registry.isSupported("SiLU"))
}

@Test
fun testSiluAliasLowersLikeSwish() {
val graph = createActivationGraph("silu")
val converter = StableHloConverterFactory.createExtended()
val module = converter.convert(graph, "test_silu")

// silu == swish: x * sigmoid(x). The emitted MLIR must
// contain the same ops the swish path produces — the
// sigmoid expansion (negate, exp, constant 1.0, add, divide)
// plus the final multiply with the original input.
assertTrue(module.content.contains("stablehlo.negate"))
assertTrue(module.content.contains("stablehlo.exponential"))
assertTrue(module.content.contains("stablehlo.constant dense<1.0>"))
assertTrue(module.content.contains("stablehlo.add"))
assertTrue(module.content.contains("stablehlo.divide"))
assertTrue(module.content.contains("stablehlo.multiply"))
assertTrue(module.content.contains("tensor<2x3xf32>"))
}

@Test
Expand Down
Loading