From f788b24556002589a44a4c077b5f75928b263190 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Feb 2026 22:31:57 -0800 Subject: [PATCH 1/5] Submodule to() tests. --- .../openequivariance/_torch/TensorProduct.py | 29 +++- .../_torch/TensorProductConv.py | 24 ++- tests/batch_test.py | 94 ++++++++++- tests/conftest.py | 6 + tests/conv_test.py | 154 +++++++++++++++--- tests/vmap_test.py | 5 - 6 files changed, 273 insertions(+), 39 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 35640b25..254da414 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -2,9 +2,13 @@ from openequivariance import TPProblem from openequivariance._torch import extlib import torch -from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum from openequivariance.benchmark.logging_utils import getLogger -from openequivariance._torch.utils import reorder_torch, string_to_tensor +from openequivariance._torch.utils import ( + reorder_torch, + string_to_tensor, + enum_to_torch_dtype, +) from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np @@ -66,6 +70,27 @@ def to(self, *args, **kwargs): torch.nn.Module.to(self, *args, **kwargs) return self + def _apply(self, fn, recurse=True): + if getattr(self, "_applying", False): + return super()._apply(fn, recurse) + + problem: TPProblem = self.input_args["problem"] + irrep_dtype = problem.irrep_dtype + + if irrep_dtype in dtype_to_enum: + irrep_dtype = dtype_to_enum[irrep_dtype] + + current_dtype = enum_to_torch_dtype[irrep_dtype] + dummy = torch.tensor(0.0, dtype=current_dtype) + result = fn(dummy) + + if result.dtype != current_dtype: + self._applying = True + self.to(result.dtype) + self._applying = False + + return super()._apply(fn, recurse) + def __getstate__(self): return self.input_args diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index e1c0e742..30931151 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -16,10 +16,11 @@ from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem -from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum from openequivariance._torch.utils import ( reorder_torch, string_to_tensor, + enum_to_torch_dtype, ) from openequivariance.benchmark.logging_utils import getLogger @@ -109,6 +110,27 @@ def to(self, *args, **kwargs): torch.nn.Module.to(self, *args, **kwargs) return self + def _apply(self, fn, recurse=True): + if getattr(self, "_applying", False): + return super()._apply(fn, recurse) + + problem: TPProblem = self.input_args["problem"] + irrep_dtype = problem.irrep_dtype + + if irrep_dtype in dtype_to_enum: + irrep_dtype = dtype_to_enum[irrep_dtype] + + current_dtype = enum_to_torch_dtype[irrep_dtype] + dummy = torch.tensor(0.0, dtype=current_dtype) + result = fn(dummy) + + if result.dtype != current_dtype: + self._applying = True + self.to(result.dtype) + self._applying = False + + return super()._apply(fn, recurse) + def __getstate__(self): return self.input_args diff --git a/tests/batch_test.py b/tests/batch_test.py index 788950ab..dbc9b2df 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -18,6 +18,9 @@ from itertools import product import torch +@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") +def dtype(request): + return request.param class TPCorrectness: def thresh(self, direction): @@ -31,18 +34,10 @@ def check_result(self, result, fieldname): f"{fieldname} observed error={error:.5f} >= {thresh}" ) - @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") - def dtype(self, request): - return request.param - @pytest.fixture(scope="class") def extra_tp_constructor_args(self): return {} - @pytest.fixture(scope="class") - def with_jax(self, request): - return request.config.getoption("--jax") - @pytest.fixture(scope="class") def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): cls = oeq.TensorProduct @@ -274,3 +269,86 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): } tp.to(switch_map[problem.irrep_dtype]) return tp, tp.config + + +class TestTorchToSubmodule: + """Test that TensorProduct works correctly as a submodule when parent's .to() is called""" + + @pytest.fixture(scope="class") + def parent_module_and_problem(self, dtype, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + problem = mace_problems()[0].clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + + class ParentModule(torch.nn.Module): + def __init__(self, problem): + super().__init__() + self.tp = oeq.TensorProduct(problem) + + def forward(self, x, y, w): + return self.tp(x, y, w) + + parent = ParentModule(problem) + return parent, problem + + def _problem_dtype(self, problem): + return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64 + + def _make_inputs(self, problem, batch_size, rng, dtype, device): + in1 = torch.tensor( + rng.uniform(size=(batch_size, problem.irreps_in1.dim)), + dtype=dtype, + device=device, + ) + in2 = torch.tensor( + rng.uniform(size=(batch_size, problem.irreps_in2.dim)), + dtype=dtype, + device=device, + ) + weights_size = ( + (problem.weight_numel,) + if problem.shared_weights + else (batch_size, problem.weight_numel) + ) + weights = torch.tensor( + rng.uniform(size=weights_size), + dtype=dtype, + device=device, + ) + return in1, in2, weights + + def test_submodule_dtype_conversion(self, parent_module_and_problem): + """Test that calling .to() on parent module properly converts TensorProduct submodule""" + parent, problem = parent_module_and_problem + + # Generate test inputs with the original dtype + batch_size = 10 + rng = np.random.default_rng(12345) + device = "cuda" + input_dtype = self._problem_dtype(problem) + in1, in2, weights = self._make_inputs( + problem, batch_size, rng, input_dtype, device + ) + + # Run forward pass with original dtype + output1 = parent(in1, in2, weights) + assert output1.dtype == in1.dtype, f"Expected output dtype {in1.dtype}, got {output1.dtype}" + + # Convert parent module to different dtype + switch_map = { + np.float32: torch.float64, + np.float64: torch.float32, + } + target_dtype = switch_map[problem.irrep_dtype] + parent.to(target_dtype) + + # Generate new test inputs with the target dtype + in1_new, in2_new, weights_new = self._make_inputs( + problem, batch_size, rng, target_dtype, device + ) + + # This should work but will fail without _apply implementation + output2 = parent(in1_new, in2_new, weights_new) + assert output2.dtype == target_dtype, f"Expected output dtype {target_dtype}, got {output2.dtype}" diff --git a/tests/conftest.py b/tests/conftest.py index 323de863..4a515664 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import pytest os.environ["JAX_ENABLE_X64"] = "True" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" @@ -12,3 +13,8 @@ def pytest_addoption(parser): default=False, help="Test the JAX frontend instead of PyTorch", ) + + +@pytest.fixture(scope="session") +def with_jax(request): + return request.config.getoption("--jax") diff --git a/tests/conv_test.py b/tests/conv_test.py index 9c6bb4c8..9efafe22 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -15,6 +15,28 @@ e3tools_problems, ) +@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") +def dtype(request): + return request.param + +@pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="module") +def graph(request): + download_prefix = ( + "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" + ) + filename = request.param + + graph = None + with tempfile.NamedTemporaryFile() as temp_file: + urllib.request.urlretrieve(download_prefix + filename, temp_file.name) + graph = load_graph(temp_file.name) + + # graph = load_graph("data/1drf_radius3.5.pickle") + return graph + +@pytest.fixture(scope="module") +def with_jax(request): + return request.config.getoption("--jax") class ConvCorrectness: def thresh(self, direction): @@ -28,33 +50,10 @@ def check_result(self, result, fieldname): f"{fieldname} observed error={error:.5f} >= {thresh}" ) - @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") - def dtype(self, request): - return request.param - - @pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="class") - def graph(self, request): - download_prefix = ( - "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" - ) - filename = request.param - - graph = None - with tempfile.NamedTemporaryFile() as temp_file: - urllib.request.urlretrieve(download_prefix + filename, temp_file.name) - graph = load_graph(temp_file.name) - - # graph = load_graph("data/1drf_radius3.5.pickle") - return graph - @pytest.fixture(scope="class") def extra_conv_constructor_args(self): return {} - @pytest.fixture(scope="class") - def with_jax(self, request): - return request.config.getoption("--jax") - @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class") def conv_object(self, request, problem, extra_conv_constructor_args, with_jax): cls = oeq.TensorProductConv @@ -281,3 +280,112 @@ def conv_object(self, request, problem, extra_conv_constructor_args): **extra_conv_constructor_args, ) return module.to(switch_map[problem.irrep_dtype]) + +class TestTorchToSubmodule: + """Test that TensorProductConv works as a submodule when parent's .to() is called""" + + @pytest.fixture(params=["atomic", "deterministic"], scope="class") + def parent_module_and_problem( + self, + request, + dtype, + with_jax + ): + if with_jax: + pytest.skip("N/A for JAX") + + problem = mace_problems()[0].clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + deterministic = request.param == "deterministic" + + class ParentModule(torch.nn.Module): + def __init__(self, problem, deterministic): + super().__init__() + self.conv = oeq.TensorProductConv( + problem, + deterministic=deterministic + ) + + def forward(self, x, y, w, rows, cols, sender_perm=None): + return self.conv(x, y, w, rows, cols, sender_perm) + + parent = ParentModule(problem, deterministic) + return parent, problem + + def _problem_dtype(self, problem): + return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64 + + def _make_inputs(self, problem, graph, rng, dtype, device, deterministic): + node_count = graph.node_count + nnz = graph.nnz + + in1 = torch.tensor( + rng.uniform(size=(node_count, problem.irreps_in1.dim)), + dtype=dtype, + device=device, + ) + in2 = torch.tensor( + rng.uniform(size=(nnz, problem.irreps_in2.dim)), + dtype=dtype, + device=device, + ) + weights_size = ( + (problem.weight_numel,) + if problem.shared_weights + else (nnz, problem.weight_numel) + ) + weights = torch.tensor( + rng.uniform(size=weights_size), + dtype=dtype, + device=device, + ) + + rows = torch.tensor(graph.rows, device=device) + cols = torch.tensor(graph.cols, device=device) + sender_perm = ( + torch.tensor(graph.transpose_perm, device=device) + if deterministic + else None + ) + return in1, in2, weights, rows, cols, sender_perm + + def test_submodule_dtype_conversion(self, parent_module_and_problem, graph): + parent, problem = parent_module_and_problem + device = "cuda" + + rng = np.random.default_rng(12345) + input_dtype = self._problem_dtype(problem) + in1, in2, weights, rows, cols, sender_perm = self._make_inputs( + problem, + graph, + rng, + input_dtype, + device, + parent.conv.deterministic, + ) + + output1 = parent(in1, in2, weights, rows, cols, sender_perm) + assert output1.dtype == input_dtype, ( + f"Expected output dtype {input_dtype}, got {output1.dtype}" + ) + + switch_map = { + np.float32: torch.float64, + np.float64: torch.float32, + } + target_dtype = switch_map[problem.irrep_dtype] + parent.to(target_dtype) + + in1_new, in2_new, weights_new, rows, cols, sender_perm = self._make_inputs( + problem, + graph, + rng, + target_dtype, + device, + parent.conv.deterministic, + ) + + output2 = parent(in1_new, in2_new, weights_new, rows, cols, sender_perm) + assert output2.dtype == target_dtype, ( + f"Expected output dtype {target_dtype}, got {output2.dtype}" + ) diff --git a/tests/vmap_test.py b/tests/vmap_test.py index e165cf25..4aad8942 100644 --- a/tests/vmap_test.py +++ b/tests/vmap_test.py @@ -2,11 +2,6 @@ import os -@pytest.fixture -def with_jax(request): - return request.config.getoption("--jax") - - @pytest.fixture def ctx(with_jax): if not with_jax: From 449fd3c05e87a6bfa5f104e9ad84764c5912ae9d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Feb 2026 22:50:55 -0800 Subject: [PATCH 2/5] Updated changelog. --- CHANGELOG.md | 23 ++++++++++++++++ openequivariance/openequivariance/__init__.py | 5 ++++ tests/batch_test.py | 26 ++++++++++++------- tests/conv_test.py | 25 +++++++----------- 4 files changed, 53 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14cd1a72..cb0e6772 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ ## Latest Changes +### v0.6.0 (2025-02-23) +OpenEquivariance v0.6.0 brings long-needed improvements to the +PyTorch frontend. We strongly encourage all users to upgrade +to PyTorch 2.10 and OEQ v0.6.0. + +**Added**: +- OpenEquivariance triggers a build of the CUDA extension module + at `pip` install time and will use this precompiled extension if + the user has PyTorch 2.10 installed. If PyTorch 2.10 is not installed, + the JIT-compiled extension is used instead. +- PyTorch stable ABI support for C++ backend, using new features in PyTorch + 2.10 to support stable, forward-compatible ahead-of-time + extensions. +- Dropped support for TorchBind classes and a new kernel cache in its + place, which greatly improves flexibility for automatic mixed precision + and AOTI compilation. An inference test in C++ is included. + +**Fixed**: +- `torch.to()` is now called when either `TensorProduct` + or `TensorProductConv` is a submodule of another PyTorch + module. + + ### v0.5.4 (2025-02-01) Improvements to JAX frontend. diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 7fc0b0f9..24ad924f 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -80,6 +80,11 @@ def torch_ext_so_path(): try: import openequivariance_extjax import openequivariance.jax as jax + + assert openequivariance_extjax.__version__ == __version__, ( + f"openequivariance_extjax version {openequivariance_extjax.__version__} does not match " + f"openequivariance version {__version__}. Ensure both are the same." + ) except Exception as e: error = e diff --git a/tests/batch_test.py b/tests/batch_test.py index dbc9b2df..30e20f31 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -18,10 +18,12 @@ from itertools import product import torch + @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") def dtype(request): return request.param + class TPCorrectness: def thresh(self, direction): return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -281,15 +283,15 @@ def parent_module_and_problem(self, dtype, with_jax): problem = mace_problems()[0].clone() problem.irrep_dtype, problem.weight_dtype = dtype, dtype - + class ParentModule(torch.nn.Module): def __init__(self, problem): super().__init__() self.tp = oeq.TensorProduct(problem) - + def forward(self, x, y, w): return self.tp(x, y, w) - + parent = ParentModule(problem) return parent, problem @@ -322,7 +324,7 @@ def _make_inputs(self, problem, batch_size, rng, dtype, device): def test_submodule_dtype_conversion(self, parent_module_and_problem): """Test that calling .to() on parent module properly converts TensorProduct submodule""" parent, problem = parent_module_and_problem - + # Generate test inputs with the original dtype batch_size = 10 rng = np.random.default_rng(12345) @@ -331,11 +333,13 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem): in1, in2, weights = self._make_inputs( problem, batch_size, rng, input_dtype, device ) - + # Run forward pass with original dtype output1 = parent(in1, in2, weights) - assert output1.dtype == in1.dtype, f"Expected output dtype {in1.dtype}, got {output1.dtype}" - + assert output1.dtype == in1.dtype, ( + f"Expected output dtype {in1.dtype}, got {output1.dtype}" + ) + # Convert parent module to different dtype switch_map = { np.float32: torch.float64, @@ -343,12 +347,14 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem): } target_dtype = switch_map[problem.irrep_dtype] parent.to(target_dtype) - + # Generate new test inputs with the target dtype in1_new, in2_new, weights_new = self._make_inputs( problem, batch_size, rng, target_dtype, device ) - + # This should work but will fail without _apply implementation output2 = parent(in1_new, in2_new, weights_new) - assert output2.dtype == target_dtype, f"Expected output dtype {target_dtype}, got {output2.dtype}" + assert output2.dtype == target_dtype, ( + f"Expected output dtype {target_dtype}, got {output2.dtype}" + ) diff --git a/tests/conv_test.py b/tests/conv_test.py index 9efafe22..50b6376b 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -15,15 +15,15 @@ e3tools_problems, ) + @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") def dtype(request): return request.param + @pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="module") def graph(request): - download_prefix = ( - "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" - ) + download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" filename = request.param graph = None @@ -34,10 +34,12 @@ def graph(request): # graph = load_graph("data/1drf_radius3.5.pickle") return graph + @pytest.fixture(scope="module") def with_jax(request): return request.config.getoption("--jax") + class ConvCorrectness: def thresh(self, direction): return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -281,16 +283,12 @@ def conv_object(self, request, problem, extra_conv_constructor_args): ) return module.to(switch_map[problem.irrep_dtype]) + class TestTorchToSubmodule: """Test that TensorProductConv works as a submodule when parent's .to() is called""" @pytest.fixture(params=["atomic", "deterministic"], scope="class") - def parent_module_and_problem( - self, - request, - dtype, - with_jax - ): + def parent_module_and_problem(self, request, dtype, with_jax): if with_jax: pytest.skip("N/A for JAX") @@ -301,10 +299,7 @@ def parent_module_and_problem( class ParentModule(torch.nn.Module): def __init__(self, problem, deterministic): super().__init__() - self.conv = oeq.TensorProductConv( - problem, - deterministic=deterministic - ) + self.conv = oeq.TensorProductConv(problem, deterministic=deterministic) def forward(self, x, y, w, rows, cols, sender_perm=None): return self.conv(x, y, w, rows, cols, sender_perm) @@ -343,9 +338,7 @@ def _make_inputs(self, problem, graph, rng, dtype, device, deterministic): rows = torch.tensor(graph.rows, device=device) cols = torch.tensor(graph.cols, device=device) sender_perm = ( - torch.tensor(graph.transpose_perm, device=device) - if deterministic - else None + torch.tensor(graph.transpose_perm, device=device) if deterministic else None ) return in1, in2, weights, rows, cols, sender_perm From 34a8e77322c8a9790e8c37763430a01c0e353e1c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Feb 2026 22:52:09 -0800 Subject: [PATCH 3/5] Minor fixes to changelog. --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb0e6772..15bfe20e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,9 @@ to PyTorch 2.10 and OEQ v0.6.0. **Added**: - OpenEquivariance triggers a build of the CUDA extension module at `pip` install time and will use this precompiled extension if - the user has PyTorch 2.10 installed. If PyTorch 2.10 is not installed, + the user has PyTorch >=2.10 installed. If PyTorch <2.10 is installed, the JIT-compiled extension is used instead. -- PyTorch stable ABI support for C++ backend, using new features in PyTorch +- PyTorch ABI support for C++ backend, using new features in PyTorch 2.10 to support stable, forward-compatible ahead-of-time extensions. - Dropped support for TorchBind classes and a new kernel cache in its From a8d376d84e01035063c28d27c5e0ac7c87f74348 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Feb 2026 23:11:22 -0800 Subject: [PATCH 4/5] Removed version mismatch error. --- CHANGELOG.md | 2 ++ openequivariance/openequivariance/__init__.py | 13 +++++++++---- tests/example_test.py | 5 ----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15bfe20e..81e7084d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ to PyTorch 2.10 and OEQ v0.6.0. - Dropped support for TorchBind classes and a new kernel cache in its place, which greatly improves flexibility for automatic mixed precision and AOTI compilation. An inference test in C++ is included. +- `openequivariance_extjax` has a version number that synchronizes with + the main `openequivariance` package; ensure the two packages stay in sync. **Fixed**: - `torch.to()` is now called when either `TensorProduct` diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 24ad924f..35c4b542 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -4,6 +4,7 @@ import numpy as np from pathlib import Path +import warnings from importlib.metadata import version from openequivariance.core.e3nn_lite import ( @@ -81,10 +82,14 @@ def torch_ext_so_path(): import openequivariance_extjax import openequivariance.jax as jax - assert openequivariance_extjax.__version__ == __version__, ( - f"openequivariance_extjax version {openequivariance_extjax.__version__} does not match " - f"openequivariance version {__version__}. Ensure both are the same." - ) + # TODO-someday: enable + # extjax_version = version("openequivariance_extjax") + # if extjax_version != __version__: + # warnings.warn( + # f"openequivariance_extjax version {extjax_version} does not match " + # f"openequivariance version {__version__}. Ensure both versions match." + # ) + except Exception as e: error = e diff --git a/tests/example_test.py b/tests/example_test.py index e8d23cb7..bf51d3ed 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -2,11 +2,6 @@ import os -@pytest.fixture -def with_jax(request): - return request.config.getoption("--jax") - - def test_tutorial_torch(with_jax): if with_jax: pytest.skip("Skipping PyTorch tutorial when testing JAX") From d78274cbf643b06a82c82526af561e8dca0f726c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Feb 2026 23:13:10 -0800 Subject: [PATCH 5/5] Removed some extraneous comments. --- tests/batch_test.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index 30e20f31..55396ded 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -325,7 +325,6 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem): """Test that calling .to() on parent module properly converts TensorProduct submodule""" parent, problem = parent_module_and_problem - # Generate test inputs with the original dtype batch_size = 10 rng = np.random.default_rng(12345) device = "cuda" @@ -334,13 +333,11 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem): problem, batch_size, rng, input_dtype, device ) - # Run forward pass with original dtype output1 = parent(in1, in2, weights) assert output1.dtype == in1.dtype, ( f"Expected output dtype {in1.dtype}, got {output1.dtype}" ) - # Convert parent module to different dtype switch_map = { np.float32: torch.float64, np.float64: torch.float32, @@ -348,12 +345,10 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem): target_dtype = switch_map[problem.irrep_dtype] parent.to(target_dtype) - # Generate new test inputs with the target dtype in1_new, in2_new, weights_new = self._make_inputs( problem, batch_size, rng, target_dtype, device ) - # This should work but will fail without _apply implementation output2 = parent(in1_new, in2_new, weights_new) assert output2.dtype == target_dtype, ( f"Expected output dtype {target_dtype}, got {output2.dtype}"