diff --git a/docs/tutorial/builder/graph_builder.md b/docs/tutorial/builder/graph_builder.md index 55bd83a90e..27eeeb4f02 100644 --- a/docs/tutorial/builder/graph_builder.md +++ b/docs/tutorial/builder/graph_builder.md @@ -326,6 +326,139 @@ print(w.name) # "encoder.W" builder.pop_module() ``` +## Building Subgraphs for Control-Flow Ops + +ONNX control-flow operators such as `Scan`, `Loop`, and `If` accept one or more +**graph-valued attributes** — graphs that define the body executed at each +iteration (or branch). `GraphBuilder.subgraph()` builds these inner graphs in +exactly the same imperative style as the outer graph, and the resulting +`ir.Graph` can be passed directly as an attribute. + +The subgraph automatically inherits the opset version from the parent +`GraphBuilder`, so there is no need to specify it separately. + +### Type annotations for subgraph inputs and outputs + +`subgraph()` accepts `input_types` and `output_types` lists that describe +the types and shapes of each input and output. Each element can be either an +`ir.TypeAndShape` object or — more conveniently — an +`onnxscript` tensor-type expression: + +| Expression | Meaning | +|----------------------|-----------------------------------------| +| `FLOAT` | Rank-0 scalar float tensor | +| `FLOAT[...]` | Float tensor of unknown rank | +| `FLOAT[1024]` | 1-D float tensor with 1024 elements | +| `FLOAT[3, 4]` | 2-D float tensor of shape (3, 4) | +| `FLOAT['M', 'N']` | 2-D float tensor with symbolic dims | + +These types come from `onnxscript.onnx_types` (also importable from +`onnxscript` directly): + +```python +from onnxscript.onnx_types import FLOAT, INT64 +``` + +### Example: cumulative sum with Scan + +The `Scan` op iterates over a sequence axis, threading a state vector through +each step. Here is how to build a cumulative-sum model with `subgraph()`: + +```python +import onnx_ir as ir +import onnxscript +from onnxscript.onnx_types import FLOAT + +D = 4 # feature dimension +N = 10 # sequence length + +# --- Parent graph ----------------------------------------------------------- +graph = ir.Graph( + name="cumsum_model", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 23}, +) + +# Initial accumulator (shape [D]) and input sequence (shape [N, D]) +init_state = ir.Value( + name="init_state", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([D]), +) +sequence = ir.Value( + name="sequence", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([N, D]), +) +graph.inputs.extend([init_state, sequence]) + +builder = onnxscript.GraphBuilder(graph) +op = builder.op + +# --- Scan body -------------------------------------------------------------- +# The body receives one state slice (the running sum) and one scan slice +# (the current element of the sequence). It adds them and returns the new +# state both as the updated state and as a scan output. + +def cumsum_body(op, state, x_i): + new_state = op.Add(state, x_i) + return new_state, new_state # (updated_state, scan_output_for_this_step) + +body = builder.subgraph( + cumsum_body, + input_types=[FLOAT[D], FLOAT[D]], # state, x_i + output_types=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i + name="cumsum_body", +) + +# --- Scan node -------------------------------------------------------------- +# Inputs: init_state (1 state variable), sequence (1 scan input) +# Outputs: final_state, all_partial_sums (shape [N, D]) +final_state, partial_sums = op.Scan( + init_state, + sequence, + body=body, + num_scan_inputs=1, + _outputs=2, +) +graph.outputs.extend([final_state, partial_sums]) + +model = ir.Model(graph=graph, ir_version=10) +``` + +Key points: + +- `builder.subgraph(fn, input_types, output_types)` creates a fresh + `ir.Graph`, calls `fn(op, *inputs)` to trace the body, and wires up the + declared input/output types. +- The `fn` receives an `OpBuilder` as its first argument — exactly the same + API as the outer graph — so you can use the full builder feature set inside + a body (constants, module scopes, nested subgraphs, etc.). +- The returned `ir.Graph` is passed as the `body` keyword attribute of `Scan`. +- `_outputs=2` tells the builder that `Scan` returns two output values. + +### Nested subgraphs + +Because the `fn` receives an `OpBuilder`, and `OpBuilder` exposes +`op.builder`, you can reach the inner `GraphBuilder` and call `subgraph()` +recursively for doubly-nested control flow (e.g. a `Scan` inside a `Loop`): + +```python +def outer_body(op, state, x_i): + # Build a nested subgraph inside the scan body + inner = op.builder.subgraph( + lambda iop, v: iop.Relu(v), + input_types=[FLOAT[D]], + output_types=[FLOAT[D]], + name="relu_body", + ) + # ... use inner as a graph attribute of a nested op ... + new_state = op.Add(state, x_i) + return new_state, new_state +``` + ## Putting It All Together Here is a complete example that builds a small model with two layers: diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index d60db2c7da..fb81f3451f 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -74,6 +74,33 @@ def _constant_name( return f"const_1d_{num}" +# Type accepted as an element of *input_types* / *output_types* by +# :meth:`GraphBuilder.subgraph`. Can be an already-resolved +# :class:`ir.TypeAndShape`, or a +# :class:`~onnxscript.onnx_types.TensorType` subclass such as ``FLOAT[1024]``. +TypeSpec = Union[ir.TypeAndShape, Any] + + +def _resolve_type_spec(spec: TypeSpec) -> ir.TypeAndShape: + """Convert a *TypeSpec* to an :class:`ir.TypeAndShape`. + + Accepts either an :class:`ir.TypeAndShape` directly, or a + :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]`` + or ``FLOAT['M', 'N']``). + """ + # Lazy import to avoid a circular dependency: onnxscript.__init__ imports + # onnx_types (line ~106) before builder (line ~132), so by the time any + # call reaches here the module is fully initialised — but a top-level + # import in builder.py could break if builder is ever imported first. + from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel + + if isinstance(spec, ir.TypeAndShape): + return spec + if isinstance(spec, type) and issubclass(spec, TensorType): + return spec.to_ir() + raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.") + + class GraphBuilder: """Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference.""" @@ -302,6 +329,78 @@ def add_node(self, node: ir.Node) -> None: onnxscript.optimizer.basic_constant_propagation([node]) inference.infer_outputs(node) + def subgraph( + self, + trace_function: Callable, + input_types: Sequence[TypeSpec], + output_types: Sequence[TypeSpec], + *, + name: str = "subgraph", + ) -> ir.Graph: + """Build an :class:`ir.Graph` suitable for use as a graph-valued attribute. + + The subgraph inherits the opset version from this :class:`GraphBuilder`. + It is particularly useful for constructing the body graphs of control-flow ops + such as ``Scan``, ``Loop``, and ``If``. + + Example - building a Scan body that adds two sequences element-wise:: + + body = graph_builder.subgraph( + lambda op, x, y: op.Add(x, y), + input_types=[FLOAT[...], FLOAT[...]], + output_types=[FLOAT[...]], + ) + + Args: + trace_function: A callable with signature + ``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``. + It is called once with freshly created placeholder inputs to record the + graph topology. + input_types: Types for each graph input. Each element may be an + :class:`ir.TypeAndShape` **or** a + :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. + ``FLOAT[1024]`` or ``FLOAT['M', 'N']``). + output_types: Types for each graph output, in the same format as + *input_types*. + name: Name of the resulting :class:`ir.Graph`. + + Returns: + An :class:`ir.Graph` whose inputs and outputs are populated and whose + nodes record the operations traced by *trace_function*. This graph can be + passed directly as a graph-valued attribute (e.g. the ``body`` attribute of + a ``Scan`` or ``Loop`` node). + """ + opset_version = self._graph.opset_imports[""] + resolved_inputs = [_resolve_type_spec(t) for t in input_types] + resolved_outputs = [_resolve_type_spec(t) for t in output_types] + + subgraph = ir.Graph( + name=name, + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": opset_version}, + ) + + for i, ts in enumerate(resolved_inputs): + subgraph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape)) + + sub_builder = GraphBuilder(subgraph) + outputs = trace_function(sub_builder.op, *subgraph.inputs) + if not isinstance(outputs, Sequence): + outputs = [outputs] + if len(outputs) != len(resolved_outputs): + raise ValueError( + f"trace_function returned {len(outputs)} output(s), " + f"but {len(resolved_outputs)} were declared in output_types." + ) + for output, ts in zip(outputs, resolved_outputs): + output.type = ts.type + output.merge_shapes(ts.shape) + + subgraph.outputs.extend(outputs) + return subgraph + def call_op( self, op_type: str, diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 8dbb81525a..1ff8460b67 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -11,6 +11,7 @@ import onnxscript._internal.builder as builder from onnxscript import script +from onnxscript.onnx_types import DOUBLE, FLOAT _default_opset_version = 23 @@ -819,5 +820,123 @@ def add_mul(X, Y): self.assertIn("does not match", str(cm.exception)) +class BuildSubgraphTest(unittest.TestCase): + """Tests for GraphBuilder.subgraph().""" + + def _make_builder(self, opset_version: int = 23) -> builder.GraphBuilder: + """Return a minimal GraphBuilder for the given opset version.""" + graph = ir.Graph( + name="parent", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": opset_version}, + ) + return builder.GraphBuilder(graph) + + def test_basic_subgraph(self): + """Subgraph returns a valid ir.Graph with correct inputs/outputs.""" + + def _add(op, x, y): + return op.Add(x, y) + + gb = self._make_builder() + graph = gb.subgraph( + _add, + input_types=[FLOAT[3, 4], FLOAT[3, 4]], + output_types=[FLOAT[3, 4]], + ) + self.assertIsInstance(graph, ir.Graph) + self.assertEqual(len(graph.inputs), 2) + self.assertEqual(len(graph.outputs), 1) + op_types = [node.op_type for node in graph] + self.assertEqual(op_types, ["Add"]) + + def test_subgraph_inherits_opset_version(self): + """The subgraph opset version matches the parent GraphBuilder.""" + gb = self._make_builder(opset_version=17) + graph = gb.subgraph( + lambda op, x: op.Identity(x), + input_types=[FLOAT[...]], + output_types=[FLOAT[...]], + ) + self.assertEqual(graph.opset_imports[""], 17) + + def test_subgraph_with_ir_type_and_shape(self): + """Subgraph also accepts ir.TypeAndShape directly.""" + + def _mul(op, x, y): + return op.Mul(x, y) + + float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([2, 3])) + gb = self._make_builder() + graph = gb.subgraph( + _mul, + input_types=[float_2d, float_2d], + output_types=[float_2d], + ) + self.assertIsInstance(graph, ir.Graph) + self.assertEqual(len(list(graph)), 1) + self.assertEqual(next(iter(graph)).op_type, "Mul") + + def test_subgraph_multiple_outputs(self): + """Subgraph handles multiple outputs.""" + + def _add_and_mul(op, x, y): + return op.Add(x, y), op.Mul(x, y) + + ts = FLOAT[...] + gb = self._make_builder() + graph = gb.subgraph( + _add_and_mul, + input_types=[ts, ts], + output_types=[ts, ts], + ) + self.assertEqual(len(graph.outputs), 2) + + def test_subgraph_output_count_mismatch_raises(self): + """Subgraph raises ValueError when output count does not match.""" + + def _returns_one(op, x, y): + return op.Add(x, y) + + gb = self._make_builder() + with self.assertRaises(ValueError): + gb.subgraph( + _returns_one, + input_types=[FLOAT[...], FLOAT[...]], + output_types=[FLOAT[...], FLOAT[...]], # expects 2, gets 1 + ) + + def test_subgraph_custom_name(self): + """Subgraph passes the name through to the ir.Graph.""" + + def _id(op, x): + return op.Identity(x) + + gb = self._make_builder() + graph = gb.subgraph( + _id, + input_types=[DOUBLE[...]], + output_types=[DOUBLE[...]], + name="scan_body", + ) + self.assertEqual(graph.name, "scan_body") + + def test_invalid_type_spec_raises(self): + """Subgraph raises TypeError for an unrecognised type specification.""" + + def _id(op, x): + return op.Identity(x) + + gb = self._make_builder() + with self.assertRaises(TypeError): + gb.subgraph( + _id, + input_types=["not_a_type_spec"], + output_types=["not_a_type_spec"], + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 9642e3f111..b0a4006329 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -99,6 +99,30 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = [cls.shape] # example: "FLOAT[10]" return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251 + @classmethod + def to_ir(cls) -> ir.TypeAndShape: + """Return an :class:`ir.TypeAndShape` representing this tensor type and shape. + + This enables using ONNX Script tensor-type notation (e.g. ``FLOAT[1024]`` + or ``FLOAT['M', 'N']``) wherever an :class:`ir.TypeAndShape` is expected, + such as the *input_types* / *output_types* arguments of + :func:`onnxscript._internal.builder.build_subgraph`. + """ + ir_type = ir.TensorType(cls.dtype) + if cls.shape is None: + # No subscript (e.g. ``FLOAT``): treat as scalar / rank-0 tensor. + ir_shape: ir.Shape | None = ir.Shape([]) + elif cls.shape is Ellipsis: + # ``FLOAT[...]``: tensor of unknown rank. + ir_shape = None + elif isinstance(cls.shape, tuple): + # ``FLOAT[3, 4]`` or ``FLOAT['M', 'N']``: explicit dims. + ir_shape = ir.Shape(list(cls.shape)) + else: + # ``FLOAT[1024]``: single-dimension 1-D tensor. + ir_shape = ir.Shape([cls.shape]) + return ir.TypeAndShape(ir_type, ir_shape) + @classmethod def to_string(cls) -> str: return f"tensor({cls.__name__.lower()})" diff --git a/onnxscript/onnx_types_test.py b/onnxscript/onnx_types_test.py new file mode 100644 index 0000000000..9898188209 --- /dev/null +++ b/onnxscript/onnx_types_test.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import unittest + +import onnx_ir as ir + +from onnxscript.onnx_types import FLOAT, INT64 + + +class TensorTypeToIrTest(unittest.TestCase): + """Tests for TensorType.to_ir().""" + + def test_scalar_type(self): + """FLOAT (no subscript) maps to rank-0 tensor (empty shape).""" + ts = FLOAT.to_ir() + self.assertIsInstance(ts, ir.TypeAndShape) + self.assertEqual(ts.type, ir.TensorType(ir.DataType.FLOAT)) + self.assertIsNotNone(ts.shape) + self.assertEqual(len(ts.shape), 0) + + def test_unknown_rank(self): + """FLOAT[...] maps to unknown-rank (shape=None).""" + ts = FLOAT[...].to_ir() + self.assertIsInstance(ts, ir.TypeAndShape) + self.assertIsNone(ts.shape) + + def test_single_dim(self): + """FLOAT[1024] maps to a 1-D tensor with dimension 1024.""" + ts = FLOAT[1024].to_ir() + self.assertIsNotNone(ts.shape) + self.assertEqual(len(ts.shape), 1) + self.assertEqual(ts.shape[0], 1024) + + def test_multi_dim_int(self): + """FLOAT[3, 4] maps to a 2-D tensor with dims (3, 4).""" + ts = FLOAT[3, 4].to_ir() + self.assertIsNotNone(ts.shape) + self.assertEqual(len(ts.shape), 2) + self.assertEqual(ts.shape[0], 3) + self.assertEqual(ts.shape[1], 4) + + def test_symbolic_dims(self): + """FLOAT['M', 'N'] maps to a 2-D tensor with symbolic dims.""" + ts = FLOAT["M", "N"].to_ir() + self.assertIsNotNone(ts.shape) + self.assertEqual(len(ts.shape), 2) + + def test_other_dtype(self): + """INT64[...] preserves the correct dtype.""" + ts = INT64[...].to_ir() + self.assertEqual(ts.type, ir.TensorType(ir.DataType.INT64)) + + +if __name__ == "__main__": + unittest.main()