diff --git a/CHANGELOG.md b/CHANGELOG.md index 14cd1a7..81e7084 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,30 @@ ## Latest Changes +### v0.6.0 (2025-02-23) +OpenEquivariance v0.6.0 brings long-needed improvements to the +PyTorch frontend. We strongly encourage all users to upgrade +to PyTorch 2.10 and OEQ v0.6.0. + +**Added**: +- OpenEquivariance triggers a build of the CUDA extension module + at `pip` install time and will use this precompiled extension if + the user has PyTorch >=2.10 installed. If PyTorch <2.10 is installed, + the JIT-compiled extension is used instead. +- PyTorch ABI support for C++ backend, using new features in PyTorch + 2.10 to support stable, forward-compatible ahead-of-time + extensions. +- Dropped support for TorchBind classes and a new kernel cache in its + place, which greatly improves flexibility for automatic mixed precision + and AOTI compilation. An inference test in C++ is included. +- `openequivariance_extjax` has a version number that synchronizes with + the main `openequivariance` package; ensure the two packages stay in sync. + +**Fixed**: +- `torch.to()` is now called when either `TensorProduct` + or `TensorProductConv` is a submodule of another PyTorch + module. + + ### v0.5.4 (2025-02-01) Improvements to JAX frontend. diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 7fc0b0f..35c4b54 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -4,6 +4,7 @@ import numpy as np from pathlib import Path +import warnings from importlib.metadata import version from openequivariance.core.e3nn_lite import ( @@ -80,6 +81,15 @@ def torch_ext_so_path(): try: import openequivariance_extjax import openequivariance.jax as jax + + # TODO-someday: enable + # extjax_version = version("openequivariance_extjax") + # if extjax_version != __version__: + # warnings.warn( + # f"openequivariance_extjax version {extjax_version} does not match " + # f"openequivariance version {__version__}. Ensure both versions match." + # ) + except Exception as e: error = e diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 35640b2..254da41 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -2,9 +2,13 @@ from openequivariance import TPProblem from openequivariance._torch import extlib import torch -from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum from openequivariance.benchmark.logging_utils import getLogger -from openequivariance._torch.utils import reorder_torch, string_to_tensor +from openequivariance._torch.utils import ( + reorder_torch, + string_to_tensor, + enum_to_torch_dtype, +) from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np @@ -66,6 +70,27 @@ def to(self, *args, **kwargs): torch.nn.Module.to(self, *args, **kwargs) return self + def _apply(self, fn, recurse=True): + if getattr(self, "_applying", False): + return super()._apply(fn, recurse) + + problem: TPProblem = self.input_args["problem"] + irrep_dtype = problem.irrep_dtype + + if irrep_dtype in dtype_to_enum: + irrep_dtype = dtype_to_enum[irrep_dtype] + + current_dtype = enum_to_torch_dtype[irrep_dtype] + dummy = torch.tensor(0.0, dtype=current_dtype) + result = fn(dummy) + + if result.dtype != current_dtype: + self._applying = True + self.to(result.dtype) + self._applying = False + + return super()._apply(fn, recurse) + def __getstate__(self): return self.input_args diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index e1c0e74..3093115 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -16,10 +16,11 @@ from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem -from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum from openequivariance._torch.utils import ( reorder_torch, string_to_tensor, + enum_to_torch_dtype, ) from openequivariance.benchmark.logging_utils import getLogger @@ -109,6 +110,27 @@ def to(self, *args, **kwargs): torch.nn.Module.to(self, *args, **kwargs) return self + def _apply(self, fn, recurse=True): + if getattr(self, "_applying", False): + return super()._apply(fn, recurse) + + problem: TPProblem = self.input_args["problem"] + irrep_dtype = problem.irrep_dtype + + if irrep_dtype in dtype_to_enum: + irrep_dtype = dtype_to_enum[irrep_dtype] + + current_dtype = enum_to_torch_dtype[irrep_dtype] + dummy = torch.tensor(0.0, dtype=current_dtype) + result = fn(dummy) + + if result.dtype != current_dtype: + self._applying = True + self.to(result.dtype) + self._applying = False + + return super()._apply(fn, recurse) + def __getstate__(self): return self.input_args diff --git a/tests/batch_test.py b/tests/batch_test.py index 788950a..55396de 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -19,6 +19,11 @@ import torch +@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") +def dtype(request): + return request.param + + class TPCorrectness: def thresh(self, direction): return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -31,18 +36,10 @@ def check_result(self, result, fieldname): f"{fieldname} observed error={error:.5f} >= {thresh}" ) - @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") - def dtype(self, request): - return request.param - @pytest.fixture(scope="class") def extra_tp_constructor_args(self): return {} - @pytest.fixture(scope="class") - def with_jax(self, request): - return request.config.getoption("--jax") - @pytest.fixture(scope="class") def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): cls = oeq.TensorProduct @@ -274,3 +271,85 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): } tp.to(switch_map[problem.irrep_dtype]) return tp, tp.config + + +class TestTorchToSubmodule: + """Test that TensorProduct works correctly as a submodule when parent's .to() is called""" + + @pytest.fixture(scope="class") + def parent_module_and_problem(self, dtype, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + problem = mace_problems()[0].clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + + class ParentModule(torch.nn.Module): + def __init__(self, problem): + super().__init__() + self.tp = oeq.TensorProduct(problem) + + def forward(self, x, y, w): + return self.tp(x, y, w) + + parent = ParentModule(problem) + return parent, problem + + def _problem_dtype(self, problem): + return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64 + + def _make_inputs(self, problem, batch_size, rng, dtype, device): + in1 = torch.tensor( + rng.uniform(size=(batch_size, problem.irreps_in1.dim)), + dtype=dtype, + device=device, + ) + in2 = torch.tensor( + rng.uniform(size=(batch_size, problem.irreps_in2.dim)), + dtype=dtype, + device=device, + ) + weights_size = ( + (problem.weight_numel,) + if problem.shared_weights + else (batch_size, problem.weight_numel) + ) + weights = torch.tensor( + rng.uniform(size=weights_size), + dtype=dtype, + device=device, + ) + return in1, in2, weights + + def test_submodule_dtype_conversion(self, parent_module_and_problem): + """Test that calling .to() on parent module properly converts TensorProduct submodule""" + parent, problem = parent_module_and_problem + + batch_size = 10 + rng = np.random.default_rng(12345) + device = "cuda" + input_dtype = self._problem_dtype(problem) + in1, in2, weights = self._make_inputs( + problem, batch_size, rng, input_dtype, device + ) + + output1 = parent(in1, in2, weights) + assert output1.dtype == in1.dtype, ( + f"Expected output dtype {in1.dtype}, got {output1.dtype}" + ) + + switch_map = { + np.float32: torch.float64, + np.float64: torch.float32, + } + target_dtype = switch_map[problem.irrep_dtype] + parent.to(target_dtype) + + in1_new, in2_new, weights_new = self._make_inputs( + problem, batch_size, rng, target_dtype, device + ) + + output2 = parent(in1_new, in2_new, weights_new) + assert output2.dtype == target_dtype, ( + f"Expected output dtype {target_dtype}, got {output2.dtype}" + ) diff --git a/tests/conftest.py b/tests/conftest.py index 323de86..4a51566 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import pytest os.environ["JAX_ENABLE_X64"] = "True" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" @@ -12,3 +13,8 @@ def pytest_addoption(parser): default=False, help="Test the JAX frontend instead of PyTorch", ) + + +@pytest.fixture(scope="session") +def with_jax(request): + return request.config.getoption("--jax") diff --git a/tests/conv_test.py b/tests/conv_test.py index 9c6bb4c..50b6376 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -16,6 +16,30 @@ ) +@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="module") +def graph(request): + download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" + filename = request.param + + graph = None + with tempfile.NamedTemporaryFile() as temp_file: + urllib.request.urlretrieve(download_prefix + filename, temp_file.name) + graph = load_graph(temp_file.name) + + # graph = load_graph("data/1drf_radius3.5.pickle") + return graph + + +@pytest.fixture(scope="module") +def with_jax(request): + return request.config.getoption("--jax") + + class ConvCorrectness: def thresh(self, direction): return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -28,33 +52,10 @@ def check_result(self, result, fieldname): f"{fieldname} observed error={error:.5f} >= {thresh}" ) - @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") - def dtype(self, request): - return request.param - - @pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="class") - def graph(self, request): - download_prefix = ( - "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" - ) - filename = request.param - - graph = None - with tempfile.NamedTemporaryFile() as temp_file: - urllib.request.urlretrieve(download_prefix + filename, temp_file.name) - graph = load_graph(temp_file.name) - - # graph = load_graph("data/1drf_radius3.5.pickle") - return graph - @pytest.fixture(scope="class") def extra_conv_constructor_args(self): return {} - @pytest.fixture(scope="class") - def with_jax(self, request): - return request.config.getoption("--jax") - @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class") def conv_object(self, request, problem, extra_conv_constructor_args, with_jax): cls = oeq.TensorProductConv @@ -281,3 +282,103 @@ def conv_object(self, request, problem, extra_conv_constructor_args): **extra_conv_constructor_args, ) return module.to(switch_map[problem.irrep_dtype]) + + +class TestTorchToSubmodule: + """Test that TensorProductConv works as a submodule when parent's .to() is called""" + + @pytest.fixture(params=["atomic", "deterministic"], scope="class") + def parent_module_and_problem(self, request, dtype, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + problem = mace_problems()[0].clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + deterministic = request.param == "deterministic" + + class ParentModule(torch.nn.Module): + def __init__(self, problem, deterministic): + super().__init__() + self.conv = oeq.TensorProductConv(problem, deterministic=deterministic) + + def forward(self, x, y, w, rows, cols, sender_perm=None): + return self.conv(x, y, w, rows, cols, sender_perm) + + parent = ParentModule(problem, deterministic) + return parent, problem + + def _problem_dtype(self, problem): + return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64 + + def _make_inputs(self, problem, graph, rng, dtype, device, deterministic): + node_count = graph.node_count + nnz = graph.nnz + + in1 = torch.tensor( + rng.uniform(size=(node_count, problem.irreps_in1.dim)), + dtype=dtype, + device=device, + ) + in2 = torch.tensor( + rng.uniform(size=(nnz, problem.irreps_in2.dim)), + dtype=dtype, + device=device, + ) + weights_size = ( + (problem.weight_numel,) + if problem.shared_weights + else (nnz, problem.weight_numel) + ) + weights = torch.tensor( + rng.uniform(size=weights_size), + dtype=dtype, + device=device, + ) + + rows = torch.tensor(graph.rows, device=device) + cols = torch.tensor(graph.cols, device=device) + sender_perm = ( + torch.tensor(graph.transpose_perm, device=device) if deterministic else None + ) + return in1, in2, weights, rows, cols, sender_perm + + def test_submodule_dtype_conversion(self, parent_module_and_problem, graph): + parent, problem = parent_module_and_problem + device = "cuda" + + rng = np.random.default_rng(12345) + input_dtype = self._problem_dtype(problem) + in1, in2, weights, rows, cols, sender_perm = self._make_inputs( + problem, + graph, + rng, + input_dtype, + device, + parent.conv.deterministic, + ) + + output1 = parent(in1, in2, weights, rows, cols, sender_perm) + assert output1.dtype == input_dtype, ( + f"Expected output dtype {input_dtype}, got {output1.dtype}" + ) + + switch_map = { + np.float32: torch.float64, + np.float64: torch.float32, + } + target_dtype = switch_map[problem.irrep_dtype] + parent.to(target_dtype) + + in1_new, in2_new, weights_new, rows, cols, sender_perm = self._make_inputs( + problem, + graph, + rng, + target_dtype, + device, + parent.conv.deterministic, + ) + + output2 = parent(in1_new, in2_new, weights_new, rows, cols, sender_perm) + assert output2.dtype == target_dtype, ( + f"Expected output dtype {target_dtype}, got {output2.dtype}" + ) diff --git a/tests/example_test.py b/tests/example_test.py index e8d23cb..bf51d3e 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -2,11 +2,6 @@ import os -@pytest.fixture -def with_jax(request): - return request.config.getoption("--jax") - - def test_tutorial_torch(with_jax): if with_jax: pytest.skip("Skipping PyTorch tutorial when testing JAX") diff --git a/tests/vmap_test.py b/tests/vmap_test.py index e165cf2..4aad894 100644 --- a/tests/vmap_test.py +++ b/tests/vmap_test.py @@ -2,11 +2,6 @@ import os -@pytest.fixture -def with_jax(request): - return request.config.getoption("--jax") - - @pytest.fixture def ctx(with_jax): if not with_jax: