diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f6e4d94..723857d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ ## Latest Changes +### v0.6.5 (2026-03-22) +This release brings `ir_mul` layout support for +OpenEquivariance. Pass the parameter +`layout='ir_mul'` to any `TPProblem` instance to use +a transposed layout for the input and output +irreps. To transpose input and output irreps use +`oeq.transpose_irreps` or `oeq.jax.transpose_irreps`; +see our API page for usage details. + ### v0.6.4 (2026-03-05) Bugfix: added missing MLIR lowerings for a pair of JAX primitives (thanks @teddykoker!) diff --git a/docs/api.rst b/docs/api.rst index c21b918f..15b1aec9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -30,6 +30,8 @@ PyTorch API :undoc-members: :exclude-members: name +.. autofunction:: openequivariance.transpose_irreps + .. autofunction:: openequivariance.torch_to_oeq_dtype .. autofunction:: openequivariance.torch_ext_so_path @@ -54,7 +56,9 @@ breaking the PyTorch version of OpenEquivariance. .. autoclass:: openequivariance.jax.TensorProductConv :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn :undoc-members: - :exclude-members: + :exclude-members: + +.. autofunction:: openequivariance.jax.transpose_irreps Common API --------------------- diff --git a/docs/conf.py b/docs/conf.py index 540cf37e..70dd285e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,6 +38,8 @@ "openequivariance._torch.extlib", "openequivariance.jax.extlib", "openequivariance_extjax", + "openequivariance.jax.jvp.tp_prim", + "openequivariance.jax.jvp.conv_prim", "jinja2", "numpy", ] diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 35c4b542..1a3681f5 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -37,6 +37,7 @@ def _check_package_editable(): from openequivariance._torch.TensorProduct import TensorProduct from openequivariance._torch.TensorProductConv import TensorProductConv + from openequivariance._torch.utils import transpose_irreps from openequivariance._torch.extlib import ( torch_ext_so_path as torch_ext_so_path_internal, @@ -111,4 +112,5 @@ def TensorProductConv(*args, **kwargs): "_check_package_editable", "torch_ext_so_path", "jax", + "transpose_irreps", ] diff --git a/openequivariance/openequivariance/_torch/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py index 33b8db12..1867e08a 100644 --- a/openequivariance/openequivariance/_torch/CUETensorProduct.py +++ b/openequivariance/openequivariance/_torch/CUETensorProduct.py @@ -6,13 +6,12 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.benchmark.tpp_creation_utils import ( +from openequivariance.core.logging import getLogger +from openequivariance.benchmark.problems import ( ChannelwiseTPP, FullyConnectedTPProblem, SingleInstruction, ) -from openequivariance.core.utils import count_cg_non_zero os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1" @@ -235,57 +234,6 @@ def benchmark_backward( kernel_names=self.kernel_names, ) - # Copied over from loop unroller to match arithmetic intensity on roofline plots - def calculate_flops_forward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_forward(batch_size) - else: - tpp = self.config - flop_count = { - "CG_decomposition": 0, - "linear_combination": 0, - "outer_products": 0, - } - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - flop_count["linear_combination"] += ( - (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 - ) - - flop_count["CG_decomposition"] *= 3 * batch_size - flop_count["linear_combination"] *= ( - batch_size # Weights do not require FMA here - ) - flop_count["total"] = sum(flop_count.values()) - return flop_count - - def calculate_flops_backward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_backward(batch_size) - else: - tpp = self.config - flop_count = {"backward": 0} - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - - flop_count["backward"] *= 9 * batch_size - flop_count["total"] = sum(flop_count.values()) - return flop_count - @staticmethod def name(): return "CUETensorProduct" diff --git a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py index 067a7e6b..df696ad6 100644 --- a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py @@ -11,7 +11,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 254da414..c18b1231 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -3,7 +3,7 @@ from openequivariance._torch import extlib import torch from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.utils import ( reorder_torch, string_to_tensor, diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 30931151..c1087a63 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -23,7 +23,7 @@ enum_to_torch_dtype, ) -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv logger = getLogger() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index be4113ec..ef17724b 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -8,7 +8,7 @@ import torch -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.logging import getLogger oeq_root = str(Path(__file__).parent.parent.parent) diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 74d5a010..b3bc986b 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -2,6 +2,7 @@ import numpy as np from types import MappingProxyType from openequivariance.core.utils import DTypeEnum +from openequivariance.core.e3nn_lite import Irreps def reorder_helper(schedule, weights_in, direction, has_batch_dim): @@ -75,3 +76,72 @@ def string_to_tensor(text: str) -> torch.Tensor: result = torch.tensor(np_bytes, device="cpu") result.requires_grad = False return result + + +def transpose_irreps( + array: torch.Tensor, + irreps: Irreps, + src_layout: str, + dst_layout: str, +) -> torch.Tensor: + r""" + Transpose irrep-packed feature tensors between ``mul_ir`` and ``ir_mul`` layouts. + + The function operates on the trailing feature dimension and preserves all leading + batch dimensions. It uses only differentiable PyTorch tensor operations, so gradients + propagate through the transpose. + + :param array: Input feature tensor with shape ``[..., irreps.dim]``. + :param irreps: Irreps specification describing how the trailing feature dimension + is partitioned into irrep blocks. + :param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + :param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + + + :returns: Tensor in ``dst_layout`` with the same shape, dtype, and device as ``array``. + If ``src_layout == dst_layout``, returns a clone of ``array``. + + + :raises TypeError: If ``array`` is not a ``torch.Tensor``. + :raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of + ``"mul_ir"`` or ``"ir_mul"``. + """ + if src_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported src_layout: {src_layout}") + if dst_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported dst_layout: {dst_layout}") + + if not isinstance(array, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(array)}") + + out = torch.empty_like(array) + + if src_layout == dst_layout: + out.copy_(array) + return out + + slices = irreps.slices() + for ir_idx, mul_ir in enumerate(irreps): + mul = mul_ir.mul + dim = mul_ir.ir.dim + seg = slices[ir_idx] + block = array[..., seg.start : seg.stop] + + if src_layout == "ir_mul" and dst_layout == "mul_ir": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], dim, mul) + .transpose(-1, -2) + .reshape(*block.shape[:-1], mul * dim) + ) + elif src_layout == "mul_ir" and dst_layout == "ir_mul": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], mul, dim) + .transpose(-1, -2) + .reshape(*block.shape[:-1], dim * mul) + ) + else: + raise ValueError( + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" + ) + + return out diff --git a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index debcc65b..8e2ac98a 100644 --- a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -6,7 +6,12 @@ import numpy as np import openequivariance as oeq -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.correctness import ( + correctness_backward_conv, + correctness_double_backward_conv, + correctness_forward_conv, +) +from openequivariance.core.logging import getLogger from openequivariance.core.ConvolutionBase import CoordGraph from openequivariance.benchmark.benchmark_utils import NpEncoder @@ -90,7 +95,8 @@ def run( if direction == "forward": if correctness: - correctness = conv.test_correctness_forward( + correctness = correctness_forward_conv( + conv, graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, @@ -105,7 +111,8 @@ def run( if direction == "backward": if correctness: - correctness = conv.test_correctness_backward( + correctness = correctness_backward_conv( + conv, graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, @@ -120,8 +127,9 @@ def run( if direction == "double_backward": if correctness: - correctness = conv.test_correctness_double_backward( - self.graph, + correctness = correctness_double_backward_conv( + conv, + graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, reference_implementation=self.reference_impl, diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index 37d20c46..72ada84d 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -10,9 +10,9 @@ from openequivariance._torch.extlib import DeviceProp from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.core.logging import getLogger, bcolors from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.correctness_utils import ( +from openequivariance.benchmark.correctness import ( correctness_forward, correctness_backward, correctness_double_backward, diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index 68dc6f9f..62687dbb 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -1,21 +1,22 @@ import json import numpy as np -from openequivariance.benchmark.random_buffer_utils import ( +from openequivariance.benchmark.test_buffers import ( get_random_buffers_forward, get_random_buffers_backward, get_random_buffers_double_backward, ) -from openequivariance.benchmark.perf_metrics_utils import ( - calculate_minimum_flops_forward, - calculate_minimum_memory_streamed_forward, - calculate_minimum_memory_streamed_backward, +from openequivariance.benchmark.metrics import ( + flops_forward, + flops_backward, + memory_streamed_forward, + memory_streamed_backward, ) from openequivariance.core.utils import calculate_total_nnz from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.core.logging import getLogger, bcolors logger = getLogger() @@ -110,24 +111,10 @@ def benchmark_forward( time_millis = np.full(shape=num_iter, fill_value=-1) # FLOPS - try: - flops = tp.calculate_flops_forward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual flop count not calculated, so minimum values are being used" - ) - flops = calculate_minimum_flops_forward(problem, batch_size=batch_size) + flops = flops_forward(problem, batch_size=batch_size) # DATA - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual memory streamed not calculated, so minimum values are being used" - ) - memory_streamed = calculate_minimum_memory_streamed_forward( - problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_forward(problem, batch_size=batch_size) result |= calculate_performance_statistics( problem=problem, @@ -181,29 +168,9 @@ def benchmark_backward( ) time_millis = np.full(shape=num_iter, fill_value=-1) - try: - flops = tp.calculate_flops_backward(batch_size=batch_size) - except NotImplementedError: - try: - flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) - logger.warning( - "Actual flops was not calculated, so minimum values are being used" - ) - except NotImplementedError: - logger.warning( - "Minimum Backwards flops calculations are not implemented, -1 is a placeholder" - ) - flops = {"total": -1} + flops = flops_backward(tpp=problem, batch_size=batch_size) - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual memory streamed was not calculated, so minimum values are being" - ) - memory_streamed = calculate_minimum_memory_streamed_backward( - tpp=problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_backward(tpp=problem, batch_size=batch_size) result |= calculate_performance_statistics( problem=problem, @@ -258,29 +225,9 @@ def benchmark_double_backward( ) time_millis = np.full(shape=num_iter, fill_value=-1) - try: - flops = tp.calculate_flops_backward(batch_size=batch_size) - except NotImplementedError: - try: - flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) - logger.warning( - "Actual flops was not calculated, so minimum values are being used" - ) - except NotImplementedError: - logger.warning( - "Minimum Backwards flops calculations are not implemented, -1 is a placeholder" - ) - flops = {"total": -1} + flops = flops_backward(tpp=problem, batch_size=batch_size) - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual memory streamed was not calculated, so minimum values are being" - ) - memory_streamed = calculate_minimum_memory_streamed_backward( - tpp=problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_backward(tpp=problem, batch_size=batch_size) result |= calculate_performance_statistics( problem=problem, diff --git a/openequivariance/openequivariance/benchmark/correctness.py b/openequivariance/openequivariance/benchmark/correctness.py new file mode 100644 index 00000000..45c45c4a --- /dev/null +++ b/openequivariance/openequivariance/benchmark/correctness.py @@ -0,0 +1,638 @@ +import copy +from typing import Optional, Union + +import numpy as np +import numpy.linalg as la + +from openequivariance._torch.CUETensorProduct import CUETensorProduct +from openequivariance.core.logging import bcolors, getLogger +from openequivariance.benchmark.test_buffers import ( + get_random_buffers_backward_conv, + get_random_buffers_backward, + get_random_buffers_double_backward_conv, + get_random_buffers_double_backward, + get_random_buffers_forward_conv, + get_random_buffers_forward, +) +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.utils import transpose_irrep_layout + +logger = getLogger() + + +def check_similiarity( + name: str, + to_check: np.ndarray, + ground_truth: np.ndarray, + correctness_threshold: float, +): + result = {} + if to_check.shape != ground_truth.shape: + result["shape_match"] = False + result["diff_Linf_norm"] = np.inf + result["pass"] = False + logger.error( + f"{bcolors.FAIL}Ground truth {name} shape does not match input! {to_check.shape=}, {ground_truth.shape=} {bcolors.ENDC}" + ) + else: + result["shape_match"] = True + diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf)) + result["diff_Linf_norm"] = diff_Linf_norm + result["pass"] = bool(diff_Linf_norm < correctness_threshold) + if result["pass"]: + logger.info( + f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" + ) + else: + logger.error( + f"{bcolors.FAIL}{name} correctness check fail! {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" + ) + + return result + + +def instantiate_implementation( + implementation: Union[type[TensorProductBase], TensorProductBase], + problem: TPProblem, +): + if isinstance(implementation, type): + test_tp = implementation(problem) + else: + test_tp = implementation + + if not isinstance(test_tp, TensorProductBase): + raise TypeError( + f"test_implementation must be a TensorProductBase or a subclass, got {type(implementation)}" + ) + + return test_tp + + +def correctness_forward( + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +) -> dict: + if reference_implementation is None: + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + + reference_implementation = E3NNTensorProduct + + result = {"thresh": correctness_threshold, "batch_size": batch_size} + in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed) + outputs = [] + + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = i == 0 + tp = instantiate_implementation(impl, problem) + uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct) + run_in1, run_in2, run_weights, run_out = [ + buf.copy() for buf in (in1, in2, weights, out) + ] + + if problem.shared_weights and uses_cue: + run_weights = run_weights[np.newaxis, :] + + # Transpose inputs, if necessary, for the test implementation + if is_test_impl: + run_in1, run_in2 = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + (run_in1, run_in2), (problem.irreps_in1, problem.irreps_in2) + ) + ] + + tp.forward_cpu( + L1_in=run_in1, L2_in=run_in2, L3_out=run_out, weights=run_weights + ) + + if is_test_impl: + run_out = transpose_irrep_layout( + run_out, problem.irreps_out, tp.config.layout, "mul_ir" + ) + + outputs.append(run_out) + + for name, to_check, ground_truth in [("output", outputs[0], outputs[1])]: + result[name] = check_similiarity( + name, to_check, ground_truth, correctness_threshold + ) + + return result + + +def correctness_backward( + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +) -> dict: + if reference_implementation is None: + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + + reference_implementation = E3NNTensorProduct + + result = {"thresh": correctness_threshold, "batch_size": batch_size} + + in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = ( + get_random_buffers_backward(problem, batch_size, prng_seed) + ) + + grads = [] + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = i == 0 + tp = instantiate_implementation(impl, problem) + + ( + run_in1, + run_in2, + run_L3_grad, + run_weights, + run_weights_grad, + run_in1_grad, + run_in2_grad, + ) = [ + buf.copy() + for buf in (in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad) + ] + + uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct) + if problem.shared_weights and uses_cue: + run_weights = run_weights[np.newaxis, :] + run_weights_grad = run_weights_grad[np.newaxis, :] + + if is_test_impl: + run_in1, run_in2, run_L3_grad = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + (run_in1, run_in2, run_L3_grad), + (problem.irreps_in1, problem.irreps_in2, problem.irreps_out), + ) + ] + + tp.backward_cpu( + L1_in=run_in1, + L1_grad=run_in1_grad, + L2_in=run_in2, + L2_grad=run_in2_grad, + L3_grad=run_L3_grad, + weights=run_weights, + weights_grad=run_weights_grad, + ) + + if is_test_impl: + run_in1_grad, run_in2_grad = [ + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") + for arr, irreps in zip( + (run_in1_grad, run_in2_grad), + (problem.irreps_in1, problem.irreps_in2), + ) + ] + + if problem.shared_weights: + run_weights_grad = run_weights_grad.squeeze() + + grads.append((run_weights_grad, run_in1_grad, run_in2_grad)) + + weight_threshold = ( + correctness_threshold * batch_size + if problem.shared_weights + else correctness_threshold + ) + + for name, to_check, ground_truth, threshold in [ + ("weight_grad", grads[0][0], grads[1][0], weight_threshold), + ("in1_grad", grads[0][1], grads[1][1], correctness_threshold), + ("in2_grad", grads[0][2], grads[1][2], correctness_threshold), + ]: + result[name] = check_similiarity(name, to_check, ground_truth, threshold) + + return result + + +def correctness_double_backward( + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +): + global torch + import torch + + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( + get_random_buffers_double_backward( + problem, batch_size=batch_size, prng_seed=prng_seed + ) + ) + + if reference_implementation is None: + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + + reference_implementation = E3NNTensorProduct + + result = {"thresh": correctness_threshold, "batch_size": batch_size} + + tensors = [] + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = i == 0 + tp = instantiate_implementation(impl, problem) + weights_reordered = tp.reorder_weights_from_e3nn( + weights, has_batch_dim=not problem.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not problem.shared_weights + ) + + if impl == CUETensorProduct and problem.shared_weights: + weights_reordered = weights_reordered[np.newaxis, :] + + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad) + ] + + if is_test_impl: + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), + ( + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + ), + ) + ] + + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( + db_in1, + db_in2, + db_out_grad, + weights_reordered, + weights_dgrad_reordered, + db_in1_dgrad, + db_in2_dgrad, + ) + + if is_test_impl: + out_dgrad, in1_grad, in2_grad = [ + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") + for arr, irreps in zip( + (out_dgrad, in1_grad, in2_grad), + (problem.irreps_out, problem.irreps_in1, problem.irreps_in2), + ) + ] + + tensors.append( + ( + out_dgrad, + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not problem.shared_weights + ), + ) + ) + + for name, to_check, ground_truth in [ + ("output_double_grad", tensors[0][0], tensors[1][0]), + ("in1_grad", tensors[0][1], tensors[1][1]), + ("in2_grad", tensors[0][2], tensors[1][2]), + ("weights_grad", tensors[0][3], tensors[1][3]), + ]: + result[name] = check_similiarity( + name, to_check, ground_truth, correctness_threshold + ) + + return result + + +def correctness_forward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + check_reproducible=True, + high_precision_ref=False, +): + global torch + import torch + + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + result = {"thresh": thresh} + + in1, in2, weights, out = get_random_buffers_forward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + reference_config = conv.config + if high_precision_ref: + reference_config = copy.deepcopy(conv.config) + reference_config.irrep_dtype = np.float64 + reference_config.weight_dtype = np.float64 + outputs = [] + + for i, impl in enumerate([conv, reference_implementation]): + is_test_impl = i == 0 + tp = impl if is_test_impl else impl(reference_config) + + run_in1, run_in2, run_weights, run_out = [ + buf.copy() for buf in (in1, in2, weights, out) + ] + + if not is_test_impl and high_precision_ref: + run_in1, run_in2, run_weights, run_out = [ + np.array(el, dtype=np.float64) + for el in (run_in1, run_in2, run_weights, run_out) + ] + + if is_test_impl: + run_in1, run_in2 = [ + transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout) + for arr, irreps in zip( + (run_in1, run_in2), + (conv.config.irreps_in1, conv.config.irreps_in2), + ) + ] + conv.forward_cpu( + L1_in=run_in1, + L2_in=run_in2, + weights=run_weights, + L3_out=run_out, + graph=graph, + ) + + run_out = transpose_irrep_layout( + run_out, conv.config.irreps_out, conv.config.layout, "mul_ir" + ) + else: + args = { + "L1_in": run_in1, + "L2_in": run_in2, + "weights": run_weights, + "rows": graph.rows, + "cols": graph.cols, + } + + if tp.deterministic: + args["transpose_perm"] = graph.transpose_perm + + for key in args: + args[key] = torch.tensor(args[key], device="cuda") + + run_out[:] = tp.forward(**args).cpu().numpy() + + outputs.append(run_out) + + test_out, ref_out = outputs[0], outputs[1] + + for name, to_check, ground_truth in [("output", ref_out, test_out)]: + result[name] = check_similiarity(name, to_check, ground_truth, thresh) + + if check_reproducible: + num_trials = 5 + for name in ["output"]: + result[name]["num_reproducibility_trials"] = num_trials + result[name]["bitwise_reproducible"] = True + + for _ in range(num_trials): + repeated_run = out.copy() + rep_in1, rep_in2, rep_weights = [buf.copy() for buf in (in1, in2, weights)] + rep_in1, rep_in2 = [ + transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout) + for arr, irreps in zip( + (rep_in1, rep_in2), + (conv.config.irreps_in1, conv.config.irreps_in2), + ) + ] + conv.forward_cpu( + L1_in=rep_in1, + L2_in=rep_in2, + weights=rep_weights, + L3_out=repeated_run, + graph=graph, + ) + + repeated_run = transpose_irrep_layout( + repeated_run, conv.config.irreps_out, conv.config.layout, "mul_ir" + ) + + result["output"]["bitwise_reproducible"] = bool( + result["output"]["bitwise_reproducible"] + and np.all(repeated_run == test_out) + ) + + return result + + +def correctness_backward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, +): + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + result = {"thresh": thresh} + + buffers = get_random_buffers_backward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + reference_problem = conv.config + + if high_precision_ref: + reference_problem = copy.deepcopy(conv.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + grads = [] + for i, impl in enumerate([conv, reference_implementation]): + is_test_impl = i == 0 + tp = impl if is_test_impl else impl(reference_problem) + + ( + run_in1, + run_in2, + run_out_grad, + run_weights, + run_weights_grad, + run_in1_grad, + run_in2_grad, + ) = [buf.copy() for buf in buffers] + + if not is_test_impl and high_precision_ref: + ( + run_in1, + run_in2, + run_out_grad, + run_weights, + run_weights_grad, + run_in1_grad, + run_in2_grad, + ) = [np.array(el, dtype=np.float64) for el in buffers] + + if is_test_impl: + run_in1, run_in2, run_out_grad = [ + transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout) + for arr, irreps in zip( + (run_in1, run_in2, run_out_grad), + ( + conv.config.irreps_in1, + conv.config.irreps_in2, + conv.config.irreps_out, + ), + ) + ] + + tp.backward_cpu( + L1_in=run_in1, + L1_grad=run_in1_grad, + L2_in=run_in2, + L2_grad=run_in2_grad, + L3_grad=run_out_grad, + weights=run_weights, + weights_grad=run_weights_grad, + graph=graph, + ) + + if is_test_impl: + run_in1_grad, run_in2_grad = [ + transpose_irrep_layout(arr, irreps, conv.config.layout, "mul_ir") + for arr, irreps in zip( + (run_in1_grad, run_in2_grad), + (conv.config.irreps_in1, conv.config.irreps_in2), + ) + ] + + grads.append((run_weights_grad, run_in1_grad, run_in2_grad)) + + for name, to_check, ground_truth, threshold in [ + ("weight_grad", grads[0][0], grads[1][0], thresh), + ("in1_grad", grads[0][1], grads[1][1], thresh), + ("in2_grad", grads[0][2], grads[1][2], thresh), + ]: + result[name] = check_similiarity(name, to_check, ground_truth, threshold) + + return result + + +def correctness_double_backward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, +): + buffers = get_random_buffers_double_backward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + reference_problem = conv.config + if high_precision_ref: + reference_problem = copy.deepcopy(conv.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + + reference_tp = reference_implementation(reference_problem, torch_op=True) + + result = {"thresh": thresh} + tensors = [] + for i, tp in enumerate([conv, reference_tp]): + is_test_impl = i == 0 + buffers_copy = [buf.copy() for buf in buffers] + + if i == 1 and high_precision_ref: + buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] + + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( + buffers_copy + ) + + weights_reordered = tp.reorder_weights_from_e3nn( + weights, not tp.config.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, not tp.config.shared_weights + ) + + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad) + ] + if is_test_impl: + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), + ( + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + ), + ) + ] + + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( + db_in1, + db_in2, + db_out_grad, + weights_reordered, + weights_dgrad_reordered, + db_in1_dgrad, + db_in2_dgrad, + graph, + ) + + if is_test_impl: + out_dgrad, in1_grad, in2_grad = [ + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") + for arr, irreps in zip( + (out_dgrad, in1_grad, in2_grad), + (tp.config.irreps_out, tp.config.irreps_in1, tp.config.irreps_in2), + ) + ] + + tensors.append( + ( + out_dgrad, + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not conv.config.shared_weights + ), + ) + ) + + for name, to_check, ground_truth in [ + ("output_grad", tensors[0][0], tensors[1][0]), + ("in1_grad", tensors[0][1], tensors[1][1]), + ("in2_grad", tensors[0][2], tensors[1][2]), + ("weights_grad", tensors[0][3], tensors[1][3]), + ]: + result[name] = check_similiarity(name, to_check, ground_truth, thresh) + + return result diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py deleted file mode 100644 index 788d209e..00000000 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ /dev/null @@ -1,255 +0,0 @@ -from typing import Optional, Union - -from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.core.e3nn_lite import TPProblem -from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.random_buffer_utils import ( - get_random_buffers_forward, - get_random_buffers_backward, - get_random_buffers_double_backward, -) - -from openequivariance.benchmark.logging_utils import getLogger, bcolors -import numpy as np -import numpy.linalg as la - -logger = getLogger() - - -def check_similiarity( - name: str, - to_check: np.ndarray, - ground_truth: np.ndarray, - correctness_threshold: float, -): - result = {} - if to_check.shape != ground_truth.shape: - result["shape_match"] = False - result["diff_Linf_norm"] = np.inf - result["pass"] = False - logger.error( - f"{bcolors.FAIL}Ground truth {name} shape does not match input! {to_check.shape=}, {ground_truth.shape=} {bcolors.ENDC}" - ) - else: - result["shape_match"] = True - diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf)) - result["diff_Linf_norm"] = diff_Linf_norm - result["pass"] = bool(diff_Linf_norm < correctness_threshold) - if result["pass"]: - logger.info( - f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" - ) - else: - logger.error( - f"{bcolors.FAIL}{name} correctness check fail! {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" - ) - - return result - - -def instantiate_implementation( - implementation: Union[type[TensorProductBase], TensorProductBase], - problem: TPProblem, -): - if isinstance(implementation, type): - test_tp = implementation(problem) - else: - test_tp = implementation - - if not isinstance(test_tp, TensorProductBase): - raise TypeError( - f"test_implementation must be a TensorProductBase or a subclass, got {type(implementation)}" - ) - - return test_tp - - -def correctness_forward( - problem: TPProblem, - test_implementation: Union[type[TensorProductBase], TensorProductBase], - reference_implementation: Optional[type[TensorProductBase]], - batch_size: int, - correctness_threshold: float, - prng_seed: int, -) -> dict: - if reference_implementation is None: - from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct - - reference_implementation = E3NNTensorProduct - - result = {"thresh": correctness_threshold, "batch_size": batch_size} - - in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed) - - # run reference - ref_tp = reference_implementation(problem) - - ref_out = out.copy() - ref_tp.forward_cpu( - L1_in=in1.copy(), L2_in=in2.copy(), L3_out=ref_out, weights=weights.copy() - ) - - weights_copy = weights.copy() - if problem.shared_weights and test_implementation == CUETensorProduct: - weights_copy = weights[np.newaxis, :] - - # run test - test_tp = instantiate_implementation(test_implementation, problem) - test_out = out.copy() - test_tp.forward_cpu( - L1_in=in1.copy(), L2_in=in2.copy(), L3_out=test_out, weights=weights_copy - ) - - for name, to_check, ground_truth in [("output", ref_out, test_out)]: - result[name] = check_similiarity( - name, to_check, ground_truth, correctness_threshold - ) - - return result - - -def correctness_backward( - problem: TPProblem, - test_implementation: Union[type[TensorProductBase], TensorProductBase], - reference_implementation: Optional[type[TensorProductBase]], - batch_size: int, - correctness_threshold: float, - prng_seed: int, -) -> dict: - if reference_implementation is None: - from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct - - reference_implementation = E3NNTensorProduct - - result = {"thresh": correctness_threshold, "batch_size": batch_size} - - # run reference - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = ( - get_random_buffers_backward(problem, batch_size, prng_seed) - ) - - ref_tp = reference_implementation(problem) - - ref_weights_grad = weights_grad.copy() - ref_in1_grad = in1_grad.copy() - ref_in2_grad = in2_grad.copy() - - ref_tp.backward_cpu( - L1_in=in1.copy(), - L1_grad=ref_in1_grad, - L2_in=in2.copy(), - L2_grad=ref_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), - weights_grad=ref_weights_grad, - ) - - # run test version - test_weights_grad = weights_grad.copy() - test_in1_grad = in1_grad.copy() - test_in2_grad = in2_grad.copy() - - weights_copy = weights.copy() - - if problem.shared_weights and test_implementation == CUETensorProduct: - weights_copy = weights[np.newaxis, :] - test_weights_grad = test_weights_grad[np.newaxis, :] - - test_tp = instantiate_implementation(test_implementation, problem) - test_tp.backward_cpu( - L1_in=in1.copy(), - L1_grad=test_in1_grad, - L2_in=in2.copy(), - L2_grad=test_in2_grad, - L3_grad=out_grad.copy(), - weights=weights_copy, - weights_grad=test_weights_grad, - ) - - weight_threshold = ( - correctness_threshold * batch_size - if problem.shared_weights - else correctness_threshold - ) - - if problem.shared_weights: - test_weights_grad = test_weights_grad.squeeze() - - for name, to_check, ground_truth, threshold in [ - ("weight_grad", test_weights_grad, ref_weights_grad, weight_threshold), - ("in1_grad", test_in1_grad, ref_in1_grad, correctness_threshold), - ("in2_grad", test_in2_grad, ref_in2_grad, correctness_threshold), - ]: - result[name] = check_similiarity(name, to_check, ground_truth, threshold) - - return result - - -def correctness_double_backward( - problem: TPProblem, - test_implementation: Union[type[TensorProductBase], TensorProductBase], - reference_implementation: Optional[type[TensorProductBase]], - batch_size: int, - correctness_threshold: float, - prng_seed: int, -): - global torch - import torch - - in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( - get_random_buffers_double_backward( - problem, batch_size=batch_size, prng_seed=prng_seed - ) - ) - - if reference_implementation is None: - from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct - - reference_implementation = E3NNTensorProduct - - result = {"thresh": correctness_threshold, "batch_size": batch_size} - - tensors = [] - for _, impl in enumerate([test_implementation, reference_implementation]): - tp = instantiate_implementation(impl, problem) - weights_reordered = tp.reorder_weights_from_e3nn( - weights, has_batch_dim=not problem.shared_weights - ) - weights_dgrad_reordered = tp.reorder_weights_from_e3nn( - weights_dgrad, has_batch_dim=not problem.shared_weights - ) - - if impl == CUETensorProduct and problem.shared_weights: - weights_reordered = weights_reordered[np.newaxis, :] - - in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( - in1, - in2, - out_grad, - weights_reordered, - weights_dgrad_reordered, - in1_dgrad, - in2_dgrad, - ) - tensors.append( - ( - out_dgrad, - in1_grad, - in2_grad, - tp.reorder_weights_to_e3nn( - weights_grad, has_batch_dim=not problem.shared_weights - ), - ) - ) - - for name, to_check, ground_truth in [ - ("output_double_grad", tensors[0][0], tensors[1][0]), - ("in1_grad", tensors[0][1], tensors[1][1]), - ("in2_grad", tensors[0][2], tensors[1][2]), - ("weights_grad", tensors[0][3], tensors[1][3]), - ]: - result[name] = check_similiarity( - name, to_check, ground_truth, correctness_threshold - ) - - return result diff --git a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/metrics.py similarity index 58% rename from openequivariance/openequivariance/benchmark/perf_metrics_utils.py rename to openequivariance/openequivariance/benchmark/metrics.py index 212f05f4..0fc87f1f 100644 --- a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/openequivariance/benchmark/metrics.py @@ -1,20 +1,15 @@ -import math - from openequivariance.core.utils import ( count_cg_non_zero, - sparse_outer_product_work, ) -from openequivariance.core.e3nn_lite import TPProblem, wigner_3j -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.core.logging import getLogger import numpy as np logger = getLogger() -def calculate_minimum_memory_streamed_forward( - tpp: TPProblem, batch_size: int -) -> dict[str, int]: +def memory_streamed_forward(tpp: TPProblem, batch_size: int) -> dict[str, int]: """ This represents an absolute minimum amount of memory that could be streamed on an ideal machine It returns the number of bytes streamed total and from each source @@ -31,7 +26,7 @@ def calculate_minimum_memory_streamed_forward( return data_size -def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict: +def memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict: """ This represents an absolute minimum amount of memory that could be streamed on an ideal machine It returns the number of bytes streamed total and from each source @@ -51,17 +46,12 @@ def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) return data_size -def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict: +def flops_forward(tpp: TPProblem, batch_size: int) -> dict: """ - This is not actually calcuating the minimum value. - Ideally you might share the outer product values between two inputs across multiple inputs. - This is assuming that you form those values and reuse them once per CG decomp. + Default FLOP estimate aligned with LoopUnrollTP's forward FLOP accounting. """ - logger.warning("Minimum flops Calculation is not the true minimum") - flops_count = {} - flops_count["outer_products"] = 0 - flops_count["CG_decomposition"] = 0 - flops_count["linear_combination"] = 0 + flops_count = {"CG_decomposition": 0, "linear_combination": 0, "outer_products": 0} + for ins in tpp.instructions: # type : Instruction l1, l2, l3 = ( tpp.irreps_in1[ins.i_in1].ir.l, @@ -69,28 +59,36 @@ def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict: tpp.irreps_out[ins.i_out].ir.l, ) - flops_count["outer_products"] += sparse_outer_product_work( - wigner_3j(l1, l2, l3) - ) flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( ins.path_shape[0] * ins.path_shape[1] ) flops_count["linear_combination"] += ( - (2 * l3 + 1) * math.prod(ins.path_shape) if ins.has_weight else 0 + (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 ) - flops_count["outer_products"] *= batch_size - flops_count["CG_decomposition"] *= 2 * batch_size - flops_count["linear_combination"] *= 2 * batch_size + flops_count["CG_decomposition"] *= 3 * batch_size + flops_count["linear_combination"] *= batch_size # Weights do not require FMA here flops_count["total"] = sum(flops_count.values()) return flops_count -def calculate_minimum_flops_backward(tpp: TPProblem, batch_size: int) -> dict: +def flops_backward(tpp: TPProblem, batch_size: int) -> dict: """ - This is not actually calcuating the minumum value. - Ideally you might share the outer product values between two inputs across multiple inputs. - This is assuming that you form those values and reuse them once per CG decomp. + Default FLOP estimate aligned with LoopUnrollTP's backward FLOP accounting. """ - raise NotImplementedError("this needs to be implemented properly") + flops_count = {"backward": 0} + + for ins in tpp.instructions: # type : Instruction + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flops_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) + + flops_count["backward"] *= 9 * batch_size + flops_count["total"] = sum(flops_count.values()) + return flops_count diff --git a/openequivariance/openequivariance/benchmark/plotting/__init__.py b/openequivariance/openequivariance/benchmark/plotting/__init__.py index 3c0bf032..6aac5e14 100644 --- a/openequivariance/openequivariance/benchmark/plotting/__init__.py +++ b/openequivariance/openequivariance/benchmark/plotting/__init__.py @@ -5,6 +5,7 @@ from openequivariance.benchmark.plotting.plot_double_backward import ( plot_double_backward, ) +from openequivariance.benchmark.plotting.plot_layout import plot_layout __all__ = [ "plot_uvu", @@ -12,4 +13,5 @@ "plot_roofline", "plot_convolution", "plot_double_backward", + "plot_layout", ] diff --git a/openequivariance/openequivariance/benchmark/plotting/plot_layout.py b/openequivariance/openequivariance/benchmark/plotting/plot_layout.py new file mode 100644 index 00000000..5d42ac32 --- /dev/null +++ b/openequivariance/openequivariance/benchmark/plotting/plot_layout.py @@ -0,0 +1,128 @@ +import pathlib + +import matplotlib.pyplot as plt +import numpy as np + +from openequivariance.benchmark.plotting.plotting_utils import ( + calculate_tp_per_sec, + grouped_barchart, + load_benchmarks, + set_grid, +) + + +def _parse_layout_label(label: str): + if label.endswith("[mul_ir]"): + return label[: -len(" [mul_ir]")], "mul_ir" + if label.endswith("[ir_mul]"): + return label[: -len(" [ir_mul]")], "ir_mul" + return label, None + + +def plot_layout(data_folder): + data_folder = pathlib.Path(data_folder) + benchmarks, _ = load_benchmarks(data_folder) + + grouped = {} + dtype_order = [] + for benchmark in benchmarks: + dtype = benchmark["benchmark results"]["rep_dtype"] + if dtype not in dtype_order: + dtype_order.append(dtype) + + direction = benchmark["direction"] + base_label, layout = _parse_layout_label(benchmark["config_label"]) + if layout is None: + continue + + grouped.setdefault(dtype, {}).setdefault(direction, {}).setdefault( + base_label, {"mul_ir": 0.0, "ir_mul": 0.0} + ) + grouped[dtype][direction][base_label][layout] = calculate_tp_per_sec(benchmark) + + def _dtype_sort_key(dtype_name: str) -> int: + if "float32" in dtype_name: + return 0 + if "float64" in dtype_name: + return 1 + return 2 + + dtype_order = sorted(dtype_order, key=_dtype_sort_key) + + directions = [ + d for d in ["forward", "backward"] if any(d in grouped[x] for x in grouped) + ] + if not directions: + raise ValueError("No forward/backward layout benchmark entries found to plot.") + + fig = plt.figure(figsize=(7, 7)) + gs = fig.add_gridspec(len(directions), max(1, len(dtype_order))) + axs = gs.subplots(sharex="col") + + if len(directions) == 1 and len(dtype_order) == 1: + axs = np.array([[axs]]) + elif len(directions) == 1: + axs = np.array([axs]) + elif len(dtype_order) == 1: + axs = np.array([[ax] for ax in axs]) + + colormap = {"mul_ir": "#1f77b4", "ir_mul": "#2ca02c"} + + for row, direction in enumerate(directions): + for col, dtype in enumerate(dtype_order): + axis = axs[row][col] + source = grouped.get(dtype, {}).get(direction, {}) + data = { + label: { + "mul_ir": vals["mul_ir"], + "ir_mul": vals["ir_mul"], + } + for label, vals in source.items() + } + grouped_barchart( + data, + axis, + bar_height_fontsize=0, + colormap=colormap, + group_spacing=6.0, + xticklabel=(row == len(directions) - 1), + ) + set_grid(axis) + + if row == 0: + axis.set_title(dtype.replace("", "")) + if col == 0: + axis.set_ylabel(direction.capitalize()) + if row < len(directions) - 1: + axis.tick_params(axis="x", labelbottom=False) + + fig.supylabel("Throughput (# tensor products / s)", x=0.03, y=0.56) + fig.supxlabel("Problem") + + handles, labels = axs[0][0].get_legend_handles_labels() + unique = [ + (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + ] + if unique: + axs[0][0].legend(*zip(*unique)) + + fig.tight_layout(rect=(0.03, 0.03, 1.0, 1.0)) + fig.savefig(str(data_folder / "layout_throughput_comparison.pdf")) + + print("Layout speedups (ir_mul / mul_ir):") + print("\t".join(["dtype", "direction", "min", "mean", "median", "max"])) + for dtype in dtype_order: + for direction in directions: + ratios = [] + for _, values in grouped.get(dtype, {}).get(direction, {}).items(): + if values["mul_ir"] > 0: + ratios.append(values["ir_mul"] / values["mul_ir"]) + if ratios: + stats = [ + np.min(ratios), + np.mean(ratios), + np.median(ratios), + np.max(ratios), + ] + stats_fmt = [f"{val:.3f}" for val in stats] + print("\t".join([dtype, direction] + stats_fmt)) diff --git a/openequivariance/openequivariance/benchmark/problems.py b/openequivariance/openequivariance/benchmark/problems.py index dd536e3a..b486941c 100644 --- a/openequivariance/openequivariance/benchmark/problems.py +++ b/openequivariance/openequivariance/benchmark/problems.py @@ -1,7 +1,203 @@ -from openequivariance.benchmark.tpp_creation_utils import ( - FullyConnectedTPProblem as FCTPP, -) -from openequivariance.benchmark.tpp_creation_utils import ChannelwiseTPP as CTPP +from typing import Iterator, Optional + +import numpy as np + +from openequivariance.core.e3nn_lite import Irrep, Irreps, TPProblem + +""" +This was taken from +https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py +Adapted to create TPPs to avoid torch dependence. +""" + + +class FullyConnectedTPProblem(TPProblem): + def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs) -> None: + irreps_in1 = Irreps(irreps_in1) + irreps_in2 = Irreps(irreps_in2) + irreps_out = Irreps(irreps_out) + + instr = [ + (i_1, i_2, i_out, "uvw", True, 1.0) + for i_1, (_, ir_1) in enumerate(irreps_in1) + for i_2, (_, ir_2) in enumerate(irreps_in2) + for i_out, (_, ir_out) in enumerate(irreps_out) + if ir_out in ir_1 * ir_2 + ] + super().__init__( + irreps_in1, + irreps_in2, + irreps_out, + instr, + **kwargs, + ) + + +class ElementwiseTPProblem(TPProblem): + def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None: + irreps_in1 = Irreps(irreps_in1).simplify() + irreps_in2 = Irreps(irreps_in2).simplify() + if filter_ir_out is not None: + try: + filter_ir_out = [Irrep(ir) for ir in filter_ir_out] + except ValueError as exc: + raise ValueError( + f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" + ) from exc + + assert irreps_in1.num_irreps == irreps_in2.num_irreps + + irreps_in1 = list(irreps_in1) + irreps_in2 = list(irreps_in2) + + i = 0 + while i < len(irreps_in1): + mul_1, ir_1 = irreps_in1[i] + mul_2, ir_2 = irreps_in2[i] + + if mul_1 < mul_2: + irreps_in2[i] = (mul_1, ir_2) + irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) + + if mul_2 < mul_1: + irreps_in1[i] = (mul_2, ir_1) + irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) + i += 1 + + out = [] + instr = [] + for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): + assert mul == mul_2 + for ir in ir_1 * ir_2: + if filter_ir_out is not None and ir not in filter_ir_out: + continue + + i_out = len(out) + out.append((mul, ir)) + instr += [(i, i, i_out, "uuu", False)] + + super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) + + +class FullTPProblem(TPProblem): + def __init__( + self, + irreps_in1: Irreps, + irreps_in2: Irreps, + filter_ir_out: Iterator[Irrep] = None, + **kwargs, + ) -> None: + irreps_in1 = Irreps(irreps_in1).simplify() + irreps_in2 = Irreps(irreps_in2).simplify() + if filter_ir_out is not None: + try: + filter_ir_out = [Irrep(ir) for ir in filter_ir_out] + except ValueError as exc: + raise ValueError( + f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" + ) from exc + + out = [] + instr = [] + for i_1, (mul_1, ir_1) in enumerate(irreps_in1): + for i_2, (mul_2, ir_2) in enumerate(irreps_in2): + for ir_out in ir_1 * ir_2: + if filter_ir_out is not None and ir_out not in filter_ir_out: + continue + + i_out = len(out) + out.append((mul_1 * mul_2, ir_out)) + instr += [(i_1, i_2, i_out, "uvuv", False)] + + out = Irreps(out) + out, p, _ = out.sort() + + instr = [ + (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr + ] + + super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) + + +class ChannelwiseTPP(TPProblem): + """ + Modified from mace/mace/modules/irreps_tools.py. + """ + + def __init__( + self, + irreps_in1: Irreps, + irreps_in2: Irreps, + irreps_out: Irreps, + label: Optional[str] = None, + irrep_dtype=np.float32, + weight_dtype=np.float32, + ): + trainable = True + irreps1 = Irreps(irreps_in1) + irreps2 = Irreps(irreps_in2) + irreps_out = Irreps(irreps_out) + + irreps_out_list = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: + if ir_out in irreps_out: + k = len(irreps_out_list) + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + irreps_out = Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + instructions = sorted(instructions, key=lambda x: x[2]) + super().__init__( + irreps1, + irreps2, + irreps_out, + instructions, + internal_weights=False, + shared_weights=False, + label=label, + irrep_dtype=irrep_dtype, + weight_dtype=weight_dtype, + ) + + +class SingleInstruction(TPProblem): + def __init__( + self, + irreps_in1: Irreps, + irreps_in2: Irreps, + irreps_in3: Irreps, + mode: str, + label: Optional[str] = None, + ): + trainable = True + irreps1 = Irreps(irreps_in1) + irreps2 = Irreps(irreps_in2) + irreps3 = Irreps(irreps_in3) + instructions = [(0, 0, 0, mode, trainable)] + + super().__init__( + irreps1, + irreps2, + irreps3, + instructions, + internal_weights=False, + shared_weights=False, + label=label, + ) + + +FCTPP = FullyConnectedTPProblem +CTPP = ChannelwiseTPP # source: https://github.com/e3nn/e3nn/blob/main/examples/tetris.py # running tetris will output the layers. I've only extracted the fully connected layers here. diff --git a/openequivariance/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/test_buffers.py similarity index 100% rename from openequivariance/openequivariance/benchmark/random_buffer_utils.py rename to openequivariance/openequivariance/benchmark/test_buffers.py diff --git a/openequivariance/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/openequivariance/benchmark/tpp_creation_utils.py deleted file mode 100644 index 7637f412..00000000 --- a/openequivariance/openequivariance/benchmark/tpp_creation_utils.py +++ /dev/null @@ -1,196 +0,0 @@ -import numpy as np - -from typing import Iterator, Optional -from openequivariance.core.e3nn_lite import Irrep, Irreps, TPProblem - -""" -This was taken from -https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py -Adapted to create TPPs to avoid torch dependence. -""" - - -class FullyConnectedTPProblem(TPProblem): - def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs) -> None: - irreps_in1 = Irreps(irreps_in1) - irreps_in2 = Irreps(irreps_in2) - irreps_out = Irreps(irreps_out) - - instr = [ - (i_1, i_2, i_out, "uvw", True, 1.0) - for i_1, (_, ir_1) in enumerate(irreps_in1) - for i_2, (_, ir_2) in enumerate(irreps_in2) - for i_out, (_, ir_out) in enumerate(irreps_out) - if ir_out in ir_1 * ir_2 - ] - super().__init__( - irreps_in1, - irreps_in2, - irreps_out, - instr, - **kwargs, - ) - - -class ElementwiseTPProblem(TPProblem): - def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None: - irreps_in1 = Irreps(irreps_in1).simplify() - irreps_in2 = Irreps(irreps_in2).simplify() - if filter_ir_out is not None: - try: - filter_ir_out = [Irrep(ir) for ir in filter_ir_out] - except ValueError: - raise ValueError( - f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" - ) - - assert irreps_in1.num_irreps == irreps_in2.num_irreps - - irreps_in1 = list(irreps_in1) - irreps_in2 = list(irreps_in2) - - i = 0 - while i < len(irreps_in1): - mul_1, ir_1 = irreps_in1[i] - mul_2, ir_2 = irreps_in2[i] - - if mul_1 < mul_2: - irreps_in2[i] = (mul_1, ir_2) - irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) - - if mul_2 < mul_1: - irreps_in1[i] = (mul_2, ir_1) - irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) - i += 1 - - out = [] - instr = [] - for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): - assert mul == mul_2 - for ir in ir_1 * ir_2: - if filter_ir_out is not None and ir not in filter_ir_out: - continue - - i_out = len(out) - out.append((mul, ir)) - instr += [(i, i, i_out, "uuu", False)] - - super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) - - -class FullTPProblem(TPProblem): - def __init__( - self, - irreps_in1: Irreps, - irreps_in2: Irreps, - filter_ir_out: Iterator[Irrep] = None, - **kwargs, - ) -> None: - irreps_in1 = Irreps(irreps_in1).simplify() - irreps_in2 = Irreps(irreps_in2).simplify() - if filter_ir_out is not None: - try: - filter_ir_out = [Irrep(ir) for ir in filter_ir_out] - except ValueError: - raise ValueError( - f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" - ) - - out = [] - instr = [] - for i_1, (mul_1, ir_1) in enumerate(irreps_in1): - for i_2, (mul_2, ir_2) in enumerate(irreps_in2): - for ir_out in ir_1 * ir_2: - if filter_ir_out is not None and ir_out not in filter_ir_out: - continue - - i_out = len(out) - out.append((mul_1 * mul_2, ir_out)) - instr += [(i_1, i_2, i_out, "uvuv", False)] - - out = Irreps(out) - out, p, _ = out.sort() - - instr = [ - (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr - ] - - super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) - - -class ChannelwiseTPP(TPProblem): - """ - Modified from mace/mace/modules/irreps_tools.py. - """ - - def __init__( - self, - irreps_in1: Irreps, - irreps_in2: Irreps, - irreps_out: Irreps, - label: Optional[str] = None, - irrep_dtype=np.float32, - weight_dtype=np.float32, - ): - trainable = True - irreps1 = Irreps(irreps_in1) - irreps2 = Irreps(irreps_in2) - irreps_out = Irreps(irreps_out) - - # Collect possible irreps and their instructions - irreps_out_list = [] - instructions = [] - for i, (mul, ir_in) in enumerate(irreps1): - for j, (_, ir_edge) in enumerate(irreps2): - for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 - if ir_out in irreps_out: - k = len(irreps_out_list) # instruction index - irreps_out_list.append((mul, ir_out)) - instructions.append((i, j, k, "uvu", trainable)) - - irreps_out = Irreps(irreps_out_list) - irreps_out, permut, _ = irreps_out.sort() - - instructions = [ - (i_in1, i_in2, permut[i_out], mode, train) - for i_in1, i_in2, i_out, mode, train in instructions - ] - - instructions = sorted(instructions, key=lambda x: x[2]) - super().__init__( - irreps1, - irreps2, - irreps_out, - instructions, - internal_weights=False, - shared_weights=False, - label=label, - irrep_dtype=irrep_dtype, - weight_dtype=weight_dtype, - ) - - -class SingleInstruction(TPProblem): - def __init__( - self, - irreps_in1: Irreps, - irreps_in2: Irreps, - irreps_in3: Irreps, - mode: str, - label: Optional[str] = None, - ): - trainable = True - irreps1 = Irreps(irreps_in1) - irreps2 = Irreps(irreps_in2) - irreps3 = Irreps(irreps_in3) - instructions = [(0, 0, 0, mode, trainable)] - - super().__init__( - irreps1, - irreps2, - irreps3, - instructions, - internal_weights=False, - shared_weights=False, - label=label, - ) diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index c9765d0d..f9f10013 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -1,7 +1,9 @@ +from itertools import accumulate + import numpy as np + +from openequivariance.core.logging import getLogger from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j -from itertools import accumulate -from openequivariance.benchmark.logging_utils import getLogger logger = getLogger() @@ -17,8 +19,9 @@ class IrrepMapping: Maps irreps from a source to a destination set. """ - def __init__(self, src_irreps, idxs): + def __init__(self, src_irreps, src_views, idxs): self.src_irreps = src_irreps + self.src_views = src_views self.idxs = sorted(list(idxs)) self.dst_irreps = Irreps([src_irreps[idx] for idx in self.idxs]) self.src_dst_map = {idx: i for i, idx in enumerate(self.idxs)} @@ -26,10 +29,17 @@ def __init__(self, src_irreps, idxs): src_ranges = [src_irreps.slices()[idx] for idx in self.src_dst_map] dst_ranges = [self.dst_irreps.slices()[i] for i in self.src_dst_map.values()] + self.storeback_procedure = {idx: "write" for idx in self.idxs} + self.persist_load = False + self.persist_store = False + + if src_views[0].layout == "ir_mul": + return + + # Merge adjacent src and dst ranges self.original_src_ranges = src_ranges self.original_dst_ranges = dst_ranges - # Merge adjacent src and dst ranges self.src_ranges = [] self.dst_ranges = [] @@ -49,11 +59,6 @@ def __init__(self, src_irreps, idxs): self.dst_ranges.append(slice(dst_start, dst_end)) self.copy_ranges = list(zip(self.src_ranges, self.dst_ranges)) - self.persist_load = False - self.persist_store = False - - self.storeback_procedure = {idx: "write" for idx in self.idxs} - class CGTensor: def __init__(self, l1, l2, l3, normalization_factor, dtype): @@ -193,6 +198,12 @@ class ChildInstruction: def __init__(self, instruction_tup, parent_idx): self.instruction_tup, self.parent_idx = instruction_tup, parent_idx + class ChildView: + def __init__(self, layout: str, ir_mul_offset: int, ir_mul_stride: int): + self.layout = layout + self.ir_mul_offset = ir_mul_offset + self.ir_mul_stride = ir_mul_stride + def __init__(self, input, mult_threshold): self.input = input self.mult_threshold = mult_threshold @@ -201,6 +212,7 @@ def __init__(self, input, mult_threshold): child_reps = [[], [], []] self.irrep_maps = {} # Maps a (input_rep_idx #, mul_ir_idx) to a lst[ir_idx] + self.irrep_views = [[], [], []] # View for input_rep_idx, input_rep in enumerate(input_reps): # Loop over L1, L2, L3 for mul_ir_idx, mul_ir in enumerate( @@ -209,12 +221,27 @@ def __init__(self, input, mult_threshold): self.irrep_maps[input_rep_idx, mul_ir_idx] = [] for mul_start in range(0, mul_ir.mul, mult_threshold): mul = min(mult_threshold, mul_ir.mul - mul_start) - child_reps[input_rep_idx] += [ + child_reps[input_rep_idx].append( self.ChildIrrep((mul, mul_ir.ir), input_rep_idx, mul_start) - ] + ) self.irrep_maps[input_rep_idx, mul_ir_idx].append( len(child_reps[input_rep_idx]) - 1 ) + if input.layout == "mul_ir": + self.irrep_views[input_rep_idx].append( + self.ChildView( + layout="mul_ir", ir_mul_offset=-1, ir_mul_stride=-1 + ) + ) + elif input.layout == "ir_mul": + self.irrep_views[input_rep_idx].append( + self.ChildView( + layout="ir_mul", + ir_mul_offset=input_rep.slices()[mul_ir_idx].start + + mul_start, + ir_mul_stride=mul_ir.mul, + ) + ) new_instructions = [] @@ -274,6 +301,7 @@ def __init__(self, input, mult_threshold): path_normalization="none", internal_weights=False, shared_weights=input.shared_weights, + layout=input.layout, ) assert self.output.weight_numel == input.weight_numel @@ -543,9 +571,9 @@ def calculate_backward_smem( for i in range(len(self.segments)): L1_idxs, L2_idxs, L3_idxs, inst_idxs = self.segments[i] - L1Map = IrrepMapping(self.L1, L1_idxs) - L2Map = IrrepMapping(self.L2, L2_idxs) - L3Map = IrrepMapping(self.L3, L3_idxs) + L1Map = IrrepMapping(self.L1, self.problem_splitter.irrep_views[0], L1_idxs) + L2Map = IrrepMapping(self.L2, self.problem_splitter.irrep_views[1], L2_idxs) + L3Map = IrrepMapping(self.L3, self.problem_splitter.irrep_views[2], L3_idxs) instructions = [ ( @@ -568,6 +596,7 @@ def calculate_backward_smem( path_normalization="none", internal_weights=False, shared_weights=config.shared_weights, + layout=config.layout, ) weight_offset = 0 diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index a06b2c79..116a21b3 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -1,13 +1,10 @@ -import copy import numpy as np -from openequivariance.benchmark.random_buffer_utils import ( - get_random_buffers_forward_conv, + +from openequivariance.core.logging import bcolors, getLogger +from openequivariance.benchmark.test_buffers import ( get_random_buffers_backward_conv, - get_random_buffers_double_backward_conv, + get_random_buffers_forward_conv, ) - -from openequivariance.benchmark.logging_utils import getLogger, bcolors -from openequivariance.benchmark.correctness_utils import check_similiarity from openequivariance.core.e3nn_lite import wigner_3j from openequivariance.core.utils import benchmark @@ -134,94 +131,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): def name(): raise NotImplementedError() - def test_correctness_forward( - self, - graph, - thresh, - prng_seed, - reference_implementation=None, - check_reproducible=True, - high_precision_ref=False, - ): - if reference_implementation is None: - from openequivariance._torch.E3NNConv import E3NNConv - - reference_implementation = E3NNConv - - result = {"thresh": thresh} - - in1, in2, weights, out = get_random_buffers_forward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) - ref_in1, ref_in2, ref_weights, ref_out = [ - buf.copy() for buf in [in1, in2, weights, out] - ] - - reference_config = self.config - if high_precision_ref: - reference_config = copy.deepcopy(self.config) - reference_config.irrep_dtype = np.float64 - reference_config.weight_dtype = np.float64 - ref_in1, ref_in2, ref_weights, ref_out = [ - np.array(el, dtype=np.float64) - for el in [ref_in1, ref_in2, ref_weights, ref_out] - ] - - args = { - "L1_in": ref_in1, - "L2_in": ref_in2, - "weights": ref_weights, - "rows": graph.rows, - "cols": graph.cols, - } - - ref_tp = reference_implementation(reference_config) - if ref_tp.deterministic: - args["transpose_perm"] = graph.transpose_perm - - for key in args: - args[key] = torch.tensor(args[key], device="cuda") - - ref_out[:] = ref_tp.forward(**args).cpu().numpy() - - test_out = out.copy() - self.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), - weights=weights.copy(), - L3_out=test_out, - graph=graph, - ) - - for name, to_check, ground_truth in [("output", ref_out, test_out)]: - result[name] = check_similiarity(name, to_check, ground_truth, thresh) - - if check_reproducible: - num_trials = 5 - for name in ["output"]: - result[name]["num_reproducibility_trials"] = num_trials - result[name]["bitwise_reproducible"] = True - - for i in range(num_trials): - repeated_run = out.copy() - self.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), - weights=weights.copy(), - L3_out=repeated_run, - graph=graph, - ) - - for name, to_check, ground_truth in [ - ("output", repeated_run, test_out) - ]: - result[name]["bitwise_reproducible"] = bool( - result[name]["bitwise_reproducible"] - and np.all(repeated_run == test_out) - ) - - return result - def benchmark_forward( self, num_warmup, num_iter, graph, prng_seed=12345, kernel_names=["forward"] ): @@ -379,159 +288,6 @@ def calculate_bench_stats( ) return result - def test_correctness_backward( - self, - graph, - thresh, - prng_seed, - reference_implementation=None, - high_precision_ref=False, - ): - if reference_implementation is None: - from openequivariance._torch.E3NNConv import E3NNConv - - reference_implementation = E3NNConv - - result = {"thresh": thresh} - - buffers = get_random_buffers_backward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) - reference_buffers = [buf.copy() for buf in buffers] - reference_problem = self.config - - if high_precision_ref: - reference_problem = copy.deepcopy(self.config) - reference_problem.irrep_dtype = np.float64 - reference_problem.weight_dtype = np.float64 - reference_buffers = [ - np.array(el, dtype=np.float64) for el in reference_buffers - ] - - ref_tp = reference_implementation(reference_problem) - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = buffers - ( - ref_in1, - ref_in2, - ref_out_grad, - ref_weights, - ref_weights_grad, - ref_in1_grad, - ref_in2_grad, - ) = reference_buffers - - ref_tp.backward_cpu( - L1_in=ref_in1, - L1_grad=ref_in1_grad, - L2_in=ref_in2, - L2_grad=ref_in2_grad, - L3_grad=ref_out_grad, - weights=ref_weights, - weights_grad=ref_weights_grad, - graph=graph, - ) - - # run test version - test_weights_grad = weights_grad.copy() - test_in1_grad = in1_grad.copy() - test_in2_grad = in2_grad.copy() - - self.backward_cpu( - L1_in=in1.copy(), - L1_grad=test_in1_grad, - L2_in=in2.copy(), - L2_grad=test_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), - weights_grad=test_weights_grad, - graph=graph, - ) - - for name, to_check, ground_truth, threshold in [ - ("weight_grad", test_weights_grad, ref_weights_grad, thresh), - ("in1_grad", test_in1_grad, ref_in1_grad, thresh), - ("in2_grad", test_in2_grad, ref_in2_grad, thresh), - ]: - result[name] = check_similiarity(name, to_check, ground_truth, threshold) - - return result - - def test_correctness_double_backward( - self, - graph, - thresh, - prng_seed, - reference_implementation=None, - high_precision_ref=False, - ): - buffers = get_random_buffers_double_backward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) - - if reference_implementation is None: - from openequivariance._torch.E3NNConv import E3NNConv - - reference_implementation = E3NNConv - - reference_problem = self.config - if high_precision_ref: - reference_problem = copy.deepcopy(self.config) - reference_problem.irrep_dtype = np.float64 - reference_problem.weight_dtype = np.float64 - - reference_tp = reference_implementation(reference_problem, torch_op=True) - - result = {"thresh": thresh} - tensors = [] - for i, tp in enumerate([self, reference_tp]): - buffers_copy = [buf.copy() for buf in buffers] - - if i == 1 and high_precision_ref: - buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] - - in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( - buffers_copy - ) - - weights_reordered = tp.reorder_weights_from_e3nn( - weights, not tp.config.shared_weights - ) - weights_dgrad_reordered = tp.reorder_weights_from_e3nn( - weights_dgrad, not tp.config.shared_weights - ) - - in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( - in1, - in2, - out_grad, - weights_reordered, - weights_dgrad_reordered, - in1_dgrad, - in2_dgrad, - graph, - ) - - tensors.append( - ( - out_dgrad, - in1_grad, - in2_grad, - tp.reorder_weights_to_e3nn( - weights_grad, has_batch_dim=not self.config.shared_weights - ), - ) - ) - - for name, to_check, ground_truth in [ - ("output_grad", tensors[0][0], tensors[1][0]), - ("in1_grad", tensors[0][1], tensors[1][1]), - ("in2_grad", tensors[0][2], tensors[1][2]), - ("weights_grad", tensors[0][3], tensors[1][3]), - ]: - result[name] = check_similiarity(name, to_check, ground_truth, thresh) - - return result - def scatter_add_wrapper(messages, rows, target_dim): L3_dim = messages.size(1) diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index ca8b4bdd..b3c63e66 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -1,18 +1,18 @@ -import numpy as np import json -from openequivariance.core.ConvolutionBase import ConvolutionBase +import numpy as np + from openequivariance.core.ComputationSchedule import ( ComputationSchedule, SMEMCapacityException, ) - -from openequivariance.templates.jinja_utils import get_jinja_environment +from openequivariance.core.ConvolutionBase import ConvolutionBase from openequivariance.core.utils import ( - filter_and_analyze_problem, dtype_to_enum, + filter_and_analyze_problem, hash_str_64, ) +from openequivariance.templates.jinja_utils import get_jinja_environment class LoopUnrollConv(ConvolutionBase): @@ -114,9 +114,11 @@ def generate_double_backward_schedule(warps_per_block): except SMEMCapacityException: warp_count -= 1 if warp_count == 0: - raise SMEMCapacityException( + raise RuntimeError( "Tensor product schedule generation failed, shared memory inadequate!" ) + except Exception: + raise if not deterministic: for segment in self.forward_schedule.segments: diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 41354e5f..36801405 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -1,16 +1,17 @@ -import numpy as np import json -from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance.core.ComputationSchedule import ComputationSchedule +from openequivariance.core.logging import getLogger +from openequivariance.core.ComputationSchedule import ( + ComputationSchedule, + SMEMCapacityException, +) from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.core.utils import dtype_to_enum, hash_str_64 - from openequivariance.core.utils import ( + dtype_to_enum, filter_and_analyze_problem, - count_cg_non_zero, + hash_str_64, ) +from openequivariance.templates.jinja_utils import get_jinja_environment logger = getLogger() @@ -80,12 +81,14 @@ def generate_double_backward_schedule(warps_per_block): try: generate_schedule(warp_count) break - except Exception: + except SMEMCapacityException: warp_count -= 2 if warp_count == 0: raise RuntimeError( "Tensor product schedule generation failed, shared memory inadequate!" ) + except Exception: + raise self.jit_kernel = postprocess_kernel( template.render( @@ -123,53 +126,3 @@ def generate_double_backward_schedule(warps_per_block): ) self.hash = hash_str_64(self.kernel_string) logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") - - def calculate_flops_forward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_forward(batch_size) - else: - tpp = self.config - flop_count = { - "CG_decomposition": 0, - "linear_combination": 0, - "outer_products": 0, - } - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - flop_count["linear_combination"] += ( - (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 - ) - - flop_count["CG_decomposition"] *= 3 * batch_size - flop_count["linear_combination"] *= ( - batch_size # Weights do not require FMA here - ) - flop_count["total"] = sum(flop_count.values()) - return flop_count - - def calculate_flops_backward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_backward(batch_size) - else: - tpp = self.config - flop_count = {"backward": 0} - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - - flop_count["backward"] *= 9 * batch_size - flop_count["total"] = sum(flop_count.values()) - return flop_count diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index b5d3831f..c6fc83f8 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -1,7 +1,7 @@ import numpy as np from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.logging import getLogger from openequivariance.core.utils import benchmark logger = getLogger() @@ -180,21 +180,3 @@ def benchmark_double_backward( mode=mode, kernel_names=kernel_names, ) - - def calculate_memory_streamed_forward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_memory_streamed_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_memory_streamed_double_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_forward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_double_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") diff --git a/openequivariance/openequivariance/core/e3nn_lite.py b/openequivariance/openequivariance/core/e3nn_lite.py index d569ad1c..41cc6e59 100644 --- a/openequivariance/openequivariance/core/e3nn_lite.py +++ b/openequivariance/openequivariance/core/e3nn_lite.py @@ -36,14 +36,15 @@ SOFTWARE. """ -import itertools -from typing import Tuple, NamedTuple, Union, List, Any, Optional -from math import sqrt, prod import collections +import copy +import functools +import itertools +from math import prod, sqrt +from typing import Any, List, NamedTuple, Optional, Tuple, Union + import numpy as np import numpy.linalg as la -import functools -import copy def perm_inverse(p): @@ -385,6 +386,7 @@ class TPProblem: :param internal_weights: Must be False; OpenEquivariance does not support internal weights. *Default*: False. :param irrep_normalization: One of ``["component", "norm", "none"]``. *Default*: "component". :param path_normalization: One of ``["element", "path", "none"]``. *Default*: "element". + :param layout: One of ``["mul_ir", "ir_mul"]``, giving the layout of irreps for all inputs and outputs. *Default*: "mul_ir". """ instructions: List[Any] @@ -392,9 +394,9 @@ class TPProblem: internal_weights: bool weight_numel: int label: str - _profiling_str: str _in1_dim: int _in2_dim: int + layout: str def __init__( self, @@ -412,12 +414,14 @@ def __init__( label: Optional[str] = None, irrep_dtype: type[np.generic] = np.float32, weight_dtype: type[np.generic] = np.float32, + layout: str = "mul_ir", ) -> None: # === Setup === super().__init__() assert irrep_normalization in ["component", "norm", "none"] assert path_normalization in ["element", "path", "none"] + assert layout in ["mul_ir", "ir_mul"] assert issubclass(irrep_dtype, np.generic) assert issubclass(weight_dtype, np.generic) @@ -432,6 +436,7 @@ def __init__( self.irrep_normalization = irrep_normalization self.path_normalization = path_normalization self.label = label if label is not None else "" + self.layout = layout del irreps_in1, irreps_in2, irreps_out instructions = [x if len(x) == 6 else x + (1.0,) for x in instructions] diff --git a/openequivariance/openequivariance/benchmark/logging_utils.py b/openequivariance/openequivariance/core/logging.py similarity index 100% rename from openequivariance/openequivariance/benchmark/logging_utils.py rename to openequivariance/openequivariance/core/logging.py diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 1950013d..53638422 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -1,15 +1,13 @@ import functools +import hashlib +import json import math +import tempfile +from enum import IntEnum import numpy as np -from openequivariance.core.e3nn_lite import Instruction, TPProblem, wigner_3j - -import json -import tempfile -import hashlib - -from enum import IntEnum +from openequivariance.core.e3nn_lite import Instruction, Irreps, TPProblem, wigner_3j class DTypeEnum(IntEnum): @@ -98,6 +96,11 @@ def filter_and_analyze_problem(problem): f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}" ) + if problem.layout == "ir_mul": + assert problem.instructions[0].connection_mode == "uvu", ( + "layout='ir_mul' is only supported for pure 'uvu' problems" + ) + assert problem.irrep_dtype == problem.weight_dtype, ( f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}" ) @@ -168,7 +171,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = timer.stop_clock_get_elapsed() else: - from torch.profiler import profile, record_function, ProfilerActivity + from torch.profiler import ProfilerActivity, profile, record_function trace_file = tempfile.NamedTemporaryFile().name @@ -204,3 +207,54 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): def hash_str_64(s: str) -> int: return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big") + + +def transpose_irrep_layout( + array: np.ndarray, + irreps: Irreps, + src_layout: str, + dst_layout: str, +) -> np.ndarray: + """ + Transpose irrep-packed feature arrays between `mul_ir` and `ir_mul` layouts. + + Expected input shape is `[..., irreps.dim]`. A new array is returned. + If `src_layout == dst_layout`, this returns a copy. + """ + if src_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported src_layout: {src_layout}") + if dst_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported dst_layout: {dst_layout}") + + x = np.asarray(array) + out = np.empty_like(x) + + if src_layout == dst_layout: + out[...] = x + return out + + slices = irreps.slices() + for ir_idx, mul_ir in enumerate(irreps): + mul = mul_ir.mul + dim = mul_ir.ir.dim + seg = slices[ir_idx] + block = x[..., seg.start : seg.stop] + + if src_layout == "ir_mul" and dst_layout == "mul_ir": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], dim, mul) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], mul * dim) + ) + elif src_layout == "mul_ir" and dst_layout == "ir_mul": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], mul, dim) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], dim * mul) + ) + else: + raise ValueError( + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" + ) + + return out diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 5ddcb7f5..440b7640 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -89,6 +89,7 @@ Stream get_current_stream() { #ifdef INCLUDE_NB_EXTENSION #include "nanobind/nanobind.h" + #include "nanobind/stl/string.h" namespace nb = nanobind; NB_MODULE(EXTENSION_NAME, m) { nb::class_(m, "DeviceProp") diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index c14637a1..7101fa00 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -9,7 +9,7 @@ from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.logging import getLogger from openequivariance.jax.jvp import conv_prim from openequivariance.jax.vjp import conv_func diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py index 410e5dbf..c26606b6 100644 --- a/openequivariance/openequivariance/jax/__init__.py +++ b/openequivariance/openequivariance/jax/__init__.py @@ -1,6 +1,81 @@ +import jax +import jax.numpy as jnp + +from openequivariance.core.e3nn_lite import Irreps from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct from openequivariance.jax.TensorProductConv import ( TensorProductConv as TensorProductConv, ) -__all__ = ["TensorProduct", "TensorProductConv"] + +def transpose_irreps( + array: jax.Array, + irreps: Irreps, + src_layout: str, + dst_layout: str, +) -> jax.Array: + r""" + Transpose irrep-packed feature arrays between ``mul_ir`` and ``ir_mul`` layouts. + + The function operates on the trailing feature dimension and preserves all leading + batch dimensions. It uses differentiable JAX operations, so gradients propagate + through the transpose. + + :param array: Input feature array with shape ``[..., irreps.dim]``. + :type array: jax.Array + :param irreps: Irreps specification describing how the trailing feature dimension + is partitioned into irrep blocks. + :type irreps: Irreps + :param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + :type src_layout: str + :param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + :type dst_layout: str + + :returns: Array in ``dst_layout`` with the same shape, dtype, and device as ``array``. + If ``src_layout == dst_layout``, returns a copy of ``array``. + :rtype: jax.Array + + :raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of + ``"mul_ir"`` or ``"ir_mul"``. + """ + if src_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported src_layout: {src_layout}") + if dst_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported dst_layout: {dst_layout}") + + x = jnp.asarray(array) + if src_layout == dst_layout: + return jnp.array(x, copy=True) + + out = jnp.empty_like(x) + slices = irreps.slices() + + for ir_idx, mul_ir in enumerate(irreps): + mul = mul_ir.mul + dim = mul_ir.ir.dim + seg = slices[ir_idx] + block = x[..., seg.start : seg.stop] + + if src_layout == "ir_mul" and dst_layout == "mul_ir": + transposed = ( + block.reshape(*block.shape[:-1], dim, mul) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], mul * dim) + ) + elif src_layout == "mul_ir" and dst_layout == "ir_mul": + transposed = ( + block.reshape(*block.shape[:-1], mul, dim) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], dim * mul) + ) + else: + raise ValueError( + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" + ) + + out = out.at[..., seg.start : seg.stop].set(transposed) + + return out + + +__all__ = ["TensorProduct", "TensorProductConv", "transpose_irreps"] diff --git a/openequivariance/openequivariance/templates/loop_unroll_tp.cuh b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh index eab95c57..e4c8dabe 100644 --- a/openequivariance/openequivariance/templates/loop_unroll_tp.cuh +++ b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh @@ -1,4 +1,4 @@ -{%- from 'macros.jinja' import transpose_load, transpose_store, reg_store with context %} +{%- from 'macros.jinja' import layout_load, layout_store, reg_store with context %} {%- from 'wmm.cuh' import generate_matmul %} {%- macro generate_segment_kernel_forward(id, segment, warp_size) %} @@ -36,7 +36,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- if k == 0 or interactions[k][0] != interactions[k-1][0] %} offset = {{ L1.slices()[u].start}}; - {{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} + {{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} {%- endif %} #pragma unroll @@ -72,7 +72,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ // ----------------- CORE CALCULATION ----------------- {%- if problem.instructions[k].connection_mode == "uvw" %} - {{transpose_store(L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_vec', '=', '1.0')}} + {{layout_store(problem.layout, L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_vec', '=', '1.0')}} __syncwarp(); offset = {{ L3.slices()[w].start}}; matmul_fwd_{{id}}_{{k}}(weights_smem, scratch, L3_smem + offset); @@ -85,7 +85,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- if problem.instructions[k].connection_mode != "uvw" %} offset = {{ L3.slices()[w].start}}; - {{transpose_store(L3[w].mul, L3[w].ir.dim, 'L3_smem', 'offset', 'l3_vec', '+=', '1.0')}} + {{layout_store(problem.layout, L3[w].mul, L3[w].ir.dim, 'L3_smem', 'offset', 'l3_vec', '+=', '1.0')}} {%- if L2[v].mul > 1%} #pragma unroll @@ -168,15 +168,15 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- if k == 0 or interactions[k][0] != interactions[k-1][0] %} offset = {{ L1.slices()[u].start}}; - {{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} - {{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad')}} + {{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} + {{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad')}} {%- endif %} {%- if problem.instructions[k].connection_mode != "uvw" %} {%- if k == 0 or interactions[k][2] != interactions[k-1][2] %} offset = {{ L3.slices()[w].start}}; - {{transpose_load(L3[w].mul, L3[w].ir.dim, 'L3_grad_smem', 'offset', 'l3_grad')}} + {{layout_load(problem.layout, L3[w].mul, L3[w].ir.dim, 'L3_grad_smem', 'offset', 'l3_grad')}} {%- endif %} {%- endif %} @@ -225,7 +225,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {{matmul_basename}}A_{{id}}_{{k}}(weights_smem, L3_grad_smem + offset, scratch); __syncwarp(); - {{transpose_load(L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_grad')}} + {{layout_load(problem.layout, L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_grad')}} {%- for i in range(tensor.nnz) %} {%- set coord1, coord2, coord3, value = tensor.tuples[i] %} @@ -305,7 +305,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ // Storeback {%- if k == num_interact - 1 or interactions[k][0] != interactions[k+1][0] %} offset = {{ L1.slices()[u].start}}; - {{transpose_store(L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad', '=', '1.0')}} + {{layout_store(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad', '=', '1.0')}} {%- endif %} {%- endfor %} diff --git a/openequivariance/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja index 9b0d3c28..f9108822 100644 --- a/openequivariance/openequivariance/templates/macros.jinja +++ b/openequivariance/openequivariance/templates/macros.jinja @@ -50,45 +50,110 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. } {%- endmacro %} +{%- macro layout_load(layout, mul, dim, smem, offset, reg) %} + {%- if layout == "ir_mul" %} + {{ reg_load(mul, dim, smem, offset, reg) }} + {%- else %} + {{ transpose_load(mul, dim, smem, offset, reg) }} + {%- endif %} +{%- endmacro %} + +{%- macro layout_store(layout, mul, dim, smem, offset, reg, op, coeff) %} + {%- if layout == "ir_mul" %} + {{ reg_store(mul, dim, smem, offset, reg, op, coeff) }} + {%- else %} + {{ transpose_store(mul, dim, smem, offset, reg, op, coeff) }} + {%- endif %} +{%- endmacro %} + {%- macro declare_smem_variables(segment, smem_base) %} {%- for name in segment.smem %} {%- if name != "total" %} {%- set smem_rng = segment.smem[name] %} - {{ smem_rng["dtype"] }}* {{name}}_smem = ({{smem_rng["dtype"]}}*) ({{smem_base}} + {{smem_rng["offset"]}}); + {{ smem_rng["dtype"] }}* {{name}}_smem = ({{smem_rng["dtype"]}}*) ({{smem_base}} + {{smem_rng["offset"]}}); {%- endif %} {%- endfor %} {%- endmacro %} {%- macro load_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %} {%- if not map.persist_load %} - {%- for (src_rng, dst_rng) in map.copy_ranges %} - {%- set range_len = src_rng.stop - src_rng.start %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) - {%- endfor %} + {%- if map.src_views[0].layout == "mul_ir" %} + {%- for (src_rng, dst_rng) in map.copy_ranges %} + {%- set range_len = src_rng.stop - src_rng.start %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) + {%- endfor %} + {%- elif map.src_views[0].layout == "ir_mul" %} + {%- for idx in map.idxs %} + {%- set m = namespace( + src_view=map.src_views[idx], + src_mul_ir=map.src_irreps[idx], + dst_rng=map.dst_irreps.slices()[map.src_dst_map[idx]] + ) %} + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}];) + {%- endfor %} + {%- endfor %} + {%- endif %} {%- endif %} {%- endmacro %} {%- macro load_ir_segments_force(map, glb_ptr_shft, smem_ptr, loop_var) %} - {%- for (src_rng, dst_rng) in map.copy_ranges %} - {%- set range_len = src_rng.stop - src_rng.start %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) - {%- endfor %} + {%- if map.src_views[0].layout == "mul_ir" %} + {%- for (src_rng, dst_rng) in map.copy_ranges %} + {%- set range_len = src_rng.stop - src_rng.start %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) + {%- endfor %} + {%- elif map.src_views[0].layout == "ir_mul" %} + {%- for idx in map.idxs %} + {%- set m = namespace( + src_view=map.src_views[idx], + src_mul_ir=map.src_irreps[idx], + dst_rng=map.dst_irreps.slices()[map.src_dst_map[idx]] + ) %} + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}];) + {%- endfor %} + {%- endfor %} + {%- endif %} {%- endmacro %} {%- macro store_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %} {%- if not map.persist_store %} - {%- for i, src_rng in enumerate(map.original_src_ranges) %} - {%- set idx = map.idxs[i] %} - {%- set dst_rng = map.original_dst_ranges[i] %} - {%- set range_len = src_rng.stop - src_rng.start %} - {%- if map.storeback_procedure[idx] == "write" %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] = {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) - {%- elif map.storeback_procedure[idx] == "accumulate" %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] += {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) - {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} - ROW_OPERATION({{range_len}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_rng.start}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start}} + lane_id + {{loop_var}}]);) - {%- endif %} - {%- endfor %} + {%- if map.src_views[0].layout == "mul_ir" %} + {%- for i, src_rng in enumerate(map.original_src_ranges) %} + {%- set idx = map.idxs[i] %} + {%- set dst_rng = map.original_dst_ranges[i] %} + {%- set range_len = src_rng.stop - src_rng.start %} + {%- if map.storeback_procedure[idx] == "write" %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] = {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) + {%- elif map.storeback_procedure[idx] == "accumulate" %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] += {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) + {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} + ROW_OPERATION({{range_len}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_rng.start}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start}} + lane_id + {{loop_var}}]);) + {%- endif %} + {%- endfor %} + {%- elif map.src_views[0].layout == "ir_mul" %} + {%- for idx in map.idxs %} + {%- set m = namespace( + src_view=map.src_views[idx], + src_mul_ir=map.src_irreps[idx], + dst_rng=map.dst_irreps.slices()[map.src_dst_map[idx]] + ) %} + {%- if map.storeback_procedure[idx] == "write" %} + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id];) + {%- endfor %} + {%- elif map.storeback_procedure[idx] == "accumulate" %} + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id];) + {%- endfor %} + {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id]);) + {%- endfor %} + {%- endif %} + {%- endfor %} + {%- endif %} {% endif %} {%- endmacro %} @@ -122,7 +187,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}] = t_regs[{{i}}]; {%- endfor %} } - } {%- endfor %} + } {%- endfor %} {%- endmacro %} {%- macro transpose_smem_B(irreps, smem_ptr) %} @@ -134,14 +199,14 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. if(lane_id < {{mul}}) { {%- set offset = slices[i].start %} {%- for i in range(dim) %} - t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}]; + t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}]; {%- endfor %} __syncwarp(); {%- for i in range(dim) %} {{smem_ptr}}[{{offset}} + lane_id * {{dim}} + {{i}}] = t_regs[{{i}}]; {%- endfor %} } - } {%- endfor %} + } {%- endfor %} {%- endmacro %} {%- macro reg_load(mul, dim, smem, offset, reg) %} diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 1d0e4394..40856819 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance" -version = "0.6.4" +version = "0.6.5" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index 4a53ff70..ba313d23 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.6.4" +version = "0.6.5" authors = [ { name="Austin Glover" }, diff --git a/tests/batch_test.py b/tests/batch_test.py index 55396ded..ff1cd1ce 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -1,22 +1,23 @@ -import pytest -from pytest_check import check +from itertools import product import numpy as np -import openequivariance as oeq -from openequivariance.benchmark.correctness_utils import ( - correctness_forward, +import pytest +import torch +from openequivariance.benchmark.correctness import ( correctness_backward, correctness_double_backward, + correctness_forward, ) - +from openequivariance.benchmark.test_buffers import get_random_buffers_forward from openequivariance.benchmark.problems import ( - e3nn_torch_tetris_poly_problems, diffdock_problems, + e3nn_torch_tetris_poly_problems, mace_problems, nequip_problems, ) -from itertools import product -import torch +from pytest_check import check + +import openequivariance as oeq @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") @@ -273,6 +274,80 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): return tp, tp.config +class TestIrMul(TPCorrectness): + """ + Tests both the ir_mul layout and the transpose_irreps functions. + """ + + tpps = mace_problems() + [ + oeq.TPProblem( + "5x5e", + "1x3e", + "5x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_5x1x5_l535", + ), + oeq.TPProblem( + "13x5e", + "1x3e", + "13x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_13x1x13_l535", + ), + ] + + @pytest.fixture(params=tpps, ids=lambda x: x.label, scope="class") + def problem(self, request, dtype): + problem = request.param.clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + problem.layout = "ir_mul" + return problem + + @pytest.fixture(params=["native", "transpose_wrapper"], scope="class") + def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax): + mode = request.param + + if with_jax: + import openequivariance.jax.TensorProduct as jax_tp + from openequivariance.jax import transpose_irreps + + tp_base_cls = jax_tp + else: + from openequivariance import transpose_irreps + + tp_base_cls = oeq.TensorProduct + + if mode == "native": + tp = tp_base_cls(problem, **extra_tp_constructor_args) + return tp, problem + else: + + class TransposeWrapperTensorProduct(tp_base_cls): + def forward(self, x, y, W): + x_t = transpose_irreps( + x, self.config.irreps_in1, "ir_mul", "mul_ir" + ) + y_t = transpose_irreps( + y, self.config.irreps_in2, "ir_mul", "mul_ir" + ) + out_mul_ir = super().forward(x_t, y_t, W) + return transpose_irreps( + out_mul_ir, self.config.irreps_out, "mul_ir", "ir_mul" + ) + + wrapped_problem = problem.clone() + wrapped_problem.layout = "mul_ir" + tp = TransposeWrapperTensorProduct( + wrapped_problem, **extra_tp_constructor_args + ) + tp.config.layout = "ir_mul" + return tp, problem + + class TestTorchToSubmodule: """Test that TensorProduct works correctly as a submodule when parent's .to() is called""" @@ -298,40 +373,29 @@ def forward(self, x, y, w): def _problem_dtype(self, problem): return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64 - def _make_inputs(self, problem, batch_size, rng, dtype, device): - in1 = torch.tensor( - rng.uniform(size=(batch_size, problem.irreps_in1.dim)), - dtype=dtype, - device=device, - ) - in2 = torch.tensor( - rng.uniform(size=(batch_size, problem.irreps_in2.dim)), - dtype=dtype, - device=device, - ) - weights_size = ( - (problem.weight_numel,) - if problem.shared_weights - else (batch_size, problem.weight_numel) - ) - weights = torch.tensor( - rng.uniform(size=weights_size), - dtype=dtype, - device=device, + def _make_inputs(self, problem, batch_size, dtype, device, prng_seed=12345): + dtype_map = {torch.float32: np.float32, torch.float64: np.float64} + buffer_problem = problem.clone() + buffer_problem.irrep_dtype = dtype_map[dtype] + buffer_problem.weight_dtype = dtype_map[dtype] + + in1_np, in2_np, weights_np, _ = get_random_buffers_forward( + buffer_problem, batch_size=batch_size, prng_seed=prng_seed ) - return in1, in2, weights + + return [ + torch.tensor(arr, dtype=dtype, device=device) + for arr in [in1_np, in2_np, weights_np] + ] def test_submodule_dtype_conversion(self, parent_module_and_problem): """Test that calling .to() on parent module properly converts TensorProduct submodule""" parent, problem = parent_module_and_problem batch_size = 10 - rng = np.random.default_rng(12345) device = "cuda" input_dtype = self._problem_dtype(problem) - in1, in2, weights = self._make_inputs( - problem, batch_size, rng, input_dtype, device - ) + in1, in2, weights = self._make_inputs(problem, batch_size, input_dtype, device) output1 = parent(in1, in2, weights) assert output1.dtype == in1.dtype, ( @@ -346,7 +410,7 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem): parent.to(target_dtype) in1_new, in2_new, weights_new = self._make_inputs( - problem, batch_size, rng, target_dtype, device + problem, batch_size, target_dtype, device, prng_seed=23456 ) output2 = parent(in1_new, in2_new, weights_new) diff --git a/tests/benchmark.py b/tests/benchmark.py index 829cc46c..9d230cf1 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -10,7 +10,7 @@ import numpy as np -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.extlib import DeviceProp from openequivariance._torch.E3NNTensorProduct import ( E3NNTensorProduct, @@ -24,7 +24,7 @@ TestDefinition, Direction, ) -from openequivariance.benchmark.tpp_creation_utils import ( +from openequivariance.benchmark.problems import ( ChannelwiseTPP, FullyConnectedTPProblem, SingleInstruction, @@ -385,6 +385,47 @@ def benchmark_kahan_accuracy(params): ) +def benchmark_layouts(params): + base_problems = mace_problems() + nequip_problems() + directions = params.directions + dtypes = [datatype_map[dtype] for dtype in params.datatypes] + + tests = [] + for dtype in dtypes: + for base_problem in base_problems: + for layout in ["mul_ir", "ir_mul"]: + layout_problem = copy.deepcopy(base_problem) + layout_problem.layout = layout + base_label = layout_problem.label or "layout_problem" + layout_problem.label = f"{base_label} [{layout}]" + layout_problem.irrep_dtype = dtype + layout_problem.weight_dtype = dtype + + for direction in directions: + tests.append( + TestDefinition( + TensorProduct, + layout_problem, + direction, + correctness=False, + benchmark=True, + ) + ) + + bench_suite = TestBenchmarkSuite( + num_warmup=100, + num_iter=100, + bench_batch_size=params.batch_size, + prng_seed=11111, + test_name="layouts", + ) + + data_folder = bench_suite.run(tests, params.output_folder) + + if params.plot: + plot({"data_folder": data_folder}) + + def plot(params): import openequivariance.benchmark.plotting as plotting @@ -402,6 +443,8 @@ def plot(params): plotting.plot_uvu(data_folder) elif test_name == "uvw": plotting.plot_uvw(data_folder) + elif test_name == "layouts": + plotting.plot_layout(data_folder) elif test_name == "roofline": plotting.plot_roofline(data_folder) elif test_name == "convolution": @@ -532,6 +575,33 @@ def plot(params): parser_uvw.add_argument("--plot", action="store_true", help="Plot the results.") parser_uvw.set_defaults(func=benchmark_uvw) + parser_layouts = subparsers.add_parser( + "layouts", help="Run benchmark comparing mul_ir vs ir_mul layouts" + ) + parser_layouts.add_argument( + "--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark" + ) + parser_layouts.add_argument( + "--directions", + "-d", + type=str, + nargs="+", + default=["forward", "backward"], + help="Directions to benchmark", + choices=["forward", "backward"], + ) + parser_layouts.add_argument( + "--datatypes", + "-t", + type=str, + nargs="+", + default=["float32", "float64"], + help="Data types to benchmark", + choices=["float32", "float64"], + ) + parser_layouts.add_argument("--plot", action="store_true", help="Plot the results.") + parser_layouts.set_defaults(func=benchmark_layouts) + parser_double_bwd = subparsers.add_parser( "double_backward", help="Run the higher derivative kernel benchmark" ) diff --git a/tests/conv_test.py b/tests/conv_test.py index 50b6376b..8471e593 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -6,6 +6,11 @@ import numpy as np import openequivariance as oeq from openequivariance.benchmark.ConvBenchmarkSuite import load_graph +from openequivariance.benchmark.correctness import ( + correctness_backward_conv, + correctness_double_backward_conv, + correctness_forward_conv, +) from itertools import product import torch @@ -89,7 +94,8 @@ def test_tp_fwd(self, conv_object, graph): if conv_object is None: pytest.skip("'conv_object' fixture returned None, skipping") - result = conv_object.test_correctness_forward( + result = correctness_forward_conv( + conv_object, graph, thresh=self.thresh("fwd"), prng_seed=12345, @@ -102,7 +108,8 @@ def test_tp_bwd(self, conv_object, graph): if conv_object is None: pytest.skip("'conv_object' fixture returned None, skipping") - result = conv_object.test_correctness_backward( + result = correctness_backward_conv( + conv_object, graph, thresh=self.thresh("bwd"), prng_seed=12345, @@ -117,7 +124,8 @@ def test_tp_double_bwd(self, conv_object, graph): if conv_object is None: pytest.skip("'conv_object' fixture returned None, skipping") - result = conv_object.test_correctness_double_backward( + result = correctness_double_backward_conv( + conv_object, graph, thresh=self.thresh("double_bwd"), prng_seed=12345, @@ -284,6 +292,36 @@ def conv_object(self, request, problem, extra_conv_constructor_args): return module.to(switch_map[problem.irrep_dtype]) +class TestIrMulLayout(ConvCorrectness): + production_model_tpps = mace_problems() + [ + oeq.TPProblem( + "5x5e", + "1x3e", + "5x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_5x1x5_l535", + ), + oeq.TPProblem( + "13x5e", + "1x3e", + "13x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_13x1x13_l535", + ), + ] + + @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") + def problem(self, request, dtype): + problem = request.param.clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + problem.layout = "ir_mul" + return problem + + class TestTorchToSubmodule: """Test that TensorProductConv works as a submodule when parent's .to() is called""" diff --git a/tests/input_validation_test.py b/tests/input_validation_test.py index 8db1d2a4..9b38d55e 100644 --- a/tests/input_validation_test.py +++ b/tests/input_validation_test.py @@ -138,3 +138,18 @@ def test_cpp_checks_forward_dtype(executable_and_buffers, subtests): with pytest.raises(RuntimeError, match=r"Dtype mismatch"): buffers[i] = buffers[i].to(dtype=torch.bfloat16) executable(*buffers) + + +def test_ir_mul_rejects_uvw_problem(): + problem = TPProblem( + "5x5e", + "1x3e", + "5x5e", + [(0, 0, 0, "uvw", True)], + shared_weights=False, + internal_weights=False, + layout="ir_mul", + ) + + with pytest.raises(AssertionError, match="layout='ir_mul'"): + TensorProduct(problem) diff --git a/tests/torch_determinism_test.py b/tests/torch_determinism_test.py deleted file mode 100644 index c8fa83ef..00000000 --- a/tests/torch_determinism_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest -import torch - -from openequivariance import TPProblem, TensorProductConv - -from e3nn import o3 -from torch_geometric import EdgeIndex - - -@pytest.fixture -def gen(): - return torch.Generator(device="cuda") - - -@pytest.fixture -def edge_index(): - return EdgeIndex( - data=[ - [0, 1, 1, 2], # Receiver - [1, 0, 2, 1], # Sender - ], - sparse_size=(3, 4), - device="cuda", - dtype=torch.long, - ) - - -@pytest.fixture -def tpp(): - X_ir = o3.Irreps("1x2e") - Y_ir = o3.Irreps("1x3e") - Z_ir = o3.Irreps("1x2e") - instructions = [(0, 0, 0, "uvu", True)] - return TPProblem( - X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False - ) - - -@pytest.fixture -def conv_buffers(edge_index, tpp, gen): - X = torch.rand( - edge_index.num_rows, tpp.irreps_in1.dim, device="cuda", generator=gen - ) - Y = torch.rand( - edge_index.num_cols, tpp.irreps_in2.dim, device="cuda", generator=gen - ) - W = torch.rand(edge_index.num_cols, tpp.weight_numel, device="cuda", generator=gen) - return (X, Y, W, edge_index[0], edge_index[1]) - - -@pytest.fixture -def tp_conv(tpp): - return TensorProductConv(tpp, deterministic=False) - - -def test_no_response(tp_conv, conv_buffers): - torch.use_deterministic_algorithms(False) - tp_conv(*conv_buffers) - - -def test_warning(tp_conv, conv_buffers, capfd): - torch.use_deterministic_algorithms(True, warn_only=True) - tp_conv(*conv_buffers) - - captured = capfd.readouterr() - assert "Warning" in captured.err - assert "does not have a deterministic implementation" in captured.err - - -def test_error(tp_conv, conv_buffers): - torch.use_deterministic_algorithms(True, warn_only=False) - with pytest.raises(RuntimeError): - tp_conv(*conv_buffers)