Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,27 @@ def _hamming_window(self, node: fx.Node) -> relax.Var:
relax.op.hamming_window(window_size, periodic, alpha, beta, dtype)
)

def _randn(self, node: fx.Node) -> relax.Var:
import numpy as np

args = self.retrieve_args(node)
size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],)
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
data = np.random.randn(*size).astype(dtype)
return self.block_builder.emit(relax.const(data, dtype))
Comment on lines +994 to +1003
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation of _randn will fail if size contains symbolic dimensions (i.e., tir.Var from torch.export.Dim), as np.random.randn requires concrete integer shapes. To prevent a crash with dynamic shapes, it's better to add an explicit check and raise a NotImplementedError.

Suggested change
def _randn(self, node: fx.Node) -> relax.Var:
import numpy as np
args = self.retrieve_args(node)
size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],)
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
data = np.random.randn(*size).astype(dtype)
return self.block_builder.emit(relax.const(data, dtype))
def _randn(self, node: fx.Node) -> relax.Var:
import numpy as np
args = self.retrieve_args(node)
size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],)
if any(isinstance(s, tir.Var) for s in size):
raise NotImplementedError("torch.randn with dynamic shapes is not supported yet.")
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
data = np.random.randn(*size).astype(dtype)
return self.block_builder.emit(relax.const(data, dtype))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int(s) will already raise a TypeError with a clear message. Adding explicit checks here would be inconsistent with the rest of the codebase.


def _randn_like(self, node: fx.Node) -> relax.Var:
import numpy as np

x = self.env[node.args[0]]
x_sinfo = x.struct_info
shape = [int(s) for s in x_sinfo.shape]
dtype = self._convert_data_type(node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env)
data = np.random.randn(*shape).astype(dtype)
return self.block_builder.emit(relax.const(data, dtype))

def _zeros(self, node: fx.Node) -> relax.Var:
Comment on lines +1005 to 1015
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to _randn, this implementation for _randn_like does not support dynamic shapes. The line shape = [int(s) for s in x_sinfo.shape] will raise a TypeError if any shape dimension s is a symbolic tir.Var. Please add a check to handle this case gracefully.

Suggested change
def _randn_like(self, node: fx.Node) -> relax.Var:
import numpy as np
x = self.env[node.args[0]]
x_sinfo = x.struct_info
shape = [int(s) for s in x_sinfo.shape]
dtype = self._convert_data_type(node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env)
data = np.random.randn(*shape).astype(dtype)
return self.block_builder.emit(relax.const(data, dtype))
def _zeros(self, node: fx.Node) -> relax.Var:
def _randn_like(self, node: fx.Node) -> relax.Var:
import numpy as np
x = self.env[node.args[0]]
x_sinfo = x.struct_info
if any(isinstance(s, tir.Var) for s in x_sinfo.shape):
raise NotImplementedError("torch.randn_like with dynamic shapes is not supported yet.")
shape = [int(s) for s in x_sinfo.shape]
dtype = self._convert_data_type(
node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env
)
data = np.random.randn(*shape).astype(dtype)
return self.block_builder.emit(relax.const(data, dtype))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

args = self.retrieve_args(node)
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
Expand Down Expand Up @@ -1484,6 +1505,8 @@ def create_convert_map(
"new_zeros.default": self._new_zeros,
"one_hot.default": self._one_hot,
"ones.default": self._ones,
"randn.default": self._randn,
"randn_like.default": self._randn_like,
"ones_like.default": lambda node: self.block_builder.emit(
relax.op.ones_like(self.env[node.args[0]])
),
Expand Down
30 changes: 30 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6796,6 +6796,36 @@ def main(input: R.Tensor((128, 128), dtype="float32")) -> R.Tuple(
verify_model(ZerosLike(), example_args, {}, Expected)


def test_randn():
class Randn(Module):
def forward(self, input):
return input + torch.randn(5, 3)

example_args = (torch.rand(5, 3, dtype=torch.float32),)
exported_program = export(Randn(), args=example_args)
mod = from_exported_program(exported_program)
func = mod["main"]
ret_sinfo = func.ret_struct_info
assert ret_sinfo.fields[0].shape[0] == 5
assert ret_sinfo.fields[0].shape[1] == 3
assert ret_sinfo.fields[0].dtype == "float32"


def test_randn_like():
class RandnLike(Module):
def forward(self, input):
return input + torch.randn_like(input)

example_args = (torch.rand(4, 6, dtype=torch.float32),)
exported_program = export(RandnLike(), args=example_args)
mod = from_exported_program(exported_program)
func = mod["main"]
ret_sinfo = func.ret_struct_info
assert ret_sinfo.fields[0].shape[0] == 4
assert ret_sinfo.fields[0].shape[1] == 6
assert ret_sinfo.fields[0].dtype == "float32"


def test_type_as():
class TypeAs(Module):
def forward(self, input, other):
Expand Down
Loading