diff --git a/onnxscript/_internal/_inliner.py b/onnxscript/_internal/_inliner.py index 6a4d6d6742..ba38f573e5 100644 --- a/onnxscript/_internal/_inliner.py +++ b/onnxscript/_internal/_inliner.py @@ -10,29 +10,29 @@ def instantiate( - function: ir.Function, + graph: ir.Graph, inputs: Sequence[ir.Value | None], attributes: Mapping[str, ir.Attr], *, prefix: str = "", ) -> tuple[list[ir.Node], list[ir.Value | None]]: - """Instantiate (inline) a function, substituting inputs and attributes. + """Instantiate (inline) a graph, substituting inputs and attributes. Args: - function: The function to instantiate. - inputs: Actual input values to bind to the function's formal parameters. + graph: The graph to instantiate. + inputs: Actual input values to bind to the graph's formal parameters. attributes: Attribute values to substitute for reference attributes. prefix: Optional prefix to prepend to node and output names. Returns: - A tuple of (nodes, outputs) where nodes are the cloned function body - and outputs are the values corresponding to the function's outputs. + A tuple of (nodes, outputs) where nodes are the cloned graph body + and outputs are the values corresponding to the graph's outputs. """ - formal_inputs = function.inputs + formal_inputs = graph.inputs if len(inputs) > len(formal_inputs): raise ValueError( f"Too many inputs: got {len(inputs)}, " - f"but function has {len(formal_inputs)} parameters." + f"but graph has {len(formal_inputs)} parameters." ) value_map: dict[ir.Value, ir.Value | None] = dict(zip(formal_inputs, inputs)) @@ -50,7 +50,8 @@ def rename(node: ir.Node) -> None: metadata_props={}, post_process=rename, resolve_ref_attrs=True, + allow_outer_scope_values=True, ) - nodes = [cloner.clone_node(n) for n in function] - outputs = [value_map.get(v) for v in function.outputs] + nodes = [cloner.clone_node(n) for n in graph] + outputs = [value_map.get(v) for v in graph.outputs] return nodes, outputs diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index fb81f3451f..06b6edaa85 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -450,25 +450,24 @@ def call( **kwargs, ): if isinstance(function, ir.Function): - function_ir = function + graph = function.graph elif isinstance(function, onnxscript.OnnxFunction): - function_proto = function.to_function_proto() - function_ir = ir.serde.deserialize_function(function_proto) + graph = function.graph() else: raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") output_renaming: dict[str, str] = {} if _outputs is not None: - if len(_outputs) != len(function_ir.outputs): + if len(_outputs) != len(graph.outputs): raise ValueError( f"Number of provided output names {_outputs} does not match " - f"number of function outputs {len(function_ir.outputs)}." + f"number of function outputs {len(graph.outputs)}." ) - for output, name in zip(function_ir.outputs, _outputs): + for output, name in zip(graph.outputs, _outputs): output_renaming[output.name] = self._qualify_value_name(name) else: - for output in function_ir.outputs: + for output in graph.outputs: output_renaming[output.name] = self._qualify_value_name(output.name) - nodes, outputs = _inliner.instantiate(function_ir, args, kwargs) + nodes, outputs = _inliner.instantiate(graph, args, kwargs) if _prefix: self.push_module(_prefix) for node in nodes: diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 1ff8460b67..ffc1ab44a4 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -10,6 +10,7 @@ import onnx_ir as ir import onnxscript._internal.builder as builder +import onnxscript.testing from onnxscript import script from onnxscript.onnx_types import DOUBLE, FLOAT @@ -713,6 +714,31 @@ def add_mul(X, Y): self.assertEqual(nodes[0].op_type, "Add") self.assertEqual(nodes[1].op_type, "Mul") + def test_call_with_outer_scope_value(self): + """Test that script supports references to pre-existing values.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + product = op.Mul(x, y) + + @script() + def add_product(X): + return op.Add(X, product) # Reference to 'product' from outer scope + + x_plus = op.call(add_product, x, _outputs=["x_plus"]) + y_plus = op.call(add_product, y, _outputs=["y_plus"]) + + op.builder.graph.outputs.extend([x_plus, y_plus]) + + # Now, create the same graph directly: + op2, x2, y2 = _create_builder_with_inputs() + product2 = op2.Mul(x2, y2) + x2_plus = op2.Add(x2, product2, _outputs=["x_plus"]) + y2_plus = op2.Add(y2, product2, _outputs=["y_plus"]) + op2.builder.graph.outputs.extend([x2_plus, y2_plus]) + + # Verify that the two graphs are structurally equivalent + onnxscript.testing.assert_isomorphic_graph(op.builder.graph, op2.builder.graph) + def test_call_with_prefix_option(self): """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" # Create a GraphBuilder first diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index dd215a7c06..468aa41675 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -377,6 +377,9 @@ def _to_onnx_var( if isinstance(val, values.SymbolValue): if isinstance(val.value, ir.Value): return val.value + if isinstance(val, ir.Value): + # An outer-scope ir.Value (e.g., from a closure variable) can be used directly. + return val # Assume value is a python-value convertible to a tensor # TODO: check if value is convertible to a TensorProto, so that we can # produce a better error _message otherwise diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 0019802f36..051cb3e686 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -333,6 +333,15 @@ def to_function_proto(self) -> onnx.FunctionProto: """Converts the function into :class:`onnx.FunctionProto`.""" return self.function_ir.to_function_proto() + def graph(self) -> ir.Graph: + """Returns the IR graph representation of this function. + + Returns: + The :class:`ir.Graph` representing the computation graph of this function. + NOTE: This is not a copy, and should not be modified by the caller. + """ + return self.function_ir.graph + def to_model_proto(self, **kwargs): """Converts the function into :class:`onnx.ModelProto`.""" if self.function_ir.attrs and any( diff --git a/onnxscript/testing/__init__.py b/onnxscript/testing/__init__.py index 048b45e7e8..0b40a4aa35 100644 --- a/onnxscript/testing/__init__.py +++ b/onnxscript/testing/__init__.py @@ -320,6 +320,8 @@ def _to_function_proto(f): return f.to_function_proto() if isinstance(f, str): return parser.parse_function(f) + if isinstance(f, ir.Function): + return ir.to_proto(f) raise TypeError(f"Cannot convert {type(f)} to FunctionProto") @@ -330,6 +332,8 @@ def _to_graph_proto(g): return g.to_model_proto().graph if isinstance(g, str): return parser.parse_graph(g) + if isinstance(g, ir.Graph): + return ir.to_proto(g) raise TypeError(f"Cannot convert {type(g)} to ModelProto") @@ -342,6 +346,10 @@ def _to_function_or_graph(obj): return obj.graph if isinstance(obj, onnxscript.OnnxFunction): return obj.to_function_proto() + if isinstance(obj, ir.Function): + return ir.to_proto(obj) + if isinstance(obj, ir.Graph): + return ir.to_proto(obj) raise TypeError(f"Cannot convert {type(obj)} to FunctionProto or GraphProto")