Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,21 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return op.Constant(value_int=size)


def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None:
"""Move all initializers from src graph to dst graph, ensuring name uniqueness."""
counter: dict[str, int] = {}
for name in list(src.initializers):
initializer = src.initializers.pop(name)
# Ensure name uniqueness in the destination graph
new_name = name
while new_name in dst.initializers:
counter[name] = counter.get(name, 0) + 1
new_name = f"{name}_{counter[name]}"
if new_name != name:
initializer.name = new_name
dst.register_initializer(initializer)


@register("If")
def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
cond_input = _get_input(node, 0)
Expand Down Expand Up @@ -598,7 +613,11 @@ def rename(name):
# Avoid name collision.
sub_node.name = f"{node.name}_{sub_node.name}"

# TODO: we should handle initializers as well!
# Move initializers from the subgraph to the main graph to avoid losing them.
main_graph = node.graph
if main_graph is not None:
_move_initializers_to_graph(graph, main_graph)

return Replacement(formal_outs, graph_nodes)
return None

Expand Down
39 changes: 39 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,45 @@ def test_fold_if_cond(self):
self.assertEqual(optimized.graph[0].outputs[0].name, "z")
self.assertEqual(optimized.graph[0].op_type, "Mul")

def test_fold_if_cond_with_subgraph_initializer(self):
"""If branch initializers should be moved to the main graph when the branch is inlined."""
# A model with a non-constant condition; constants inside the then_branch will
# be folded into subgraph initializers on the first fold pass.
model = ir.from_onnx_text("""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[16, 16] x, bool cond) => (float[16, 16] z) {
two = Constant <value_float=2.0> ()
three = Constant <value_float=3.0> ()
z = If (cond) <
then_branch = then_graph () => (then_z) {
temp = Add (two, three)
then_z = Mul (temp, x)
},
else_branch = else_graph () => (else_z) {
else_z = Identity (x)
}
>
}
""")
# First fold: 'temp = Add(2.0, 3.0)' gets folded into a subgraph initializer.
_constant_folding.fold_constants(model)
optimizer.remove_unused_nodes(model)
if_node = next(n for n in model.graph if n.op_type == "If")
then_branch = if_node.attributes["then_branch"].as_graph()
self.assertIn("temp", then_branch.initializers)
self.assertNotIn("temp", model.graph.initializers)

# Make the condition constant (True) to trigger inlining of the then_branch.
const_true = ir.Value(name="const_true")
const_true.const_value = ir.Tensor(np.array(True))
if_node.replace_input_with(0, const_true)

# Second fold: the If is inlined; 'temp' must be moved to the main graph.
_constant_folding.fold_constants(model)
optimizer.remove_unused_nodes(model)
onnx.checker.check_model(ir.serde.serialize_model(model))
self.assertIn("temp", model.graph.initializers)

def test_fold_inside_if_branch(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
Expand Down
Loading