Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from onnxscript.nn._module import Module
from onnxscript.nn._module_list import ModuleList
from onnxscript.nn._parameter import Parameter
from onnxscript.nn._sequential import Sequential

__all__ = ["Module", "ModuleList", "Parameter"]
__all__ = ["Module", "ModuleList", "Parameter", "Sequential"]
8 changes: 4 additions & 4 deletions onnxscript/nn/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import onnx_ir as ir

from onnxscript._internal.builder import GraphBuilder, OpBuilder
from onnxscript._internal import builder as _builder
from onnxscript.nn._parameter import Parameter


Expand Down Expand Up @@ -70,8 +70,8 @@ def __setattr__(self, name: str, value: Any) -> None:
else:
object.__setattr__(self, name, value)

def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any:
builder: GraphBuilder = op.builder
def __call__(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any:
builder = op.builder
module_name = self._name or ""
class_name = type(self).__qualname__
builder.push_module(module_name, class_name)
Expand All @@ -85,7 +85,7 @@ def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any:
builder.pop_module()
return result

def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any:
def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any:
"""Define the computation performed by this module.

Must be overridden by subclasses. Receives an ``OpBuilder`` as the
Expand Down
10 changes: 7 additions & 3 deletions onnxscript/nn/_module_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, Iterator, overload

from onnxscript._internal.builder import OpBuilder
from onnxscript._internal import builder as _builder
from onnxscript.nn._module import Module


Expand Down Expand Up @@ -47,7 +47,11 @@ def _set_name(self, name: str) -> None:
def _register_child(self, key: str, module: Module) -> None:
"""Register a child module under the given string key."""
if module._name is None: # pylint: disable=protected-access
object.__setattr__(module, "_name", key)
# Qualify with parent name if already set (e.g. after append)
if self._name is not None:
module._set_name(f"{self._name}.{key}") # pylint: disable=protected-access
else:
object.__setattr__(module, "_name", key)
self._modules[key] = module
object.__setattr__(self, key, module)

Expand Down Expand Up @@ -89,7 +93,7 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())

def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any:
def forward(self, op: _builder.OpBuilder, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError(
"ModuleList is not callable directly. "
"Iterate over its children and call them individually."
Expand Down
303 changes: 302 additions & 1 deletion onnxscript/nn/_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import onnx_ir as ir

from onnxscript._internal.builder import GraphBuilder, OpBuilder
from onnxscript.nn import Module, ModuleList, Parameter
from onnxscript.nn import Module, ModuleList, Parameter, Sequential


def _create_graph_and_op() -> tuple[ir.Graph, OpBuilder]:
Expand Down Expand Up @@ -828,5 +828,306 @@ def test_modulelist_not_directly_callable(self):
ml(op)


class SequentialTest(unittest.TestCase):
def _make_input(
self, name: str = "input", shape: list[int] | None = None
) -> tuple[ir.Graph, OpBuilder, ir.Value]:
"""Create a graph, OpBuilder, and a FLOAT input value."""
if shape is None:
shape = [3]
graph, op = _create_graph_and_op()
x = ir.Value(
name=name,
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape(shape),
)
graph.inputs.append(x)
return graph, op, x

# -- basic forward chaining --

def test_chains_forward_calls(self):
"""Sequential calls children in order, each receiving one input."""

class AddOne(Module):
def forward(self, op, x):
return op.Add(x, op.Constant(value_float=1.0))

graph, op, x = self._make_input()
seq = Sequential([AddOne(), AddOne(), AddOne()])
result = seq(op, x)

self.assertIsInstance(result, ir.Value)
op_types = [node.op_type for node in graph]
self.assertEqual(op_types.count("Add"), 3)

def test_single_module_passthrough(self):
"""Single-module Sequential passes input through correctly."""

class PassThrough(Module):
def forward(self, op, x):
return op.Identity(x)

_, op, x = self._make_input()
seq = Sequential([PassThrough()])
result = seq(op, x)
self.assertIsInstance(result, ir.Value)

def test_empty_raises(self):
"""Empty Sequential raises RuntimeError on forward."""
seq = Sequential()
_, op = _create_graph_and_op()
with self.assertRaises(RuntimeError):
seq(op, None)

# -- tuple / list passthrough (PyTorch-aligned: no unpacking) --

def test_tuple_result_passed_as_single_arg(self):
"""When a module returns a tuple, the next module receives it as one argument."""

class SplitTwo(Module):
"""Returns two outputs as a tuple."""

def forward(self, op, x):
return (op.Identity(x), op.Identity(x))

class UnpackAndAdd(Module):
"""Accepts a single tuple argument and unpacks it."""

def forward(self, op, pair):
a, b = pair
return op.Add(a, b)

graph, op, x = self._make_input()
seq = Sequential([SplitTwo(), UnpackAndAdd()])
result = seq(op, x)

self.assertIsInstance(result, ir.Value)
op_types = [node.op_type for node in graph]
self.assertEqual(op_types.count("Identity"), 2)
self.assertEqual(op_types.count("Add"), 1)

def test_list_result_passed_as_single_arg(self):
"""When a module returns a list, the next module receives it as one argument."""

class SplitTwoList(Module):
def forward(self, op, x):
return [op.Identity(x), op.Identity(x)]

class UnpackAndAdd(Module):
def forward(self, op, pair):
a, b = pair
return op.Add(a, b)

_, op, x = self._make_input()
seq = Sequential([SplitTwoList(), UnpackAndAdd()])
result = seq(op, x)
self.assertIsInstance(result, ir.Value)

def test_tuple_passthrough_chain(self):
"""Tuple output passes through identity-like modules as a single arg."""

class ReturnPair(Module):
def forward(self, op, x):
return (op.Identity(x), op.Identity(x))

class TupleIdentity(Module):
"""Receives a tuple, returns it as-is."""

def forward(self, op, pair):
return pair

_, op, x = self._make_input()
seq = Sequential([ReturnPair(), TupleIdentity()])
result = seq(op, x)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)

def test_single_module_returns_tuple_unchanged(self):
"""A single-module Sequential that returns a tuple passes it through."""

class ReturnPair(Module):
def forward(self, op, x):
return (op.Identity(x), op.Identity(x))

_, op, x = self._make_input()
seq = Sequential([ReturnPair()])
result = seq(op, x)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)

def test_last_module_tuple_return_is_preserved(self):
"""The final module's tuple return is preserved as the Sequential result."""

class Identity(Module):
def forward(self, op, x):
return op.Identity(x)

class SplitThree(Module):
def forward(self, op, x):
return (op.Identity(x), op.Identity(x), op.Identity(x))

_, op, x = self._make_input()
seq = Sequential([Identity(), SplitThree()])
result = seq(op, x)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 3)

def test_none_passthrough(self):
"""Module returning None passes None to the next module."""

class ReturnNone(Module):
def forward(self, op, x): # pylint: disable=unused-argument
return None

class AcceptNone(Module):
def forward(self, op, x):
self.received = x
return x

_, op = _create_graph_and_op()
accept = AcceptNone()
seq = Sequential([ReturnNone(), accept])
result = seq(op, "anything")
self.assertIsNone(result)
self.assertIsNone(accept.received)

# -- parameter naming --

def test_parameter_naming(self):
"""Sequential produces numeric-indexed parameter names like ModuleList."""

class Linear(Module):
def __init__(self, in_f, out_f):
super().__init__()
self.weight = Parameter([out_f, in_f], name="weight")

def forward(self, op, x):
return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0]))

class Model(Module):
def __init__(self):
super().__init__("model")
self.layers = Sequential([Linear(4, 4), Linear(4, 4)])

def forward(self, op, x):
return self.layers(op, x)

graph, op, x = self._make_input(shape=[1, 4])
m = Model()
m(op, x)

self.assertIn("model.layers.0.weight", graph.initializers)
self.assertIn("model.layers.1.weight", graph.initializers)

def test_parameterless_modules(self):
"""Sequential works with mixed param/no-param children (like SiLU + Linear)."""

class SiLU(Module):
def forward(self, op, x):
return op.Mul(x, op.Sigmoid(x))

class Linear(Module):
def __init__(self, size):
super().__init__()
self.weight = Parameter([size, size], name="weight")

def forward(self, op, x):
return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0]))

seq = Sequential([SiLU(), Linear(4)])
named = dict(seq.named_parameters())
# SiLU at index 0 has no params; Linear at index 1 has weight
self.assertIn("1.weight", named)
self.assertEqual(len(named), 1)

def test_append_produces_correct_initializer_names(self):
"""Children appended after parent registration get correct names.

Tests the pattern where a Sequential is created empty, registered
on a parent Module, and then children are appended. The children should
produce initializer names like ``parent.seq.0.weight``, not
``parent.seq.seq.0.weight`` (double-prefixed).
"""

class Linear(Module):
def __init__(self, size):
super().__init__()
self.weight = Parameter([size, size], name="weight")

def forward(self, op, x):
return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0]))

class Model(Module):
def __init__(self):
super().__init__("model")
self.blocks = Sequential([])
# Append AFTER __setattr__ has set Sequential._name = "blocks"
self.blocks.append(Linear(4))
self.blocks.append(Linear(4))

def forward(self, op, x):
return self.blocks(op, x)

graph, op, x = self._make_input(shape=[1, 4])
m = Model()
m(op, x)

self.assertIn("model.blocks.0.weight", graph.initializers)
self.assertIn("model.blocks.1.weight", graph.initializers)
self.assertEqual(len(graph.initializers), 2)

def test_modulelist_append_produces_correct_initializer_names(self):
"""ModuleList with append after parent registration gets correct names.

Tests the interleaved ModuleList pattern (like a mid_block with separate
resnets and attentions lists). Children appended after parent registration
should produce names like ``parent.resnets.1.weight``.
"""

class Linear(Module):
def __init__(self, size):
super().__init__()
self.weight = Parameter([size, size], name="weight")

def forward(self, op, x):
return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0]))

class MidBlock(Module):
def __init__(self):
super().__init__()
self.resnets = ModuleList([Linear(4)])
self.attentions = ModuleList([])
# Appends after parent has set _name on the ModuleLists
self.attentions.append(Linear(4))
self.resnets.append(Linear(4))

def forward(self, op, x):
x = self.resnets[0](op, x)
for i in range(len(self.attentions)): # pylint: disable=consider-using-enumerate
x = self.attentions[i](op, x)
x = self.resnets[i + 1](op, x)
return x

class Model(Module):
def __init__(self):
super().__init__("model")
self.mid = MidBlock()

def forward(self, op, x):
return self.mid(op, x)

graph, op, x = self._make_input(shape=[1, 4])
m = Model()
m(op, x)

expected = {
"model.mid.resnets.0.weight",
"model.mid.resnets.1.weight",
"model.mid.attentions.0.weight",
}
self.assertEqual(set(graph.initializers.keys()), expected)


if __name__ == "__main__":
unittest.main()
Loading
Loading