Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions tests/test_nn_Module_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,218 @@ def test_case_7():
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_8():
"""Positional dtype: to(torch.float64)"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to(torch.float64)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_9():
"""Positional dtype: to(torch.float16)"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to(torch.float16)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_10():
"""Keyword dtype: to(dtype=torch.float64)"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to(dtype=torch.float64)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_11():
"""Positional device string: to('cpu')"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to('cpu')
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_12():
"""Positional device + keyword dtype: to('cpu', dtype=torch.float64)"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to('cpu', dtype=torch.float64)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_13():
"""Positional tensor: to(some_tensor)"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
ref = torch.tensor([1.0], dtype=torch.float64)
module.to(ref)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_14():
"""Chaining: ret = module.to(torch.float64)"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
result = module.to(torch.float64).buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_15():
"""floating_only: int buffer should NOT be cast"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('float_buf', torch.tensor([1.0, 2.0, 3.0]))
module.register_buffer('int_buf', torch.tensor([1, 2, 3]))
module.to(torch.float64)
result = module.int_buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_16():
"""floating_only: float buffer should be cast"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('float_buf', torch.tensor([1.0, 2.0, 3.0]))
module.register_buffer('int_buf', torch.tensor([1, 2, 3]))
module.to(torch.float64)
result = module.float_buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_17():
"""Keyword non_blocking with positional dtype"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to(torch.float64, non_blocking=False)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_18():
"""Keyword device and dtype together"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to(device='cpu', dtype=torch.float64)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_19():
"""Reordered kwargs: dtype first, device second"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to(dtype=torch.float64, device='cpu')
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_20():
"""Sequential to() calls"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
module.to(torch.float64)
module.to(torch.float32)
result = module.buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_21():
"""Sublayers should be cast too"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
sub = torch.nn.Module()
sub.register_buffer('sub_buf', torch.tensor([4.0, 5.0, 6.0]))
module.add_module('sub', sub)
module.to(torch.float64)
result = module.sub.sub_buf
"""
)
obj.run(pytorch_code, ["result"])


def test_case_22():
"""to() with no args returns self"""
pytorch_code = textwrap.dedent(
"""
import torch
module = torch.nn.Module()
module.register_buffer('buf', torch.tensor([1.0, 2.0, 3.0]))
result = module.to().buf
"""
)
obj.run(pytorch_code, ["result"])