From bdde8e3fd5b5cba192ccc4238651c1888c3614d8 Mon Sep 17 00:00:00 2001 From: wenfei qi Date: Sat, 14 Mar 2026 19:30:36 -0400 Subject: [PATCH] enhance test for torch.nn.Module.to --- tests/test_nn_Module_to.py | 215 +++++++++++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) diff --git a/tests/test_nn_Module_to.py b/tests/test_nn_Module_to.py index b6df8c8fc..0e579d285 100644 --- a/tests/test_nn_Module_to.py +++ b/tests/test_nn_Module_to.py @@ -125,3 +125,218 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_8(): + """Positional dtype: to(torch.float64)""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to(torch.float64) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Positional dtype: to(torch.float16)""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to(torch.float16) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Keyword dtype: to(dtype=torch.float64)""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to(dtype=torch.float64) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_11(): + """Positional device string: to('cpu')""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to('cpu') + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + """Positional device + keyword dtype: to('cpu', dtype=torch.float64)""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to('cpu', dtype=torch.float64) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_13(): + """Positional tensor: to(some_tensor)""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + ref = torch.tensor([1.0], dtype=torch.float64) + module.to(ref) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_14(): + """Chaining: ret = module.to(torch.float64)""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + result = module.to(torch.float64).buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_15(): + """floating_only: int buffer should NOT be cast""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('float_buf', torch.tensor([1.0, 2.0, 3.0])) + module.register_buffer('int_buf', torch.tensor([1, 2, 3])) + module.to(torch.float64) + result = module.int_buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_16(): + """floating_only: float buffer should be cast""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('float_buf', torch.tensor([1.0, 2.0, 3.0])) + module.register_buffer('int_buf', torch.tensor([1, 2, 3])) + module.to(torch.float64) + result = module.float_buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_17(): + """Keyword non_blocking with positional dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to(torch.float64, non_blocking=False) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_18(): + """Keyword device and dtype together""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to(device='cpu', dtype=torch.float64) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_19(): + """Reordered kwargs: dtype first, device second""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to(dtype=torch.float64, device='cpu') + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_20(): + """Sequential to() calls""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + module.to(torch.float64) + module.to(torch.float32) + result = module.buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_21(): + """Sublayers should be cast too""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + sub = torch.nn.Module() + sub.register_buffer('sub_buf', torch.tensor([4.0, 5.0, 6.0])) + module.add_module('sub', sub) + module.to(torch.float64) + result = module.sub.sub_buf + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_22(): + """to() with no args returns self""" + pytorch_code = textwrap.dedent( + """ + import torch + module = torch.nn.Module() + module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0])) + result = module.to().buf + """ + ) + obj.run(pytorch_code, ["result"])