From bdf8da802825219ef2a7654ed515e1634c4d60e9 Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Tue, 24 Feb 2026 02:46:39 +0800 Subject: [PATCH] [Relax][PyTroch] Add randn.default and randn_like.default support Signed-off-by: Guan-Ming Chiu --- .../torch/exported_program_translator.py | 23 ++++++++++++++ .../test_frontend_from_exported_program.py | 30 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 233cba8df93d..39595a9f0043 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -991,6 +991,27 @@ def _hamming_window(self, node: fx.Node) -> relax.Var: relax.op.hamming_window(window_size, periodic, alpha, beta, dtype) ) + def _randn(self, node: fx.Node) -> relax.Var: + import numpy as np + + args = self.retrieve_args(node) + size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + data = np.random.randn(*size).astype(dtype) + return self.block_builder.emit(relax.const(data, dtype)) + + def _randn_like(self, node: fx.Node) -> relax.Var: + import numpy as np + + x = self.env[node.args[0]] + x_sinfo = x.struct_info + shape = [int(s) for s in x_sinfo.shape] + dtype = self._convert_data_type(node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env) + data = np.random.randn(*shape).astype(dtype) + return self.block_builder.emit(relax.const(data, dtype)) + def _zeros(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) @@ -1484,6 +1505,8 @@ def create_convert_map( "new_zeros.default": self._new_zeros, "one_hot.default": self._one_hot, "ones.default": self._ones, + "randn.default": self._randn, + "randn_like.default": self._randn_like, "ones_like.default": lambda node: self.block_builder.emit( relax.op.ones_like(self.env[node.args[0]]) ), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 0b5cb1f77775..6bab158c08eb 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6796,6 +6796,36 @@ def main(input: R.Tensor((128, 128), dtype="float32")) -> R.Tuple( verify_model(ZerosLike(), example_args, {}, Expected) +def test_randn(): + class Randn(Module): + def forward(self, input): + return input + torch.randn(5, 3) + + example_args = (torch.rand(5, 3, dtype=torch.float32),) + exported_program = export(Randn(), args=example_args) + mod = from_exported_program(exported_program) + func = mod["main"] + ret_sinfo = func.ret_struct_info + assert ret_sinfo.fields[0].shape[0] == 5 + assert ret_sinfo.fields[0].shape[1] == 3 + assert ret_sinfo.fields[0].dtype == "float32" + + +def test_randn_like(): + class RandnLike(Module): + def forward(self, input): + return input + torch.randn_like(input) + + example_args = (torch.rand(4, 6, dtype=torch.float32),) + exported_program = export(RandnLike(), args=example_args) + mod = from_exported_program(exported_program) + func = mod["main"] + ret_sinfo = func.ret_struct_info + assert ret_sinfo.fields[0].shape[0] == 4 + assert ret_sinfo.fields[0].shape[1] == 6 + assert ret_sinfo.fields[0].dtype == "float32" + + def test_type_as(): class TypeAs(Module): def forward(self, input, other):