diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 933460ad04..583f4e9399 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -854,7 +854,7 @@ def forward(self, op, x): return op.Add(x, op.Constant(value_float=1.0)) graph, op, x = self._make_input() - seq = Sequential([AddOne(), AddOne(), AddOne()]) + seq = Sequential(AddOne(), AddOne(), AddOne()) result = seq(op, x) self.assertIsInstance(result, ir.Value) @@ -869,7 +869,7 @@ def forward(self, op, x): return op.Identity(x) _, op, x = self._make_input() - seq = Sequential([PassThrough()]) + seq = Sequential(PassThrough()) result = seq(op, x) self.assertIsInstance(result, ir.Value) @@ -899,7 +899,7 @@ def forward(self, op, pair): return op.Add(a, b) graph, op, x = self._make_input() - seq = Sequential([SplitTwo(), UnpackAndAdd()]) + seq = Sequential(SplitTwo(), UnpackAndAdd()) result = seq(op, x) self.assertIsInstance(result, ir.Value) @@ -920,7 +920,7 @@ def forward(self, op, pair): return op.Add(a, b) _, op, x = self._make_input() - seq = Sequential([SplitTwoList(), UnpackAndAdd()]) + seq = Sequential(SplitTwoList(), UnpackAndAdd()) result = seq(op, x) self.assertIsInstance(result, ir.Value) @@ -938,7 +938,7 @@ def forward(self, op, pair): return pair _, op, x = self._make_input() - seq = Sequential([ReturnPair(), TupleIdentity()]) + seq = Sequential(ReturnPair(), TupleIdentity()) result = seq(op, x) self.assertIsInstance(result, tuple) self.assertEqual(len(result), 2) @@ -951,7 +951,7 @@ def forward(self, op, x): return (op.Identity(x), op.Identity(x)) _, op, x = self._make_input() - seq = Sequential([ReturnPair()]) + seq = Sequential(ReturnPair()) result = seq(op, x) self.assertIsInstance(result, tuple) self.assertEqual(len(result), 2) @@ -968,7 +968,7 @@ def forward(self, op, x): return (op.Identity(x), op.Identity(x), op.Identity(x)) _, op, x = self._make_input() - seq = Sequential([Identity(), SplitThree()]) + seq = Sequential(Identity(), SplitThree()) result = seq(op, x) self.assertIsInstance(result, tuple) self.assertEqual(len(result), 3) @@ -987,7 +987,7 @@ def forward(self, op, x): _, op = _create_graph_and_op() accept = AcceptNone() - seq = Sequential([ReturnNone(), accept]) + seq = Sequential(ReturnNone(), accept) result = seq(op, "anything") self.assertIsNone(result) self.assertIsNone(accept.received) @@ -1008,7 +1008,7 @@ def forward(self, op, x): class Model(Module): def __init__(self): super().__init__("model") - self.layers = Sequential([Linear(4, 4), Linear(4, 4)]) + self.layers = Sequential(Linear(4, 4), Linear(4, 4)) def forward(self, op, x): return self.layers(op, x) @@ -1035,7 +1035,7 @@ def __init__(self, size): def forward(self, op, x): return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0])) - seq = Sequential([SiLU(), Linear(4)]) + seq = Sequential(SiLU(), Linear(4)) named = dict(seq.named_parameters()) # SiLU at index 0 has no params; Linear at index 1 has weight self.assertIn("1.weight", named) @@ -1061,7 +1061,7 @@ def forward(self, op, x): class Model(Module): def __init__(self): super().__init__("model") - self.blocks = Sequential([]) + self.blocks = Sequential() # Append AFTER __setattr__ has set Sequential._name = "blocks" self.blocks.append(Linear(4)) self.blocks.append(Linear(4)) diff --git a/onnxscript/nn/_sequential.py b/onnxscript/nn/_sequential.py index 3f59230cdb..2aaba3a3cf 100644 --- a/onnxscript/nn/_sequential.py +++ b/onnxscript/nn/_sequential.py @@ -24,13 +24,16 @@ def forward(self, op, x): # Produces parameter names: "mod.0.weight", "mod.0.bias" # SiLU at index 0 has no parameters. - mod = Sequential([SiLU(), Linear(4, 4)]) + mod = Sequential(SiLU(), Linear(4, 4)) # Calling mod(op, x) is equivalent to: # x = silu(op, x) # x = linear(op, x) """ + def __init__(self, *modules: _module_list.Module) -> None: + super().__init__(modules) + def _set_name(self, name: str) -> None: """Set this container's name. Children keep simple ``"0"``, ``"1"`` names.