diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt index f04d1e91..0d67bc14 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt @@ -22,12 +22,18 @@ import sk.ainet.lang.graph.GraphNode public class ActivationOperationsConverter : StableHloOperationConverter { override val supportedOperations: Set = setOf( - "sigmoid", "softmax", "tanh", "gelu", "swish" + "sigmoid", "softmax", "tanh", "gelu", "swish", + // SiLU (Sigmoid Linear Unit) is the name every Llama / Mistral / + // Qwen / Gemma family model uses for the same x * sigmoid(x) + // activation that PyTorch historically called swish. Register + // the alias so traced LLM graphs don't fall through to the + // "no converter found" path. + "silu", "SiLU" ) - + override fun convert( - node: GraphNode, - operands: List, + node: GraphNode, + operands: List, context: ConversionContext ): ConversionResult { return when (node.operation.name.lowercase()) { @@ -35,7 +41,7 @@ public class ActivationOperationsConverter : StableHloOperationConverter { "softmax" -> convertSoftmax(node, operands, context) "tanh" -> convertTanh(node, operands, context) "gelu" -> convertGelu(node, operands, context) - "swish" -> convertSwish(node, operands, context) + "swish", "silu" -> convertSwish(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by ActivationOperationsConverter" diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt index 7ecaa25f..6e91c242 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt @@ -37,6 +37,29 @@ class ActivationOperationsConverterTest { assertTrue(registry.isSupported("tanh")) assertTrue(registry.isSupported("gelu")) assertTrue(registry.isSupported("swish")) + // silu / SiLU is the Llama-family alias for the same + // x * sigmoid(x) lowering that swish uses. + assertTrue(registry.isSupported("silu")) + assertTrue(registry.isSupported("SiLU")) + } + + @Test + fun testSiluAliasLowersLikeSwish() { + val graph = createActivationGraph("silu") + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_silu") + + // silu == swish: x * sigmoid(x). The emitted MLIR must + // contain the same ops the swish path produces — the + // sigmoid expansion (negate, exp, constant 1.0, add, divide) + // plus the final multiply with the original input. + assertTrue(module.content.contains("stablehlo.negate")) + assertTrue(module.content.contains("stablehlo.exponential")) + assertTrue(module.content.contains("stablehlo.constant dense<1.0>")) + assertTrue(module.content.contains("stablehlo.add")) + assertTrue(module.content.contains("stablehlo.divide")) + assertTrue(module.content.contains("stablehlo.multiply")) + assertTrue(module.content.contains("tensor<2x3xf32>")) } @Test