[Relax][PyTroch] Add randn.default and randn_like.default support#18815
Conversation
Summary of ChangesHello @guan404ming, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue where PyTorch models utilizing Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Signed-off-by: Guan-Ming Chiu <guanmingchiu@gmail.com>
a2d36fc to
bdf8da8
Compare
There was a problem hiding this comment.
Code Review
This pull request adds support for torch.randn and torch.randn_like in the PyTorch frontend by emitting constant tensors. This is a solid approach for handling these operations during model import. The accompanying tests correctly verify the output shape and data type for static shapes.
My main feedback concerns the handling of dynamic shapes. The current implementations for _randn and _randn_like will fail if they encounter symbolic dimensions, as they rely on numpy which requires concrete shapes. I've suggested adding explicit checks to raise a NotImplementedError in such cases to prevent crashes and make the limitation clear.
| def _randn(self, node: fx.Node) -> relax.Var: | ||
| import numpy as np | ||
|
|
||
| args = self.retrieve_args(node) | ||
| size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],) | ||
| dtype = self._convert_data_type( | ||
| node.kwargs.get("dtype", torch.get_default_dtype()), self.env | ||
| ) | ||
| data = np.random.randn(*size).astype(dtype) | ||
| return self.block_builder.emit(relax.const(data, dtype)) |
There was a problem hiding this comment.
This implementation of _randn will fail if size contains symbolic dimensions (i.e., tir.Var from torch.export.Dim), as np.random.randn requires concrete integer shapes. To prevent a crash with dynamic shapes, it's better to add an explicit check and raise a NotImplementedError.
| def _randn(self, node: fx.Node) -> relax.Var: | |
| import numpy as np | |
| args = self.retrieve_args(node) | |
| size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],) | |
| dtype = self._convert_data_type( | |
| node.kwargs.get("dtype", torch.get_default_dtype()), self.env | |
| ) | |
| data = np.random.randn(*size).astype(dtype) | |
| return self.block_builder.emit(relax.const(data, dtype)) | |
| def _randn(self, node: fx.Node) -> relax.Var: | |
| import numpy as np | |
| args = self.retrieve_args(node) | |
| size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],) | |
| if any(isinstance(s, tir.Var) for s in size): | |
| raise NotImplementedError("torch.randn with dynamic shapes is not supported yet.") | |
| dtype = self._convert_data_type( | |
| node.kwargs.get("dtype", torch.get_default_dtype()), self.env | |
| ) | |
| data = np.random.randn(*size).astype(dtype) | |
| return self.block_builder.emit(relax.const(data, dtype)) |
There was a problem hiding this comment.
int(s) will already raise a TypeError with a clear message. Adding explicit checks here would be inconsistent with the rest of the codebase.
| def _randn_like(self, node: fx.Node) -> relax.Var: | ||
| import numpy as np | ||
|
|
||
| x = self.env[node.args[0]] | ||
| x_sinfo = x.struct_info | ||
| shape = [int(s) for s in x_sinfo.shape] | ||
| dtype = self._convert_data_type(node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env) | ||
| data = np.random.randn(*shape).astype(dtype) | ||
| return self.block_builder.emit(relax.const(data, dtype)) | ||
|
|
||
| def _zeros(self, node: fx.Node) -> relax.Var: |
There was a problem hiding this comment.
Similar to _randn, this implementation for _randn_like does not support dynamic shapes. The line shape = [int(s) for s in x_sinfo.shape] will raise a TypeError if any shape dimension s is a symbolic tir.Var. Please add a check to handle this case gracefully.
| def _randn_like(self, node: fx.Node) -> relax.Var: | |
| import numpy as np | |
| x = self.env[node.args[0]] | |
| x_sinfo = x.struct_info | |
| shape = [int(s) for s in x_sinfo.shape] | |
| dtype = self._convert_data_type(node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env) | |
| data = np.random.randn(*shape).astype(dtype) | |
| return self.block_builder.emit(relax.const(data, dtype)) | |
| def _zeros(self, node: fx.Node) -> relax.Var: | |
| def _randn_like(self, node: fx.Node) -> relax.Var: | |
| import numpy as np | |
| x = self.env[node.args[0]] | |
| x_sinfo = x.struct_info | |
| if any(isinstance(s, tir.Var) for s in x_sinfo.shape): | |
| raise NotImplementedError("torch.randn_like with dynamic shapes is not supported yet.") | |
| shape = [int(s) for s in x_sinfo.shape] | |
| dtype = self._convert_data_type( | |
| node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env | |
| ) | |
| data = np.random.randn(*shape).astype(dtype) | |
| return self.block_builder.emit(relax.const(data, dtype)) |
|
Thanks! |
Why
PyTorch models using torch.randn() or torch.randn_like() fail to convert via from_exported_program (part of #18476).
How