diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 574ddd8aef..26375cc4dc 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -561,6 +561,21 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return op.Constant(value_int=size) +def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None: + """Move all initializers from src graph to dst graph, ensuring name uniqueness.""" + counter: dict[str, int] = {} + for name in list(src.initializers): + initializer = src.initializers.pop(name) + # Ensure name uniqueness in the destination graph + new_name = name + while new_name in dst.initializers: + counter[name] = counter.get(name, 0) + 1 + new_name = f"{name}_{counter[name]}" + if new_name != name: + initializer.name = new_name + dst.register_initializer(initializer) + + @register("If") def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: cond_input = _get_input(node, 0) @@ -598,7 +613,11 @@ def rename(name): # Avoid name collision. sub_node.name = f"{node.name}_{sub_node.name}" - # TODO: we should handle initializers as well! + # Move initializers from the subgraph to the main graph to avoid losing them. + main_graph = node.graph + if main_graph is not None: + _move_initializers_to_graph(graph, main_graph) + return Replacement(formal_outs, graph_nodes) return None diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 080af9c2f3..75a454a39d 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -130,6 +130,45 @@ def test_fold_if_cond(self): self.assertEqual(optimized.graph[0].outputs[0].name, "z") self.assertEqual(optimized.graph[0].op_type, "Mul") + def test_fold_if_cond_with_subgraph_initializer(self): + """If branch initializers should be moved to the main graph when the branch is inlined.""" + # A model with a non-constant condition; constants inside the then_branch will + # be folded into subgraph initializers on the first fold pass. + model = ir.from_onnx_text(""" + + agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { + two = Constant () + three = Constant () + z = If (cond) < + then_branch = then_graph () => (then_z) { + temp = Add (two, three) + then_z = Mul (temp, x) + }, + else_branch = else_graph () => (else_z) { + else_z = Identity (x) + } + > + } + """) + # First fold: 'temp = Add(2.0, 3.0)' gets folded into a subgraph initializer. + _constant_folding.fold_constants(model) + optimizer.remove_unused_nodes(model) + if_node = next(n for n in model.graph if n.op_type == "If") + then_branch = if_node.attributes["then_branch"].as_graph() + self.assertIn("temp", then_branch.initializers) + self.assertNotIn("temp", model.graph.initializers) + + # Make the condition constant (True) to trigger inlining of the then_branch. + const_true = ir.Value(name="const_true") + const_true.const_value = ir.Tensor(np.array(True)) + if_node.replace_input_with(0, const_true) + + # Second fold: the If is inlined; 'temp' must be moved to the main graph. + _constant_folding.fold_constants(model) + optimizer.remove_unused_nodes(model) + onnx.checker.check_model(ir.serde.serialize_model(model)) + self.assertIn("temp", model.graph.initializers) + def test_fold_inside_if_branch(self): model = """