diff --git a/onnxscript/nn/__init__.py b/onnxscript/nn/__init__.py index 248c9ef40e..53d377ecb9 100644 --- a/onnxscript/nn/__init__.py +++ b/onnxscript/nn/__init__.py @@ -6,5 +6,6 @@ from onnxscript.nn._module import Module from onnxscript.nn._module_list import ModuleList from onnxscript.nn._parameter import Parameter +from onnxscript.nn._sequential import Sequential -__all__ = ["Module", "ModuleList", "Parameter"] +__all__ = ["Module", "ModuleList", "Parameter", "Sequential"] diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py index 084703703e..277ee3b6e6 100644 --- a/onnxscript/nn/_module.py +++ b/onnxscript/nn/_module.py @@ -7,7 +7,7 @@ import onnx_ir as ir -from onnxscript._internal.builder import GraphBuilder, OpBuilder +from onnxscript._internal import builder as _builder from onnxscript.nn._parameter import Parameter @@ -70,8 +70,8 @@ def __setattr__(self, name: str, value: Any) -> None: else: object.__setattr__(self, name, value) - def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: - builder: GraphBuilder = op.builder + def __call__(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: + builder = op.builder module_name = self._name or "" class_name = type(self).__qualname__ builder.push_module(module_name, class_name) @@ -85,7 +85,7 @@ def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: builder.pop_module() return result - def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: """Define the computation performed by this module. Must be overridden by subclasses. Receives an ``OpBuilder`` as the diff --git a/onnxscript/nn/_module_list.py b/onnxscript/nn/_module_list.py index e71ab2a855..06850018b1 100644 --- a/onnxscript/nn/_module_list.py +++ b/onnxscript/nn/_module_list.py @@ -5,7 +5,7 @@ from typing import Any, Iterator, overload -from onnxscript._internal.builder import OpBuilder +from onnxscript._internal import builder as _builder from onnxscript.nn._module import Module @@ -47,7 +47,11 @@ def _set_name(self, name: str) -> None: def _register_child(self, key: str, module: Module) -> None: """Register a child module under the given string key.""" if module._name is None: # pylint: disable=protected-access - object.__setattr__(module, "_name", key) + # Qualify with parent name if already set (e.g. after append) + if self._name is not None: + module._set_name(f"{self._name}.{key}") # pylint: disable=protected-access + else: + object.__setattr__(module, "_name", key) self._modules[key] = module object.__setattr__(self, key, module) @@ -89,7 +93,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Module]: return iter(self._modules.values()) - def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( "ModuleList is not callable directly. " "Iterate over its children and call them individually." diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 8c68a13e9a..933460ad04 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -9,7 +9,7 @@ import onnx_ir as ir from onnxscript._internal.builder import GraphBuilder, OpBuilder -from onnxscript.nn import Module, ModuleList, Parameter +from onnxscript.nn import Module, ModuleList, Parameter, Sequential def _create_graph_and_op() -> tuple[ir.Graph, OpBuilder]: @@ -828,5 +828,306 @@ def test_modulelist_not_directly_callable(self): ml(op) +class SequentialTest(unittest.TestCase): + def _make_input( + self, name: str = "input", shape: list[int] | None = None + ) -> tuple[ir.Graph, OpBuilder, ir.Value]: + """Create a graph, OpBuilder, and a FLOAT input value.""" + if shape is None: + shape = [3] + graph, op = _create_graph_and_op() + x = ir.Value( + name=name, + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape(shape), + ) + graph.inputs.append(x) + return graph, op, x + + # -- basic forward chaining -- + + def test_chains_forward_calls(self): + """Sequential calls children in order, each receiving one input.""" + + class AddOne(Module): + def forward(self, op, x): + return op.Add(x, op.Constant(value_float=1.0)) + + graph, op, x = self._make_input() + seq = Sequential([AddOne(), AddOne(), AddOne()]) + result = seq(op, x) + + self.assertIsInstance(result, ir.Value) + op_types = [node.op_type for node in graph] + self.assertEqual(op_types.count("Add"), 3) + + def test_single_module_passthrough(self): + """Single-module Sequential passes input through correctly.""" + + class PassThrough(Module): + def forward(self, op, x): + return op.Identity(x) + + _, op, x = self._make_input() + seq = Sequential([PassThrough()]) + result = seq(op, x) + self.assertIsInstance(result, ir.Value) + + def test_empty_raises(self): + """Empty Sequential raises RuntimeError on forward.""" + seq = Sequential() + _, op = _create_graph_and_op() + with self.assertRaises(RuntimeError): + seq(op, None) + + # -- tuple / list passthrough (PyTorch-aligned: no unpacking) -- + + def test_tuple_result_passed_as_single_arg(self): + """When a module returns a tuple, the next module receives it as one argument.""" + + class SplitTwo(Module): + """Returns two outputs as a tuple.""" + + def forward(self, op, x): + return (op.Identity(x), op.Identity(x)) + + class UnpackAndAdd(Module): + """Accepts a single tuple argument and unpacks it.""" + + def forward(self, op, pair): + a, b = pair + return op.Add(a, b) + + graph, op, x = self._make_input() + seq = Sequential([SplitTwo(), UnpackAndAdd()]) + result = seq(op, x) + + self.assertIsInstance(result, ir.Value) + op_types = [node.op_type for node in graph] + self.assertEqual(op_types.count("Identity"), 2) + self.assertEqual(op_types.count("Add"), 1) + + def test_list_result_passed_as_single_arg(self): + """When a module returns a list, the next module receives it as one argument.""" + + class SplitTwoList(Module): + def forward(self, op, x): + return [op.Identity(x), op.Identity(x)] + + class UnpackAndAdd(Module): + def forward(self, op, pair): + a, b = pair + return op.Add(a, b) + + _, op, x = self._make_input() + seq = Sequential([SplitTwoList(), UnpackAndAdd()]) + result = seq(op, x) + self.assertIsInstance(result, ir.Value) + + def test_tuple_passthrough_chain(self): + """Tuple output passes through identity-like modules as a single arg.""" + + class ReturnPair(Module): + def forward(self, op, x): + return (op.Identity(x), op.Identity(x)) + + class TupleIdentity(Module): + """Receives a tuple, returns it as-is.""" + + def forward(self, op, pair): + return pair + + _, op, x = self._make_input() + seq = Sequential([ReturnPair(), TupleIdentity()]) + result = seq(op, x) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + + def test_single_module_returns_tuple_unchanged(self): + """A single-module Sequential that returns a tuple passes it through.""" + + class ReturnPair(Module): + def forward(self, op, x): + return (op.Identity(x), op.Identity(x)) + + _, op, x = self._make_input() + seq = Sequential([ReturnPair()]) + result = seq(op, x) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + + def test_last_module_tuple_return_is_preserved(self): + """The final module's tuple return is preserved as the Sequential result.""" + + class Identity(Module): + def forward(self, op, x): + return op.Identity(x) + + class SplitThree(Module): + def forward(self, op, x): + return (op.Identity(x), op.Identity(x), op.Identity(x)) + + _, op, x = self._make_input() + seq = Sequential([Identity(), SplitThree()]) + result = seq(op, x) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 3) + + def test_none_passthrough(self): + """Module returning None passes None to the next module.""" + + class ReturnNone(Module): + def forward(self, op, x): # pylint: disable=unused-argument + return None + + class AcceptNone(Module): + def forward(self, op, x): + self.received = x + return x + + _, op = _create_graph_and_op() + accept = AcceptNone() + seq = Sequential([ReturnNone(), accept]) + result = seq(op, "anything") + self.assertIsNone(result) + self.assertIsNone(accept.received) + + # -- parameter naming -- + + def test_parameter_naming(self): + """Sequential produces numeric-indexed parameter names like ModuleList.""" + + class Linear(Module): + def __init__(self, in_f, out_f): + super().__init__() + self.weight = Parameter([out_f, in_f], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + class Model(Module): + def __init__(self): + super().__init__("model") + self.layers = Sequential([Linear(4, 4), Linear(4, 4)]) + + def forward(self, op, x): + return self.layers(op, x) + + graph, op, x = self._make_input(shape=[1, 4]) + m = Model() + m(op, x) + + self.assertIn("model.layers.0.weight", graph.initializers) + self.assertIn("model.layers.1.weight", graph.initializers) + + def test_parameterless_modules(self): + """Sequential works with mixed param/no-param children (like SiLU + Linear).""" + + class SiLU(Module): + def forward(self, op, x): + return op.Mul(x, op.Sigmoid(x)) + + class Linear(Module): + def __init__(self, size): + super().__init__() + self.weight = Parameter([size, size], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + seq = Sequential([SiLU(), Linear(4)]) + named = dict(seq.named_parameters()) + # SiLU at index 0 has no params; Linear at index 1 has weight + self.assertIn("1.weight", named) + self.assertEqual(len(named), 1) + + def test_append_produces_correct_initializer_names(self): + """Children appended after parent registration get correct names. + + Tests the pattern where a Sequential is created empty, registered + on a parent Module, and then children are appended. The children should + produce initializer names like ``parent.seq.0.weight``, not + ``parent.seq.seq.0.weight`` (double-prefixed). + """ + + class Linear(Module): + def __init__(self, size): + super().__init__() + self.weight = Parameter([size, size], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + class Model(Module): + def __init__(self): + super().__init__("model") + self.blocks = Sequential([]) + # Append AFTER __setattr__ has set Sequential._name = "blocks" + self.blocks.append(Linear(4)) + self.blocks.append(Linear(4)) + + def forward(self, op, x): + return self.blocks(op, x) + + graph, op, x = self._make_input(shape=[1, 4]) + m = Model() + m(op, x) + + self.assertIn("model.blocks.0.weight", graph.initializers) + self.assertIn("model.blocks.1.weight", graph.initializers) + self.assertEqual(len(graph.initializers), 2) + + def test_modulelist_append_produces_correct_initializer_names(self): + """ModuleList with append after parent registration gets correct names. + + Tests the interleaved ModuleList pattern (like a mid_block with separate + resnets and attentions lists). Children appended after parent registration + should produce names like ``parent.resnets.1.weight``. + """ + + class Linear(Module): + def __init__(self, size): + super().__init__() + self.weight = Parameter([size, size], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + class MidBlock(Module): + def __init__(self): + super().__init__() + self.resnets = ModuleList([Linear(4)]) + self.attentions = ModuleList([]) + # Appends after parent has set _name on the ModuleLists + self.attentions.append(Linear(4)) + self.resnets.append(Linear(4)) + + def forward(self, op, x): + x = self.resnets[0](op, x) + for i in range(len(self.attentions)): # pylint: disable=consider-using-enumerate + x = self.attentions[i](op, x) + x = self.resnets[i + 1](op, x) + return x + + class Model(Module): + def __init__(self): + super().__init__("model") + self.mid = MidBlock() + + def forward(self, op, x): + return self.mid(op, x) + + graph, op, x = self._make_input(shape=[1, 4]) + m = Model() + m(op, x) + + expected = { + "model.mid.resnets.0.weight", + "model.mid.resnets.1.weight", + "model.mid.attentions.0.weight", + } + self.assertEqual(set(graph.initializers.keys()), expected) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index d319868e89..f217292e58 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -7,7 +7,7 @@ import onnx_ir as ir -from onnxscript._internal.builder import GraphBuilder +from onnxscript._internal import builder as _builder class Parameter(ir.Value): @@ -57,7 +57,7 @@ def dtype(self) -> ir.DataType | None: # type: ignore[override] """Return the element data type of this parameter.""" return self.type.dtype if self.type is not None else None - def _realize(self, builder: GraphBuilder) -> Parameter: + def _realize(self, builder: _builder.GraphBuilder) -> Parameter: """Qualify the name and register as a graph initializer. Uses direct assignment to ``graph.initializers[...]`` to skip the diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py new file mode 100644 index 0000000000..3f59230cdb --- /dev/null +++ b/onnxscript/nn/_sequential.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any + +from onnxscript._internal import builder as _builder +from onnxscript.nn import _module_list + + +class Sequential(_module_list.ModuleList): + """A sequential container that calls children in order, mirroring ``torch.nn.Sequential``. + + Children are registered with string keys ``"0"``, ``"1"``, etc., just like + ``ModuleList``. The ``forward`` method passes the output of each child as + the input to the next. + + Example:: + + class SiLU(Module): + def forward(self, op, x): + return op.Mul(x, op.Sigmoid(x)) + + # Produces parameter names: "mod.0.weight", "mod.0.bias" + # SiLU at index 0 has no parameters. + mod = Sequential([SiLU(), Linear(4, 4)]) + + # Calling mod(op, x) is equivalent to: + # x = silu(op, x) + # x = linear(op, x) + """ + + def _set_name(self, name: str) -> None: + """Set this container's name. Children keep simple ``"0"``, ``"1"`` names. + + Unlike ``ModuleList._set_name`` which fully qualifies children (used + when ModuleList is iterated externally), Sequential is called via + ``__call__`` which already pushes its own name onto the builder stack. + Children must keep simple keys to avoid double-prefixing. + """ + object.__setattr__(self, "_name", name) + for key, child in self._modules.items(): + child._set_name(key) # pylint: disable=protected-access + + def _register_child(self, key: str, module: _module_list.Module) -> None: + """Register a child module under the given string key. + + Unlike ``ModuleList._register_child`` which qualifies the child name + with the parent name, Sequential keeps children with simple index + names because ``__call__`` already pushes the Sequential's own name. + """ + if module._name is None: # pylint: disable=protected-access + object.__setattr__(module, "_name", key) + self._modules[key] = module + object.__setattr__(self, key, module) + + def forward(self, op: _builder.OpBuilder, input: Any) -> Any: # pylint: disable=redefined-builtin + """Run each child module sequentially, passing output to the next. + + Mirrors ``torch.nn.Sequential.forward``: each child receives exactly + one positional argument (the output of the previous child). + """ + if len(self) == 0: + raise RuntimeError("Cannot call forward on an empty Sequential container") + for module in self: + input = module(op, input) + return input + + def __repr__(self) -> str: + lines = ["Sequential("] + for name, module in self._modules.items(): + mod_repr = repr(module).replace("\n", "\n ") + lines.append(f" ({name}): {mod_repr}") + lines.append(")") + return "\n".join(lines)