From 634f59aa4ca33c2e2ffa8a49c91795447edb70c7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 23 Feb 2026 16:44:02 -0800 Subject: [PATCH 1/7] Add nn.Sequential: callable ModuleList that chains forward calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit nn.Sequential mirrors torch.nn.Sequential — children are registered with numeric keys and forward() passes each child's output as the next child's input. Key design decisions: - Subclasses ModuleList for registration, indexing, iteration - Overrides _set_name to keep children with simple '0', '1' names (avoids double-prefixing since __call__ already pushes the parent name) - Raises RuntimeError on empty Sequential This enables matching HF diffusers naming conventions (e.g. nn.Sequential(SiLU(), Linear(...)) producing 'mod.1.weight') without needing preprocess_weights renames. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/nn/__init__.py | 3 +- onnxscript/nn/_module_test.py | 104 ++++++++++++++++++++++++++++++++++ onnxscript/nn/_sequential.py | 65 +++++++++++++++++++++ 3 files changed, 171 insertions(+), 1 deletion(-) create mode 100644 onnxscript/nn/_sequential.py diff --git a/onnxscript/nn/__init__.py b/onnxscript/nn/__init__.py index 248c9ef40e..53d377ecb9 100644 --- a/onnxscript/nn/__init__.py +++ b/onnxscript/nn/__init__.py @@ -6,5 +6,6 @@ from onnxscript.nn._module import Module from onnxscript.nn._module_list import ModuleList from onnxscript.nn._parameter import Parameter +from onnxscript.nn._sequential import Sequential -__all__ = ["Module", "ModuleList", "Parameter"] +__all__ = ["Module", "ModuleList", "Parameter", "Sequential"] diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 8c68a13e9a..96b697dbc2 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -828,5 +828,109 @@ def test_modulelist_not_directly_callable(self): ml(op) +class SequentialTest(unittest.TestCase): + def test_sequential_chains_forward_calls(self): + """Sequential calls children in order, passing output to next.""" + + class AddOne(Module): + def __init__(self): + super().__init__() + + def forward(self, op, x): + return op.Add(x, op.Constant(value_float=1.0)) + + from onnxscript.nn._sequential import Sequential + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([3]), + ) + graph.inputs.append(x) + + seq = Sequential([AddOne(), AddOne()]) + result = seq(op, x) + + self.assertIsInstance(result, ir.Value) + op_types = [node.op_type for node in graph] + # Two Add ops (one Constant + Add per child) + self.assertEqual(op_types.count("Add"), 2) + + def test_sequential_parameter_naming(self): + """Sequential produces numeric-indexed parameter names like ModuleList.""" + + class Linear(Module): + def __init__(self, in_f, out_f): + super().__init__() + self.weight = Parameter([out_f, in_f], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + from onnxscript.nn._sequential import Sequential + + class Model(Module): + def __init__(self): + super().__init__("model") + self.layers = Sequential([Linear(4, 4), Linear(4, 4)]) + + def forward(self, op, x): + return self.layers(op, x) + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([1, 4]), + ) + graph.inputs.append(x) + + m = Model() + m(op, x) + + self.assertIn("model.layers.0.weight", graph.initializers) + self.assertIn("model.layers.1.weight", graph.initializers) + + def test_sequential_with_parameterless_modules(self): + """Sequential works with mixed param/no-param children (like SiLU + Linear).""" + + class SiLU(Module): + def forward(self, op, x): + return op.Mul(x, op.Sigmoid(x)) + + class Linear(Module): + def __init__(self, size): + super().__init__() + self.weight = Parameter([size, size], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + from onnxscript.nn._sequential import Sequential + + seq = Sequential([SiLU(), Linear(4)]) + named = dict(seq.named_parameters()) + # SiLU at index 0 has no params; Linear at index 1 has weight + self.assertIn("1.weight", named) + self.assertEqual(len(named), 1) + + def test_sequential_empty_raises(self): + """Empty Sequential raises RuntimeError on forward.""" + from onnxscript.nn._sequential import Sequential + + seq = Sequential() + _, op = _create_graph_and_op() + with self.assertRaises(RuntimeError): + seq(op, None) + + def test_sequential_import_from_nn(self): + """Sequential is importable from onnxscript.nn.""" + from onnxscript.nn import Sequential as Seq + from onnxscript.nn._sequential import Sequential + + self.assertIs(Seq, Sequential) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py new file mode 100644 index 0000000000..a03c422292 --- /dev/null +++ b/onnxscript/nn/_sequential.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any + +from onnxscript._internal.builder import OpBuilder +from onnxscript.nn._module import Module +from onnxscript.nn._module_list import ModuleList + + +class Sequential(ModuleList): + """A sequential container that calls children in order, mirroring ``torch.nn.Sequential``. + + Children are registered with string keys ``"0"``, ``"1"``, etc., just like + ``ModuleList``. The ``forward`` method passes the output of each child as + the input to the next. + + Example:: + + class SiLU(Module): + def forward(self, op, x): + return op.Mul(x, op.Sigmoid(x)) + + # Produces parameter names: "mod.0.weight", "mod.0.bias" + # SiLU at index 0 has no parameters. + mod = Sequential([SiLU(), Linear(4, 4)]) + + # Calling mod(op, x) is equivalent to: + # x = silu(op, x) + # x = linear(op, x) + """ + + def _set_name(self, name: str) -> None: + """Set this container's name. Children keep simple ``"0"``, ``"1"`` names. + + Unlike ``ModuleList._set_name`` which fully qualifies children (used + when ModuleList is iterated externally), Sequential is called via + ``__call__`` which already pushes its own name onto the builder stack. + Children must keep simple keys to avoid double-prefixing. + """ + object.__setattr__(self, "_name", name) + for key, child in self._modules.items(): + child._set_name(key) + + def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + """Run each child module sequentially, passing output to the next.""" + if len(self) == 0: + raise RuntimeError("Sequential is empty") + for i, module in enumerate(self): + if i == 0: + args = (module(op, *args, **kwargs),) + kwargs = {} + else: + args = (module(op, *args),) + return args[0] + + def __repr__(self) -> str: + lines = ["Sequential("] + for name, module in self._modules.items(): + mod_repr = repr(module).replace("\n", "\n ") + lines.append(f" ({name}): {mod_repr}") + lines.append(")") + return "\n".join(lines) From 1eb5c6d4014af56dba5756f022c0d359f49a11a0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 23 Feb 2026 16:58:43 -0800 Subject: [PATCH 2/7] Lint Signed-off-by: Justin Chu --- onnxscript/nn/_module.py | 8 ++++---- onnxscript/nn/_module_list.py | 4 ++-- onnxscript/nn/_module_test.py | 16 +--------------- onnxscript/nn/_parameter.py | 4 ++-- onnxscript/nn/_sequential.py | 9 ++++----- 5 files changed, 13 insertions(+), 28 deletions(-) diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py index 084703703e..277ee3b6e6 100644 --- a/onnxscript/nn/_module.py +++ b/onnxscript/nn/_module.py @@ -7,7 +7,7 @@ import onnx_ir as ir -from onnxscript._internal.builder import GraphBuilder, OpBuilder +from onnxscript._internal import builder as _builder from onnxscript.nn._parameter import Parameter @@ -70,8 +70,8 @@ def __setattr__(self, name: str, value: Any) -> None: else: object.__setattr__(self, name, value) - def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: - builder: GraphBuilder = op.builder + def __call__(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: + builder = op.builder module_name = self._name or "" class_name = type(self).__qualname__ builder.push_module(module_name, class_name) @@ -85,7 +85,7 @@ def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: builder.pop_module() return result - def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: """Define the computation performed by this module. Must be overridden by subclasses. Receives an ``OpBuilder`` as the diff --git a/onnxscript/nn/_module_list.py b/onnxscript/nn/_module_list.py index e71ab2a855..65d374551b 100644 --- a/onnxscript/nn/_module_list.py +++ b/onnxscript/nn/_module_list.py @@ -5,7 +5,7 @@ from typing import Any, Iterator, overload -from onnxscript._internal.builder import OpBuilder +from onnxscript._internal import builder as _builder from onnxscript.nn._module import Module @@ -89,7 +89,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Module]: return iter(self._modules.values()) - def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( "ModuleList is not callable directly. " "Iterate over its children and call them individually." diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 96b697dbc2..f9b0024b15 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -9,7 +9,7 @@ import onnx_ir as ir from onnxscript._internal.builder import GraphBuilder, OpBuilder -from onnxscript.nn import Module, ModuleList, Parameter +from onnxscript.nn import Module, ModuleList, Parameter, Sequential def _create_graph_and_op() -> tuple[ir.Graph, OpBuilder]: @@ -839,8 +839,6 @@ def __init__(self): def forward(self, op, x): return op.Add(x, op.Constant(value_float=1.0)) - from onnxscript.nn._sequential import Sequential - graph, op = _create_graph_and_op() x = ir.Value( name="input", @@ -868,8 +866,6 @@ def __init__(self, in_f, out_f): def forward(self, op, x): return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) - from onnxscript.nn._sequential import Sequential - class Model(Module): def __init__(self): super().__init__("model") @@ -907,8 +903,6 @@ def __init__(self, size): def forward(self, op, x): return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) - from onnxscript.nn._sequential import Sequential - seq = Sequential([SiLU(), Linear(4)]) named = dict(seq.named_parameters()) # SiLU at index 0 has no params; Linear at index 1 has weight @@ -917,20 +911,12 @@ def forward(self, op, x): def test_sequential_empty_raises(self): """Empty Sequential raises RuntimeError on forward.""" - from onnxscript.nn._sequential import Sequential seq = Sequential() _, op = _create_graph_and_op() with self.assertRaises(RuntimeError): seq(op, None) - def test_sequential_import_from_nn(self): - """Sequential is importable from onnxscript.nn.""" - from onnxscript.nn import Sequential as Seq - from onnxscript.nn._sequential import Sequential - - self.assertIs(Seq, Sequential) - if __name__ == "__main__": unittest.main() diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index d319868e89..f217292e58 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -7,7 +7,7 @@ import onnx_ir as ir -from onnxscript._internal.builder import GraphBuilder +from onnxscript._internal import builder as _builder class Parameter(ir.Value): @@ -57,7 +57,7 @@ def dtype(self) -> ir.DataType | None: # type: ignore[override] """Return the element data type of this parameter.""" return self.type.dtype if self.type is not None else None - def _realize(self, builder: GraphBuilder) -> Parameter: + def _realize(self, builder: _builder.GraphBuilder) -> Parameter: """Qualify the name and register as a graph initializer. Uses direct assignment to ``graph.initializers[...]`` to skip the diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py index a03c422292..236bcc32a5 100644 --- a/onnxscript/nn/_sequential.py +++ b/onnxscript/nn/_sequential.py @@ -5,12 +5,11 @@ from typing import Any -from onnxscript._internal.builder import OpBuilder -from onnxscript.nn._module import Module -from onnxscript.nn._module_list import ModuleList +from onnxscript._internal import builder as _builder +from onnxscript.nn import _module_list -class Sequential(ModuleList): +class Sequential(_module_list.ModuleList): """A sequential container that calls children in order, mirroring ``torch.nn.Sequential``. Children are registered with string keys ``"0"``, ``"1"``, etc., just like @@ -44,7 +43,7 @@ def _set_name(self, name: str) -> None: for key, child in self._modules.items(): child._set_name(key) - def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: """Run each child module sequentially, passing output to the next.""" if len(self) == 0: raise RuntimeError("Sequential is empty") From 0dc305418b47f3e9b95e5d94752e11f7b73c3806 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 23 Feb 2026 17:00:55 -0800 Subject: [PATCH 3/7] Update onnxscript/nn/_sequential.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/nn/_sequential.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py index 236bcc32a5..6872614d48 100644 --- a/onnxscript/nn/_sequential.py +++ b/onnxscript/nn/_sequential.py @@ -46,7 +46,7 @@ def _set_name(self, name: str) -> None: def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: """Run each child module sequentially, passing output to the next.""" if len(self) == 0: - raise RuntimeError("Sequential is empty") + raise RuntimeError("Cannot call forward on an empty Sequential container") for i, module in enumerate(self): if i == 0: args = (module(op, *args, **kwargs),) From dddcbcc11532da2f2cf834715d057a11504efd9d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 23 Feb 2026 19:03:59 -0800 Subject: [PATCH 4/7] Fix ModuleList._register_child to qualify names for late-appended children When children are appended to a ModuleList after the parent Module has already called _set_name (e.g. via __setattr__), _register_child now qualifies the child name with the ModuleList's own name. This ensures that children appended after registration produce correct ONNX initializer names like 'mid_block.attentions.0.weight' instead of 'mid_block.0.weight'. Sequential overrides _register_child to keep simple index names ('0', '1') since Sequential.__call__ already pushes its own scope. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/nn/_module_list.py | 6 +- onnxscript/nn/_module_test.py | 101 ++++++++++++++++++++++++++++++++++ onnxscript/nn/_sequential.py | 12 ++++ 3 files changed, 118 insertions(+), 1 deletion(-) diff --git a/onnxscript/nn/_module_list.py b/onnxscript/nn/_module_list.py index 65d374551b..04cf81c505 100644 --- a/onnxscript/nn/_module_list.py +++ b/onnxscript/nn/_module_list.py @@ -47,7 +47,11 @@ def _set_name(self, name: str) -> None: def _register_child(self, key: str, module: Module) -> None: """Register a child module under the given string key.""" if module._name is None: # pylint: disable=protected-access - object.__setattr__(module, "_name", key) + # Qualify with parent name if already set (e.g. after append) + if self._name is not None: + module._set_name(f"{self._name}.{key}") + else: + object.__setattr__(module, "_name", key) self._modules[key] = module object.__setattr__(self, key, module) diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index f9b0024b15..1cd9643fdb 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -917,6 +917,107 @@ def test_sequential_empty_raises(self): with self.assertRaises(RuntimeError): seq(op, None) + def test_sequential_append_produces_correct_initializer_names(self): + """Sequential with append (after parent registration) gets correct names. + + This tests the pattern where a Sequential is created empty, registered + on a parent Module, and then children are appended. The children should + produce initializer names like ``parent.seq.0.weight``, not + ``parent.seq.seq.0.weight`` (double-prefixed). + """ + + class Linear(Module): + def __init__(self, size): + super().__init__() + self.weight = Parameter([size, size], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + class Model(Module): + def __init__(self): + super().__init__("model") + self.blocks = Sequential([]) + # Append AFTER __setattr__ has set Sequential._name = "blocks" + self.blocks.append(Linear(4)) + self.blocks.append(Linear(4)) + + def forward(self, op, x): + return self.blocks(op, x) + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([1, 4]), + ) + graph.inputs.append(x) + + m = Model() + m(op, x) + + self.assertIn("model.blocks.0.weight", graph.initializers) + self.assertIn("model.blocks.1.weight", graph.initializers) + self.assertEqual(len(graph.initializers), 2) + + def test_modulelist_append_produces_correct_initializer_names(self): + """ModuleList with append after parent registration gets correct names. + + Tests the interleaved ModuleList pattern (like a mid_block with separate + resnets and attentions lists). Children appended after parent registration + should produce names like ``parent.resnets.1.weight``. + """ + + class Linear(Module): + def __init__(self, size): + super().__init__() + self.weight = Parameter([size, size], name="weight") + + def forward(self, op, x): + return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) + + class MidBlock(Module): + def __init__(self): + super().__init__() + self.resnets = ModuleList([Linear(4)]) + self.attentions = ModuleList([]) + # Appends after parent has set _name on the ModuleLists + self.attentions.append(Linear(4)) + self.resnets.append(Linear(4)) + + def forward(self, op, x): + x = self.resnets[0](op, x) + for i in range(len(self.attentions)): + x = self.attentions[i](op, x) + x = self.resnets[i + 1](op, x) + return x + + class Model(Module): + def __init__(self): + super().__init__("model") + self.mid = MidBlock() + + def forward(self, op, x): + return self.mid(op, x) + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([1, 4]), + ) + graph.inputs.append(x) + + m = Model() + m(op, x) + + expected = { + "model.mid.resnets.0.weight", + "model.mid.resnets.1.weight", + "model.mid.attentions.0.weight", + } + self.assertEqual(set(graph.initializers.keys()), expected) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py index 6872614d48..799f5300c8 100644 --- a/onnxscript/nn/_sequential.py +++ b/onnxscript/nn/_sequential.py @@ -43,6 +43,18 @@ def _set_name(self, name: str) -> None: for key, child in self._modules.items(): child._set_name(key) + def _register_child(self, key: str, module: _module_list.Module) -> None: + """Register a child module under the given string key. + + Unlike ``ModuleList._register_child`` which qualifies the child name + with the parent name, Sequential keeps children with simple index + names because ``__call__`` already pushes the Sequential's own name. + """ + if module._name is None: # pylint: disable=protected-access + object.__setattr__(module, "_name", key) + self._modules[key] = module + object.__setattr__(self, key, module) + def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: """Run each child module sequentially, passing output to the next.""" if len(self) == 0: From 50689d2e814aa152876110ac90e6a8ecd061cda7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Feb 2026 10:39:46 -0800 Subject: [PATCH 5/7] Update Sequential Signed-off-by: Justin Chu --- onnxscript/nn/_sequential.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py index 799f5300c8..579462879a 100644 --- a/onnxscript/nn/_sequential.py +++ b/onnxscript/nn/_sequential.py @@ -61,11 +61,14 @@ def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: raise RuntimeError("Cannot call forward on an empty Sequential container") for i, module in enumerate(self): if i == 0: - args = (module(op, *args, **kwargs),) - kwargs = {} + result = module(op, *args, **kwargs) else: - args = (module(op, *args),) - return args[0] + result = module(op, *args) + if not isinstance(result, (list, tuple)): + args = (result,) + else: + args = result + return result def __repr__(self) -> str: lines = ["Sequential("] From 91f7f564d3aa5f34abd8fea51f7308b5455cc69d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Feb 2026 10:53:53 -0800 Subject: [PATCH 6/7] Update signature Signed-off-by: Justin Chu --- onnxscript/nn/_module_test.py | 216 +++++++++++++++++++++++++--------- onnxscript/nn/_sequential.py | 21 ++-- 2 files changed, 172 insertions(+), 65 deletions(-) diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 1cd9643fdb..933460ad04 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -829,33 +829,172 @@ def test_modulelist_not_directly_callable(self): class SequentialTest(unittest.TestCase): - def test_sequential_chains_forward_calls(self): - """Sequential calls children in order, passing output to next.""" - - class AddOne(Module): - def __init__(self): - super().__init__() - - def forward(self, op, x): - return op.Add(x, op.Constant(value_float=1.0)) - + def _make_input( + self, name: str = "input", shape: list[int] | None = None + ) -> tuple[ir.Graph, OpBuilder, ir.Value]: + """Create a graph, OpBuilder, and a FLOAT input value.""" + if shape is None: + shape = [3] graph, op = _create_graph_and_op() x = ir.Value( - name="input", + name=name, type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape([3]), + shape=ir.Shape(shape), ) graph.inputs.append(x) + return graph, op, x + + # -- basic forward chaining -- - seq = Sequential([AddOne(), AddOne()]) + def test_chains_forward_calls(self): + """Sequential calls children in order, each receiving one input.""" + + class AddOne(Module): + def forward(self, op, x): + return op.Add(x, op.Constant(value_float=1.0)) + + graph, op, x = self._make_input() + seq = Sequential([AddOne(), AddOne(), AddOne()]) result = seq(op, x) self.assertIsInstance(result, ir.Value) op_types = [node.op_type for node in graph] - # Two Add ops (one Constant + Add per child) - self.assertEqual(op_types.count("Add"), 2) + self.assertEqual(op_types.count("Add"), 3) - def test_sequential_parameter_naming(self): + def test_single_module_passthrough(self): + """Single-module Sequential passes input through correctly.""" + + class PassThrough(Module): + def forward(self, op, x): + return op.Identity(x) + + _, op, x = self._make_input() + seq = Sequential([PassThrough()]) + result = seq(op, x) + self.assertIsInstance(result, ir.Value) + + def test_empty_raises(self): + """Empty Sequential raises RuntimeError on forward.""" + seq = Sequential() + _, op = _create_graph_and_op() + with self.assertRaises(RuntimeError): + seq(op, None) + + # -- tuple / list passthrough (PyTorch-aligned: no unpacking) -- + + def test_tuple_result_passed_as_single_arg(self): + """When a module returns a tuple, the next module receives it as one argument.""" + + class SplitTwo(Module): + """Returns two outputs as a tuple.""" + + def forward(self, op, x): + return (op.Identity(x), op.Identity(x)) + + class UnpackAndAdd(Module): + """Accepts a single tuple argument and unpacks it.""" + + def forward(self, op, pair): + a, b = pair + return op.Add(a, b) + + graph, op, x = self._make_input() + seq = Sequential([SplitTwo(), UnpackAndAdd()]) + result = seq(op, x) + + self.assertIsInstance(result, ir.Value) + op_types = [node.op_type for node in graph] + self.assertEqual(op_types.count("Identity"), 2) + self.assertEqual(op_types.count("Add"), 1) + + def test_list_result_passed_as_single_arg(self): + """When a module returns a list, the next module receives it as one argument.""" + + class SplitTwoList(Module): + def forward(self, op, x): + return [op.Identity(x), op.Identity(x)] + + class UnpackAndAdd(Module): + def forward(self, op, pair): + a, b = pair + return op.Add(a, b) + + _, op, x = self._make_input() + seq = Sequential([SplitTwoList(), UnpackAndAdd()]) + result = seq(op, x) + self.assertIsInstance(result, ir.Value) + + def test_tuple_passthrough_chain(self): + """Tuple output passes through identity-like modules as a single arg.""" + + class ReturnPair(Module): + def forward(self, op, x): + return (op.Identity(x), op.Identity(x)) + + class TupleIdentity(Module): + """Receives a tuple, returns it as-is.""" + + def forward(self, op, pair): + return pair + + _, op, x = self._make_input() + seq = Sequential([ReturnPair(), TupleIdentity()]) + result = seq(op, x) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + + def test_single_module_returns_tuple_unchanged(self): + """A single-module Sequential that returns a tuple passes it through.""" + + class ReturnPair(Module): + def forward(self, op, x): + return (op.Identity(x), op.Identity(x)) + + _, op, x = self._make_input() + seq = Sequential([ReturnPair()]) + result = seq(op, x) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + + def test_last_module_tuple_return_is_preserved(self): + """The final module's tuple return is preserved as the Sequential result.""" + + class Identity(Module): + def forward(self, op, x): + return op.Identity(x) + + class SplitThree(Module): + def forward(self, op, x): + return (op.Identity(x), op.Identity(x), op.Identity(x)) + + _, op, x = self._make_input() + seq = Sequential([Identity(), SplitThree()]) + result = seq(op, x) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 3) + + def test_none_passthrough(self): + """Module returning None passes None to the next module.""" + + class ReturnNone(Module): + def forward(self, op, x): # pylint: disable=unused-argument + return None + + class AcceptNone(Module): + def forward(self, op, x): + self.received = x + return x + + _, op = _create_graph_and_op() + accept = AcceptNone() + seq = Sequential([ReturnNone(), accept]) + result = seq(op, "anything") + self.assertIsNone(result) + self.assertIsNone(accept.received) + + # -- parameter naming -- + + def test_parameter_naming(self): """Sequential produces numeric-indexed parameter names like ModuleList.""" class Linear(Module): @@ -874,21 +1013,14 @@ def __init__(self): def forward(self, op, x): return self.layers(op, x) - graph, op = _create_graph_and_op() - x = ir.Value( - name="input", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape([1, 4]), - ) - graph.inputs.append(x) - + graph, op, x = self._make_input(shape=[1, 4]) m = Model() m(op, x) self.assertIn("model.layers.0.weight", graph.initializers) self.assertIn("model.layers.1.weight", graph.initializers) - def test_sequential_with_parameterless_modules(self): + def test_parameterless_modules(self): """Sequential works with mixed param/no-param children (like SiLU + Linear).""" class SiLU(Module): @@ -909,18 +1041,10 @@ def forward(self, op, x): self.assertIn("1.weight", named) self.assertEqual(len(named), 1) - def test_sequential_empty_raises(self): - """Empty Sequential raises RuntimeError on forward.""" + def test_append_produces_correct_initializer_names(self): + """Children appended after parent registration get correct names. - seq = Sequential() - _, op = _create_graph_and_op() - with self.assertRaises(RuntimeError): - seq(op, None) - - def test_sequential_append_produces_correct_initializer_names(self): - """Sequential with append (after parent registration) gets correct names. - - This tests the pattern where a Sequential is created empty, registered + Tests the pattern where a Sequential is created empty, registered on a parent Module, and then children are appended. The children should produce initializer names like ``parent.seq.0.weight``, not ``parent.seq.seq.0.weight`` (double-prefixed). @@ -945,14 +1069,7 @@ def __init__(self): def forward(self, op, x): return self.blocks(op, x) - graph, op = _create_graph_and_op() - x = ir.Value( - name="input", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape([1, 4]), - ) - graph.inputs.append(x) - + graph, op, x = self._make_input(shape=[1, 4]) m = Model() m(op, x) @@ -987,7 +1104,7 @@ def __init__(self): def forward(self, op, x): x = self.resnets[0](op, x) - for i in range(len(self.attentions)): + for i in range(len(self.attentions)): # pylint: disable=consider-using-enumerate x = self.attentions[i](op, x) x = self.resnets[i + 1](op, x) return x @@ -1000,14 +1117,7 @@ def __init__(self): def forward(self, op, x): return self.mid(op, x) - graph, op = _create_graph_and_op() - x = ir.Value( - name="input", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape([1, 4]), - ) - graph.inputs.append(x) - + graph, op, x = self._make_input(shape=[1, 4]) m = Model() m(op, x) diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py index 579462879a..840dff8b7b 100644 --- a/onnxscript/nn/_sequential.py +++ b/onnxscript/nn/_sequential.py @@ -55,20 +55,17 @@ def _register_child(self, key: str, module: _module_list.Module) -> None: self._modules[key] = module object.__setattr__(self, key, module) - def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any: - """Run each child module sequentially, passing output to the next.""" + def forward(self, op: _builder.OpBuilder, input: Any) -> Any: # pylint: disable=redefined-builtin + """Run each child module sequentially, passing output to the next. + + Mirrors ``torch.nn.Sequential.forward``: each child receives exactly + one positional argument (the output of the previous child). + """ if len(self) == 0: raise RuntimeError("Cannot call forward on an empty Sequential container") - for i, module in enumerate(self): - if i == 0: - result = module(op, *args, **kwargs) - else: - result = module(op, *args) - if not isinstance(result, (list, tuple)): - args = (result,) - else: - args = result - return result + for module in self: + input = module(op, input) + return input def __repr__(self) -> str: lines = ["Sequential("] From 9961725b28ee44a8b9b5d0a935dfb9924543bae9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Feb 2026 10:55:47 -0800 Subject: [PATCH 7/7] lint Signed-off-by: Justin Chu --- onnxscript/nn/_module_list.py | 2 +- onnxscript/nn/_sequential.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/nn/_module_list.py b/onnxscript/nn/_module_list.py index 04cf81c505..06850018b1 100644 --- a/onnxscript/nn/_module_list.py +++ b/onnxscript/nn/_module_list.py @@ -49,7 +49,7 @@ def _register_child(self, key: str, module: Module) -> None: if module._name is None: # pylint: disable=protected-access # Qualify with parent name if already set (e.g. after append) if self._name is not None: - module._set_name(f"{self._name}.{key}") + module._set_name(f"{self._name}.{key}") # pylint: disable=protected-access else: object.__setattr__(module, "_name", key) self._modules[key] = module diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py index 840dff8b7b..3f59230cdb 100644 --- a/onnxscript/nn/_sequential.py +++ b/onnxscript/nn/_sequential.py @@ -41,7 +41,7 @@ def _set_name(self, name: str) -> None: """ object.__setattr__(self, "_name", name) for key, child in self._modules.items(): - child._set_name(key) + child._set_name(key) # pylint: disable=protected-access def _register_child(self, key: str, module: _module_list.Module) -> None: """Register a child module under the given string key.