Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions docs/tutorial/builder/graph_builder.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,139 @@ print(w.name) # "encoder.W"
builder.pop_module()
```

## Building Subgraphs for Control-Flow Ops

ONNX control-flow operators such as `Scan`, `Loop`, and `If` accept one or more
**graph-valued attributes** — graphs that define the body executed at each
iteration (or branch). `GraphBuilder.subgraph()` builds these inner graphs in
exactly the same imperative style as the outer graph, and the resulting
`ir.Graph` can be passed directly as an attribute.

The subgraph automatically inherits the opset version from the parent
`GraphBuilder`, so there is no need to specify it separately.

### Type annotations for subgraph inputs and outputs

`subgraph()` accepts `input_types` and `output_types` lists that describe
the types and shapes of each input and output. Each element can be either an
`ir.TypeAndShape` object or — more conveniently — an
`onnxscript` tensor-type expression:

| Expression | Meaning |
|----------------------|-----------------------------------------|
| `FLOAT` | Rank-0 scalar float tensor |
| `FLOAT[...]` | Float tensor of unknown rank |
| `FLOAT[1024]` | 1-D float tensor with 1024 elements |
| `FLOAT[3, 4]` | 2-D float tensor of shape (3, 4) |
| `FLOAT['M', 'N']` | 2-D float tensor with symbolic dims |

These types come from `onnxscript.onnx_types` (also importable from
`onnxscript` directly):

```python
from onnxscript.onnx_types import FLOAT, INT64
```

### Example: cumulative sum with Scan

The `Scan` op iterates over a sequence axis, threading a state vector through
each step. Here is how to build a cumulative-sum model with `subgraph()`:

```python
import onnx_ir as ir
import onnxscript
from onnxscript.onnx_types import FLOAT

D = 4 # feature dimension
N = 10 # sequence length

# --- Parent graph -----------------------------------------------------------
graph = ir.Graph(
name="cumsum_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)

# Initial accumulator (shape [D]) and input sequence (shape [N, D])
init_state = ir.Value(
name="init_state",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape([D]),
)
sequence = ir.Value(
name="sequence",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape([N, D]),
)
graph.inputs.extend([init_state, sequence])

builder = onnxscript.GraphBuilder(graph)
op = builder.op

# --- Scan body --------------------------------------------------------------
# The body receives one state slice (the running sum) and one scan slice
# (the current element of the sequence). It adds them and returns the new
# state both as the updated state and as a scan output.

def cumsum_body(op, state, x_i):
new_state = op.Add(state, x_i)
return new_state, new_state # (updated_state, scan_output_for_this_step)

body = builder.subgraph(
cumsum_body,
input_types=[FLOAT[D], FLOAT[D]], # state, x_i
output_types=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
name="cumsum_body",
)

# --- Scan node --------------------------------------------------------------
# Inputs: init_state (1 state variable), sequence (1 scan input)
# Outputs: final_state, all_partial_sums (shape [N, D])
final_state, partial_sums = op.Scan(
init_state,
sequence,
body=body,
num_scan_inputs=1,
_outputs=2,
)
graph.outputs.extend([final_state, partial_sums])

model = ir.Model(graph=graph, ir_version=10)
```

Key points:

- `builder.subgraph(fn, input_types, output_types)` creates a fresh
`ir.Graph`, calls `fn(op, *inputs)` to trace the body, and wires up the
declared input/output types.
- The `fn` receives an `OpBuilder` as its first argument — exactly the same
API as the outer graph — so you can use the full builder feature set inside
a body (constants, module scopes, nested subgraphs, etc.).
- The returned `ir.Graph` is passed as the `body` keyword attribute of `Scan`.
- `_outputs=2` tells the builder that `Scan` returns two output values.

### Nested subgraphs

Because the `fn` receives an `OpBuilder`, and `OpBuilder` exposes
`op.builder`, you can reach the inner `GraphBuilder` and call `subgraph()`
recursively for doubly-nested control flow (e.g. a `Scan` inside a `Loop`):

```python
def outer_body(op, state, x_i):
# Build a nested subgraph inside the scan body
inner = op.builder.subgraph(
lambda iop, v: iop.Relu(v),
input_types=[FLOAT[D]],
output_types=[FLOAT[D]],
name="relu_body",
)
# ... use inner as a graph attribute of a nested op ...
new_state = op.Add(state, x_i)
return new_state, new_state
```

## Putting It All Together

Here is a complete example that builds a small model with two layers:
Expand Down
99 changes: 99 additions & 0 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@
return f"const_1d_{num}"


# Type accepted as an element of *input_types* / *output_types* by
# :meth:`GraphBuilder.subgraph`. Can be an already-resolved
# :class:`ir.TypeAndShape`, or a
# :class:`~onnxscript.onnx_types.TensorType` subclass such as ``FLOAT[1024]``.
TypeSpec = Union[ir.TypeAndShape, Any]


def _resolve_type_spec(spec: TypeSpec) -> ir.TypeAndShape:
"""Convert a *TypeSpec* to an :class:`ir.TypeAndShape`.

Accepts either an :class:`ir.TypeAndShape` directly, or a
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]``
or ``FLOAT['M', 'N']``).
"""
# Lazy import to avoid a circular dependency: onnxscript.__init__ imports
# onnx_types (line ~106) before builder (line ~132), so by the time any
# call reaches here the module is fully initialised — but a top-level

Check warning on line 93 in onnxscript/_internal/builder.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "initialised" is a misspelling of "initialized" Raw Output: ./onnxscript/_internal/builder.py:93:44: "initialised" is a misspelling of "initialized"
# import in builder.py could break if builder is ever imported first.
from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel

if isinstance(spec, ir.TypeAndShape):
return spec
if isinstance(spec, type) and issubclass(spec, TensorType):
return spec.to_ir()
raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.")


class GraphBuilder:
"""Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference."""

Expand Down Expand Up @@ -302,6 +329,78 @@
onnxscript.optimizer.basic_constant_propagation([node])
inference.infer_outputs(node)

def subgraph(
self,
trace_function: Callable,
input_types: Sequence[TypeSpec],
output_types: Sequence[TypeSpec],
*,
name: str = "subgraph",
) -> ir.Graph:
"""Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.

The subgraph inherits the opset version from this :class:`GraphBuilder`.
It is particularly useful for constructing the body graphs of control-flow ops
such as ``Scan``, ``Loop``, and ``If``.

Example - building a Scan body that adds two sequences element-wise::

body = graph_builder.subgraph(
lambda op, x, y: op.Add(x, y),
input_types=[FLOAT[...], FLOAT[...]],
output_types=[FLOAT[...]],
)

Args:
trace_function: A callable with signature
``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
It is called once with freshly created placeholder inputs to record the
graph topology.
input_types: Types for each graph input. Each element may be an
:class:`ir.TypeAndShape` **or** a
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
output_types: Types for each graph output, in the same format as
*input_types*.
name: Name of the resulting :class:`ir.Graph`.

Returns:
An :class:`ir.Graph` whose inputs and outputs are populated and whose
nodes record the operations traced by *trace_function*. This graph can be
passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
a ``Scan`` or ``Loop`` node).
"""
opset_version = self._graph.opset_imports[""]
resolved_inputs = [_resolve_type_spec(t) for t in input_types]
resolved_outputs = [_resolve_type_spec(t) for t in output_types]

subgraph = ir.Graph(
name=name,
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": opset_version},
)

for i, ts in enumerate(resolved_inputs):
subgraph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape))

sub_builder = GraphBuilder(subgraph)
outputs = trace_function(sub_builder.op, *subgraph.inputs)
if not isinstance(outputs, Sequence):
outputs = [outputs]
if len(outputs) != len(resolved_outputs):
raise ValueError(
f"trace_function returned {len(outputs)} output(s), "
f"but {len(resolved_outputs)} were declared in output_types."
)
for output, ts in zip(outputs, resolved_outputs):
output.type = ts.type
output.merge_shapes(ts.shape)

subgraph.outputs.extend(outputs)
return subgraph

def call_op(
self,
op_type: str,
Expand Down
119 changes: 119 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import onnxscript._internal.builder as builder
from onnxscript import script
from onnxscript.onnx_types import DOUBLE, FLOAT

_default_opset_version = 23

Expand Down Expand Up @@ -819,5 +820,123 @@
self.assertIn("does not match", str(cm.exception))


class BuildSubgraphTest(unittest.TestCase):
"""Tests for GraphBuilder.subgraph()."""

def _make_builder(self, opset_version: int = 23) -> builder.GraphBuilder:
"""Return a minimal GraphBuilder for the given opset version."""
graph = ir.Graph(
name="parent",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": opset_version},
)
return builder.GraphBuilder(graph)

def test_basic_subgraph(self):
"""Subgraph returns a valid ir.Graph with correct inputs/outputs."""

def _add(op, x, y):
return op.Add(x, y)

gb = self._make_builder()
graph = gb.subgraph(
_add,
input_types=[FLOAT[3, 4], FLOAT[3, 4]],
output_types=[FLOAT[3, 4]],
)
self.assertIsInstance(graph, ir.Graph)
self.assertEqual(len(graph.inputs), 2)
self.assertEqual(len(graph.outputs), 1)
op_types = [node.op_type for node in graph]
self.assertEqual(op_types, ["Add"])

def test_subgraph_inherits_opset_version(self):
"""The subgraph opset version matches the parent GraphBuilder."""
gb = self._make_builder(opset_version=17)
graph = gb.subgraph(
lambda op, x: op.Identity(x),
input_types=[FLOAT[...]],
output_types=[FLOAT[...]],
)
self.assertEqual(graph.opset_imports[""], 17)

def test_subgraph_with_ir_type_and_shape(self):
"""Subgraph also accepts ir.TypeAndShape directly."""

def _mul(op, x, y):
return op.Mul(x, y)

float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([2, 3]))
gb = self._make_builder()
graph = gb.subgraph(
_mul,
input_types=[float_2d, float_2d],
output_types=[float_2d],
)
self.assertIsInstance(graph, ir.Graph)
self.assertEqual(len(list(graph)), 1)
self.assertEqual(next(iter(graph)).op_type, "Mul")

def test_subgraph_multiple_outputs(self):
"""Subgraph handles multiple outputs."""

def _add_and_mul(op, x, y):
return op.Add(x, y), op.Mul(x, y)

ts = FLOAT[...]
gb = self._make_builder()
graph = gb.subgraph(
_add_and_mul,
input_types=[ts, ts],
output_types=[ts, ts],
)
self.assertEqual(len(graph.outputs), 2)

def test_subgraph_output_count_mismatch_raises(self):
"""Subgraph raises ValueError when output count does not match."""

def _returns_one(op, x, y):
return op.Add(x, y)

gb = self._make_builder()
with self.assertRaises(ValueError):
gb.subgraph(
_returns_one,
input_types=[FLOAT[...], FLOAT[...]],
output_types=[FLOAT[...], FLOAT[...]], # expects 2, gets 1
)

def test_subgraph_custom_name(self):
"""Subgraph passes the name through to the ir.Graph."""

def _id(op, x):
return op.Identity(x)

gb = self._make_builder()
graph = gb.subgraph(
_id,
input_types=[DOUBLE[...]],
output_types=[DOUBLE[...]],
name="scan_body",
)
self.assertEqual(graph.name, "scan_body")

def test_invalid_type_spec_raises(self):
"""Subgraph raises TypeError for an unrecognised type specification."""

Check warning on line 927 in onnxscript/_internal/builder_test.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "unrecognised" is a misspelling of "unrecognized" Raw Output: ./onnxscript/_internal/builder_test.py:927:44: "unrecognised" is a misspelling of "unrecognized"

def _id(op, x):
return op.Identity(x)

gb = self._make_builder()
with self.assertRaises(TypeError):
gb.subgraph(
_id,
input_types=["not_a_type_spec"],
output_types=["not_a_type_spec"],
)


if __name__ == "__main__":
unittest.main()
Loading
Loading