Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions onnxscript/_internal/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,29 @@


def instantiate(
function: ir.Function,
graph: ir.Graph,
inputs: Sequence[ir.Value | None],
attributes: Mapping[str, ir.Attr],
*,
prefix: str = "",
) -> tuple[list[ir.Node], list[ir.Value | None]]:
"""Instantiate (inline) a function, substituting inputs and attributes.
"""Instantiate (inline) a graph, substituting inputs and attributes.

Args:
function: The function to instantiate.
inputs: Actual input values to bind to the function's formal parameters.
graph: The graph to instantiate.
inputs: Actual input values to bind to the graph's formal parameters.
attributes: Attribute values to substitute for reference attributes.
prefix: Optional prefix to prepend to node and output names.

Returns:
A tuple of (nodes, outputs) where nodes are the cloned function body
and outputs are the values corresponding to the function's outputs.
A tuple of (nodes, outputs) where nodes are the cloned graph body
and outputs are the values corresponding to the graph's outputs.
"""
formal_inputs = function.inputs
formal_inputs = graph.inputs
if len(inputs) > len(formal_inputs):
raise ValueError(
f"Too many inputs: got {len(inputs)}, "
f"but function has {len(formal_inputs)} parameters."
f"but graph has {len(formal_inputs)} parameters."
)
value_map: dict[ir.Value, ir.Value | None] = dict(zip(formal_inputs, inputs))

Expand All @@ -50,7 +50,8 @@ def rename(node: ir.Node) -> None:
metadata_props={},
post_process=rename,
resolve_ref_attrs=True,
allow_outer_scope_values=True,
)
nodes = [cloner.clone_node(n) for n in function]
outputs = [value_map.get(v) for v in function.outputs]
nodes = [cloner.clone_node(n) for n in graph]
outputs = [value_map.get(v) for v in graph.outputs]
return nodes, outputs
15 changes: 7 additions & 8 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,25 +450,24 @@ def call(
**kwargs,
):
if isinstance(function, ir.Function):
function_ir = function
graph = function.graph
elif isinstance(function, onnxscript.OnnxFunction):
function_proto = function.to_function_proto()
function_ir = ir.serde.deserialize_function(function_proto)
graph = function.graph()
else:
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
output_renaming: dict[str, str] = {}
if _outputs is not None:
if len(_outputs) != len(function_ir.outputs):
if len(_outputs) != len(graph.outputs):
raise ValueError(
f"Number of provided output names {_outputs} does not match "
f"number of function outputs {len(function_ir.outputs)}."
f"number of function outputs {len(graph.outputs)}."
)
for output, name in zip(function_ir.outputs, _outputs):
for output, name in zip(graph.outputs, _outputs):
output_renaming[output.name] = self._qualify_value_name(name)
else:
for output in function_ir.outputs:
for output in graph.outputs:
output_renaming[output.name] = self._qualify_value_name(output.name)
nodes, outputs = _inliner.instantiate(function_ir, args, kwargs)
nodes, outputs = _inliner.instantiate(graph, args, kwargs)
if _prefix:
self.push_module(_prefix)
for node in nodes:
Expand Down
26 changes: 26 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import onnx_ir as ir

import onnxscript._internal.builder as builder
import onnxscript.testing
from onnxscript import script
from onnxscript.onnx_types import DOUBLE, FLOAT

Expand Down Expand Up @@ -713,6 +714,31 @@ def add_mul(X, Y):
self.assertEqual(nodes[0].op_type, "Add")
self.assertEqual(nodes[1].op_type, "Mul")

def test_call_with_outer_scope_value(self):
"""Test that script supports references to pre-existing values."""
# Create a GraphBuilder first
op, x, y = _create_builder_with_inputs()
product = op.Mul(x, y)

@script()
def add_product(X):
return op.Add(X, product) # Reference to 'product' from outer scope

x_plus = op.call(add_product, x, _outputs=["x_plus"])
y_plus = op.call(add_product, y, _outputs=["y_plus"])

op.builder.graph.outputs.extend([x_plus, y_plus])

# Now, create the same graph directly:
op2, x2, y2 = _create_builder_with_inputs()
product2 = op2.Mul(x2, y2)
x2_plus = op2.Add(x2, product2, _outputs=["x_plus"])
y2_plus = op2.Add(y2, product2, _outputs=["y_plus"])
op2.builder.graph.outputs.extend([x2_plus, y2_plus])

# Verify that the two graphs are structurally equivalent
onnxscript.testing.assert_isomorphic_graph(op.builder.graph, op2.builder.graph)

def test_call_with_prefix_option(self):
"""Test that GraphBuilder.call respects the _prefix option for hierarchical naming."""
# Create a GraphBuilder first
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/_internal/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def _to_onnx_var(
if isinstance(val, values.SymbolValue):
if isinstance(val.value, ir.Value):
return val.value
if isinstance(val, ir.Value):
# An outer-scope ir.Value (e.g., from a closure variable) can be used directly.
return val
# Assume value is a python-value convertible to a tensor
# TODO: check if value is convertible to a TensorProto, so that we can
# produce a better error _message otherwise
Expand Down
9 changes: 9 additions & 0 deletions onnxscript/_internal/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,15 @@ def to_function_proto(self) -> onnx.FunctionProto:
"""Converts the function into :class:`onnx.FunctionProto`."""
return self.function_ir.to_function_proto()

def graph(self) -> ir.Graph:
"""Returns the IR graph representation of this function.

Returns:
The :class:`ir.Graph` representing the computation graph of this function.
NOTE: This is not a copy, and should not be modified by the caller.
"""
return self.function_ir.graph

def to_model_proto(self, **kwargs):
"""Converts the function into :class:`onnx.ModelProto`."""
if self.function_ir.attrs and any(
Expand Down
8 changes: 8 additions & 0 deletions onnxscript/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def _to_function_proto(f):
return f.to_function_proto()
if isinstance(f, str):
return parser.parse_function(f)
if isinstance(f, ir.Function):
return ir.to_proto(f)
raise TypeError(f"Cannot convert {type(f)} to FunctionProto")


Expand All @@ -330,6 +332,8 @@ def _to_graph_proto(g):
return g.to_model_proto().graph
if isinstance(g, str):
return parser.parse_graph(g)
if isinstance(g, ir.Graph):
return ir.to_proto(g)
raise TypeError(f"Cannot convert {type(g)} to ModelProto")


Expand All @@ -342,6 +346,10 @@ def _to_function_or_graph(obj):
return obj.graph
if isinstance(obj, onnxscript.OnnxFunction):
return obj.to_function_proto()
if isinstance(obj, ir.Function):
return ir.to_proto(obj)
if isinstance(obj, ir.Graph):
return ir.to_proto(obj)
raise TypeError(f"Cannot convert {type(obj)} to FunctionProto or GraphProto")


Expand Down
Loading