From 1eac81f84cad2a3afc317d83cfbc8faa1054e252 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 5 Mar 2026 20:00:08 -0800 Subject: [PATCH 01/27] More progress. --- .../core/ComputationSchedule.py | 27 ++++++++++++++++--- .../openequivariance/core/e3nn_lite.py | 4 ++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index c9765d0d..6b386b28 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -17,8 +17,9 @@ class IrrepMapping: Maps irreps from a source to a destination set. """ - def __init__(self, src_irreps, idxs): + def __init__(self, src_irreps, src_views, idxs): self.src_irreps = src_irreps + self.src_views = src_views self.idxs = sorted(list(idxs)) self.dst_irreps = Irreps([src_irreps[idx] for idx in self.idxs]) self.src_dst_map = {idx: i for i, idx in enumerate(self.idxs)} @@ -193,6 +194,11 @@ class ChildInstruction: def __init__(self, instruction_tup, parent_idx): self.instruction_tup, self.parent_idx = instruction_tup, parent_idx + class ChildView: + layout: str + ir_mul_offset: int + ir_mul_stride: int + def __init__(self, input, mult_threshold): self.input = input self.mult_threshold = mult_threshold @@ -201,6 +207,7 @@ def __init__(self, input, mult_threshold): child_reps = [[], [], []] self.irrep_maps = {} # Maps a (input_rep_idx #, mul_ir_idx) to a lst[ir_idx] + self.irrep_views = [[], [], []] # View for input_rep_idx, input_rep in enumerate(input_reps): # Loop over L1, L2, L3 for mul_ir_idx, mul_ir in enumerate( @@ -209,12 +216,26 @@ def __init__(self, input, mult_threshold): self.irrep_maps[input_rep_idx, mul_ir_idx] = [] for mul_start in range(0, mul_ir.mul, mult_threshold): mul = min(mult_threshold, mul_ir.mul - mul_start) - child_reps[input_rep_idx] += [ + child_reps[input_rep_idx].append( self.ChildIrrep((mul, mul_ir.ir), input_rep_idx, mul_start) - ] + ) self.irrep_maps[input_rep_idx, mul_ir_idx].append( len(child_reps[input_rep_idx]) - 1 ) + if input.layout == "mul_ir": + self.irrep_views.append( + self.ChildView( + layout="mul_ir", + ir_mul_offset=-1, + ir_mul_stride=-1 + )) + elif input.layout == "ir_mul": + self.irrep_views.append( + self.ChildView( + layout="ir_mul", + ir_mul_offset=input_rep.slices()[mul_ir_idx].start + mul_start, + ir_mul_stride=mul_ir.mul + )) new_instructions = [] diff --git a/openequivariance/openequivariance/core/e3nn_lite.py b/openequivariance/openequivariance/core/e3nn_lite.py index d569ad1c..a9b25b92 100644 --- a/openequivariance/openequivariance/core/e3nn_lite.py +++ b/openequivariance/openequivariance/core/e3nn_lite.py @@ -392,9 +392,9 @@ class TPProblem: internal_weights: bool weight_numel: int label: str - _profiling_str: str _in1_dim: int _in2_dim: int + layout: str def __init__( self, @@ -412,12 +412,14 @@ def __init__( label: Optional[str] = None, irrep_dtype: type[np.generic] = np.float32, weight_dtype: type[np.generic] = np.float32, + layout: str = "mul_ir" ) -> None: # === Setup === super().__init__() assert irrep_normalization in ["component", "norm", "none"] assert path_normalization in ["element", "path", "none"] + assert layout in ["mul_ir", "ir_mul"] assert issubclass(irrep_dtype, np.generic) assert issubclass(weight_dtype, np.generic) From 33ed045c0f71e290a691dc37447ad59e0aee467d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 20 Mar 2026 22:53:24 -0700 Subject: [PATCH 02/27] More diffs. --- .../_torch/NPDoubleBackwardMixin.py | 98 +++++++++++++---- .../openequivariance/_torch/TensorProduct.py | 67 +++++++++--- .../core/ComputationSchedule.py | 45 ++++---- .../openequivariance/core/utils.py | 63 +++++++++-- .../openequivariance/templates/macros.jinja | 100 +++++++++++++----- tests/batch_test.py | 27 +++-- 6 files changed, 305 insertions(+), 95 deletions(-) diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py index caf94268..1356b38f 100644 --- a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -1,5 +1,7 @@ import torch +from openequivariance.core.utils import IrrepLayoutUtils + class NumpyDoubleBackwardMixin: """ @@ -13,12 +15,30 @@ def double_backward_cpu( ): assert self.torch_op - in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) - in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + layout = self.config.layout + + in1_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in1, self.config.irreps_in1, layout, "mul_ir" + ) + in2_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in2, self.config.irreps_in2, layout, "mul_ir" + ) + out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout( + out_grad, self.config.irreps_out, layout, "mul_ir" + ) + in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in1_dgrad, self.config.irreps_in1, layout, "mul_ir" + ) + in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in2_dgrad, self.config.irreps_in2, layout, "mul_ir" + ) + + in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True) weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) - out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") - in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") + out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda") weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") out_torch = self.forward(in1_torch, in2_torch, weights_torch) @@ -36,12 +56,22 @@ def double_backward_cpu( grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], ) - return ( - a.detach().cpu().numpy(), - b.detach().cpu().numpy(), - c.detach().cpu().numpy(), - d.detach().cpu().numpy(), + a_np = a.detach().cpu().numpy() + b_np = b.detach().cpu().numpy() + c_np = c.detach().cpu().numpy() + d_np = d.detach().cpu().numpy() + + a_np = IrrepLayoutUtils.transpose_irrep_layout( + a_np, self.config.irreps_in1, "mul_ir", layout ) + b_np = IrrepLayoutUtils.transpose_irrep_layout( + b_np, self.config.irreps_in2, "mul_ir", layout + ) + d_np = IrrepLayoutUtils.transpose_irrep_layout( + d_np, self.config.irreps_out, "mul_ir", layout + ) + + return (a_np, b_np, c_np, d_np) class NumpyDoubleBackwardMixinConv: @@ -54,12 +84,30 @@ def double_backward_cpu( ): assert self.torch_op - in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) - in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + layout = self.config.layout + + in1_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in1, self.config.irreps_in1, layout, "mul_ir" + ) + in2_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in2, self.config.irreps_in2, layout, "mul_ir" + ) + out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout( + out_grad, self.config.irreps_out, layout, "mul_ir" + ) + in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in1_dgrad, self.config.irreps_in1, layout, "mul_ir" + ) + in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( + in2_dgrad, self.config.irreps_in2, layout, "mul_ir" + ) + + in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True) weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) - out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") - in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") + out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda") weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") torch_rows = torch.tensor(graph.rows, device="cuda") @@ -89,9 +137,19 @@ def double_backward_cpu( grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], ) - return ( - a.detach().cpu().numpy(), - b.detach().cpu().numpy(), - c.detach().cpu().numpy(), - d.detach().cpu().numpy(), + a_np = a.detach().cpu().numpy() + b_np = b.detach().cpu().numpy() + c_np = c.detach().cpu().numpy() + d_np = d.detach().cpu().numpy() + + a_np = IrrepLayoutUtils.transpose_irrep_layout( + a_np, self.config.irreps_in1, "mul_ir", layout + ) + b_np = IrrepLayoutUtils.transpose_irrep_layout( + b_np, self.config.irreps_in2, "mul_ir", layout ) + d_np = IrrepLayoutUtils.transpose_irrep_layout( + d_np, self.config.irreps_out, "mul_ir", layout + ) + + return (a_np, b_np, c_np, d_np) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 254da414..4c38feae 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -1,17 +1,21 @@ -from openequivariance.core.LoopUnrollTP import LoopUnrollTP +import numpy as np +import torch + from openequivariance import TPProblem from openequivariance._torch import extlib -import torch -from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin from openequivariance._torch.utils import ( + enum_to_torch_dtype, reorder_torch, string_to_tensor, - enum_to_torch_dtype, ) -from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin - -import numpy as np +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.LoopUnrollTP import LoopUnrollTP +from openequivariance.core.utils import ( + IrrepLayoutUtils, + dtype_to_enum, + torch_to_oeq_dtype, +) logger = getLogger() @@ -146,12 +150,24 @@ def forward_cpu( weights, not self.config.shared_weights ) - torch_L1_in = torch.tensor(L1_in, device="cuda") - torch_L2_in = torch.tensor(L2_in, device="cuda") + layout = self.config.layout + + L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( + L1_in, self.config.irreps_in1, layout, "mul_ir" + ) + L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( + L2_in, self.config.irreps_in2, layout, "mul_ir" + ) + + torch_L1_in = torch.tensor(L1_in_kernel, device="cuda") + torch_L2_in = torch.tensor(L2_in_kernel, device="cuda") torch_weights = torch.tensor(weights_chunked, device="cuda") torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) - L3_out[:] = torch_L3_out.numpy(force=True) + L3_kernel = torch_L3_out.numpy(force=True) + L3_out[:] = IrrepLayoutUtils.transpose_irrep_layout( + L3_kernel, self.config.irreps_out, "mul_ir", layout + ) def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad @@ -160,18 +176,37 @@ def backward_cpu( weights, not self.config.shared_weights ) - torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") - torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + layout = self.config.layout + + L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( + L1_in, self.config.irreps_in1, layout, "mul_ir" + ) + L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( + L2_in, self.config.irreps_in2, layout, "mul_ir" + ) + L3_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout( + L3_grad, self.config.irreps_out, layout, "mul_ir" + ) + + torch_L1_in = torch.tensor(L1_in_kernel, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in_kernel, requires_grad=True, device="cuda") torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda") torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) - torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") + torch_L3_grad_in = torch.tensor(L3_grad_kernel, device="cuda") torch_out.backward(gradient=torch_L3_grad_in) - L1_grad[:] = torch_L1_in.grad.numpy(force=True) - L2_grad[:] = torch_L2_in.grad.numpy(force=True) + L1_grad_kernel = torch_L1_in.grad.numpy(force=True) + L2_grad_kernel = torch_L2_in.grad.numpy(force=True) + + L1_grad[:] = IrrepLayoutUtils.transpose_irrep_layout( + L1_grad_kernel, self.config.irreps_in1, "mul_ir", layout + ) + L2_grad[:] = IrrepLayoutUtils.transpose_irrep_layout( + L2_grad_kernel, self.config.irreps_in2, "mul_ir", layout + ) weights_grad[:] = torch_weights.grad.numpy(force=True) weights_grad[:] = self.reorder_weights_to_e3nn( diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 6b386b28..da4016ac 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -1,7 +1,9 @@ -import numpy as np -from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j from itertools import accumulate + +import numpy as np + from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j logger = getLogger() @@ -27,10 +29,13 @@ def __init__(self, src_irreps, src_views, idxs): src_ranges = [src_irreps.slices()[idx] for idx in self.src_dst_map] dst_ranges = [self.dst_irreps.slices()[i] for i in self.src_dst_map.values()] + if src_views[0].layout == "ir_mul": + return + + # Merge adjacent src and dst ranges self.original_src_ranges = src_ranges self.original_dst_ranges = dst_ranges - # Merge adjacent src and dst ranges self.src_ranges = [] self.dst_ranges = [] @@ -195,9 +200,10 @@ def __init__(self, instruction_tup, parent_idx): self.instruction_tup, self.parent_idx = instruction_tup, parent_idx class ChildView: - layout: str - ir_mul_offset: int - ir_mul_stride: int + def __init__(self, layout: str, ir_mul_offset: int, ir_mul_stride: int): + self.layout = layout + self.ir_mul_offset = ir_mul_offset + self.ir_mul_stride = ir_mul_stride def __init__(self, input, mult_threshold): self.input = input @@ -207,7 +213,7 @@ def __init__(self, input, mult_threshold): child_reps = [[], [], []] self.irrep_maps = {} # Maps a (input_rep_idx #, mul_ir_idx) to a lst[ir_idx] - self.irrep_views = [[], [], []] # View + self.irrep_views = [[], [], []] # View for input_rep_idx, input_rep in enumerate(input_reps): # Loop over L1, L2, L3 for mul_ir_idx, mul_ir in enumerate( @@ -223,19 +229,20 @@ def __init__(self, input, mult_threshold): len(child_reps[input_rep_idx]) - 1 ) if input.layout == "mul_ir": - self.irrep_views.append( + self.irrep_views[input_rep_idx].append( self.ChildView( - layout="mul_ir", - ir_mul_offset=-1, - ir_mul_stride=-1 - )) + layout="mul_ir", ir_mul_offset=-1, ir_mul_stride=-1 + ) + ) elif input.layout == "ir_mul": - self.irrep_views.append( + self.irrep_views[input_rep_idx].append( self.ChildView( layout="ir_mul", - ir_mul_offset=input_rep.slices()[mul_ir_idx].start + mul_start, - ir_mul_stride=mul_ir.mul - )) + ir_mul_offset=input_rep.slices()[mul_ir_idx].start + + mul_start, + ir_mul_stride=mul_ir.mul, + ) + ) new_instructions = [] @@ -564,9 +571,9 @@ def calculate_backward_smem( for i in range(len(self.segments)): L1_idxs, L2_idxs, L3_idxs, inst_idxs = self.segments[i] - L1Map = IrrepMapping(self.L1, L1_idxs) - L2Map = IrrepMapping(self.L2, L2_idxs) - L3Map = IrrepMapping(self.L3, L3_idxs) + L1Map = IrrepMapping(self.L1, self.problem_splitter.irrep_views[0], L1_idxs) + L2Map = IrrepMapping(self.L2, self.problem_splitter.irrep_views[1], L2_idxs) + L3Map = IrrepMapping(self.L3, self.problem_splitter.irrep_views[2], L3_idxs) instructions = [ ( diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 1950013d..af44fb43 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -1,15 +1,13 @@ import functools +import hashlib +import json import math +import tempfile +from enum import IntEnum import numpy as np -from openequivariance.core.e3nn_lite import Instruction, TPProblem, wigner_3j - -import json -import tempfile -import hashlib - -from enum import IntEnum +from openequivariance.core.e3nn_lite import Instruction, Irreps, TPProblem, wigner_3j class DTypeEnum(IntEnum): @@ -168,7 +166,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = timer.stop_clock_get_elapsed() else: - from torch.profiler import profile, record_function, ProfilerActivity + from torch.profiler import ProfilerActivity, profile, record_function trace_file = tempfile.NamedTemporaryFile().name @@ -204,3 +202,52 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): def hash_str_64(s: str) -> int: return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big") + + +class IrrepLayoutUtils: + @staticmethod + def transpose_irrep_layout( + array: np.ndarray, + irreps: Irreps, + src_layout: str, + dst_layout: str, + ) -> np.ndarray: + """ + Transpose irrep-packed feature arrays between `mul_ir` and `ir_mul` layouts. + + Expected input shape is `[..., irreps.dim]`. A new array is returned. + If `src_layout == dst_layout`, this returns a copy. + """ + if src_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported src_layout: {src_layout}") + if dst_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported dst_layout: {dst_layout}") + + x = np.asarray(array) + out = np.empty_like(x) + + if src_layout == dst_layout: + out[...] = x + return out + + slices = irreps.slices() + for ir_idx, mul_ir in enumerate(irreps): + mul = mul_ir.mul + dim = mul_ir.ir.dim + seg = slices[ir_idx] + block = x[..., seg.start : seg.stop] + + if src_layout == "ir_mul" and dst_layout == "mul_ir": + out[..., seg.start : seg.stop] = block.reshape( + *block.shape[:-1], dim, mul + ).reshape(*block.shape[:-1], mul * dim) + elif src_layout == "mul_ir" and dst_layout == "ir_mul": + out[..., seg.start : seg.stop] = block.reshape( + *block.shape[:-1], mul, dim + ).reshape(*block.shape[:-1], dim * mul) + else: + raise ValueError( + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" + ) + + return out diff --git a/openequivariance/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja index 9b0d3c28..ea6110e9 100644 --- a/openequivariance/openequivariance/templates/macros.jinja +++ b/openequivariance/openequivariance/templates/macros.jinja @@ -54,41 +54,93 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- for name in segment.smem %} {%- if name != "total" %} {%- set smem_rng = segment.smem[name] %} - {{ smem_rng["dtype"] }}* {{name}}_smem = ({{smem_rng["dtype"]}}*) ({{smem_base}} + {{smem_rng["offset"]}}); + {{ smem_rng["dtype"] }}* {{name}}_smem = ({{smem_rng["dtype"]}}*) ({{smem_base}} + {{smem_rng["offset"]}}); {%- endif %} {%- endfor %} {%- endmacro %} {%- macro load_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %} {%- if not map.persist_load %} - {%- for (src_rng, dst_rng) in map.copy_ranges %} - {%- set range_len = src_rng.stop - src_rng.start %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) - {%- endfor %} + {%- if map.src_views[0].layout == "mul_ir" %} + {%- for (src_rng, dst_rng) in map.copy_ranges %} + {%- set range_len = src_rng.stop - src_rng.start %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) + {%- endfor %} + {%- elif map.src_views[0].layout == "ir_mul" %} + {%- for idx in map.idxs %} + {%- set src_view = map.src_views[idx] %} + {%- set src_mul_ir = map.src_irreps[idx] %} + {%- set dst_idx = map.src_dst_map[idx] %} + {%- set dst_rng = map.dst_irreps.slices()[dst_idx] %} + {%- set dim = src_mul_ir.ir.dim %} + {%- set mul = src_mul_ir.mul %} + {%- for i in range(dim) %} + ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + {%- endfor %} + {%- endfor %} + {%- endif %} {%- endif %} {%- endmacro %} {%- macro load_ir_segments_force(map, glb_ptr_shft, smem_ptr, loop_var) %} - {%- for (src_rng, dst_rng) in map.copy_ranges %} - {%- set range_len = src_rng.stop - src_rng.start %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) - {%- endfor %} + {%- if map.src_views[0].layout == "mul_ir" %} + {%- for (src_rng, dst_rng) in map.copy_ranges %} + {%- set range_len = src_rng.stop - src_rng.start %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) + {%- endfor %} + {%- elif map.src_views[0].layout == "ir_mul" %} + {%- for idx in map.idxs %} + {%- set src_view = map.src_views[idx] %} + {%- set src_mul_ir = map.src_irreps[idx] %} + {%- set dst_idx = map.src_dst_map[idx] %} + {%- set dst_rng = map.dst_irreps.slices()[dst_idx] %} + {%- set dim = src_mul_ir.ir.dim %} + {%- set mul = src_mul_ir.mul %} + {%- for i in range(dim) %} + ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + {%- endfor %} + {%- endfor %} + {%- endif %} {%- endmacro %} {%- macro store_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %} {%- if not map.persist_store %} - {%- for i, src_rng in enumerate(map.original_src_ranges) %} - {%- set idx = map.idxs[i] %} - {%- set dst_rng = map.original_dst_ranges[i] %} - {%- set range_len = src_rng.stop - src_rng.start %} - {%- if map.storeback_procedure[idx] == "write" %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] = {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) - {%- elif map.storeback_procedure[idx] == "accumulate" %} - ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] += {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) - {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} - ROW_OPERATION({{range_len}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_rng.start}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start}} + lane_id + {{loop_var}}]);) - {%- endif %} - {%- endfor %} + {%- if map.src_views[0].layout == "mul_ir" %} + {%- for i, src_rng in enumerate(map.original_src_ranges) %} + {%- set idx = map.idxs[i] %} + {%- set dst_rng = map.original_dst_ranges[i] %} + {%- set range_len = src_rng.stop - src_rng.start %} + {%- if map.storeback_procedure[idx] == "write" %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] = {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) + {%- elif map.storeback_procedure[idx] == "accumulate" %} + ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] += {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) + {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} + ROW_OPERATION({{range_len}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_rng.start}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start}} + lane_id + {{loop_var}}]);) + {%- endif %} + {%- endfor %} + {%- elif map.src_views[0].layout == "ir_mul" %} + {%- for idx in map.idxs %} + {%- set src_view = map.src_views[idx] %} + {%- set src_mul_ir = map.src_irreps[idx] %} + {%- set dst_idx = map.src_dst_map[idx] %} + {%- set dst_rng = map.dst_irreps.slices()[dst_idx] %} + {%- set dim = src_mul_ir.ir.dim %} + {%- set mul = src_mul_ir.mul %} + {%- if map.storeback_procedure[idx] == "write" %} + {%- for i in range(dim) %} + ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) + {%- endfor %} + {%- elif map.storeback_procedure[idx] == "accumulate" %} + {%- for i in range(dim) %} + ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) + {%- endfor %} + {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} + {%- for i in range(dim) %} + ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + lane_id + {{loop_var}}]);) + {%- endfor %} + {%- endif %} + {%- endfor %} + {%- endif %} {% endif %} {%- endmacro %} @@ -122,7 +174,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}] = t_regs[{{i}}]; {%- endfor %} } - } {%- endfor %} + } {%- endfor %} {%- endmacro %} {%- macro transpose_smem_B(irreps, smem_ptr) %} @@ -134,14 +186,14 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. if(lane_id < {{mul}}) { {%- set offset = slices[i].start %} {%- for i in range(dim) %} - t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}]; + t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}]; {%- endfor %} __syncwarp(); {%- for i in range(dim) %} {{smem_ptr}}[{{offset}} + lane_id * {{dim}} + {{i}}] = t_regs[{{i}}]; {%- endfor %} } - } {%- endfor %} + } {%- endfor %} {%- endmacro %} {%- macro reg_load(mul, dim, smem, offset, reg) %} diff --git a/tests/batch_test.py b/tests/batch_test.py index 55396ded..2fbb84cb 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -1,22 +1,22 @@ -import pytest -from pytest_check import check +from itertools import product import numpy as np -import openequivariance as oeq +import pytest +import torch from openequivariance.benchmark.correctness_utils import ( - correctness_forward, correctness_backward, correctness_double_backward, + correctness_forward, ) - from openequivariance.benchmark.problems import ( - e3nn_torch_tetris_poly_problems, diffdock_problems, + e3nn_torch_tetris_poly_problems, mace_problems, nequip_problems, ) -from itertools import product -import torch +from pytest_check import check + +import openequivariance as oeq @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module") @@ -273,6 +273,17 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): return tp, tp.config +class TestMulIrLayoutMACE(TPCorrectness): + production_model_tpps = mace_problems() + + @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") + def problem(self, request, dtype): + problem = request.param.clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + problem.layout = "mul_ir" + return problem + + class TestTorchToSubmodule: """Test that TensorProduct works correctly as a submodule when parent's .to() is called""" From c78d48ff6cb8888d81754583990b4da982b9d5f1 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Mar 2026 18:22:25 -0700 Subject: [PATCH 03/27] Avoided transposing irreps once the shared memory load is complete. --- .../openequivariance/core/LoopUnrollTP.py | 22 ++++++++++------ .../openequivariance/core/e3nn_lite.py | 14 +++++----- .../templates/loop_unroll_tp.cuh | 20 +++++++------- .../openequivariance/templates/macros.jinja | 26 +++++++++++++++---- 4 files changed, 53 insertions(+), 29 deletions(-) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 41354e5f..43c9e244 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -1,16 +1,20 @@ -import numpy as np import json -from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance.core.ComputationSchedule import ComputationSchedule -from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.core.utils import dtype_to_enum, hash_str_64 +import numpy as np +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.ComputationSchedule import ( + ComputationSchedule, + SMEMCapacityException, +) +from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.utils import ( - filter_and_analyze_problem, count_cg_non_zero, + dtype_to_enum, + filter_and_analyze_problem, + hash_str_64, ) +from openequivariance.templates.jinja_utils import get_jinja_environment logger = getLogger() @@ -80,12 +84,14 @@ def generate_double_backward_schedule(warps_per_block): try: generate_schedule(warp_count) break - except Exception: + except SMEMCapacityException: warp_count -= 2 if warp_count == 0: raise RuntimeError( "Tensor product schedule generation failed, shared memory inadequate!" ) + except Exception: + raise self.jit_kernel = postprocess_kernel( template.render( diff --git a/openequivariance/openequivariance/core/e3nn_lite.py b/openequivariance/openequivariance/core/e3nn_lite.py index a9b25b92..e0afe83c 100644 --- a/openequivariance/openequivariance/core/e3nn_lite.py +++ b/openequivariance/openequivariance/core/e3nn_lite.py @@ -36,14 +36,15 @@ SOFTWARE. """ -import itertools -from typing import Tuple, NamedTuple, Union, List, Any, Optional -from math import sqrt, prod import collections +import copy +import functools +import itertools +from math import prod, sqrt +from typing import Any, List, NamedTuple, Optional, Tuple, Union + import numpy as np import numpy.linalg as la -import functools -import copy def perm_inverse(p): @@ -412,7 +413,7 @@ def __init__( label: Optional[str] = None, irrep_dtype: type[np.generic] = np.float32, weight_dtype: type[np.generic] = np.float32, - layout: str = "mul_ir" + layout: str = "mul_ir", ) -> None: # === Setup === super().__init__() @@ -434,6 +435,7 @@ def __init__( self.irrep_normalization = irrep_normalization self.path_normalization = path_normalization self.label = label if label is not None else "" + self.layout = layout del irreps_in1, irreps_in2, irreps_out instructions = [x if len(x) == 6 else x + (1.0,) for x in instructions] diff --git a/openequivariance/openequivariance/templates/loop_unroll_tp.cuh b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh index eab95c57..a56fb4b1 100644 --- a/openequivariance/openequivariance/templates/loop_unroll_tp.cuh +++ b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh @@ -1,4 +1,4 @@ -{%- from 'macros.jinja' import transpose_load, transpose_store, reg_store with context %} +{%- from 'macros.jinja' import layout_load, layout_store with context %} {%- from 'wmm.cuh' import generate_matmul %} {%- macro generate_segment_kernel_forward(id, segment, warp_size) %} @@ -36,7 +36,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- if k == 0 or interactions[k][0] != interactions[k-1][0] %} offset = {{ L1.slices()[u].start}}; - {{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} + {{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} {%- endif %} #pragma unroll @@ -72,7 +72,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ // ----------------- CORE CALCULATION ----------------- {%- if problem.instructions[k].connection_mode == "uvw" %} - {{transpose_store(L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_vec', '=', '1.0')}} + {{layout_store(problem.layout, L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_vec', '=', '1.0')}} __syncwarp(); offset = {{ L3.slices()[w].start}}; matmul_fwd_{{id}}_{{k}}(weights_smem, scratch, L3_smem + offset); @@ -85,7 +85,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- if problem.instructions[k].connection_mode != "uvw" %} offset = {{ L3.slices()[w].start}}; - {{transpose_store(L3[w].mul, L3[w].ir.dim, 'L3_smem', 'offset', 'l3_vec', '+=', '1.0')}} + {{layout_store(problem.layout, L3[w].mul, L3[w].ir.dim, 'L3_smem', 'offset', 'l3_vec', '+=', '1.0')}} {%- if L2[v].mul > 1%} #pragma unroll @@ -168,15 +168,15 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- if k == 0 or interactions[k][0] != interactions[k-1][0] %} offset = {{ L1.slices()[u].start}}; - {{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} - {{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad')}} + {{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}} + {{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad')}} {%- endif %} {%- if problem.instructions[k].connection_mode != "uvw" %} {%- if k == 0 or interactions[k][2] != interactions[k-1][2] %} offset = {{ L3.slices()[w].start}}; - {{transpose_load(L3[w].mul, L3[w].ir.dim, 'L3_grad_smem', 'offset', 'l3_grad')}} + {{layout_load(problem.layout, L3[w].mul, L3[w].ir.dim, 'L3_grad_smem', 'offset', 'l3_grad')}} {%- endif %} {%- endif %} @@ -225,7 +225,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {{matmul_basename}}A_{{id}}_{{k}}(weights_smem, L3_grad_smem + offset, scratch); __syncwarp(); - {{transpose_load(L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_grad')}} + {{layout_load(problem.layout, L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_grad')}} {%- for i in range(tensor.nnz) %} {%- set coord1, coord2, coord3, value = tensor.tuples[i] %} @@ -248,7 +248,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- endif %} {%- endfor %} - {{ reg_store(L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }} + {{ layout_store(problem.layout, L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }} __syncwarp(); {{matmul_basename}}B_{{id}}_{{k}}(L3_grad_smem + offset, scratch, weights_smem); @@ -305,7 +305,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ // Storeback {%- if k == num_interact - 1 or interactions[k][0] != interactions[k+1][0] %} offset = {{ L1.slices()[u].start}}; - {{transpose_store(L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad', '=', '1.0')}} + {{layout_store(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad', '=', '1.0')}} {%- endif %} {%- endfor %} diff --git a/openequivariance/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja index ea6110e9..b4a28158 100644 --- a/openequivariance/openequivariance/templates/macros.jinja +++ b/openequivariance/openequivariance/templates/macros.jinja @@ -50,6 +50,22 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. } {%- endmacro %} +{%- macro layout_load(layout, mul, dim, smem, offset, reg) %} + {%- if layout == "ir_mul" %} + {{ reg_load(mul, dim, smem, offset, reg) }} + {%- else %} + {{ transpose_load(mul, dim, smem, offset, reg) }} + {%- endif %} +{%- endmacro %} + +{%- macro layout_store(layout, mul, dim, smem, offset, reg, op, coeff) %} + {%- if layout == "ir_mul" %} + {{ reg_store(mul, dim, smem, offset, reg, op, coeff) }} + {%- else %} + {{ transpose_store(mul, dim, smem, offset, reg, op, coeff) }} + {%- endif %} +{%- endmacro %} + {%- macro declare_smem_variables(segment, smem_base) %} {%- for name in segment.smem %} {%- if name != "total" %} @@ -75,7 +91,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- set dim = src_mul_ir.ir.dim %} {%- set mul = src_mul_ir.mul %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) {%- endfor %} {%- endfor %} {%- endif %} @@ -97,7 +113,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- set dim = src_mul_ir.ir.dim %} {%- set mul = src_mul_ir.mul %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) {%- endfor %} {%- endfor %} {%- endif %} @@ -128,15 +144,15 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- set mul = src_mul_ir.mul %} {%- if map.storeback_procedure[idx] == "write" %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];) {%- endfor %} {%- elif map.storeback_procedure[idx] == "accumulate" %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];) {%- endfor %} {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + lane_id + {{loop_var}}]);) + ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id]);) {%- endfor %} {%- endif %} {%- endfor %} From 85e988f634d97682f1daab292ce1fcfe4a5316cb Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Mar 2026 18:47:28 -0700 Subject: [PATCH 04/27] Made more progress. --- .../_torch/NPDoubleBackwardMixin.py | 78 ++--------- .../openequivariance/_torch/TensorProduct.py | 53 ++------ .../benchmark/correctness_utils.py | 124 +++++++++++++++--- 3 files changed, 123 insertions(+), 132 deletions(-) diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py index 1356b38f..18a8e873 100644 --- a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -1,7 +1,5 @@ import torch -from openequivariance.core.utils import IrrepLayoutUtils - class NumpyDoubleBackwardMixin: """ @@ -15,30 +13,12 @@ def double_backward_cpu( ): assert self.torch_op - layout = self.config.layout - - in1_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in1, self.config.irreps_in1, layout, "mul_ir" - ) - in2_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in2, self.config.irreps_in2, layout, "mul_ir" - ) - out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout( - out_grad, self.config.irreps_out, layout, "mul_ir" - ) - in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in1_dgrad, self.config.irreps_in1, layout, "mul_ir" - ) - in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in2_dgrad, self.config.irreps_in2, layout, "mul_ir" - ) - - in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True) - in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True) + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) - out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda") - in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda") + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") out_torch = self.forward(in1_torch, in2_torch, weights_torch) @@ -61,16 +41,6 @@ def double_backward_cpu( c_np = c.detach().cpu().numpy() d_np = d.detach().cpu().numpy() - a_np = IrrepLayoutUtils.transpose_irrep_layout( - a_np, self.config.irreps_in1, "mul_ir", layout - ) - b_np = IrrepLayoutUtils.transpose_irrep_layout( - b_np, self.config.irreps_in2, "mul_ir", layout - ) - d_np = IrrepLayoutUtils.transpose_irrep_layout( - d_np, self.config.irreps_out, "mul_ir", layout - ) - return (a_np, b_np, c_np, d_np) @@ -84,30 +54,12 @@ def double_backward_cpu( ): assert self.torch_op - layout = self.config.layout - - in1_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in1, self.config.irreps_in1, layout, "mul_ir" - ) - in2_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in2, self.config.irreps_in2, layout, "mul_ir" - ) - out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout( - out_grad, self.config.irreps_out, layout, "mul_ir" - ) - in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in1_dgrad, self.config.irreps_in1, layout, "mul_ir" - ) - in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout( - in2_dgrad, self.config.irreps_in2, layout, "mul_ir" - ) - - in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True) - in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True) + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) - out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda") - in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda") + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") torch_rows = torch.tensor(graph.rows, device="cuda") @@ -142,14 +94,4 @@ def double_backward_cpu( c_np = c.detach().cpu().numpy() d_np = d.detach().cpu().numpy() - a_np = IrrepLayoutUtils.transpose_irrep_layout( - a_np, self.config.irreps_in1, "mul_ir", layout - ) - b_np = IrrepLayoutUtils.transpose_irrep_layout( - b_np, self.config.irreps_in2, "mul_ir", layout - ) - d_np = IrrepLayoutUtils.transpose_irrep_layout( - d_np, self.config.irreps_out, "mul_ir", layout - ) - return (a_np, b_np, c_np, d_np) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 4c38feae..2087f03c 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -11,11 +11,7 @@ ) from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance.core.utils import ( - IrrepLayoutUtils, - dtype_to_enum, - torch_to_oeq_dtype, -) +from openequivariance.core.utils import dtype_to_enum, torch_to_oeq_dtype logger = getLogger() @@ -150,24 +146,12 @@ def forward_cpu( weights, not self.config.shared_weights ) - layout = self.config.layout - - L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( - L1_in, self.config.irreps_in1, layout, "mul_ir" - ) - L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( - L2_in, self.config.irreps_in2, layout, "mul_ir" - ) - - torch_L1_in = torch.tensor(L1_in_kernel, device="cuda") - torch_L2_in = torch.tensor(L2_in_kernel, device="cuda") + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") torch_weights = torch.tensor(weights_chunked, device="cuda") torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) - L3_kernel = torch_L3_out.numpy(force=True) - L3_out[:] = IrrepLayoutUtils.transpose_irrep_layout( - L3_kernel, self.config.irreps_out, "mul_ir", layout - ) + L3_out[:] = torch_L3_out.numpy(force=True) def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad @@ -176,37 +160,18 @@ def backward_cpu( weights, not self.config.shared_weights ) - layout = self.config.layout - - L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( - L1_in, self.config.irreps_in1, layout, "mul_ir" - ) - L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout( - L2_in, self.config.irreps_in2, layout, "mul_ir" - ) - L3_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout( - L3_grad, self.config.irreps_out, layout, "mul_ir" - ) - - torch_L1_in = torch.tensor(L1_in_kernel, requires_grad=True, device="cuda") - torch_L2_in = torch.tensor(L2_in_kernel, requires_grad=True, device="cuda") + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda") torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) - torch_L3_grad_in = torch.tensor(L3_grad_kernel, device="cuda") + torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") torch_out.backward(gradient=torch_L3_grad_in) - L1_grad_kernel = torch_L1_in.grad.numpy(force=True) - L2_grad_kernel = torch_L2_in.grad.numpy(force=True) - - L1_grad[:] = IrrepLayoutUtils.transpose_irrep_layout( - L1_grad_kernel, self.config.irreps_in1, "mul_ir", layout - ) - L2_grad[:] = IrrepLayoutUtils.transpose_irrep_layout( - L2_grad_kernel, self.config.irreps_in2, "mul_ir", layout - ) + L1_grad[:] = torch_L1_in.grad.numpy(force=True) + L2_grad[:] = torch_L2_in.grad.numpy(force=True) weights_grad[:] = torch_weights.grad.numpy(force=True) weights_grad[:] = self.reorder_weights_to_e3nn( diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py index 788d209e..91dc4760 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -1,17 +1,18 @@ from typing import Optional, Union -from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.core.e3nn_lite import TPProblem +import numpy as np +import numpy.linalg as la + from openequivariance._torch.CUETensorProduct import CUETensorProduct +from openequivariance.benchmark.logging_utils import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( - get_random_buffers_forward, get_random_buffers_backward, get_random_buffers_double_backward, + get_random_buffers_forward, ) - -from openequivariance.benchmark.logging_utils import getLogger, bcolors -import numpy as np -import numpy.linalg as la +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.utils import IrrepLayoutUtils logger = getLogger() @@ -81,7 +82,7 @@ def correctness_forward( in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed) - # run reference + # run reference (always in mul_ir) ref_tp = reference_implementation(problem) ref_out = out.copy() @@ -93,13 +94,31 @@ def correctness_forward( if problem.shared_weights and test_implementation == CUETensorProduct: weights_copy = weights[np.newaxis, :] - # run test + # run test (may require ir_mul conversion) test_tp = instantiate_implementation(test_implementation, problem) + test_layout = getattr(test_tp.config, "layout", "mul_ir") + + test_in1 = in1.copy() + test_in2 = in2.copy() test_out = out.copy() + + if test_layout == "ir_mul": + test_in1 = IrrepLayoutUtils.transpose_irrep_layout( + test_in1, problem.irreps_in1, "mul_ir", "ir_mul" + ) + test_in2 = IrrepLayoutUtils.transpose_irrep_layout( + test_in2, problem.irreps_in2, "mul_ir", "ir_mul" + ) + test_tp.forward_cpu( - L1_in=in1.copy(), L2_in=in2.copy(), L3_out=test_out, weights=weights_copy + L1_in=test_in1, L2_in=test_in2, L3_out=test_out, weights=weights_copy ) + if test_layout == "ir_mul": + test_out = IrrepLayoutUtils.transpose_irrep_layout( + test_out, problem.irreps_out, "ir_mul", "mul_ir" + ) + for name, to_check, ground_truth in [("output", ref_out, test_out)]: result[name] = check_similiarity( name, to_check, ground_truth, correctness_threshold @@ -144,7 +163,7 @@ def correctness_backward( weights_grad=ref_weights_grad, ) - # run test version + # run test version (may require ir_mul conversion) test_weights_grad = weights_grad.copy() test_in1_grad = in1_grad.copy() test_in2_grad = in2_grad.copy() @@ -156,16 +175,41 @@ def correctness_backward( test_weights_grad = test_weights_grad[np.newaxis, :] test_tp = instantiate_implementation(test_implementation, problem) + test_layout = getattr(test_tp.config, "layout", "mul_ir") + + test_in1 = in1.copy() + test_in2 = in2.copy() + test_L3_grad = out_grad.copy() + + if test_layout == "ir_mul": + test_in1 = IrrepLayoutUtils.transpose_irrep_layout( + test_in1, problem.irreps_in1, "mul_ir", "ir_mul" + ) + test_in2 = IrrepLayoutUtils.transpose_irrep_layout( + test_in2, problem.irreps_in2, "mul_ir", "ir_mul" + ) + test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout( + test_L3_grad, problem.irreps_out, "mul_ir", "ir_mul" + ) + test_tp.backward_cpu( - L1_in=in1.copy(), + L1_in=test_in1, L1_grad=test_in1_grad, - L2_in=in2.copy(), + L2_in=test_in2, L2_grad=test_in2_grad, - L3_grad=out_grad.copy(), + L3_grad=test_L3_grad, weights=weights_copy, weights_grad=test_weights_grad, ) + if test_layout == "ir_mul": + test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout( + test_in1_grad, problem.irreps_in1, "ir_mul", "mul_ir" + ) + test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout( + test_in2_grad, problem.irreps_in2, "ir_mul", "mul_ir" + ) + weight_threshold = ( correctness_threshold * batch_size if problem.shared_weights @@ -210,7 +254,9 @@ def correctness_double_backward( result = {"thresh": correctness_threshold, "batch_size": batch_size} tensors = [] - for _, impl in enumerate([test_implementation, reference_implementation]): + for is_test_impl, impl in enumerate( + [test_implementation, reference_implementation] + ): tp = instantiate_implementation(impl, problem) weights_reordered = tp.reorder_weights_from_e3nn( weights, has_batch_dim=not problem.shared_weights @@ -222,15 +268,53 @@ def correctness_double_backward( if impl == CUETensorProduct and problem.shared_weights: weights_reordered = weights_reordered[np.newaxis, :] + tp_layout = getattr(tp.config, "layout", "mul_ir") + apply_test_layout = is_test_impl == 0 and tp_layout == "ir_mul" + + db_in1 = in1 + db_in2 = in2 + db_out_grad = out_grad + db_in1_dgrad = in1_dgrad + db_in2_dgrad = in2_dgrad + + if apply_test_layout: + db_in1 = IrrepLayoutUtils.transpose_irrep_layout( + in1, problem.irreps_in1, "mul_ir", "ir_mul" + ) + db_in2 = IrrepLayoutUtils.transpose_irrep_layout( + in2, problem.irreps_in2, "mul_ir", "ir_mul" + ) + db_out_grad = IrrepLayoutUtils.transpose_irrep_layout( + out_grad, problem.irreps_out, "mul_ir", "ir_mul" + ) + db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout( + in1_dgrad, problem.irreps_in1, "mul_ir", "ir_mul" + ) + db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout( + in2_dgrad, problem.irreps_in2, "mul_ir", "ir_mul" + ) + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( - in1, - in2, - out_grad, + db_in1, + db_in2, + db_out_grad, weights_reordered, weights_dgrad_reordered, - in1_dgrad, - in2_dgrad, + db_in1_dgrad, + db_in2_dgrad, ) + + if apply_test_layout: + out_dgrad = IrrepLayoutUtils.transpose_irrep_layout( + out_dgrad, problem.irreps_out, "ir_mul", "mul_ir" + ) + in1_grad = IrrepLayoutUtils.transpose_irrep_layout( + in1_grad, problem.irreps_in1, "ir_mul", "mul_ir" + ) + in2_grad = IrrepLayoutUtils.transpose_irrep_layout( + in2_grad, problem.irreps_in2, "ir_mul", "mul_ir" + ) + tensors.append( ( out_dgrad, From 379fd28f33ce336dff64487829992e58aba1db84 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Mar 2026 22:07:58 -0700 Subject: [PATCH 05/27] Convolution test is failing. --- .../core/ComputationSchedule.py | 9 +- .../openequivariance/core/ConvolutionBase.py | 119 +++++++++++++++--- .../openequivariance/core/LoopUnrollConv.py | 14 ++- .../openequivariance/templates/macros.jinja | 10 +- tests/conv_test.py | 11 ++ 5 files changed, 130 insertions(+), 33 deletions(-) diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index da4016ac..8cfad757 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -29,6 +29,10 @@ def __init__(self, src_irreps, src_views, idxs): src_ranges = [src_irreps.slices()[idx] for idx in self.src_dst_map] dst_ranges = [self.dst_irreps.slices()[i] for i in self.src_dst_map.values()] + self.storeback_procedure = {idx: "write" for idx in self.idxs} + self.persist_load = False + self.persist_store = False + if src_views[0].layout == "ir_mul": return @@ -55,11 +59,6 @@ def __init__(self, src_irreps, src_views, idxs): self.dst_ranges.append(slice(dst_start, dst_end)) self.copy_ranges = list(zip(self.src_ranges, self.dst_ranges)) - self.persist_load = False - self.persist_store = False - - self.storeback_procedure = {idx: "write" for idx in self.idxs} - class CGTensor: def __init__(self, l1, l2, l3, normalization_factor, dtype): diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index a06b2c79..ec5f5905 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -1,15 +1,16 @@ import copy + import numpy as np + +from openequivariance.benchmark.correctness_utils import check_similiarity +from openequivariance.benchmark.logging_utils import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( - get_random_buffers_forward_conv, get_random_buffers_backward_conv, get_random_buffers_double_backward_conv, + get_random_buffers_forward_conv, ) - -from openequivariance.benchmark.logging_utils import getLogger, bcolors -from openequivariance.benchmark.correctness_utils import check_similiarity from openequivariance.core.e3nn_lite import wigner_3j -from openequivariance.core.utils import benchmark +from openequivariance.core.utils import IrrepLayoutUtils, benchmark logger = getLogger() @@ -143,6 +144,13 @@ def test_correctness_forward( check_reproducible=True, high_precision_ref=False, ): + def maybe_transpose_input_for_test_impl(x, irreps): + if self.config.layout == "ir_mul": + return IrrepLayoutUtils.transpose_irrep_layout( + x, irreps, "mul_ir", "ir_mul" + ) + return x + if reference_implementation is None: from openequivariance._torch.E3NNConv import E3NNConv @@ -186,13 +194,22 @@ def test_correctness_forward( test_out = out.copy() self.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), + L1_in=maybe_transpose_input_for_test_impl( + in1.copy(), self.config.irreps_in1 + ), + L2_in=maybe_transpose_input_for_test_impl( + in2.copy(), self.config.irreps_in2 + ), weights=weights.copy(), L3_out=test_out, graph=graph, ) + if self.config.layout == "ir_mul": + test_out = IrrepLayoutUtils.transpose_irrep_layout( + test_out, self.config.irreps_out, "ir_mul", "mul_ir" + ) + for name, to_check, ground_truth in [("output", ref_out, test_out)]: result[name] = check_similiarity(name, to_check, ground_truth, thresh) @@ -205,13 +222,22 @@ def test_correctness_forward( for i in range(num_trials): repeated_run = out.copy() self.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), + L1_in=maybe_transpose_input_for_test_impl( + in1.copy(), self.config.irreps_in1 + ), + L2_in=maybe_transpose_input_for_test_impl( + in2.copy(), self.config.irreps_in2 + ), weights=weights.copy(), L3_out=repeated_run, graph=graph, ) + if self.config.layout == "ir_mul": + repeated_run = IrrepLayoutUtils.transpose_irrep_layout( + repeated_run, self.config.irreps_out, "ir_mul", "mul_ir" + ) + for name, to_check, ground_truth in [ ("output", repeated_run, test_out) ]: @@ -387,6 +413,13 @@ def test_correctness_backward( reference_implementation=None, high_precision_ref=False, ): + def maybe_transpose_input_for_test_impl(x, irreps): + if self.config.layout == "ir_mul": + return IrrepLayoutUtils.transpose_irrep_layout( + x, irreps, "mul_ir", "ir_mul" + ) + return x + if reference_implementation is None: from openequivariance._torch.E3NNConv import E3NNConv @@ -436,17 +469,35 @@ def test_correctness_backward( test_in1_grad = in1_grad.copy() test_in2_grad = in2_grad.copy() + test_L3_grad = out_grad.copy() + if self.config.layout == "ir_mul": + test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout( + test_L3_grad, self.config.irreps_out, "mul_ir", "ir_mul" + ) + self.backward_cpu( - L1_in=in1.copy(), + L1_in=maybe_transpose_input_for_test_impl( + in1.copy(), self.config.irreps_in1 + ), L1_grad=test_in1_grad, - L2_in=in2.copy(), + L2_in=maybe_transpose_input_for_test_impl( + in2.copy(), self.config.irreps_in2 + ), L2_grad=test_in2_grad, - L3_grad=out_grad.copy(), + L3_grad=test_L3_grad, weights=weights.copy(), weights_grad=test_weights_grad, graph=graph, ) + if self.config.layout == "ir_mul": + test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout( + test_in1_grad, self.config.irreps_in1, "ir_mul", "mul_ir" + ) + test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout( + test_in2_grad, self.config.irreps_in2, "ir_mul", "mul_ir" + ) + for name, to_check, ground_truth, threshold in [ ("weight_grad", test_weights_grad, ref_weights_grad, thresh), ("in1_grad", test_in1_grad, ref_in1_grad, thresh), @@ -464,6 +515,13 @@ def test_correctness_double_backward( reference_implementation=None, high_precision_ref=False, ): + def maybe_transpose_input_for_test_impl(tp, x, irreps): + if tp is self and tp.config.layout == "ir_mul": + return IrrepLayoutUtils.transpose_irrep_layout( + x, irreps, "mul_ir", "ir_mul" + ) + return x + buffers = get_random_buffers_double_backward_conv( self.config, graph.node_count, graph.nnz, prng_seed ) @@ -500,17 +558,44 @@ def test_correctness_double_backward( weights_dgrad, not tp.config.shared_weights ) + db_in1 = maybe_transpose_input_for_test_impl(tp, in1, tp.config.irreps_in1) + db_in2 = maybe_transpose_input_for_test_impl(tp, in2, tp.config.irreps_in2) + db_out_grad = out_grad + db_in1_dgrad = in1_dgrad + db_in2_dgrad = in2_dgrad + if tp is self and tp.config.layout == "ir_mul": + db_out_grad = IrrepLayoutUtils.transpose_irrep_layout( + out_grad, tp.config.irreps_out, "mul_ir", "ir_mul" + ) + db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout( + in1_dgrad, tp.config.irreps_in1, "mul_ir", "ir_mul" + ) + db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout( + in2_dgrad, tp.config.irreps_in2, "mul_ir", "ir_mul" + ) + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( - in1, - in2, - out_grad, + db_in1, + db_in2, + db_out_grad, weights_reordered, weights_dgrad_reordered, - in1_dgrad, - in2_dgrad, + db_in1_dgrad, + db_in2_dgrad, graph, ) + if tp is self and tp.config.layout == "ir_mul": + out_dgrad = IrrepLayoutUtils.transpose_irrep_layout( + out_dgrad, tp.config.irreps_out, "ir_mul", "mul_ir" + ) + in1_grad = IrrepLayoutUtils.transpose_irrep_layout( + in1_grad, tp.config.irreps_in1, "ir_mul", "mul_ir" + ) + in2_grad = IrrepLayoutUtils.transpose_irrep_layout( + in2_grad, tp.config.irreps_in2, "ir_mul", "mul_ir" + ) + tensors.append( ( out_dgrad, diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index ca8b4bdd..b3c63e66 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -1,18 +1,18 @@ -import numpy as np import json -from openequivariance.core.ConvolutionBase import ConvolutionBase +import numpy as np + from openequivariance.core.ComputationSchedule import ( ComputationSchedule, SMEMCapacityException, ) - -from openequivariance.templates.jinja_utils import get_jinja_environment +from openequivariance.core.ConvolutionBase import ConvolutionBase from openequivariance.core.utils import ( - filter_and_analyze_problem, dtype_to_enum, + filter_and_analyze_problem, hash_str_64, ) +from openequivariance.templates.jinja_utils import get_jinja_environment class LoopUnrollConv(ConvolutionBase): @@ -114,9 +114,11 @@ def generate_double_backward_schedule(warps_per_block): except SMEMCapacityException: warp_count -= 1 if warp_count == 0: - raise SMEMCapacityException( + raise RuntimeError( "Tensor product schedule generation failed, shared memory inadequate!" ) + except Exception: + raise if not deterministic: for segment in self.forward_schedule.segments: diff --git a/openequivariance/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja index b4a28158..8622b656 100644 --- a/openequivariance/openequivariance/templates/macros.jinja +++ b/openequivariance/openequivariance/templates/macros.jinja @@ -91,7 +91,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- set dim = src_mul_ir.ir.dim %} {%- set mul = src_mul_ir.mul %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) {%- endfor %} {%- endfor %} {%- endif %} @@ -113,7 +113,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- set dim = src_mul_ir.ir.dim %} {%- set mul = src_mul_ir.mul %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) {%- endfor %} {%- endfor %} {%- endif %} @@ -144,15 +144,15 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- set mul = src_mul_ir.mul %} {%- if map.storeback_procedure[idx] == "write" %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) {%- endfor %} {%- elif map.storeback_procedure[idx] == "accumulate" %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];) + ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) {%- endfor %} {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id]);) + ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id]);) {%- endfor %} {%- endif %} {%- endfor %} diff --git a/tests/conv_test.py b/tests/conv_test.py index 50b6376b..8164a9f9 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -284,6 +284,17 @@ def conv_object(self, request, problem, extra_conv_constructor_args): return module.to(switch_map[problem.irrep_dtype]) +class TestIrMulLayout(ConvCorrectness): + production_model_tpps = mace_problems() + + @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") + def problem(self, request, dtype): + problem = request.param.clone() + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + problem.layout = "ir_mul" + return problem + + class TestTorchToSubmodule: """Test that TensorProductConv works as a submodule when parent's .to() is called""" From 7dcd8871c38e2897d5d3fb5dfcd01e325611146a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Mar 2026 22:39:59 -0700 Subject: [PATCH 06/27] More bugfixes. --- openequivariance/openequivariance/core/ComputationSchedule.py | 2 ++ openequivariance/openequivariance/core/utils.py | 4 ++-- tests/batch_test.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 8cfad757..461ac565 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -301,6 +301,7 @@ def __init__(self, input, mult_threshold): path_normalization="none", internal_weights=False, shared_weights=input.shared_weights, + layout=input.layout, ) assert self.output.weight_numel == input.weight_numel @@ -595,6 +596,7 @@ def calculate_backward_smem( path_normalization="none", internal_weights=False, shared_weights=config.shared_weights, + layout=config.layout, ) weight_offset = 0 diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index af44fb43..2f8d398e 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -240,11 +240,11 @@ def transpose_irrep_layout( if src_layout == "ir_mul" and dst_layout == "mul_ir": out[..., seg.start : seg.stop] = block.reshape( *block.shape[:-1], dim, mul - ).reshape(*block.shape[:-1], mul * dim) + ).swapaxes(-1, -2).reshape(*block.shape[:-1], mul * dim) elif src_layout == "mul_ir" and dst_layout == "ir_mul": out[..., seg.start : seg.stop] = block.reshape( *block.shape[:-1], mul, dim - ).reshape(*block.shape[:-1], dim * mul) + ).swapaxes(-1, -2).reshape(*block.shape[:-1], dim * mul) else: raise ValueError( f"Unsupported layout transpose: {src_layout} -> {dst_layout}" diff --git a/tests/batch_test.py b/tests/batch_test.py index 2fbb84cb..069cbf92 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -273,14 +273,14 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): return tp, tp.config -class TestMulIrLayoutMACE(TPCorrectness): +class TestIrMulLayoutMACE(TPCorrectness): production_model_tpps = mace_problems() @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): problem = request.param.clone() problem.irrep_dtype, problem.weight_dtype = dtype, dtype - problem.layout = "mul_ir" + problem.layout = "ir_mul" return problem From 4806748eaa02313dd7491b0cb8018014c9dc1d60 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Mar 2026 22:59:36 -0700 Subject: [PATCH 07/27] Fixed more stuff. --- .../openequivariance/core/utils.py | 5 +++ tests/batch_test.py | 42 ++++++++++++++++++- tests/conv_test.py | 25 ++++++++++- 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 2f8d398e..f574ae4c 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -96,6 +96,11 @@ def filter_and_analyze_problem(problem): f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}" ) + if problem.layout == "ir_mul": + assert problem.instructions[0].connection_mode == "uvu", ( + "layout='ir_mul' is only supported for pure 'uvu' problems" + ) + assert problem.irrep_dtype == problem.weight_dtype, ( f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}" ) diff --git a/tests/batch_test.py b/tests/batch_test.py index 069cbf92..fd00c65e 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -273,8 +273,31 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): return tp, tp.config +def ir_mul_representative_uvu_problems(): + return [ + oeq.TPProblem( + "5x5e", + "1x3e", + "5x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_5x1x5_l535", + ), + oeq.TPProblem( + "13x5e", + "1x3e", + "13x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_13x1x13_l535", + ), + ] + + class TestIrMulLayoutMACE(TPCorrectness): - production_model_tpps = mace_problems() + production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems() @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): @@ -284,6 +307,23 @@ def problem(self, request, dtype): return problem +def test_ir_mul_rejects_uvw_problem(dtype): + problem = oeq.TPProblem( + "5x5e", + "1x3e", + "5x5e", + [(0, 0, 0, "uvw", True)], + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + layout="ir_mul", + ) + + with pytest.raises(AssertionError, match="layout='ir_mul'"): + oeq.TensorProduct(problem) + + class TestTorchToSubmodule: """Test that TensorProduct works correctly as a submodule when parent's .to() is called""" diff --git a/tests/conv_test.py b/tests/conv_test.py index 8164a9f9..6ef0a36a 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -284,8 +284,31 @@ def conv_object(self, request, problem, extra_conv_constructor_args): return module.to(switch_map[problem.irrep_dtype]) +def ir_mul_representative_uvu_problems(): + return [ + oeq.TPProblem( + "5x5e", + "1x3e", + "5x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_5x1x5_l535", + ), + oeq.TPProblem( + "13x5e", + "1x3e", + "13x5e", + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + label="ir_mul_repr_13x1x13_l535", + ), + ] + + class TestIrMulLayout(ConvCorrectness): - production_model_tpps = mace_problems() + production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems() @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): From 44ecf0efff8ae498fe961e49bf831edae621d315 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Mar 2026 23:06:33 -0700 Subject: [PATCH 08/27] Compacted everything. --- .../_torch/NPDoubleBackwardMixin.py | 24 +++++++++---------- .../openequivariance/_torch/TensorProduct.py | 16 ++++++------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py index 18a8e873..caf94268 100644 --- a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -36,12 +36,12 @@ def double_backward_cpu( grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], ) - a_np = a.detach().cpu().numpy() - b_np = b.detach().cpu().numpy() - c_np = c.detach().cpu().numpy() - d_np = d.detach().cpu().numpy() - - return (a_np, b_np, c_np, d_np) + return ( + a.detach().cpu().numpy(), + b.detach().cpu().numpy(), + c.detach().cpu().numpy(), + d.detach().cpu().numpy(), + ) class NumpyDoubleBackwardMixinConv: @@ -89,9 +89,9 @@ def double_backward_cpu( grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], ) - a_np = a.detach().cpu().numpy() - b_np = b.detach().cpu().numpy() - c_np = c.detach().cpu().numpy() - d_np = d.detach().cpu().numpy() - - return (a_np, b_np, c_np, d_np) + return ( + a.detach().cpu().numpy(), + b.detach().cpu().numpy(), + c.detach().cpu().numpy(), + d.detach().cpu().numpy(), + ) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 2087f03c..254da414 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -1,17 +1,17 @@ -import numpy as np -import torch - +from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance import TPProblem from openequivariance._torch import extlib -from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin +import torch +from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum +from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.utils import ( - enum_to_torch_dtype, reorder_torch, string_to_tensor, + enum_to_torch_dtype, ) -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance.core.utils import dtype_to_enum, torch_to_oeq_dtype +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin + +import numpy as np logger = getLogger() From 49477816ecdccac5e2a0da38e5471cc0f233f24b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 14:01:35 -0700 Subject: [PATCH 09/27] compaction. --- .../benchmark/correctness_utils.py | 251 ++++++++---------- 1 file changed, 111 insertions(+), 140 deletions(-) diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py index 91dc4760..07b523db 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -79,47 +79,39 @@ def correctness_forward( reference_implementation = E3NNTensorProduct result = {"thresh": correctness_threshold, "batch_size": batch_size} - in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed) + outputs = [] - # run reference (always in mul_ir) - ref_tp = reference_implementation(problem) - - ref_out = out.copy() - ref_tp.forward_cpu( - L1_in=in1.copy(), L2_in=in2.copy(), L3_out=ref_out, weights=weights.copy() - ) - - weights_copy = weights.copy() - if problem.shared_weights and test_implementation == CUETensorProduct: - weights_copy = weights[np.newaxis, :] - - # run test (may require ir_mul conversion) - test_tp = instantiate_implementation(test_implementation, problem) - test_layout = getattr(test_tp.config, "layout", "mul_ir") - - test_in1 = in1.copy() - test_in2 = in2.copy() - test_out = out.copy() - - if test_layout == "ir_mul": - test_in1 = IrrepLayoutUtils.transpose_irrep_layout( - test_in1, problem.irreps_in1, "mul_ir", "ir_mul" - ) - test_in2 = IrrepLayoutUtils.transpose_irrep_layout( - test_in2, problem.irreps_in2, "mul_ir", "ir_mul" - ) - - test_tp.forward_cpu( - L1_in=test_in1, L2_in=test_in2, L3_out=test_out, weights=weights_copy - ) + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = (i == 0) + tp = instantiate_implementation(impl, problem) + uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct) + run_in1, run_in2, run_weights, run_out = [ buf.copy() for buf in (in1, in2, weights, out) ] + + if problem.shared_weights and uses_cue: + run_weights = run_weights[np.newaxis, :] + + # Transpose inputs, if necessary, for the test implementation + if is_test_impl: + run_in1, run_in2 = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", tp.config.layout + ) for arr, irreps in zip( + (run_in1, run_in2), + (problem.irreps_in1, problem.irreps_in2) + ) + ] + + tp.forward_cpu(L1_in=run_in1, L2_in=run_in2, L3_out=run_out, weights=run_weights) + + if is_test_impl: + run_out = IrrepLayoutUtils.transpose_irrep_layout( + run_out, problem.irreps_out, tp.config.layout, "mul_ir" + ) - if test_layout == "ir_mul": - test_out = IrrepLayoutUtils.transpose_irrep_layout( - test_out, problem.irreps_out, "ir_mul", "mul_ir" - ) + outputs.append(run_out) - for name, to_check, ground_truth in [("output", ref_out, test_out)]: + for name, to_check, ground_truth in [("output", outputs[0], outputs[1])]: result[name] = check_similiarity( name, to_check, ground_truth, correctness_threshold ) @@ -142,73 +134,61 @@ def correctness_backward( result = {"thresh": correctness_threshold, "batch_size": batch_size} - # run reference in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = ( get_random_buffers_backward(problem, batch_size, prng_seed) ) - ref_tp = reference_implementation(problem) - - ref_weights_grad = weights_grad.copy() - ref_in1_grad = in1_grad.copy() - ref_in2_grad = in2_grad.copy() - - ref_tp.backward_cpu( - L1_in=in1.copy(), - L1_grad=ref_in1_grad, - L2_in=in2.copy(), - L2_grad=ref_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), - weights_grad=ref_weights_grad, - ) - - # run test version (may require ir_mul conversion) - test_weights_grad = weights_grad.copy() - test_in1_grad = in1_grad.copy() - test_in2_grad = in2_grad.copy() - - weights_copy = weights.copy() - - if problem.shared_weights and test_implementation == CUETensorProduct: - weights_copy = weights[np.newaxis, :] - test_weights_grad = test_weights_grad[np.newaxis, :] - - test_tp = instantiate_implementation(test_implementation, problem) - test_layout = getattr(test_tp.config, "layout", "mul_ir") - - test_in1 = in1.copy() - test_in2 = in2.copy() - test_L3_grad = out_grad.copy() + grads = [] + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = i == 0 + tp = instantiate_implementation(impl, problem) - if test_layout == "ir_mul": - test_in1 = IrrepLayoutUtils.transpose_irrep_layout( - test_in1, problem.irreps_in1, "mul_ir", "ir_mul" - ) - test_in2 = IrrepLayoutUtils.transpose_irrep_layout( - test_in2, problem.irreps_in2, "mul_ir", "ir_mul" - ) - test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout( - test_L3_grad, problem.irreps_out, "mul_ir", "ir_mul" + run_in1, run_in2, run_L3_grad, run_weights, run_weights_grad, run_in1_grad, run_in2_grad = [ + buf.copy() + for buf in (in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad) + ] + + uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct) + if problem.shared_weights and uses_cue: + run_weights = run_weights[np.newaxis, :] + run_weights_grad = run_weights_grad[np.newaxis, :] + + if is_test_impl: + run_in1, run_in2, run_L3_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", tp.config.layout + ) + for arr, irreps in zip( + (run_in1, run_in2, run_L3_grad), + (problem.irreps_in1, problem.irreps_in2, problem.irreps_out), + ) + ] + + tp.backward_cpu( + L1_in=run_in1, + L1_grad=run_in1_grad, + L2_in=run_in2, + L2_grad=run_in2_grad, + L3_grad=run_L3_grad, + weights=run_weights, + weights_grad=run_weights_grad, ) - test_tp.backward_cpu( - L1_in=test_in1, - L1_grad=test_in1_grad, - L2_in=test_in2, - L2_grad=test_in2_grad, - L3_grad=test_L3_grad, - weights=weights_copy, - weights_grad=test_weights_grad, - ) + if is_test_impl: + run_in1_grad, run_in2_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, tp.config.layout, "mul_ir" + ) + for arr, irreps in zip( + (run_in1_grad, run_in2_grad), + (problem.irreps_in1, problem.irreps_in2), + ) + ] - if test_layout == "ir_mul": - test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout( - test_in1_grad, problem.irreps_in1, "ir_mul", "mul_ir" - ) - test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout( - test_in2_grad, problem.irreps_in2, "ir_mul", "mul_ir" - ) + if problem.shared_weights: + run_weights_grad = run_weights_grad.squeeze() + + grads.append((run_weights_grad, run_in1_grad, run_in2_grad)) weight_threshold = ( correctness_threshold * batch_size @@ -216,13 +196,10 @@ def correctness_backward( else correctness_threshold ) - if problem.shared_weights: - test_weights_grad = test_weights_grad.squeeze() - for name, to_check, ground_truth, threshold in [ - ("weight_grad", test_weights_grad, ref_weights_grad, weight_threshold), - ("in1_grad", test_in1_grad, ref_in1_grad, correctness_threshold), - ("in2_grad", test_in2_grad, ref_in2_grad, correctness_threshold), + ("weight_grad", grads[0][0], grads[1][0], weight_threshold), + ("in1_grad", grads[0][1], grads[1][1], correctness_threshold), + ("in2_grad", grads[0][2], grads[1][2], correctness_threshold), ]: result[name] = check_similiarity(name, to_check, ground_truth, threshold) @@ -254,9 +231,8 @@ def correctness_double_backward( result = {"thresh": correctness_threshold, "batch_size": batch_size} tensors = [] - for is_test_impl, impl in enumerate( - [test_implementation, reference_implementation] - ): + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = i == 0 tp = instantiate_implementation(impl, problem) weights_reordered = tp.reorder_weights_from_e3nn( weights, has_batch_dim=not problem.shared_weights @@ -268,31 +244,26 @@ def correctness_double_backward( if impl == CUETensorProduct and problem.shared_weights: weights_reordered = weights_reordered[np.newaxis, :] - tp_layout = getattr(tp.config, "layout", "mul_ir") - apply_test_layout = is_test_impl == 0 and tp_layout == "ir_mul" - - db_in1 = in1 - db_in2 = in2 - db_out_grad = out_grad - db_in1_dgrad = in1_dgrad - db_in2_dgrad = in2_dgrad - - if apply_test_layout: - db_in1 = IrrepLayoutUtils.transpose_irrep_layout( - in1, problem.irreps_in1, "mul_ir", "ir_mul" - ) - db_in2 = IrrepLayoutUtils.transpose_irrep_layout( - in2, problem.irreps_in2, "mul_ir", "ir_mul" - ) - db_out_grad = IrrepLayoutUtils.transpose_irrep_layout( - out_grad, problem.irreps_out, "mul_ir", "ir_mul" - ) - db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout( - in1_dgrad, problem.irreps_in1, "mul_ir", "ir_mul" - ) - db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout( - in2_dgrad, problem.irreps_in2, "mul_ir", "ir_mul" - ) + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad) + ] + + if is_test_impl: + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", tp.config.layout + ) + for arr, irreps in zip( + (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), + ( + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + ), + ) + ] in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( db_in1, @@ -304,16 +275,16 @@ def correctness_double_backward( db_in2_dgrad, ) - if apply_test_layout: - out_dgrad = IrrepLayoutUtils.transpose_irrep_layout( - out_dgrad, problem.irreps_out, "ir_mul", "mul_ir" - ) - in1_grad = IrrepLayoutUtils.transpose_irrep_layout( - in1_grad, problem.irreps_in1, "ir_mul", "mul_ir" - ) - in2_grad = IrrepLayoutUtils.transpose_irrep_layout( - in2_grad, problem.irreps_in2, "ir_mul", "mul_ir" - ) + if is_test_impl: + out_dgrad, in1_grad, in2_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, tp.config.layout, "mul_ir" + ) + for arr, irreps in zip( + (out_dgrad, in1_grad, in2_grad), + (problem.irreps_out, problem.irreps_in1, problem.irreps_in2), + ) + ] tensors.append( ( From 596874576f4bf6fdd54c5b44d55d6e298d1c5a14 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 14:12:15 -0700 Subject: [PATCH 10/27] File renaming. --- openequivariance/openequivariance/_torch/CUETensorProduct.py | 2 +- openequivariance/openequivariance/_torch/E3NNTensorProduct.py | 2 +- openequivariance/openequivariance/_torch/TensorProduct.py | 2 +- openequivariance/openequivariance/_torch/TensorProductConv.py | 2 +- openequivariance/openequivariance/_torch/extlib/__init__.py | 2 +- .../openequivariance/benchmark/ConvBenchmarkSuite.py | 2 +- .../openequivariance/benchmark/TestBenchmarkSuite.py | 4 ++-- .../openequivariance/benchmark/benchmark_utils.py | 2 +- .../benchmark/{correctness_utils.py => correctness.py} | 2 +- .../benchmark/{logging_utils.py => logging.py} | 0 .../openequivariance/benchmark/perf_metrics_utils.py | 2 +- openequivariance/openequivariance/core/ComputationSchedule.py | 2 +- openequivariance/openequivariance/core/ConvolutionBase.py | 4 ++-- openequivariance/openequivariance/core/LoopUnrollTP.py | 2 +- openequivariance/openequivariance/core/TensorProductBase.py | 2 +- openequivariance/openequivariance/jax/TensorProductConv.py | 2 +- tests/batch_test.py | 2 +- tests/benchmark.py | 2 +- 18 files changed, 19 insertions(+), 19 deletions(-) rename openequivariance/openequivariance/benchmark/{correctness_utils.py => correctness.py} (99%) rename openequivariance/openequivariance/benchmark/{logging_utils.py => logging.py} (100%) diff --git a/openequivariance/openequivariance/_torch/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py index 33b8db12..241b1df6 100644 --- a/openequivariance/openequivariance/_torch/CUETensorProduct.py +++ b/openequivariance/openequivariance/_torch/CUETensorProduct.py @@ -6,7 +6,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.benchmark.tpp_creation_utils import ( ChannelwiseTPP, FullyConnectedTPProblem, diff --git a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py index 067a7e6b..cacbc017 100644 --- a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py @@ -11,7 +11,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 254da414..c00530d5 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -3,7 +3,7 @@ from openequivariance._torch import extlib import torch from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.utils import ( reorder_torch, string_to_tensor, diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 30931151..e522c0e8 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -23,7 +23,7 @@ enum_to_torch_dtype, ) -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv logger = getLogger() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index be4113ec..81f243cf 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -8,7 +8,7 @@ import torch -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger oeq_root = str(Path(__file__).parent.parent.parent) diff --git a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index debcc65b..2e9a329e 100644 --- a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -6,7 +6,7 @@ import numpy as np import openequivariance as oeq -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.core.ConvolutionBase import CoordGraph from openequivariance.benchmark.benchmark_utils import NpEncoder diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index 37d20c46..c18bde00 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -10,9 +10,9 @@ from openequivariance._torch.extlib import DeviceProp from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.benchmark.logging import getLogger, bcolors from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.correctness_utils import ( +from openequivariance.benchmark.correctness import ( correctness_forward, correctness_backward, correctness_double_backward, diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index 68dc6f9f..31eee40c 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -15,7 +15,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.benchmark.logging import getLogger, bcolors logger = getLogger() diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness.py similarity index 99% rename from openequivariance/openequivariance/benchmark/correctness_utils.py rename to openequivariance/openequivariance/benchmark/correctness.py index 07b523db..79f93814 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness.py @@ -4,7 +4,7 @@ import numpy.linalg as la from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging_utils import bcolors, getLogger +from openequivariance.benchmark.logging import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_backward, get_random_buffers_double_backward, diff --git a/openequivariance/openequivariance/benchmark/logging_utils.py b/openequivariance/openequivariance/benchmark/logging.py similarity index 100% rename from openequivariance/openequivariance/benchmark/logging_utils.py rename to openequivariance/openequivariance/benchmark/logging.py diff --git a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py index 212f05f4..01bd5836 100644 --- a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py @@ -6,7 +6,7 @@ ) from openequivariance.core.e3nn_lite import TPProblem, wigner_3j -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger import numpy as np logger = getLogger() diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 461ac565..57094706 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -2,7 +2,7 @@ import numpy as np -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j logger = getLogger() diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index ec5f5905..79bc3759 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -2,8 +2,8 @@ import numpy as np -from openequivariance.benchmark.correctness_utils import check_similiarity -from openequivariance.benchmark.logging_utils import bcolors, getLogger +from openequivariance.benchmark.correctness import check_similiarity +from openequivariance.benchmark.logging import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_backward_conv, get_random_buffers_double_backward_conv, diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 43c9e244..e07387b0 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -2,7 +2,7 @@ import numpy as np -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.core.ComputationSchedule import ( ComputationSchedule, SMEMCapacityException, diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index b5d3831f..fa065b7b 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -1,7 +1,7 @@ import numpy as np from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.core.utils import benchmark logger = getLogger() diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index c14637a1..4645ae86 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -9,7 +9,7 @@ from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.jax.jvp import conv_prim from openequivariance.jax.vjp import conv_func diff --git a/tests/batch_test.py b/tests/batch_test.py index fd00c65e..dcec47a4 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from openequivariance.benchmark.correctness_utils import ( +from openequivariance.benchmark.correctness import ( correctness_backward, correctness_double_backward, correctness_forward, diff --git a/tests/benchmark.py b/tests/benchmark.py index 829cc46c..441a3ca0 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -10,7 +10,7 @@ import numpy as np -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.extlib import DeviceProp from openequivariance._torch.E3NNTensorProduct import ( E3NNTensorProduct, From 161a1b649ecffb5e3e1d34a14054b0856816caed Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 14:44:55 -0700 Subject: [PATCH 11/27] Revert "File renaming." This reverts commit 596874576f4bf6fdd54c5b44d55d6e298d1c5a14. --- openequivariance/openequivariance/_torch/CUETensorProduct.py | 2 +- openequivariance/openequivariance/_torch/E3NNTensorProduct.py | 2 +- openequivariance/openequivariance/_torch/TensorProduct.py | 2 +- openequivariance/openequivariance/_torch/TensorProductConv.py | 2 +- openequivariance/openequivariance/_torch/extlib/__init__.py | 2 +- .../openequivariance/benchmark/ConvBenchmarkSuite.py | 2 +- .../openequivariance/benchmark/TestBenchmarkSuite.py | 4 ++-- .../openequivariance/benchmark/benchmark_utils.py | 2 +- .../benchmark/{correctness.py => correctness_utils.py} | 2 +- .../benchmark/{logging.py => logging_utils.py} | 0 .../openequivariance/benchmark/perf_metrics_utils.py | 2 +- openequivariance/openequivariance/core/ComputationSchedule.py | 2 +- openequivariance/openequivariance/core/ConvolutionBase.py | 4 ++-- openequivariance/openequivariance/core/LoopUnrollTP.py | 2 +- openequivariance/openequivariance/core/TensorProductBase.py | 2 +- openequivariance/openequivariance/jax/TensorProductConv.py | 2 +- tests/batch_test.py | 2 +- tests/benchmark.py | 2 +- 18 files changed, 19 insertions(+), 19 deletions(-) rename openequivariance/openequivariance/benchmark/{correctness.py => correctness_utils.py} (99%) rename openequivariance/openequivariance/benchmark/{logging.py => logging_utils.py} (100%) diff --git a/openequivariance/openequivariance/_torch/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py index 241b1df6..33b8db12 100644 --- a/openequivariance/openequivariance/_torch/CUETensorProduct.py +++ b/openequivariance/openequivariance/_torch/CUETensorProduct.py @@ -6,7 +6,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance.benchmark.tpp_creation_utils import ( ChannelwiseTPP, FullyConnectedTPProblem, diff --git a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py index cacbc017..067a7e6b 100644 --- a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py @@ -11,7 +11,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index c00530d5..254da414 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -3,7 +3,7 @@ from openequivariance._torch import extlib import torch from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.utils import ( reorder_torch, string_to_tensor, diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index e522c0e8..30931151 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -23,7 +23,7 @@ enum_to_torch_dtype, ) -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv logger = getLogger() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 81f243cf..be4113ec 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -8,7 +8,7 @@ import torch -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger oeq_root = str(Path(__file__).parent.parent.parent) diff --git a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index 2e9a329e..debcc65b 100644 --- a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -6,7 +6,7 @@ import numpy as np import openequivariance as oeq -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.ConvolutionBase import CoordGraph from openequivariance.benchmark.benchmark_utils import NpEncoder diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index c18bde00..37d20c46 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -10,9 +10,9 @@ from openequivariance._torch.extlib import DeviceProp from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging import getLogger, bcolors +from openequivariance.benchmark.logging_utils import getLogger, bcolors from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.correctness import ( +from openequivariance.benchmark.correctness_utils import ( correctness_forward, correctness_backward, correctness_double_backward, diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index 31eee40c..68dc6f9f 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -15,7 +15,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging import getLogger, bcolors +from openequivariance.benchmark.logging_utils import getLogger, bcolors logger = getLogger() diff --git a/openequivariance/openequivariance/benchmark/correctness.py b/openequivariance/openequivariance/benchmark/correctness_utils.py similarity index 99% rename from openequivariance/openequivariance/benchmark/correctness.py rename to openequivariance/openequivariance/benchmark/correctness_utils.py index 79f93814..07b523db 100644 --- a/openequivariance/openequivariance/benchmark/correctness.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -4,7 +4,7 @@ import numpy.linalg as la from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging import bcolors, getLogger +from openequivariance.benchmark.logging_utils import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_backward, get_random_buffers_double_backward, diff --git a/openequivariance/openequivariance/benchmark/logging.py b/openequivariance/openequivariance/benchmark/logging_utils.py similarity index 100% rename from openequivariance/openequivariance/benchmark/logging.py rename to openequivariance/openequivariance/benchmark/logging_utils.py diff --git a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py index 01bd5836..212f05f4 100644 --- a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py @@ -6,7 +6,7 @@ ) from openequivariance.core.e3nn_lite import TPProblem, wigner_3j -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger import numpy as np logger = getLogger() diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 57094706..461ac565 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -2,7 +2,7 @@ import numpy as np -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j logger = getLogger() diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 79bc3759..ec5f5905 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -2,8 +2,8 @@ import numpy as np -from openequivariance.benchmark.correctness import check_similiarity -from openequivariance.benchmark.logging import bcolors, getLogger +from openequivariance.benchmark.correctness_utils import check_similiarity +from openequivariance.benchmark.logging_utils import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_backward_conv, get_random_buffers_double_backward_conv, diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index e07387b0..43c9e244 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -2,7 +2,7 @@ import numpy as np -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.ComputationSchedule import ( ComputationSchedule, SMEMCapacityException, diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index fa065b7b..b5d3831f 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -1,7 +1,7 @@ import numpy as np from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.utils import benchmark logger = getLogger() diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 4645ae86..c14637a1 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -9,7 +9,7 @@ from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance.jax.jvp import conv_prim from openequivariance.jax.vjp import conv_func diff --git a/tests/batch_test.py b/tests/batch_test.py index dcec47a4..fd00c65e 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from openequivariance.benchmark.correctness import ( +from openequivariance.benchmark.correctness_utils import ( correctness_backward, correctness_double_backward, correctness_forward, diff --git a/tests/benchmark.py b/tests/benchmark.py index 441a3ca0..829cc46c 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -10,7 +10,7 @@ import numpy as np -from openequivariance.benchmark.logging import getLogger +from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.extlib import DeviceProp from openequivariance._torch.E3NNTensorProduct import ( E3NNTensorProduct, From 7fd7ab6dbd32c49494c7c2236342e41bf3757ea4 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 15:26:41 -0700 Subject: [PATCH 12/27] Fixed a regression. --- .../openequivariance/templates/loop_unroll_tp.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/templates/loop_unroll_tp.cuh b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh index a56fb4b1..e4c8dabe 100644 --- a/openequivariance/openequivariance/templates/loop_unroll_tp.cuh +++ b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh @@ -1,4 +1,4 @@ -{%- from 'macros.jinja' import layout_load, layout_store with context %} +{%- from 'macros.jinja' import layout_load, layout_store, reg_store with context %} {%- from 'wmm.cuh' import generate_matmul %} {%- macro generate_segment_kernel_forward(id, segment, warp_size) %} @@ -248,7 +248,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ {%- endif %} {%- endfor %} - {{ layout_store(problem.layout, L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }} + {{ reg_store(L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }} __syncwarp(); {{matmul_basename}}B_{{id}}_{{k}}(L3_grad_smem + offset, scratch, weights_smem); From 3c9ed29d917414f09e768fe9d27ec82db8d18398 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 15:47:21 -0700 Subject: [PATCH 13/27] More compaction. --- .../openequivariance/core/ConvolutionBase.py | 170 +++++++++--------- tests/conv_test.py | 9 +- 2 files changed, 89 insertions(+), 90 deletions(-) diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index ec5f5905..84f4445c 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -144,13 +144,6 @@ def test_correctness_forward( check_reproducible=True, high_precision_ref=False, ): - def maybe_transpose_input_for_test_impl(x, irreps): - if self.config.layout == "ir_mul": - return IrrepLayoutUtils.transpose_irrep_layout( - x, irreps, "mul_ir", "ir_mul" - ) - return x - if reference_implementation is None: from openequivariance._torch.E3NNConv import E3NNConv @@ -192,23 +185,29 @@ def maybe_transpose_input_for_test_impl(x, irreps): ref_out[:] = ref_tp.forward(**args).cpu().numpy() - test_out = out.copy() + run_in1, run_in2, run_weights, test_out = [ + buf.copy() for buf in (in1, in2, weights, out) + ] + run_in1, run_in2 = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", self.config.layout + ) + for arr, irreps in zip( + (run_in1, run_in2), + (self.config.irreps_in1, self.config.irreps_in2), + ) + ] self.forward_cpu( - L1_in=maybe_transpose_input_for_test_impl( - in1.copy(), self.config.irreps_in1 - ), - L2_in=maybe_transpose_input_for_test_impl( - in2.copy(), self.config.irreps_in2 - ), - weights=weights.copy(), + L1_in=run_in1, + L2_in=run_in2, + weights=run_weights, L3_out=test_out, graph=graph, ) - if self.config.layout == "ir_mul": - test_out = IrrepLayoutUtils.transpose_irrep_layout( - test_out, self.config.irreps_out, "ir_mul", "mul_ir" - ) + test_out = IrrepLayoutUtils.transpose_irrep_layout( + test_out, self.config.irreps_out, self.config.layout, "mul_ir" + ) for name, to_check, ground_truth in [("output", ref_out, test_out)]: result[name] = check_similiarity(name, to_check, ground_truth, thresh) @@ -221,22 +220,29 @@ def maybe_transpose_input_for_test_impl(x, irreps): for i in range(num_trials): repeated_run = out.copy() + rep_in1, rep_in2, rep_weights = [ + buf.copy() for buf in (in1, in2, weights) + ] + rep_in1, rep_in2 = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", self.config.layout + ) + for arr, irreps in zip( + (rep_in1, rep_in2), + (self.config.irreps_in1, self.config.irreps_in2), + ) + ] self.forward_cpu( - L1_in=maybe_transpose_input_for_test_impl( - in1.copy(), self.config.irreps_in1 - ), - L2_in=maybe_transpose_input_for_test_impl( - in2.copy(), self.config.irreps_in2 - ), - weights=weights.copy(), + L1_in=rep_in1, + L2_in=rep_in2, + weights=rep_weights, L3_out=repeated_run, graph=graph, ) - if self.config.layout == "ir_mul": - repeated_run = IrrepLayoutUtils.transpose_irrep_layout( - repeated_run, self.config.irreps_out, "ir_mul", "mul_ir" - ) + repeated_run = IrrepLayoutUtils.transpose_irrep_layout( + repeated_run, self.config.irreps_out, self.config.layout, "mul_ir" + ) for name, to_check, ground_truth in [ ("output", repeated_run, test_out) @@ -413,13 +419,6 @@ def test_correctness_backward( reference_implementation=None, high_precision_ref=False, ): - def maybe_transpose_input_for_test_impl(x, irreps): - if self.config.layout == "ir_mul": - return IrrepLayoutUtils.transpose_irrep_layout( - x, irreps, "mul_ir", "ir_mul" - ) - return x - if reference_implementation is None: from openequivariance._torch.E3NNConv import E3NNConv @@ -469,20 +468,23 @@ def maybe_transpose_input_for_test_impl(x, irreps): test_in1_grad = in1_grad.copy() test_in2_grad = in2_grad.copy() - test_L3_grad = out_grad.copy() - if self.config.layout == "ir_mul": - test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout( - test_L3_grad, self.config.irreps_out, "mul_ir", "ir_mul" + test_in1, test_in2, test_L3_grad = [ + buf.copy() for buf in (in1, in2, out_grad) + ] + test_in1, test_in2, test_L3_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", self.config.layout + ) + for arr, irreps in zip( + (test_in1, test_in2, test_L3_grad), + (self.config.irreps_in1, self.config.irreps_in2, self.config.irreps_out), ) + ] self.backward_cpu( - L1_in=maybe_transpose_input_for_test_impl( - in1.copy(), self.config.irreps_in1 - ), + L1_in=test_in1, L1_grad=test_in1_grad, - L2_in=maybe_transpose_input_for_test_impl( - in2.copy(), self.config.irreps_in2 - ), + L2_in=test_in2, L2_grad=test_in2_grad, L3_grad=test_L3_grad, weights=weights.copy(), @@ -490,13 +492,15 @@ def maybe_transpose_input_for_test_impl(x, irreps): graph=graph, ) - if self.config.layout == "ir_mul": - test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout( - test_in1_grad, self.config.irreps_in1, "ir_mul", "mul_ir" + test_in1_grad, test_in2_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, self.config.layout, "mul_ir" ) - test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout( - test_in2_grad, self.config.irreps_in2, "ir_mul", "mul_ir" + for arr, irreps in zip( + (test_in1_grad, test_in2_grad), + (self.config.irreps_in1, self.config.irreps_in2), ) + ] for name, to_check, ground_truth, threshold in [ ("weight_grad", test_weights_grad, ref_weights_grad, thresh), @@ -515,13 +519,6 @@ def test_correctness_double_backward( reference_implementation=None, high_precision_ref=False, ): - def maybe_transpose_input_for_test_impl(tp, x, irreps): - if tp is self and tp.config.layout == "ir_mul": - return IrrepLayoutUtils.transpose_irrep_layout( - x, irreps, "mul_ir", "ir_mul" - ) - return x - buffers = get_random_buffers_double_backward_conv( self.config, graph.node_count, graph.nnz, prng_seed ) @@ -542,6 +539,7 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps): result = {"thresh": thresh} tensors = [] for i, tp in enumerate([self, reference_tp]): + is_test_impl = i == 0 buffers_copy = [buf.copy() for buf in buffers] if i == 1 and high_precision_ref: @@ -558,21 +556,25 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps): weights_dgrad, not tp.config.shared_weights ) - db_in1 = maybe_transpose_input_for_test_impl(tp, in1, tp.config.irreps_in1) - db_in2 = maybe_transpose_input_for_test_impl(tp, in2, tp.config.irreps_in2) - db_out_grad = out_grad - db_in1_dgrad = in1_dgrad - db_in2_dgrad = in2_dgrad - if tp is self and tp.config.layout == "ir_mul": - db_out_grad = IrrepLayoutUtils.transpose_irrep_layout( - out_grad, tp.config.irreps_out, "mul_ir", "ir_mul" - ) - db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout( - in1_dgrad, tp.config.irreps_in1, "mul_ir", "ir_mul" - ) - db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout( - in2_dgrad, tp.config.irreps_in2, "mul_ir", "ir_mul" - ) + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad) + ] + if is_test_impl: + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", tp.config.layout + ) + for arr, irreps in zip( + (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), + ( + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + ), + ) + ] in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( db_in1, @@ -585,16 +587,16 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps): graph, ) - if tp is self and tp.config.layout == "ir_mul": - out_dgrad = IrrepLayoutUtils.transpose_irrep_layout( - out_dgrad, tp.config.irreps_out, "ir_mul", "mul_ir" - ) - in1_grad = IrrepLayoutUtils.transpose_irrep_layout( - in1_grad, tp.config.irreps_in1, "ir_mul", "mul_ir" - ) - in2_grad = IrrepLayoutUtils.transpose_irrep_layout( - in2_grad, tp.config.irreps_in2, "ir_mul", "mul_ir" - ) + if is_test_impl: + out_dgrad, in1_grad, in2_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, tp.config.layout, "mul_ir" + ) + for arr, irreps in zip( + (out_dgrad, in1_grad, in2_grad), + (tp.config.irreps_out, tp.config.irreps_in1, tp.config.irreps_in2), + ) + ] tensors.append( ( diff --git a/tests/conv_test.py b/tests/conv_test.py index 6ef0a36a..d61605d0 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -284,8 +284,9 @@ def conv_object(self, request, problem, extra_conv_constructor_args): return module.to(switch_map[problem.irrep_dtype]) -def ir_mul_representative_uvu_problems(): - return [ +class TestIrMulLayout(ConvCorrectness): + production_model_tpps = mace_problems() + \ + [ oeq.TPProblem( "5x5e", "1x3e", @@ -306,10 +307,6 @@ def ir_mul_representative_uvu_problems(): ), ] - -class TestIrMulLayout(ConvCorrectness): - production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems() - @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): problem = request.param.clone() From 95783ef6950b91dbf57982a169ff43f696a8ab3c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 15:51:07 -0700 Subject: [PATCH 14/27] More test cleaning. --- tests/batch_test.py | 29 ++++------------------------- tests/input_validation_test.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index fd00c65e..1df1fb6b 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -272,9 +272,9 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): tp.to(switch_map[problem.irrep_dtype]) return tp, tp.config - -def ir_mul_representative_uvu_problems(): - return [ +class TestIrMulLayoutMACE(TPCorrectness): + production_model_tpps = mace_problems() + \ + [ oeq.TPProblem( "5x5e", "1x3e", @@ -293,11 +293,7 @@ def ir_mul_representative_uvu_problems(): internal_weights=False, label="ir_mul_repr_13x1x13_l535", ), - ] - - -class TestIrMulLayoutMACE(TPCorrectness): - production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems() + ] @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): @@ -307,23 +303,6 @@ def problem(self, request, dtype): return problem -def test_ir_mul_rejects_uvw_problem(dtype): - problem = oeq.TPProblem( - "5x5e", - "1x3e", - "5x5e", - [(0, 0, 0, "uvw", True)], - shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, - weight_dtype=dtype, - layout="ir_mul", - ) - - with pytest.raises(AssertionError, match="layout='ir_mul'"): - oeq.TensorProduct(problem) - - class TestTorchToSubmodule: """Test that TensorProduct works correctly as a submodule when parent's .to() is called""" diff --git a/tests/input_validation_test.py b/tests/input_validation_test.py index 8db1d2a4..bbf39e73 100644 --- a/tests/input_validation_test.py +++ b/tests/input_validation_test.py @@ -138,3 +138,20 @@ def test_cpp_checks_forward_dtype(executable_and_buffers, subtests): with pytest.raises(RuntimeError, match=r"Dtype mismatch"): buffers[i] = buffers[i].to(dtype=torch.bfloat16) executable(*buffers) + + +def test_ir_mul_rejects_uvw_problem(dtype): + problem = TPProblem( + "5x5e", + "1x3e", + "5x5e", + [(0, 0, 0, "uvw", True)], + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + layout="ir_mul", + ) + + with pytest.raises(AssertionError, match="layout='ir_mul'"): + TensorProduct(problem) \ No newline at end of file From 9de117e94a1520121c992ac47343e0baca40205a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 16:58:50 -0700 Subject: [PATCH 15/27] More refactoring. --- .../_torch/CUETensorProduct.py | 2 +- .../_torch/E3NNTensorProduct.py | 2 +- .../openequivariance/_torch/TensorProduct.py | 2 +- .../_torch/TensorProductConv.py | 2 +- .../_torch/extlib/__init__.py | 2 +- .../benchmark/ConvBenchmarkSuite.py | 18 +- .../benchmark/TestBenchmarkSuite.py | 4 +- .../benchmark/benchmark_utils.py | 2 +- .../{correctness_utils.py => correctness.py} | 332 +++++++++++++++++- .../{logging_utils.py => logging.py} | 0 .../benchmark/perf_metrics_utils.py | 2 +- .../core/ComputationSchedule.py | 2 +- .../openequivariance/core/ConvolutionBase.py | 330 +---------------- .../openequivariance/core/LoopUnrollTP.py | 2 +- .../core/TensorProductBase.py | 2 +- .../openequivariance/jax/TensorProductConv.py | 2 +- .../openequivariance/templates/macros.jinja | 53 ++- tests/batch_test.py | 2 +- tests/benchmark.py | 2 +- tests/conv_test.py | 14 +- 20 files changed, 396 insertions(+), 381 deletions(-) rename openequivariance/openequivariance/benchmark/{correctness_utils.py => correctness.py} (50%) rename openequivariance/openequivariance/benchmark/{logging_utils.py => logging.py} (100%) diff --git a/openequivariance/openequivariance/_torch/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py index 33b8db12..241b1df6 100644 --- a/openequivariance/openequivariance/_torch/CUETensorProduct.py +++ b/openequivariance/openequivariance/_torch/CUETensorProduct.py @@ -6,7 +6,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.benchmark.tpp_creation_utils import ( ChannelwiseTPP, FullyConnectedTPProblem, diff --git a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py index 067a7e6b..cacbc017 100644 --- a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py @@ -11,7 +11,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 254da414..c00530d5 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -3,7 +3,7 @@ from openequivariance._torch import extlib import torch from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.utils import ( reorder_torch, string_to_tensor, diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 30931151..e522c0e8 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -23,7 +23,7 @@ enum_to_torch_dtype, ) -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv logger = getLogger() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index be4113ec..81f243cf 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -8,7 +8,7 @@ import torch -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger oeq_root = str(Path(__file__).parent.parent.parent) diff --git a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index debcc65b..7f602e04 100644 --- a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -6,7 +6,12 @@ import numpy as np import openequivariance as oeq -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.correctness import ( + correctness_backward_conv, + correctness_double_backward_conv, + correctness_forward_conv, +) +from openequivariance.benchmark.logging import getLogger from openequivariance.core.ConvolutionBase import CoordGraph from openequivariance.benchmark.benchmark_utils import NpEncoder @@ -90,7 +95,8 @@ def run( if direction == "forward": if correctness: - correctness = conv.test_correctness_forward( + correctness = correctness_forward_conv( + conv, graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, @@ -105,7 +111,8 @@ def run( if direction == "backward": if correctness: - correctness = conv.test_correctness_backward( + correctness = correctness_backward_conv( + conv, graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, @@ -120,8 +127,9 @@ def run( if direction == "double_backward": if correctness: - correctness = conv.test_correctness_double_backward( - self.graph, + correctness = correctness_double_backward_conv( + conv, + graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, reference_implementation=self.reference_impl, diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index 37d20c46..c18bde00 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -10,9 +10,9 @@ from openequivariance._torch.extlib import DeviceProp from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.benchmark.logging import getLogger, bcolors from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.correctness_utils import ( +from openequivariance.benchmark.correctness import ( correctness_forward, correctness_backward, correctness_double_backward, diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index 68dc6f9f..31eee40c 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -15,7 +15,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.benchmark.logging import getLogger, bcolors logger = getLogger() diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness.py similarity index 50% rename from openequivariance/openequivariance/benchmark/correctness_utils.py rename to openequivariance/openequivariance/benchmark/correctness.py index 07b523db..2fe270a3 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness.py @@ -1,13 +1,17 @@ +import copy from typing import Optional, Union import numpy as np import numpy.linalg as la from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging_utils import bcolors, getLogger +from openequivariance.benchmark.logging import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( + get_random_buffers_backward_conv, get_random_buffers_backward, + get_random_buffers_double_backward_conv, get_random_buffers_double_backward, + get_random_buffers_forward_conv, get_random_buffers_forward, ) from openequivariance.core.e3nn_lite import TPProblem @@ -308,3 +312,329 @@ def correctness_double_backward( ) return result + + +def correctness_forward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + check_reproducible=True, + high_precision_ref=False, +): + global torch + import torch + + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + result = {"thresh": thresh} + + in1, in2, weights, out = get_random_buffers_forward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + reference_config = conv.config + if high_precision_ref: + reference_config = copy.deepcopy(conv.config) + reference_config.irrep_dtype = np.float64 + reference_config.weight_dtype = np.float64 + outputs = [] + + for i, impl in enumerate([conv, reference_implementation]): + is_test_impl = i == 0 + tp = impl if is_test_impl else impl(reference_config) + + run_in1, run_in2, run_weights, run_out = [ + buf.copy() for buf in (in1, in2, weights, out) + ] + + if not is_test_impl and high_precision_ref: + run_in1, run_in2, run_weights, run_out = [ + np.array(el, dtype=np.float64) + for el in (run_in1, run_in2, run_weights, run_out) + ] + + if is_test_impl: + run_in1, run_in2 = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", conv.config.layout + ) + for arr, irreps in zip( + (run_in1, run_in2), + (conv.config.irreps_in1, conv.config.irreps_in2), + ) + ] + conv.forward_cpu( + L1_in=run_in1, + L2_in=run_in2, + weights=run_weights, + L3_out=run_out, + graph=graph, + ) + + run_out = IrrepLayoutUtils.transpose_irrep_layout( + run_out, conv.config.irreps_out, conv.config.layout, "mul_ir" + ) + else: + args = { + "L1_in": run_in1, + "L2_in": run_in2, + "weights": run_weights, + "rows": graph.rows, + "cols": graph.cols, + } + + if tp.deterministic: + args["transpose_perm"] = graph.transpose_perm + + for key in args: + args[key] = torch.tensor(args[key], device="cuda") + + run_out[:] = tp.forward(**args).cpu().numpy() + + outputs.append(run_out) + + test_out, ref_out = outputs[0], outputs[1] + + for name, to_check, ground_truth in [("output", ref_out, test_out)]: + result[name] = check_similiarity(name, to_check, ground_truth, thresh) + + if check_reproducible: + num_trials = 5 + for name in ["output"]: + result[name]["num_reproducibility_trials"] = num_trials + result[name]["bitwise_reproducible"] = True + + for _ in range(num_trials): + repeated_run = out.copy() + rep_in1, rep_in2, rep_weights = [ + buf.copy() for buf in (in1, in2, weights) + ] + rep_in1, rep_in2 = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", conv.config.layout + ) + for arr, irreps in zip( + (rep_in1, rep_in2), + (conv.config.irreps_in1, conv.config.irreps_in2), + ) + ] + conv.forward_cpu( + L1_in=rep_in1, + L2_in=rep_in2, + weights=rep_weights, + L3_out=repeated_run, + graph=graph, + ) + + repeated_run = IrrepLayoutUtils.transpose_irrep_layout( + repeated_run, conv.config.irreps_out, conv.config.layout, "mul_ir" + ) + + result["output"]["bitwise_reproducible"] = bool( + result["output"]["bitwise_reproducible"] + and np.all(repeated_run == test_out) + ) + + return result + + +def correctness_backward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, +): + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + result = {"thresh": thresh} + + buffers = get_random_buffers_backward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + reference_problem = conv.config + + if high_precision_ref: + reference_problem = copy.deepcopy(conv.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + grads = [] + for i, impl in enumerate([conv, reference_implementation]): + is_test_impl = i == 0 + tp = impl if is_test_impl else impl(reference_problem) + + run_in1, run_in2, run_out_grad, run_weights, run_weights_grad, run_in1_grad, run_in2_grad = [ + buf.copy() for buf in buffers + ] + + if not is_test_impl and high_precision_ref: + ( + run_in1, + run_in2, + run_out_grad, + run_weights, + run_weights_grad, + run_in1_grad, + run_in2_grad, + ) = [np.array(el, dtype=np.float64) for el in buffers] + + if is_test_impl: + run_in1, run_in2, run_out_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", conv.config.layout + ) + for arr, irreps in zip( + (run_in1, run_in2, run_out_grad), + (conv.config.irreps_in1, conv.config.irreps_in2, conv.config.irreps_out), + ) + ] + + tp.backward_cpu( + L1_in=run_in1, + L1_grad=run_in1_grad, + L2_in=run_in2, + L2_grad=run_in2_grad, + L3_grad=run_out_grad, + weights=run_weights, + weights_grad=run_weights_grad, + graph=graph, + ) + + if is_test_impl: + run_in1_grad, run_in2_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, conv.config.layout, "mul_ir" + ) + for arr, irreps in zip( + (run_in1_grad, run_in2_grad), + (conv.config.irreps_in1, conv.config.irreps_in2), + ) + ] + + grads.append((run_weights_grad, run_in1_grad, run_in2_grad)) + + for name, to_check, ground_truth, threshold in [ + ("weight_grad", grads[0][0], grads[1][0], thresh), + ("in1_grad", grads[0][1], grads[1][1], thresh), + ("in2_grad", grads[0][2], grads[1][2], thresh), + ]: + result[name] = check_similiarity(name, to_check, ground_truth, threshold) + + return result + + +def correctness_double_backward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, +): + buffers = get_random_buffers_double_backward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + reference_problem = conv.config + if high_precision_ref: + reference_problem = copy.deepcopy(conv.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + + reference_tp = reference_implementation(reference_problem, torch_op=True) + + result = {"thresh": thresh} + tensors = [] + for i, tp in enumerate([conv, reference_tp]): + is_test_impl = i == 0 + buffers_copy = [buf.copy() for buf in buffers] + + if i == 1 and high_precision_ref: + buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] + + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( + buffers_copy + ) + + weights_reordered = tp.reorder_weights_from_e3nn( + weights, not tp.config.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, not tp.config.shared_weights + ) + + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad) + ] + if is_test_impl: + db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, "mul_ir", tp.config.layout + ) + for arr, irreps in zip( + (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), + ( + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + ), + ) + ] + + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( + db_in1, + db_in2, + db_out_grad, + weights_reordered, + weights_dgrad_reordered, + db_in1_dgrad, + db_in2_dgrad, + graph, + ) + + if is_test_impl: + out_dgrad, in1_grad, in2_grad = [ + IrrepLayoutUtils.transpose_irrep_layout( + arr, irreps, tp.config.layout, "mul_ir" + ) + for arr, irreps in zip( + (out_dgrad, in1_grad, in2_grad), + (tp.config.irreps_out, tp.config.irreps_in1, tp.config.irreps_in2), + ) + ] + + tensors.append( + ( + out_dgrad, + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not conv.config.shared_weights + ), + ) + ) + + for name, to_check, ground_truth in [ + ("output_grad", tensors[0][0], tensors[1][0]), + ("in1_grad", tensors[0][1], tensors[1][1]), + ("in2_grad", tensors[0][2], tensors[1][2]), + ("weights_grad", tensors[0][3], tensors[1][3]), + ]: + result[name] = check_similiarity(name, to_check, ground_truth, thresh) + + return result diff --git a/openequivariance/openequivariance/benchmark/logging_utils.py b/openequivariance/openequivariance/benchmark/logging.py similarity index 100% rename from openequivariance/openequivariance/benchmark/logging_utils.py rename to openequivariance/openequivariance/benchmark/logging.py diff --git a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py index 212f05f4..01bd5836 100644 --- a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py @@ -6,7 +6,7 @@ ) from openequivariance.core.e3nn_lite import TPProblem, wigner_3j -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger import numpy as np logger = getLogger() diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 461ac565..57094706 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -2,7 +2,7 @@ import numpy as np -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j logger = getLogger() diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 84f4445c..92914efc 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -2,8 +2,7 @@ import numpy as np -from openequivariance.benchmark.correctness_utils import check_similiarity -from openequivariance.benchmark.logging_utils import bcolors, getLogger +from openequivariance.benchmark.logging import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_backward_conv, get_random_buffers_double_backward_conv, @@ -135,125 +134,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): def name(): raise NotImplementedError() - def test_correctness_forward( - self, - graph, - thresh, - prng_seed, - reference_implementation=None, - check_reproducible=True, - high_precision_ref=False, - ): - if reference_implementation is None: - from openequivariance._torch.E3NNConv import E3NNConv - - reference_implementation = E3NNConv - - result = {"thresh": thresh} - - in1, in2, weights, out = get_random_buffers_forward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) - ref_in1, ref_in2, ref_weights, ref_out = [ - buf.copy() for buf in [in1, in2, weights, out] - ] - - reference_config = self.config - if high_precision_ref: - reference_config = copy.deepcopy(self.config) - reference_config.irrep_dtype = np.float64 - reference_config.weight_dtype = np.float64 - ref_in1, ref_in2, ref_weights, ref_out = [ - np.array(el, dtype=np.float64) - for el in [ref_in1, ref_in2, ref_weights, ref_out] - ] - - args = { - "L1_in": ref_in1, - "L2_in": ref_in2, - "weights": ref_weights, - "rows": graph.rows, - "cols": graph.cols, - } - - ref_tp = reference_implementation(reference_config) - if ref_tp.deterministic: - args["transpose_perm"] = graph.transpose_perm - - for key in args: - args[key] = torch.tensor(args[key], device="cuda") - - ref_out[:] = ref_tp.forward(**args).cpu().numpy() - - run_in1, run_in2, run_weights, test_out = [ - buf.copy() for buf in (in1, in2, weights, out) - ] - run_in1, run_in2 = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", self.config.layout - ) - for arr, irreps in zip( - (run_in1, run_in2), - (self.config.irreps_in1, self.config.irreps_in2), - ) - ] - self.forward_cpu( - L1_in=run_in1, - L2_in=run_in2, - weights=run_weights, - L3_out=test_out, - graph=graph, - ) - - test_out = IrrepLayoutUtils.transpose_irrep_layout( - test_out, self.config.irreps_out, self.config.layout, "mul_ir" - ) - - for name, to_check, ground_truth in [("output", ref_out, test_out)]: - result[name] = check_similiarity(name, to_check, ground_truth, thresh) - - if check_reproducible: - num_trials = 5 - for name in ["output"]: - result[name]["num_reproducibility_trials"] = num_trials - result[name]["bitwise_reproducible"] = True - - for i in range(num_trials): - repeated_run = out.copy() - rep_in1, rep_in2, rep_weights = [ - buf.copy() for buf in (in1, in2, weights) - ] - rep_in1, rep_in2 = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", self.config.layout - ) - for arr, irreps in zip( - (rep_in1, rep_in2), - (self.config.irreps_in1, self.config.irreps_in2), - ) - ] - self.forward_cpu( - L1_in=rep_in1, - L2_in=rep_in2, - weights=rep_weights, - L3_out=repeated_run, - graph=graph, - ) - - repeated_run = IrrepLayoutUtils.transpose_irrep_layout( - repeated_run, self.config.irreps_out, self.config.layout, "mul_ir" - ) - - for name, to_check, ground_truth in [ - ("output", repeated_run, test_out) - ]: - result[name]["bitwise_reproducible"] = bool( - result[name]["bitwise_reproducible"] - and np.all(repeated_run == test_out) - ) - - return result - def benchmark_forward( self, num_warmup, num_iter, graph, prng_seed=12345, kernel_names=["forward"] ): @@ -411,214 +291,6 @@ def calculate_bench_stats( ) return result - def test_correctness_backward( - self, - graph, - thresh, - prng_seed, - reference_implementation=None, - high_precision_ref=False, - ): - if reference_implementation is None: - from openequivariance._torch.E3NNConv import E3NNConv - - reference_implementation = E3NNConv - - result = {"thresh": thresh} - - buffers = get_random_buffers_backward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) - reference_buffers = [buf.copy() for buf in buffers] - reference_problem = self.config - - if high_precision_ref: - reference_problem = copy.deepcopy(self.config) - reference_problem.irrep_dtype = np.float64 - reference_problem.weight_dtype = np.float64 - reference_buffers = [ - np.array(el, dtype=np.float64) for el in reference_buffers - ] - - ref_tp = reference_implementation(reference_problem) - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = buffers - ( - ref_in1, - ref_in2, - ref_out_grad, - ref_weights, - ref_weights_grad, - ref_in1_grad, - ref_in2_grad, - ) = reference_buffers - - ref_tp.backward_cpu( - L1_in=ref_in1, - L1_grad=ref_in1_grad, - L2_in=ref_in2, - L2_grad=ref_in2_grad, - L3_grad=ref_out_grad, - weights=ref_weights, - weights_grad=ref_weights_grad, - graph=graph, - ) - - # run test version - test_weights_grad = weights_grad.copy() - test_in1_grad = in1_grad.copy() - test_in2_grad = in2_grad.copy() - - test_in1, test_in2, test_L3_grad = [ - buf.copy() for buf in (in1, in2, out_grad) - ] - test_in1, test_in2, test_L3_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", self.config.layout - ) - for arr, irreps in zip( - (test_in1, test_in2, test_L3_grad), - (self.config.irreps_in1, self.config.irreps_in2, self.config.irreps_out), - ) - ] - - self.backward_cpu( - L1_in=test_in1, - L1_grad=test_in1_grad, - L2_in=test_in2, - L2_grad=test_in2_grad, - L3_grad=test_L3_grad, - weights=weights.copy(), - weights_grad=test_weights_grad, - graph=graph, - ) - - test_in1_grad, test_in2_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, self.config.layout, "mul_ir" - ) - for arr, irreps in zip( - (test_in1_grad, test_in2_grad), - (self.config.irreps_in1, self.config.irreps_in2), - ) - ] - - for name, to_check, ground_truth, threshold in [ - ("weight_grad", test_weights_grad, ref_weights_grad, thresh), - ("in1_grad", test_in1_grad, ref_in1_grad, thresh), - ("in2_grad", test_in2_grad, ref_in2_grad, thresh), - ]: - result[name] = check_similiarity(name, to_check, ground_truth, threshold) - - return result - - def test_correctness_double_backward( - self, - graph, - thresh, - prng_seed, - reference_implementation=None, - high_precision_ref=False, - ): - buffers = get_random_buffers_double_backward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) - - if reference_implementation is None: - from openequivariance._torch.E3NNConv import E3NNConv - - reference_implementation = E3NNConv - - reference_problem = self.config - if high_precision_ref: - reference_problem = copy.deepcopy(self.config) - reference_problem.irrep_dtype = np.float64 - reference_problem.weight_dtype = np.float64 - - reference_tp = reference_implementation(reference_problem, torch_op=True) - - result = {"thresh": thresh} - tensors = [] - for i, tp in enumerate([self, reference_tp]): - is_test_impl = i == 0 - buffers_copy = [buf.copy() for buf in buffers] - - if i == 1 and high_precision_ref: - buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] - - in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( - buffers_copy - ) - - weights_reordered = tp.reorder_weights_from_e3nn( - weights, not tp.config.shared_weights - ) - weights_dgrad_reordered = tp.reorder_weights_from_e3nn( - weights_dgrad, not tp.config.shared_weights - ) - - db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ - buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad) - ] - if is_test_impl: - db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", tp.config.layout - ) - for arr, irreps in zip( - (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), - ( - tp.config.irreps_in1, - tp.config.irreps_in2, - tp.config.irreps_out, - tp.config.irreps_in1, - tp.config.irreps_in2, - ), - ) - ] - - in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( - db_in1, - db_in2, - db_out_grad, - weights_reordered, - weights_dgrad_reordered, - db_in1_dgrad, - db_in2_dgrad, - graph, - ) - - if is_test_impl: - out_dgrad, in1_grad, in2_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, tp.config.layout, "mul_ir" - ) - for arr, irreps in zip( - (out_dgrad, in1_grad, in2_grad), - (tp.config.irreps_out, tp.config.irreps_in1, tp.config.irreps_in2), - ) - ] - - tensors.append( - ( - out_dgrad, - in1_grad, - in2_grad, - tp.reorder_weights_to_e3nn( - weights_grad, has_batch_dim=not self.config.shared_weights - ), - ) - ) - - for name, to_check, ground_truth in [ - ("output_grad", tensors[0][0], tensors[1][0]), - ("in1_grad", tensors[0][1], tensors[1][1]), - ("in2_grad", tensors[0][2], tensors[1][2]), - ("weights_grad", tensors[0][3], tensors[1][3]), - ]: - result[name] = check_similiarity(name, to_check, ground_truth, thresh) - - return result - def scatter_add_wrapper(messages, rows, target_dim): L3_dim = messages.size(1) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 43c9e244..e07387b0 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -2,7 +2,7 @@ import numpy as np -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.core.ComputationSchedule import ( ComputationSchedule, SMEMCapacityException, diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index b5d3831f..fa065b7b 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -1,7 +1,7 @@ import numpy as np from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.core.utils import benchmark logger = getLogger() diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index c14637a1..4645ae86 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -9,7 +9,7 @@ from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance.jax.jvp import conv_prim from openequivariance.jax.vjp import conv_func diff --git a/openequivariance/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja index 8622b656..f9108822 100644 --- a/openequivariance/openequivariance/templates/macros.jinja +++ b/openequivariance/openequivariance/templates/macros.jinja @@ -84,14 +84,13 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- endfor %} {%- elif map.src_views[0].layout == "ir_mul" %} {%- for idx in map.idxs %} - {%- set src_view = map.src_views[idx] %} - {%- set src_mul_ir = map.src_irreps[idx] %} - {%- set dst_idx = map.src_dst_map[idx] %} - {%- set dst_rng = map.dst_irreps.slices()[dst_idx] %} - {%- set dim = src_mul_ir.ir.dim %} - {%- set mul = src_mul_ir.mul %} - {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + {%- set m = namespace( + src_view=map.src_views[idx], + src_mul_ir=map.src_irreps[idx], + dst_rng=map.dst_irreps.slices()[map.src_dst_map[idx]] + ) %} + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}];) {%- endfor %} {%- endfor %} {%- endif %} @@ -106,14 +105,13 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- endfor %} {%- elif map.src_views[0].layout == "ir_mul" %} {%- for idx in map.idxs %} - {%- set src_view = map.src_views[idx] %} - {%- set src_mul_ir = map.src_irreps[idx] %} - {%- set dst_idx = map.src_dst_map[idx] %} - {%- set dst_rng = map.dst_irreps.slices()[dst_idx] %} - {%- set dim = src_mul_ir.ir.dim %} - {%- set mul = src_mul_ir.mul %} - {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];) + {%- set m = namespace( + src_view=map.src_views[idx], + src_mul_ir=map.src_irreps[idx], + dst_rng=map.dst_irreps.slices()[map.src_dst_map[idx]] + ) %} + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}];) {%- endfor %} {%- endfor %} {%- endif %} @@ -136,23 +134,22 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. {%- endfor %} {%- elif map.src_views[0].layout == "ir_mul" %} {%- for idx in map.idxs %} - {%- set src_view = map.src_views[idx] %} - {%- set src_mul_ir = map.src_irreps[idx] %} - {%- set dst_idx = map.src_dst_map[idx] %} - {%- set dst_rng = map.dst_irreps.slices()[dst_idx] %} - {%- set dim = src_mul_ir.ir.dim %} - {%- set mul = src_mul_ir.mul %} + {%- set m = namespace( + src_view=map.src_views[idx], + src_mul_ir=map.src_irreps[idx], + dst_rng=map.dst_irreps.slices()[map.src_dst_map[idx]] + ) %} {%- if map.storeback_procedure[idx] == "write" %} - {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id];) {%- endfor %} {%- elif map.storeback_procedure[idx] == "accumulate" %} - {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];) + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id];) {%- endfor %} {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} - {%- for i in range(dim) %} - ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id]);) + {%- for i in range(m.src_mul_ir.ir.dim) %} + ROW_OPERATION({{m.src_mul_ir.mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{m.src_view.ir_mul_offset + i * m.src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{m.dst_rng.start + i * m.src_mul_ir.mul}} + {{loop_var}} + lane_id]);) {%- endfor %} {%- endif %} {%- endfor %} diff --git a/tests/batch_test.py b/tests/batch_test.py index 1df1fb6b..8263cec0 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from openequivariance.benchmark.correctness_utils import ( +from openequivariance.benchmark.correctness import ( correctness_backward, correctness_double_backward, correctness_forward, diff --git a/tests/benchmark.py b/tests/benchmark.py index 829cc46c..441a3ca0 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -10,7 +10,7 @@ import numpy as np -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.benchmark.logging import getLogger from openequivariance._torch.extlib import DeviceProp from openequivariance._torch.E3NNTensorProduct import ( E3NNTensorProduct, diff --git a/tests/conv_test.py b/tests/conv_test.py index d61605d0..100f6ac0 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -6,6 +6,11 @@ import numpy as np import openequivariance as oeq from openequivariance.benchmark.ConvBenchmarkSuite import load_graph +from openequivariance.benchmark.correctness import ( + correctness_backward_conv, + correctness_double_backward_conv, + correctness_forward_conv, +) from itertools import product import torch @@ -89,7 +94,8 @@ def test_tp_fwd(self, conv_object, graph): if conv_object is None: pytest.skip("'conv_object' fixture returned None, skipping") - result = conv_object.test_correctness_forward( + result = correctness_forward_conv( + conv_object, graph, thresh=self.thresh("fwd"), prng_seed=12345, @@ -102,7 +108,8 @@ def test_tp_bwd(self, conv_object, graph): if conv_object is None: pytest.skip("'conv_object' fixture returned None, skipping") - result = conv_object.test_correctness_backward( + result = correctness_backward_conv( + conv_object, graph, thresh=self.thresh("bwd"), prng_seed=12345, @@ -117,7 +124,8 @@ def test_tp_double_bwd(self, conv_object, graph): if conv_object is None: pytest.skip("'conv_object' fixture returned None, skipping") - result = conv_object.test_correctness_double_backward( + result = correctness_double_backward_conv( + conv_object, graph, thresh=self.thresh("double_bwd"), prng_seed=12345, From aa9bbf64ef1b8b016d42a4fe90286b500b6a9470 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 17:12:27 -0700 Subject: [PATCH 16/27] Ruff. --- .../openequivariance/benchmark/correctness.py | 100 +++++++++--------- .../openequivariance/core/ConvolutionBase.py | 3 +- .../openequivariance/core/utils.py | 92 ++++++++-------- tests/batch_test.py | 6 +- tests/conv_test.py | 3 +- tests/input_validation_test.py | 2 +- 6 files changed, 102 insertions(+), 104 deletions(-) diff --git a/openequivariance/openequivariance/benchmark/correctness.py b/openequivariance/openequivariance/benchmark/correctness.py index 2fe270a3..17bbd074 100644 --- a/openequivariance/openequivariance/benchmark/correctness.py +++ b/openequivariance/openequivariance/benchmark/correctness.py @@ -16,7 +16,7 @@ ) from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.core.utils import IrrepLayoutUtils +from openequivariance.core.utils import transpose_irrep_layout logger = getLogger() @@ -87,29 +87,31 @@ def correctness_forward( outputs = [] for i, impl in enumerate([test_implementation, reference_implementation]): - is_test_impl = (i == 0) + is_test_impl = i == 0 tp = instantiate_implementation(impl, problem) uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct) - run_in1, run_in2, run_weights, run_out = [ buf.copy() for buf in (in1, in2, weights, out) ] + run_in1, run_in2, run_weights, run_out = [ + buf.copy() for buf in (in1, in2, weights, out) + ] if problem.shared_weights and uses_cue: run_weights = run_weights[np.newaxis, :] - # Transpose inputs, if necessary, for the test implementation + # Transpose inputs, if necessary, for the test implementation if is_test_impl: run_in1, run_in2 = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", tp.config.layout - ) for arr, irreps in zip( - (run_in1, run_in2), - (problem.irreps_in1, problem.irreps_in2) + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + (run_in1, run_in2), (problem.irreps_in1, problem.irreps_in2) ) ] - tp.forward_cpu(L1_in=run_in1, L2_in=run_in2, L3_out=run_out, weights=run_weights) + tp.forward_cpu( + L1_in=run_in1, L2_in=run_in2, L3_out=run_out, weights=run_weights + ) if is_test_impl: - run_out = IrrepLayoutUtils.transpose_irrep_layout( + run_out = transpose_irrep_layout( run_out, problem.irreps_out, tp.config.layout, "mul_ir" ) @@ -147,7 +149,15 @@ def correctness_backward( is_test_impl = i == 0 tp = instantiate_implementation(impl, problem) - run_in1, run_in2, run_L3_grad, run_weights, run_weights_grad, run_in1_grad, run_in2_grad = [ + ( + run_in1, + run_in2, + run_L3_grad, + run_weights, + run_weights_grad, + run_in1_grad, + run_in2_grad, + ) = [ buf.copy() for buf in (in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad) ] @@ -159,9 +169,7 @@ def correctness_backward( if is_test_impl: run_in1, run_in2, run_L3_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", tp.config.layout - ) + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) for arr, irreps in zip( (run_in1, run_in2, run_L3_grad), (problem.irreps_in1, problem.irreps_in2, problem.irreps_out), @@ -180,9 +188,7 @@ def correctness_backward( if is_test_impl: run_in1_grad, run_in2_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, tp.config.layout, "mul_ir" - ) + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") for arr, irreps in zip( (run_in1_grad, run_in2_grad), (problem.irreps_in1, problem.irreps_in2), @@ -254,9 +260,7 @@ def correctness_double_backward( if is_test_impl: db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", tp.config.layout - ) + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) for arr, irreps in zip( (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), ( @@ -281,9 +285,7 @@ def correctness_double_backward( if is_test_impl: out_dgrad, in1_grad, in2_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, tp.config.layout, "mul_ir" - ) + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") for arr, irreps in zip( (out_dgrad, in1_grad, in2_grad), (problem.irreps_out, problem.irreps_in1, problem.irreps_in2), @@ -359,9 +361,7 @@ def correctness_forward_conv( if is_test_impl: run_in1, run_in2 = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", conv.config.layout - ) + transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout) for arr, irreps in zip( (run_in1, run_in2), (conv.config.irreps_in1, conv.config.irreps_in2), @@ -375,7 +375,7 @@ def correctness_forward_conv( graph=graph, ) - run_out = IrrepLayoutUtils.transpose_irrep_layout( + run_out = transpose_irrep_layout( run_out, conv.config.irreps_out, conv.config.layout, "mul_ir" ) else: @@ -410,13 +410,9 @@ def correctness_forward_conv( for _ in range(num_trials): repeated_run = out.copy() - rep_in1, rep_in2, rep_weights = [ - buf.copy() for buf in (in1, in2, weights) - ] + rep_in1, rep_in2, rep_weights = [buf.copy() for buf in (in1, in2, weights)] rep_in1, rep_in2 = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", conv.config.layout - ) + transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout) for arr, irreps in zip( (rep_in1, rep_in2), (conv.config.irreps_in1, conv.config.irreps_in2), @@ -430,7 +426,7 @@ def correctness_forward_conv( graph=graph, ) - repeated_run = IrrepLayoutUtils.transpose_irrep_layout( + repeated_run = transpose_irrep_layout( repeated_run, conv.config.irreps_out, conv.config.layout, "mul_ir" ) @@ -471,9 +467,15 @@ def correctness_backward_conv( is_test_impl = i == 0 tp = impl if is_test_impl else impl(reference_problem) - run_in1, run_in2, run_out_grad, run_weights, run_weights_grad, run_in1_grad, run_in2_grad = [ - buf.copy() for buf in buffers - ] + ( + run_in1, + run_in2, + run_out_grad, + run_weights, + run_weights_grad, + run_in1_grad, + run_in2_grad, + ) = [buf.copy() for buf in buffers] if not is_test_impl and high_precision_ref: ( @@ -488,12 +490,14 @@ def correctness_backward_conv( if is_test_impl: run_in1, run_in2, run_out_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", conv.config.layout - ) + transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout) for arr, irreps in zip( (run_in1, run_in2, run_out_grad), - (conv.config.irreps_in1, conv.config.irreps_in2, conv.config.irreps_out), + ( + conv.config.irreps_in1, + conv.config.irreps_in2, + conv.config.irreps_out, + ), ) ] @@ -510,9 +514,7 @@ def correctness_backward_conv( if is_test_impl: run_in1_grad, run_in2_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, conv.config.layout, "mul_ir" - ) + transpose_irrep_layout(arr, irreps, conv.config.layout, "mul_ir") for arr, irreps in zip( (run_in1_grad, run_in2_grad), (conv.config.irreps_in1, conv.config.irreps_in2), @@ -581,9 +583,7 @@ def correctness_double_backward_conv( ] if is_test_impl: db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, "mul_ir", tp.config.layout - ) + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) for arr, irreps in zip( (db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad), ( @@ -609,9 +609,7 @@ def correctness_double_backward_conv( if is_test_impl: out_dgrad, in1_grad, in2_grad = [ - IrrepLayoutUtils.transpose_irrep_layout( - arr, irreps, tp.config.layout, "mul_ir" - ) + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") for arr, irreps in zip( (out_dgrad, in1_grad, in2_grad), (tp.config.irreps_out, tp.config.irreps_in1, tp.config.irreps_in2), diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 92914efc..7f36e4ce 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -5,11 +5,10 @@ from openequivariance.benchmark.logging import bcolors, getLogger from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_backward_conv, - get_random_buffers_double_backward_conv, get_random_buffers_forward_conv, ) from openequivariance.core.e3nn_lite import wigner_3j -from openequivariance.core.utils import IrrepLayoutUtils, benchmark +from openequivariance.core.utils import benchmark logger = getLogger() diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index f574ae4c..53638422 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -209,50 +209,52 @@ def hash_str_64(s: str) -> int: return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big") -class IrrepLayoutUtils: - @staticmethod - def transpose_irrep_layout( - array: np.ndarray, - irreps: Irreps, - src_layout: str, - dst_layout: str, - ) -> np.ndarray: - """ - Transpose irrep-packed feature arrays between `mul_ir` and `ir_mul` layouts. - - Expected input shape is `[..., irreps.dim]`. A new array is returned. - If `src_layout == dst_layout`, this returns a copy. - """ - if src_layout not in ("mul_ir", "ir_mul"): - raise ValueError(f"Unsupported src_layout: {src_layout}") - if dst_layout not in ("mul_ir", "ir_mul"): - raise ValueError(f"Unsupported dst_layout: {dst_layout}") - - x = np.asarray(array) - out = np.empty_like(x) - - if src_layout == dst_layout: - out[...] = x - return out - - slices = irreps.slices() - for ir_idx, mul_ir in enumerate(irreps): - mul = mul_ir.mul - dim = mul_ir.ir.dim - seg = slices[ir_idx] - block = x[..., seg.start : seg.stop] - - if src_layout == "ir_mul" and dst_layout == "mul_ir": - out[..., seg.start : seg.stop] = block.reshape( - *block.shape[:-1], dim, mul - ).swapaxes(-1, -2).reshape(*block.shape[:-1], mul * dim) - elif src_layout == "mul_ir" and dst_layout == "ir_mul": - out[..., seg.start : seg.stop] = block.reshape( - *block.shape[:-1], mul, dim - ).swapaxes(-1, -2).reshape(*block.shape[:-1], dim * mul) - else: - raise ValueError( - f"Unsupported layout transpose: {src_layout} -> {dst_layout}" - ) +def transpose_irrep_layout( + array: np.ndarray, + irreps: Irreps, + src_layout: str, + dst_layout: str, +) -> np.ndarray: + """ + Transpose irrep-packed feature arrays between `mul_ir` and `ir_mul` layouts. + + Expected input shape is `[..., irreps.dim]`. A new array is returned. + If `src_layout == dst_layout`, this returns a copy. + """ + if src_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported src_layout: {src_layout}") + if dst_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported dst_layout: {dst_layout}") + x = np.asarray(array) + out = np.empty_like(x) + + if src_layout == dst_layout: + out[...] = x return out + + slices = irreps.slices() + for ir_idx, mul_ir in enumerate(irreps): + mul = mul_ir.mul + dim = mul_ir.ir.dim + seg = slices[ir_idx] + block = x[..., seg.start : seg.stop] + + if src_layout == "ir_mul" and dst_layout == "mul_ir": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], dim, mul) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], mul * dim) + ) + elif src_layout == "mul_ir" and dst_layout == "ir_mul": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], mul, dim) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], dim * mul) + ) + else: + raise ValueError( + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" + ) + + return out diff --git a/tests/batch_test.py b/tests/batch_test.py index 8263cec0..a08466c7 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -272,9 +272,9 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): tp.to(switch_map[problem.irrep_dtype]) return tp, tp.config + class TestIrMulLayoutMACE(TPCorrectness): - production_model_tpps = mace_problems() + \ - [ + production_model_tpps = mace_problems() + [ oeq.TPProblem( "5x5e", "1x3e", @@ -293,7 +293,7 @@ class TestIrMulLayoutMACE(TPCorrectness): internal_weights=False, label="ir_mul_repr_13x1x13_l535", ), - ] + ] @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): diff --git a/tests/conv_test.py b/tests/conv_test.py index 100f6ac0..8471e593 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -293,8 +293,7 @@ def conv_object(self, request, problem, extra_conv_constructor_args): class TestIrMulLayout(ConvCorrectness): - production_model_tpps = mace_problems() + \ - [ + production_model_tpps = mace_problems() + [ oeq.TPProblem( "5x5e", "1x3e", diff --git a/tests/input_validation_test.py b/tests/input_validation_test.py index bbf39e73..2a581314 100644 --- a/tests/input_validation_test.py +++ b/tests/input_validation_test.py @@ -154,4 +154,4 @@ def test_ir_mul_rejects_uvw_problem(dtype): ) with pytest.raises(AssertionError, match="layout='ir_mul'"): - TensorProduct(problem) \ No newline at end of file + TensorProduct(problem) From 5a5080a288ff5bb8892ac31c54855fd8b67692e6 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 18:34:36 -0700 Subject: [PATCH 17/27] Even more refactoring. --- .../_torch/CUETensorProduct.py | 56 +---- .../_torch/E3NNTensorProduct.py | 2 +- .../openequivariance/_torch/TensorProduct.py | 2 +- .../_torch/TensorProductConv.py | 2 +- .../_torch/extlib/__init__.py | 2 +- .../benchmark/ConvBenchmarkSuite.py | 2 +- .../benchmark/TestBenchmarkSuite.py | 2 +- .../benchmark/benchmark_utils.py | 85 ++------ .../openequivariance/benchmark/correctness.py | 4 +- .../{perf_metrics_utils.py => metrics.py} | 58 ++--- .../openequivariance/benchmark/problems.py | 204 +++++++++++++++++- ...random_buffer_utils.py => test_buffers.py} | 0 .../benchmark/tpp_creation_utils.py | 196 ----------------- .../core/ComputationSchedule.py | 2 +- .../openequivariance/core/ConvolutionBase.py | 4 +- .../openequivariance/core/LoopUnrollTP.py | 53 +---- .../core/TensorProductBase.py | 19 +- .../{benchmark => core}/logging.py | 0 .../openequivariance/jax/TensorProductConv.py | 2 +- tests/benchmark.py | 4 +- 20 files changed, 267 insertions(+), 432 deletions(-) rename openequivariance/openequivariance/benchmark/{perf_metrics_utils.py => metrics.py} (60%) rename openequivariance/openequivariance/benchmark/{random_buffer_utils.py => test_buffers.py} (100%) delete mode 100644 openequivariance/openequivariance/benchmark/tpp_creation_utils.py rename openequivariance/openequivariance/{benchmark => core}/logging.py (100%) diff --git a/openequivariance/openequivariance/_torch/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py index 241b1df6..1867e08a 100644 --- a/openequivariance/openequivariance/_torch/CUETensorProduct.py +++ b/openequivariance/openequivariance/_torch/CUETensorProduct.py @@ -6,13 +6,12 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging import getLogger -from openequivariance.benchmark.tpp_creation_utils import ( +from openequivariance.core.logging import getLogger +from openequivariance.benchmark.problems import ( ChannelwiseTPP, FullyConnectedTPProblem, SingleInstruction, ) -from openequivariance.core.utils import count_cg_non_zero os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1" @@ -235,57 +234,6 @@ def benchmark_backward( kernel_names=self.kernel_names, ) - # Copied over from loop unroller to match arithmetic intensity on roofline plots - def calculate_flops_forward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_forward(batch_size) - else: - tpp = self.config - flop_count = { - "CG_decomposition": 0, - "linear_combination": 0, - "outer_products": 0, - } - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - flop_count["linear_combination"] += ( - (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 - ) - - flop_count["CG_decomposition"] *= 3 * batch_size - flop_count["linear_combination"] *= ( - batch_size # Weights do not require FMA here - ) - flop_count["total"] = sum(flop_count.values()) - return flop_count - - def calculate_flops_backward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_backward(batch_size) - else: - tpp = self.config - flop_count = {"backward": 0} - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - - flop_count["backward"] *= 9 * batch_size - flop_count["total"] = sum(flop_count.values()) - return flop_count - @staticmethod def name(): return "CUETensorProduct" diff --git a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py index cacbc017..df696ad6 100644 --- a/openequivariance/openequivariance/_torch/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py @@ -11,7 +11,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index c00530d5..c18b1231 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -3,7 +3,7 @@ from openequivariance._torch import extlib import torch from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.utils import ( reorder_torch, string_to_tensor, diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index e522c0e8..c1087a63 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -23,7 +23,7 @@ enum_to_torch_dtype, ) -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv logger = getLogger() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 81f243cf..ef17724b 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -8,7 +8,7 @@ import torch -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger oeq_root = str(Path(__file__).parent.parent.parent) diff --git a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index 7f602e04..8e2ac98a 100644 --- a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -11,7 +11,7 @@ correctness_double_backward_conv, correctness_forward_conv, ) -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance.core.ConvolutionBase import CoordGraph from openequivariance.benchmark.benchmark_utils import NpEncoder diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index c18bde00..72ada84d 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -10,7 +10,7 @@ from openequivariance._torch.extlib import DeviceProp from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging import getLogger, bcolors +from openequivariance.core.logging import getLogger, bcolors from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.correctness import ( correctness_forward, diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index 31eee40c..dafc5995 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -1,21 +1,22 @@ import json import numpy as np -from openequivariance.benchmark.random_buffer_utils import ( +from openequivariance.benchmark.test_buffers import ( get_random_buffers_forward, get_random_buffers_backward, get_random_buffers_double_backward, ) -from openequivariance.benchmark.perf_metrics_utils import ( - calculate_minimum_flops_forward, - calculate_minimum_memory_streamed_forward, - calculate_minimum_memory_streamed_backward, +from openequivariance.benchmark.metrics import ( + flops_forward, + flops_backward, + memory_streamed_forward, + memory_streamed_backward, ) from openequivariance.core.utils import calculate_total_nnz from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging import getLogger, bcolors +from openequivariance.core.logging import getLogger, bcolors logger = getLogger() @@ -110,24 +111,12 @@ def benchmark_forward( time_millis = np.full(shape=num_iter, fill_value=-1) # FLOPS - try: - flops = tp.calculate_flops_forward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual flop count not calculated, so minimum values are being used" - ) - flops = calculate_minimum_flops_forward(problem, batch_size=batch_size) + flops = flops_forward(problem, batch_size=batch_size) # DATA - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual memory streamed not calculated, so minimum values are being used" - ) - memory_streamed = calculate_minimum_memory_streamed_forward( - problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_forward( + problem, batch_size=batch_size + ) result |= calculate_performance_statistics( problem=problem, @@ -181,29 +170,11 @@ def benchmark_backward( ) time_millis = np.full(shape=num_iter, fill_value=-1) - try: - flops = tp.calculate_flops_backward(batch_size=batch_size) - except NotImplementedError: - try: - flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) - logger.warning( - "Actual flops was not calculated, so minimum values are being used" - ) - except NotImplementedError: - logger.warning( - "Minimum Backwards flops calculations are not implemented, -1 is a placeholder" - ) - flops = {"total": -1} + flops = flops_backward(tpp=problem, batch_size=batch_size) - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual memory streamed was not calculated, so minimum values are being" - ) - memory_streamed = calculate_minimum_memory_streamed_backward( - tpp=problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_backward( + tpp=problem, batch_size=batch_size + ) result |= calculate_performance_statistics( problem=problem, @@ -258,29 +229,11 @@ def benchmark_double_backward( ) time_millis = np.full(shape=num_iter, fill_value=-1) - try: - flops = tp.calculate_flops_backward(batch_size=batch_size) - except NotImplementedError: - try: - flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) - logger.warning( - "Actual flops was not calculated, so minimum values are being used" - ) - except NotImplementedError: - logger.warning( - "Minimum Backwards flops calculations are not implemented, -1 is a placeholder" - ) - flops = {"total": -1} + flops = flops_backward(tpp=problem, batch_size=batch_size) - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning( - "Actual memory streamed was not calculated, so minimum values are being" - ) - memory_streamed = calculate_minimum_memory_streamed_backward( - tpp=problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_backward( + tpp=problem, batch_size=batch_size + ) result |= calculate_performance_statistics( problem=problem, diff --git a/openequivariance/openequivariance/benchmark/correctness.py b/openequivariance/openequivariance/benchmark/correctness.py index 17bbd074..45c45c4a 100644 --- a/openequivariance/openequivariance/benchmark/correctness.py +++ b/openequivariance/openequivariance/benchmark/correctness.py @@ -5,8 +5,8 @@ import numpy.linalg as la from openequivariance._torch.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging import bcolors, getLogger -from openequivariance.benchmark.random_buffer_utils import ( +from openequivariance.core.logging import bcolors, getLogger +from openequivariance.benchmark.test_buffers import ( get_random_buffers_backward_conv, get_random_buffers_backward, get_random_buffers_double_backward_conv, diff --git a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/metrics.py similarity index 60% rename from openequivariance/openequivariance/benchmark/perf_metrics_utils.py rename to openequivariance/openequivariance/benchmark/metrics.py index 01bd5836..81d3f123 100644 --- a/openequivariance/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/openequivariance/benchmark/metrics.py @@ -1,18 +1,15 @@ -import math - from openequivariance.core.utils import ( count_cg_non_zero, - sparse_outer_product_work, ) -from openequivariance.core.e3nn_lite import TPProblem, wigner_3j -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.core.logging import getLogger import numpy as np logger = getLogger() -def calculate_minimum_memory_streamed_forward( +def memory_streamed_forward( tpp: TPProblem, batch_size: int ) -> dict[str, int]: """ @@ -31,7 +28,7 @@ def calculate_minimum_memory_streamed_forward( return data_size -def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict: +def memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict: """ This represents an absolute minimum amount of memory that could be streamed on an ideal machine It returns the number of bytes streamed total and from each source @@ -51,17 +48,12 @@ def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) return data_size -def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict: +def flops_forward(tpp: TPProblem, batch_size: int) -> dict: """ - This is not actually calcuating the minimum value. - Ideally you might share the outer product values between two inputs across multiple inputs. - This is assuming that you form those values and reuse them once per CG decomp. + Default FLOP estimate aligned with LoopUnrollTP's forward FLOP accounting. """ - logger.warning("Minimum flops Calculation is not the true minimum") - flops_count = {} - flops_count["outer_products"] = 0 - flops_count["CG_decomposition"] = 0 - flops_count["linear_combination"] = 0 + flops_count = {"CG_decomposition": 0, "linear_combination": 0, "outer_products": 0} + for ins in tpp.instructions: # type : Instruction l1, l2, l3 = ( tpp.irreps_in1[ins.i_in1].ir.l, @@ -69,28 +61,38 @@ def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict: tpp.irreps_out[ins.i_out].ir.l, ) - flops_count["outer_products"] += sparse_outer_product_work( - wigner_3j(l1, l2, l3) - ) flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( ins.path_shape[0] * ins.path_shape[1] ) flops_count["linear_combination"] += ( - (2 * l3 + 1) * math.prod(ins.path_shape) if ins.has_weight else 0 + (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 ) - flops_count["outer_products"] *= batch_size - flops_count["CG_decomposition"] *= 2 * batch_size - flops_count["linear_combination"] *= 2 * batch_size + flops_count["CG_decomposition"] *= 3 * batch_size + flops_count["linear_combination"] *= ( + batch_size # Weights do not require FMA here + ) flops_count["total"] = sum(flops_count.values()) return flops_count -def calculate_minimum_flops_backward(tpp: TPProblem, batch_size: int) -> dict: +def flops_backward(tpp: TPProblem, batch_size: int) -> dict: """ - This is not actually calcuating the minumum value. - Ideally you might share the outer product values between two inputs across multiple inputs. - This is assuming that you form those values and reuse them once per CG decomp. + Default FLOP estimate aligned with LoopUnrollTP's backward FLOP accounting. """ - raise NotImplementedError("this needs to be implemented properly") + flops_count = {"backward": 0} + + for ins in tpp.instructions: # type : Instruction + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flops_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) + + flops_count["backward"] *= 9 * batch_size + flops_count["total"] = sum(flops_count.values()) + return flops_count diff --git a/openequivariance/openequivariance/benchmark/problems.py b/openequivariance/openequivariance/benchmark/problems.py index dd536e3a..b486941c 100644 --- a/openequivariance/openequivariance/benchmark/problems.py +++ b/openequivariance/openequivariance/benchmark/problems.py @@ -1,7 +1,203 @@ -from openequivariance.benchmark.tpp_creation_utils import ( - FullyConnectedTPProblem as FCTPP, -) -from openequivariance.benchmark.tpp_creation_utils import ChannelwiseTPP as CTPP +from typing import Iterator, Optional + +import numpy as np + +from openequivariance.core.e3nn_lite import Irrep, Irreps, TPProblem + +""" +This was taken from +https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py +Adapted to create TPPs to avoid torch dependence. +""" + + +class FullyConnectedTPProblem(TPProblem): + def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs) -> None: + irreps_in1 = Irreps(irreps_in1) + irreps_in2 = Irreps(irreps_in2) + irreps_out = Irreps(irreps_out) + + instr = [ + (i_1, i_2, i_out, "uvw", True, 1.0) + for i_1, (_, ir_1) in enumerate(irreps_in1) + for i_2, (_, ir_2) in enumerate(irreps_in2) + for i_out, (_, ir_out) in enumerate(irreps_out) + if ir_out in ir_1 * ir_2 + ] + super().__init__( + irreps_in1, + irreps_in2, + irreps_out, + instr, + **kwargs, + ) + + +class ElementwiseTPProblem(TPProblem): + def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None: + irreps_in1 = Irreps(irreps_in1).simplify() + irreps_in2 = Irreps(irreps_in2).simplify() + if filter_ir_out is not None: + try: + filter_ir_out = [Irrep(ir) for ir in filter_ir_out] + except ValueError as exc: + raise ValueError( + f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" + ) from exc + + assert irreps_in1.num_irreps == irreps_in2.num_irreps + + irreps_in1 = list(irreps_in1) + irreps_in2 = list(irreps_in2) + + i = 0 + while i < len(irreps_in1): + mul_1, ir_1 = irreps_in1[i] + mul_2, ir_2 = irreps_in2[i] + + if mul_1 < mul_2: + irreps_in2[i] = (mul_1, ir_2) + irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) + + if mul_2 < mul_1: + irreps_in1[i] = (mul_2, ir_1) + irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) + i += 1 + + out = [] + instr = [] + for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): + assert mul == mul_2 + for ir in ir_1 * ir_2: + if filter_ir_out is not None and ir not in filter_ir_out: + continue + + i_out = len(out) + out.append((mul, ir)) + instr += [(i, i, i_out, "uuu", False)] + + super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) + + +class FullTPProblem(TPProblem): + def __init__( + self, + irreps_in1: Irreps, + irreps_in2: Irreps, + filter_ir_out: Iterator[Irrep] = None, + **kwargs, + ) -> None: + irreps_in1 = Irreps(irreps_in1).simplify() + irreps_in2 = Irreps(irreps_in2).simplify() + if filter_ir_out is not None: + try: + filter_ir_out = [Irrep(ir) for ir in filter_ir_out] + except ValueError as exc: + raise ValueError( + f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" + ) from exc + + out = [] + instr = [] + for i_1, (mul_1, ir_1) in enumerate(irreps_in1): + for i_2, (mul_2, ir_2) in enumerate(irreps_in2): + for ir_out in ir_1 * ir_2: + if filter_ir_out is not None and ir_out not in filter_ir_out: + continue + + i_out = len(out) + out.append((mul_1 * mul_2, ir_out)) + instr += [(i_1, i_2, i_out, "uvuv", False)] + + out = Irreps(out) + out, p, _ = out.sort() + + instr = [ + (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr + ] + + super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) + + +class ChannelwiseTPP(TPProblem): + """ + Modified from mace/mace/modules/irreps_tools.py. + """ + + def __init__( + self, + irreps_in1: Irreps, + irreps_in2: Irreps, + irreps_out: Irreps, + label: Optional[str] = None, + irrep_dtype=np.float32, + weight_dtype=np.float32, + ): + trainable = True + irreps1 = Irreps(irreps_in1) + irreps2 = Irreps(irreps_in2) + irreps_out = Irreps(irreps_out) + + irreps_out_list = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: + if ir_out in irreps_out: + k = len(irreps_out_list) + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + irreps_out = Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + instructions = sorted(instructions, key=lambda x: x[2]) + super().__init__( + irreps1, + irreps2, + irreps_out, + instructions, + internal_weights=False, + shared_weights=False, + label=label, + irrep_dtype=irrep_dtype, + weight_dtype=weight_dtype, + ) + + +class SingleInstruction(TPProblem): + def __init__( + self, + irreps_in1: Irreps, + irreps_in2: Irreps, + irreps_in3: Irreps, + mode: str, + label: Optional[str] = None, + ): + trainable = True + irreps1 = Irreps(irreps_in1) + irreps2 = Irreps(irreps_in2) + irreps3 = Irreps(irreps_in3) + instructions = [(0, 0, 0, mode, trainable)] + + super().__init__( + irreps1, + irreps2, + irreps3, + instructions, + internal_weights=False, + shared_weights=False, + label=label, + ) + + +FCTPP = FullyConnectedTPProblem +CTPP = ChannelwiseTPP # source: https://github.com/e3nn/e3nn/blob/main/examples/tetris.py # running tetris will output the layers. I've only extracted the fully connected layers here. diff --git a/openequivariance/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/test_buffers.py similarity index 100% rename from openequivariance/openequivariance/benchmark/random_buffer_utils.py rename to openequivariance/openequivariance/benchmark/test_buffers.py diff --git a/openequivariance/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/openequivariance/benchmark/tpp_creation_utils.py deleted file mode 100644 index 7637f412..00000000 --- a/openequivariance/openequivariance/benchmark/tpp_creation_utils.py +++ /dev/null @@ -1,196 +0,0 @@ -import numpy as np - -from typing import Iterator, Optional -from openequivariance.core.e3nn_lite import Irrep, Irreps, TPProblem - -""" -This was taken from -https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py -Adapted to create TPPs to avoid torch dependence. -""" - - -class FullyConnectedTPProblem(TPProblem): - def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs) -> None: - irreps_in1 = Irreps(irreps_in1) - irreps_in2 = Irreps(irreps_in2) - irreps_out = Irreps(irreps_out) - - instr = [ - (i_1, i_2, i_out, "uvw", True, 1.0) - for i_1, (_, ir_1) in enumerate(irreps_in1) - for i_2, (_, ir_2) in enumerate(irreps_in2) - for i_out, (_, ir_out) in enumerate(irreps_out) - if ir_out in ir_1 * ir_2 - ] - super().__init__( - irreps_in1, - irreps_in2, - irreps_out, - instr, - **kwargs, - ) - - -class ElementwiseTPProblem(TPProblem): - def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None: - irreps_in1 = Irreps(irreps_in1).simplify() - irreps_in2 = Irreps(irreps_in2).simplify() - if filter_ir_out is not None: - try: - filter_ir_out = [Irrep(ir) for ir in filter_ir_out] - except ValueError: - raise ValueError( - f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" - ) - - assert irreps_in1.num_irreps == irreps_in2.num_irreps - - irreps_in1 = list(irreps_in1) - irreps_in2 = list(irreps_in2) - - i = 0 - while i < len(irreps_in1): - mul_1, ir_1 = irreps_in1[i] - mul_2, ir_2 = irreps_in2[i] - - if mul_1 < mul_2: - irreps_in2[i] = (mul_1, ir_2) - irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) - - if mul_2 < mul_1: - irreps_in1[i] = (mul_2, ir_1) - irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) - i += 1 - - out = [] - instr = [] - for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): - assert mul == mul_2 - for ir in ir_1 * ir_2: - if filter_ir_out is not None and ir not in filter_ir_out: - continue - - i_out = len(out) - out.append((mul, ir)) - instr += [(i, i, i_out, "uuu", False)] - - super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) - - -class FullTPProblem(TPProblem): - def __init__( - self, - irreps_in1: Irreps, - irreps_in2: Irreps, - filter_ir_out: Iterator[Irrep] = None, - **kwargs, - ) -> None: - irreps_in1 = Irreps(irreps_in1).simplify() - irreps_in2 = Irreps(irreps_in2).simplify() - if filter_ir_out is not None: - try: - filter_ir_out = [Irrep(ir) for ir in filter_ir_out] - except ValueError: - raise ValueError( - f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" - ) - - out = [] - instr = [] - for i_1, (mul_1, ir_1) in enumerate(irreps_in1): - for i_2, (mul_2, ir_2) in enumerate(irreps_in2): - for ir_out in ir_1 * ir_2: - if filter_ir_out is not None and ir_out not in filter_ir_out: - continue - - i_out = len(out) - out.append((mul_1 * mul_2, ir_out)) - instr += [(i_1, i_2, i_out, "uvuv", False)] - - out = Irreps(out) - out, p, _ = out.sort() - - instr = [ - (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr - ] - - super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) - - -class ChannelwiseTPP(TPProblem): - """ - Modified from mace/mace/modules/irreps_tools.py. - """ - - def __init__( - self, - irreps_in1: Irreps, - irreps_in2: Irreps, - irreps_out: Irreps, - label: Optional[str] = None, - irrep_dtype=np.float32, - weight_dtype=np.float32, - ): - trainable = True - irreps1 = Irreps(irreps_in1) - irreps2 = Irreps(irreps_in2) - irreps_out = Irreps(irreps_out) - - # Collect possible irreps and their instructions - irreps_out_list = [] - instructions = [] - for i, (mul, ir_in) in enumerate(irreps1): - for j, (_, ir_edge) in enumerate(irreps2): - for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 - if ir_out in irreps_out: - k = len(irreps_out_list) # instruction index - irreps_out_list.append((mul, ir_out)) - instructions.append((i, j, k, "uvu", trainable)) - - irreps_out = Irreps(irreps_out_list) - irreps_out, permut, _ = irreps_out.sort() - - instructions = [ - (i_in1, i_in2, permut[i_out], mode, train) - for i_in1, i_in2, i_out, mode, train in instructions - ] - - instructions = sorted(instructions, key=lambda x: x[2]) - super().__init__( - irreps1, - irreps2, - irreps_out, - instructions, - internal_weights=False, - shared_weights=False, - label=label, - irrep_dtype=irrep_dtype, - weight_dtype=weight_dtype, - ) - - -class SingleInstruction(TPProblem): - def __init__( - self, - irreps_in1: Irreps, - irreps_in2: Irreps, - irreps_in3: Irreps, - mode: str, - label: Optional[str] = None, - ): - trainable = True - irreps1 = Irreps(irreps_in1) - irreps2 = Irreps(irreps_in2) - irreps3 = Irreps(irreps_in3) - instructions = [(0, 0, 0, mode, trainable)] - - super().__init__( - irreps1, - irreps2, - irreps3, - instructions, - internal_weights=False, - shared_weights=False, - label=label, - ) diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 57094706..f9f10013 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -2,7 +2,7 @@ import numpy as np -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j logger = getLogger() diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 7f36e4ce..5ee4a613 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -2,8 +2,8 @@ import numpy as np -from openequivariance.benchmark.logging import bcolors, getLogger -from openequivariance.benchmark.random_buffer_utils import ( +from openequivariance.core.logging import bcolors, getLogger +from openequivariance.benchmark.test_buffers import ( get_random_buffers_backward_conv, get_random_buffers_forward_conv, ) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index e07387b0..e9091b41 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -2,14 +2,13 @@ import numpy as np -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance.core.ComputationSchedule import ( ComputationSchedule, SMEMCapacityException, ) from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.utils import ( - count_cg_non_zero, dtype_to_enum, filter_and_analyze_problem, hash_str_64, @@ -129,53 +128,3 @@ def generate_double_backward_schedule(warps_per_block): ) self.hash = hash_str_64(self.kernel_string) logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") - - def calculate_flops_forward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_forward(batch_size) - else: - tpp = self.config - flop_count = { - "CG_decomposition": 0, - "linear_combination": 0, - "outer_products": 0, - } - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - flop_count["linear_combination"] += ( - (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 - ) - - flop_count["CG_decomposition"] *= 3 * batch_size - flop_count["linear_combination"] *= ( - batch_size # Weights do not require FMA here - ) - flop_count["total"] = sum(flop_count.values()) - return flop_count - - def calculate_flops_backward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_backward(batch_size) - else: - tpp = self.config - flop_count = {"backward": 0} - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - - flop_count["backward"] *= 9 * batch_size - flop_count["total"] = sum(flop_count.values()) - return flop_count diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index fa065b7b..d955e47e 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -1,7 +1,7 @@ import numpy as np from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance.core.utils import benchmark logger = getLogger() @@ -181,20 +181,3 @@ def benchmark_double_backward( kernel_names=kernel_names, ) - def calculate_memory_streamed_forward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_memory_streamed_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_memory_streamed_double_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_forward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_double_backward(self, batch_size: int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") diff --git a/openequivariance/openequivariance/benchmark/logging.py b/openequivariance/openequivariance/core/logging.py similarity index 100% rename from openequivariance/openequivariance/benchmark/logging.py rename to openequivariance/openequivariance/core/logging.py diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 4645ae86..7101fa00 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -9,7 +9,7 @@ from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance.jax.jvp import conv_prim from openequivariance.jax.vjp import conv_func diff --git a/tests/benchmark.py b/tests/benchmark.py index 441a3ca0..1763a565 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -10,7 +10,7 @@ import numpy as np -from openequivariance.benchmark.logging import getLogger +from openequivariance.core.logging import getLogger from openequivariance._torch.extlib import DeviceProp from openequivariance._torch.E3NNTensorProduct import ( E3NNTensorProduct, @@ -24,7 +24,7 @@ TestDefinition, Direction, ) -from openequivariance.benchmark.tpp_creation_utils import ( +from openequivariance.benchmark.problems import ( ChannelwiseTPP, FullyConnectedTPProblem, SingleInstruction, From f5b7a26b52e03a2db1e53f54a71f39e674837d08 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 21:11:48 -0700 Subject: [PATCH 18/27] Added include for automatic string conversion. --- .../openequivariance/extension/libtorch_tp_jit_stable.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 5ddcb7f5..440b7640 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -89,6 +89,7 @@ Stream get_current_stream() { #ifdef INCLUDE_NB_EXTENSION #include "nanobind/nanobind.h" + #include "nanobind/stl/string.h" namespace nb = nanobind; NB_MODULE(EXTENSION_NAME, m) { nb::class_(m, "DeviceProp") From b80d15cd53beabce2a1912e3523297394cfb14b0 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 21:44:11 -0700 Subject: [PATCH 19/27] Competed benchmarking. --- .../benchmark/plotting/__init__.py | 2 + .../benchmark/plotting/plot_layout.py | 119 ++++++++++++++++++ tests/benchmark.py | 70 +++++++++++ 3 files changed, 191 insertions(+) create mode 100644 openequivariance/openequivariance/benchmark/plotting/plot_layout.py diff --git a/openequivariance/openequivariance/benchmark/plotting/__init__.py b/openequivariance/openequivariance/benchmark/plotting/__init__.py index 3c0bf032..6aac5e14 100644 --- a/openequivariance/openequivariance/benchmark/plotting/__init__.py +++ b/openequivariance/openequivariance/benchmark/plotting/__init__.py @@ -5,6 +5,7 @@ from openequivariance.benchmark.plotting.plot_double_backward import ( plot_double_backward, ) +from openequivariance.benchmark.plotting.plot_layout import plot_layout __all__ = [ "plot_uvu", @@ -12,4 +13,5 @@ "plot_roofline", "plot_convolution", "plot_double_backward", + "plot_layout", ] diff --git a/openequivariance/openequivariance/benchmark/plotting/plot_layout.py b/openequivariance/openequivariance/benchmark/plotting/plot_layout.py new file mode 100644 index 00000000..652c0c4e --- /dev/null +++ b/openequivariance/openequivariance/benchmark/plotting/plot_layout.py @@ -0,0 +1,119 @@ +import pathlib + +import matplotlib.pyplot as plt +import numpy as np + +from openequivariance.benchmark.plotting.plotting_utils import ( + calculate_tp_per_sec, + grouped_barchart, + load_benchmarks, + set_grid, +) + + +def _parse_layout_label(label: str): + if label.endswith("[mul_ir]"): + return label[: -len(" [mul_ir]")], "mul_ir" + if label.endswith("[ir_mul]"): + return label[: -len(" [ir_mul]")], "ir_mul" + return label, None + + +def plot_layout(data_folder): + data_folder = pathlib.Path(data_folder) + benchmarks, _ = load_benchmarks(data_folder) + + grouped = {} + dtype_order = [] + for benchmark in benchmarks: + dtype = benchmark["benchmark results"]["rep_dtype"] + if dtype not in dtype_order: + dtype_order.append(dtype) + + direction = benchmark["direction"] + base_label, layout = _parse_layout_label(benchmark["config_label"]) + if layout is None: + continue + + grouped.setdefault(dtype, {}).setdefault(direction, {}).setdefault( + base_label, {"mul_ir": 0.0, "ir_mul": 0.0} + ) + grouped[dtype][direction][base_label][layout] = calculate_tp_per_sec(benchmark) + + def _dtype_sort_key(dtype_name: str) -> int: + if "float32" in dtype_name: + return 0 + if "float64" in dtype_name: + return 1 + return 2 + + dtype_order = sorted(dtype_order, key=_dtype_sort_key) + + directions = [d for d in ["forward", "backward"] if any(d in grouped[x] for x in grouped)] + if not directions: + raise ValueError("No forward/backward layout benchmark entries found to plot.") + + fig = plt.figure(figsize=(7, 7)) + gs = fig.add_gridspec(len(directions), max(1, len(dtype_order))) + axs = gs.subplots(sharex="col") + + if len(directions) == 1 and len(dtype_order) == 1: + axs = np.array([[axs]]) + elif len(directions) == 1: + axs = np.array([axs]) + elif len(dtype_order) == 1: + axs = np.array([[ax] for ax in axs]) + + colormap = {"mul_ir": "#1f77b4", "ir_mul": "#2ca02c"} + + for row, direction in enumerate(directions): + for col, dtype in enumerate(dtype_order): + axis = axs[row][col] + source = grouped.get(dtype, {}).get(direction, {}) + data = { + label: { + "mul_ir": vals["mul_ir"], + "ir_mul": vals["ir_mul"], + } + for label, vals in source.items() + } + grouped_barchart( + data, + axis, + bar_height_fontsize=0, + colormap=colormap, + group_spacing=6.0, + xticklabel=(row == len(directions) - 1), + ) + set_grid(axis) + + if row == 0: + axis.set_title(dtype.replace("", "")) + if col == 0: + axis.set_ylabel(direction.capitalize()) + if row < len(directions) - 1: + axis.tick_params(axis="x", labelbottom=False) + + fig.supylabel("Throughput (# tensor products / s)", x=0.03, y=0.56) + fig.supxlabel("Problem") + + handles, labels = axs[0][0].get_legend_handles_labels() + unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] + if unique: + axs[0][0].legend(*zip(*unique)) + + fig.tight_layout(rect=(0.03, 0.03, 1.0, 1.0)) + fig.savefig(str(data_folder / "layout_throughput_comparison.pdf")) + + print("Layout speedups (ir_mul / mul_ir):") + print("\t".join(["dtype", "direction", "min", "mean", "median", "max"])) + for dtype in dtype_order: + for direction in directions: + ratios = [] + for _, values in grouped.get(dtype, {}).get(direction, {}).items(): + if values["mul_ir"] > 0: + ratios.append(values["ir_mul"] / values["mul_ir"]) + if ratios: + stats = [np.min(ratios), np.mean(ratios), np.median(ratios), np.max(ratios)] + stats_fmt = [f"{val:.3f}" for val in stats] + print("\t".join([dtype, direction] + stats_fmt)) diff --git a/tests/benchmark.py b/tests/benchmark.py index 1763a565..9d230cf1 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -385,6 +385,47 @@ def benchmark_kahan_accuracy(params): ) +def benchmark_layouts(params): + base_problems = mace_problems() + nequip_problems() + directions = params.directions + dtypes = [datatype_map[dtype] for dtype in params.datatypes] + + tests = [] + for dtype in dtypes: + for base_problem in base_problems: + for layout in ["mul_ir", "ir_mul"]: + layout_problem = copy.deepcopy(base_problem) + layout_problem.layout = layout + base_label = layout_problem.label or "layout_problem" + layout_problem.label = f"{base_label} [{layout}]" + layout_problem.irrep_dtype = dtype + layout_problem.weight_dtype = dtype + + for direction in directions: + tests.append( + TestDefinition( + TensorProduct, + layout_problem, + direction, + correctness=False, + benchmark=True, + ) + ) + + bench_suite = TestBenchmarkSuite( + num_warmup=100, + num_iter=100, + bench_batch_size=params.batch_size, + prng_seed=11111, + test_name="layouts", + ) + + data_folder = bench_suite.run(tests, params.output_folder) + + if params.plot: + plot({"data_folder": data_folder}) + + def plot(params): import openequivariance.benchmark.plotting as plotting @@ -402,6 +443,8 @@ def plot(params): plotting.plot_uvu(data_folder) elif test_name == "uvw": plotting.plot_uvw(data_folder) + elif test_name == "layouts": + plotting.plot_layout(data_folder) elif test_name == "roofline": plotting.plot_roofline(data_folder) elif test_name == "convolution": @@ -532,6 +575,33 @@ def plot(params): parser_uvw.add_argument("--plot", action="store_true", help="Plot the results.") parser_uvw.set_defaults(func=benchmark_uvw) + parser_layouts = subparsers.add_parser( + "layouts", help="Run benchmark comparing mul_ir vs ir_mul layouts" + ) + parser_layouts.add_argument( + "--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark" + ) + parser_layouts.add_argument( + "--directions", + "-d", + type=str, + nargs="+", + default=["forward", "backward"], + help="Directions to benchmark", + choices=["forward", "backward"], + ) + parser_layouts.add_argument( + "--datatypes", + "-t", + type=str, + nargs="+", + default=["float32", "float64"], + help="Data types to benchmark", + choices=["float32", "float64"], + ) + parser_layouts.add_argument("--plot", action="store_true", help="Plot the results.") + parser_layouts.set_defaults(func=benchmark_layouts) + parser_double_bwd = subparsers.add_parser( "double_backward", help="Run the higher derivative kernel benchmark" ) From a5d0fe56dbb8b6c66efe5b0903ad2484a57c17ad Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 22:43:39 -0700 Subject: [PATCH 20/27] Modified changelog before release. --- CHANGELOG.md | 7 ++ .../benchmark/benchmark_utils.py | 12 +-- .../openequivariance/benchmark/metrics.py | 8 +- .../benchmark/plotting/plot_layout.py | 15 +++- .../openequivariance/core/ConvolutionBase.py | 2 - .../openequivariance/core/LoopUnrollTP.py | 2 - .../core/TensorProductBase.py | 1 - openequivariance/pyproject.toml | 2 +- openequivariance_extjax/pyproject.toml | 2 +- tests/input_validation_test.py | 4 +- tests/torch_determinism_test.py | 73 ------------------- 11 files changed, 27 insertions(+), 101 deletions(-) delete mode 100644 tests/torch_determinism_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f6e4d94..6aa11d05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ ## Latest Changes +### v0.6.5 (2026-03-22) +This release brings `ir_mul` layout support for +OpenEquivariance. Pass the parameter +`layout='ir_mul'` to any `TPProblem` instance to use +a transposed layout for the input and output +irreps. + ### v0.6.4 (2026-03-05) Bugfix: added missing MLIR lowerings for a pair of JAX primitives (thanks @teddykoker!) diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index dafc5995..62687dbb 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -114,9 +114,7 @@ def benchmark_forward( flops = flops_forward(problem, batch_size=batch_size) # DATA - memory_streamed = memory_streamed_forward( - problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_forward(problem, batch_size=batch_size) result |= calculate_performance_statistics( problem=problem, @@ -172,9 +170,7 @@ def benchmark_backward( flops = flops_backward(tpp=problem, batch_size=batch_size) - memory_streamed = memory_streamed_backward( - tpp=problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_backward(tpp=problem, batch_size=batch_size) result |= calculate_performance_statistics( problem=problem, @@ -231,9 +227,7 @@ def benchmark_double_backward( flops = flops_backward(tpp=problem, batch_size=batch_size) - memory_streamed = memory_streamed_backward( - tpp=problem, batch_size=batch_size - ) + memory_streamed = memory_streamed_backward(tpp=problem, batch_size=batch_size) result |= calculate_performance_statistics( problem=problem, diff --git a/openequivariance/openequivariance/benchmark/metrics.py b/openequivariance/openequivariance/benchmark/metrics.py index 81d3f123..0fc87f1f 100644 --- a/openequivariance/openequivariance/benchmark/metrics.py +++ b/openequivariance/openequivariance/benchmark/metrics.py @@ -9,9 +9,7 @@ logger = getLogger() -def memory_streamed_forward( - tpp: TPProblem, batch_size: int -) -> dict[str, int]: +def memory_streamed_forward(tpp: TPProblem, batch_size: int) -> dict[str, int]: """ This represents an absolute minimum amount of memory that could be streamed on an ideal machine It returns the number of bytes streamed total and from each source @@ -69,9 +67,7 @@ def flops_forward(tpp: TPProblem, batch_size: int) -> dict: ) flops_count["CG_decomposition"] *= 3 * batch_size - flops_count["linear_combination"] *= ( - batch_size # Weights do not require FMA here - ) + flops_count["linear_combination"] *= batch_size # Weights do not require FMA here flops_count["total"] = sum(flops_count.values()) return flops_count diff --git a/openequivariance/openequivariance/benchmark/plotting/plot_layout.py b/openequivariance/openequivariance/benchmark/plotting/plot_layout.py index 652c0c4e..5d42ac32 100644 --- a/openequivariance/openequivariance/benchmark/plotting/plot_layout.py +++ b/openequivariance/openequivariance/benchmark/plotting/plot_layout.py @@ -49,7 +49,9 @@ def _dtype_sort_key(dtype_name: str) -> int: dtype_order = sorted(dtype_order, key=_dtype_sort_key) - directions = [d for d in ["forward", "backward"] if any(d in grouped[x] for x in grouped)] + directions = [ + d for d in ["forward", "backward"] if any(d in grouped[x] for x in grouped) + ] if not directions: raise ValueError("No forward/backward layout benchmark entries found to plot.") @@ -98,7 +100,9 @@ def _dtype_sort_key(dtype_name: str) -> int: fig.supxlabel("Problem") handles, labels = axs[0][0].get_legend_handles_labels() - unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] + unique = [ + (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + ] if unique: axs[0][0].legend(*zip(*unique)) @@ -114,6 +118,11 @@ def _dtype_sort_key(dtype_name: str) -> int: if values["mul_ir"] > 0: ratios.append(values["ir_mul"] / values["mul_ir"]) if ratios: - stats = [np.min(ratios), np.mean(ratios), np.median(ratios), np.max(ratios)] + stats = [ + np.min(ratios), + np.mean(ratios), + np.median(ratios), + np.max(ratios), + ] stats_fmt = [f"{val:.3f}" for val in stats] print("\t".join([dtype, direction] + stats_fmt)) diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 5ee4a613..116a21b3 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -1,5 +1,3 @@ -import copy - import numpy as np from openequivariance.core.logging import bcolors, getLogger diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index e9091b41..36801405 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -1,7 +1,5 @@ import json -import numpy as np - from openequivariance.core.logging import getLogger from openequivariance.core.ComputationSchedule import ( ComputationSchedule, diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index d955e47e..c6fc83f8 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -180,4 +180,3 @@ def benchmark_double_backward( mode=mode, kernel_names=kernel_names, ) - diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 1d0e4394..40856819 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance" -version = "0.6.4" +version = "0.6.5" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index 4a53ff70..ba313d23 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.6.4" +version = "0.6.5" authors = [ { name="Austin Glover" }, diff --git a/tests/input_validation_test.py b/tests/input_validation_test.py index 2a581314..9b38d55e 100644 --- a/tests/input_validation_test.py +++ b/tests/input_validation_test.py @@ -140,7 +140,7 @@ def test_cpp_checks_forward_dtype(executable_and_buffers, subtests): executable(*buffers) -def test_ir_mul_rejects_uvw_problem(dtype): +def test_ir_mul_rejects_uvw_problem(): problem = TPProblem( "5x5e", "1x3e", @@ -148,8 +148,6 @@ def test_ir_mul_rejects_uvw_problem(dtype): [(0, 0, 0, "uvw", True)], shared_weights=False, internal_weights=False, - irrep_dtype=dtype, - weight_dtype=dtype, layout="ir_mul", ) diff --git a/tests/torch_determinism_test.py b/tests/torch_determinism_test.py deleted file mode 100644 index c8fa83ef..00000000 --- a/tests/torch_determinism_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest -import torch - -from openequivariance import TPProblem, TensorProductConv - -from e3nn import o3 -from torch_geometric import EdgeIndex - - -@pytest.fixture -def gen(): - return torch.Generator(device="cuda") - - -@pytest.fixture -def edge_index(): - return EdgeIndex( - data=[ - [0, 1, 1, 2], # Receiver - [1, 0, 2, 1], # Sender - ], - sparse_size=(3, 4), - device="cuda", - dtype=torch.long, - ) - - -@pytest.fixture -def tpp(): - X_ir = o3.Irreps("1x2e") - Y_ir = o3.Irreps("1x3e") - Z_ir = o3.Irreps("1x2e") - instructions = [(0, 0, 0, "uvu", True)] - return TPProblem( - X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False - ) - - -@pytest.fixture -def conv_buffers(edge_index, tpp, gen): - X = torch.rand( - edge_index.num_rows, tpp.irreps_in1.dim, device="cuda", generator=gen - ) - Y = torch.rand( - edge_index.num_cols, tpp.irreps_in2.dim, device="cuda", generator=gen - ) - W = torch.rand(edge_index.num_cols, tpp.weight_numel, device="cuda", generator=gen) - return (X, Y, W, edge_index[0], edge_index[1]) - - -@pytest.fixture -def tp_conv(tpp): - return TensorProductConv(tpp, deterministic=False) - - -def test_no_response(tp_conv, conv_buffers): - torch.use_deterministic_algorithms(False) - tp_conv(*conv_buffers) - - -def test_warning(tp_conv, conv_buffers, capfd): - torch.use_deterministic_algorithms(True, warn_only=True) - tp_conv(*conv_buffers) - - captured = capfd.readouterr() - assert "Warning" in captured.err - assert "does not have a deterministic implementation" in captured.err - - -def test_error(tp_conv, conv_buffers): - torch.use_deterministic_algorithms(True, warn_only=False) - with pytest.raises(RuntimeError): - tp_conv(*conv_buffers) From b5866a21e8edb509c478c3a0e80e4c4e679dd773 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 22 Mar 2026 23:38:19 -0700 Subject: [PATCH 21/27] Began adding a pair of robust transpose functions. --- .../openequivariance/_torch/utils.py | 70 +++++++++++++++++ .../openequivariance/core/e3nn_lite.py | 1 + .../openequivariance/jax/__init__.py | 77 ++++++++++++++++++- 3 files changed, 147 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 74d5a010..b3bc986b 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -2,6 +2,7 @@ import numpy as np from types import MappingProxyType from openequivariance.core.utils import DTypeEnum +from openequivariance.core.e3nn_lite import Irreps def reorder_helper(schedule, weights_in, direction, has_batch_dim): @@ -75,3 +76,72 @@ def string_to_tensor(text: str) -> torch.Tensor: result = torch.tensor(np_bytes, device="cpu") result.requires_grad = False return result + + +def transpose_irreps( + array: torch.Tensor, + irreps: Irreps, + src_layout: str, + dst_layout: str, +) -> torch.Tensor: + r""" + Transpose irrep-packed feature tensors between ``mul_ir`` and ``ir_mul`` layouts. + + The function operates on the trailing feature dimension and preserves all leading + batch dimensions. It uses only differentiable PyTorch tensor operations, so gradients + propagate through the transpose. + + :param array: Input feature tensor with shape ``[..., irreps.dim]``. + :param irreps: Irreps specification describing how the trailing feature dimension + is partitioned into irrep blocks. + :param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + :param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + + + :returns: Tensor in ``dst_layout`` with the same shape, dtype, and device as ``array``. + If ``src_layout == dst_layout``, returns a clone of ``array``. + + + :raises TypeError: If ``array`` is not a ``torch.Tensor``. + :raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of + ``"mul_ir"`` or ``"ir_mul"``. + """ + if src_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported src_layout: {src_layout}") + if dst_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported dst_layout: {dst_layout}") + + if not isinstance(array, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(array)}") + + out = torch.empty_like(array) + + if src_layout == dst_layout: + out.copy_(array) + return out + + slices = irreps.slices() + for ir_idx, mul_ir in enumerate(irreps): + mul = mul_ir.mul + dim = mul_ir.ir.dim + seg = slices[ir_idx] + block = array[..., seg.start : seg.stop] + + if src_layout == "ir_mul" and dst_layout == "mul_ir": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], dim, mul) + .transpose(-1, -2) + .reshape(*block.shape[:-1], mul * dim) + ) + elif src_layout == "mul_ir" and dst_layout == "ir_mul": + out[..., seg.start : seg.stop] = ( + block.reshape(*block.shape[:-1], mul, dim) + .transpose(-1, -2) + .reshape(*block.shape[:-1], dim * mul) + ) + else: + raise ValueError( + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" + ) + + return out diff --git a/openequivariance/openequivariance/core/e3nn_lite.py b/openequivariance/openequivariance/core/e3nn_lite.py index e0afe83c..41cc6e59 100644 --- a/openequivariance/openequivariance/core/e3nn_lite.py +++ b/openequivariance/openequivariance/core/e3nn_lite.py @@ -386,6 +386,7 @@ class TPProblem: :param internal_weights: Must be False; OpenEquivariance does not support internal weights. *Default*: False. :param irrep_normalization: One of ``["component", "norm", "none"]``. *Default*: "component". :param path_normalization: One of ``["element", "path", "none"]``. *Default*: "element". + :param layout: One of ``["mul_ir", "ir_mul"]``, giving the layout of irreps for all inputs and outputs. *Default*: "mul_ir". """ instructions: List[Any] diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py index 410e5dbf..c26606b6 100644 --- a/openequivariance/openequivariance/jax/__init__.py +++ b/openequivariance/openequivariance/jax/__init__.py @@ -1,6 +1,81 @@ +import jax +import jax.numpy as jnp + +from openequivariance.core.e3nn_lite import Irreps from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct from openequivariance.jax.TensorProductConv import ( TensorProductConv as TensorProductConv, ) -__all__ = ["TensorProduct", "TensorProductConv"] + +def transpose_irreps( + array: jax.Array, + irreps: Irreps, + src_layout: str, + dst_layout: str, +) -> jax.Array: + r""" + Transpose irrep-packed feature arrays between ``mul_ir`` and ``ir_mul`` layouts. + + The function operates on the trailing feature dimension and preserves all leading + batch dimensions. It uses differentiable JAX operations, so gradients propagate + through the transpose. + + :param array: Input feature array with shape ``[..., irreps.dim]``. + :type array: jax.Array + :param irreps: Irreps specification describing how the trailing feature dimension + is partitioned into irrep blocks. + :type irreps: Irreps + :param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + :type src_layout: str + :param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. + :type dst_layout: str + + :returns: Array in ``dst_layout`` with the same shape, dtype, and device as ``array``. + If ``src_layout == dst_layout``, returns a copy of ``array``. + :rtype: jax.Array + + :raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of + ``"mul_ir"`` or ``"ir_mul"``. + """ + if src_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported src_layout: {src_layout}") + if dst_layout not in ("mul_ir", "ir_mul"): + raise ValueError(f"Unsupported dst_layout: {dst_layout}") + + x = jnp.asarray(array) + if src_layout == dst_layout: + return jnp.array(x, copy=True) + + out = jnp.empty_like(x) + slices = irreps.slices() + + for ir_idx, mul_ir in enumerate(irreps): + mul = mul_ir.mul + dim = mul_ir.ir.dim + seg = slices[ir_idx] + block = x[..., seg.start : seg.stop] + + if src_layout == "ir_mul" and dst_layout == "mul_ir": + transposed = ( + block.reshape(*block.shape[:-1], dim, mul) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], mul * dim) + ) + elif src_layout == "mul_ir" and dst_layout == "ir_mul": + transposed = ( + block.reshape(*block.shape[:-1], mul, dim) + .swapaxes(-1, -2) + .reshape(*block.shape[:-1], dim * mul) + ) + else: + raise ValueError( + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" + ) + + out = out.at[..., seg.start : seg.stop].set(transposed) + + return out + + +__all__ = ["TensorProduct", "TensorProductConv", "transpose_irreps"] From 9544e14f8755af79d006e7a6ff747ac116199d89 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Mar 2026 21:33:58 -0700 Subject: [PATCH 22/27] Wrote a compact test for the transpose functions. --- tests/batch_test.py | 50 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index a08466c7..fcf179e2 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -273,8 +273,8 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): return tp, tp.config -class TestIrMulLayoutMACE(TPCorrectness): - production_model_tpps = mace_problems() + [ +class TestIrMul(TPCorrectness): + tpps = mace_problems() + [ oeq.TPProblem( "5x5e", "1x3e", @@ -295,13 +295,57 @@ class TestIrMulLayoutMACE(TPCorrectness): ), ] - @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") + @pytest.fixture(params=tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): problem = request.param.clone() problem.irrep_dtype, problem.weight_dtype = dtype, dtype problem.layout = "ir_mul" return problem + @pytest.fixture(params=["native", "transpose_wrapper"], scope="class") + def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax): + mode = request.param + + if mode == "native": + cls = oeq.TensorProduct + if with_jax: + import openequivariance.jax.TensorProduct as jax_tp + + cls = jax_tp + tp = cls(problem, **extra_tp_constructor_args) + return tp, problem + else: + if with_jax: + import openequivariance.jax.TensorProduct as jax_tp + from openequivariance.jax import transpose_irreps + + tp_base_cls = jax_tp + else: + from openequivariance._torch.utils import transpose_irreps + + tp_base_cls = oeq.TensorProduct + + class TransposeWrapperTensorProduct(tp_base_cls): + def forward(self, x, y, W): + x_t = transpose_irreps( + x, self.config.irreps_in1, "ir_mul", "mul_ir" + ) + y_t = transpose_irreps( + y, self.config.irreps_in2, "ir_mul", "mul_ir" + ) + out_mul_ir = super().forward(x_t, y_t, W) + return transpose_irreps( + out_mul_ir, self.config.irreps_out, "mul_ir", "ir_mul" + ) + + wrapped_problem = problem.clone() + wrapped_problem.layout = "mul_ir" + tp = TransposeWrapperTensorProduct( + wrapped_problem, **extra_tp_constructor_args + ) + tp.config.layout = "ir_mul" + return tp, problem + class TestTorchToSubmodule: """Test that TensorProduct works correctly as a submodule when parent's .to() is called""" From 5440d551f8c0ba16a37cdd269b3357bdf9d0ade4 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Mar 2026 21:41:15 -0700 Subject: [PATCH 23/27] Compacted diff further. --- tests/batch_test.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index fcf179e2..805af938 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -274,6 +274,10 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): class TestIrMul(TPCorrectness): + ''' + Tests both the ir_mul layout and the transpose_irreps functions + via a wrapper. + ''' tpps = mace_problems() + [ oeq.TPProblem( "5x5e", @@ -306,25 +310,20 @@ def problem(self, request, dtype): def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax): mode = request.param - if mode == "native": - cls = oeq.TensorProduct - if with_jax: - import openequivariance.jax.TensorProduct as jax_tp + if with_jax: + import openequivariance.jax.TensorProduct as jax_tp + from openequivariance.jax import transpose_irreps - cls = jax_tp - tp = cls(problem, **extra_tp_constructor_args) - return tp, problem + tp_base_cls = jax_tp else: - if with_jax: - import openequivariance.jax.TensorProduct as jax_tp - from openequivariance.jax import transpose_irreps + from openequivariance._torch.utils import transpose_irreps - tp_base_cls = jax_tp - else: - from openequivariance._torch.utils import transpose_irreps - - tp_base_cls = oeq.TensorProduct + tp_base_cls = oeq.TensorProduct + if mode == "native": + tp = tp_base_cls(problem, **extra_tp_constructor_args) + return tp, problem + else: class TransposeWrapperTensorProduct(tp_base_cls): def forward(self, x, y, W): x_t = transpose_irreps( From 45f236a07521b7c3a788b3ae02025d4467747feb Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Mar 2026 21:48:52 -0700 Subject: [PATCH 24/27] Compacted tests. --- tests/batch_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index 805af938..431e9b7d 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -275,8 +275,7 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): class TestIrMul(TPCorrectness): ''' - Tests both the ir_mul layout and the transpose_irreps functions - via a wrapper. + Tests both the ir_mul layout and the transpose_irreps functions. ''' tpps = mace_problems() + [ oeq.TPProblem( From 26b746bf27dd7df1d222dbada35bc49e8e679da3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Mar 2026 21:57:12 -0700 Subject: [PATCH 25/27] Almost there. --- tests/batch_test.py | 50 +++++++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index 431e9b7d..2d376379 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -8,6 +8,7 @@ correctness_double_backward, correctness_forward, ) +from openequivariance.benchmark.test_buffers import get_random_buffers_forward from openequivariance.benchmark.problems import ( diffdock_problems, e3nn_torch_tetris_poly_problems, @@ -274,9 +275,10 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): class TestIrMul(TPCorrectness): - ''' - Tests both the ir_mul layout and the transpose_irreps functions. - ''' + """ + Tests both the ir_mul layout and the transpose_irreps functions. + """ + tpps = mace_problems() + [ oeq.TPProblem( "5x5e", @@ -323,6 +325,7 @@ def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax): tp = tp_base_cls(problem, **extra_tp_constructor_args) return tp, problem else: + class TransposeWrapperTensorProduct(tp_base_cls): def forward(self, x, y, W): x_t = transpose_irreps( @@ -370,40 +373,29 @@ def forward(self, x, y, w): def _problem_dtype(self, problem): return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64 - def _make_inputs(self, problem, batch_size, rng, dtype, device): - in1 = torch.tensor( - rng.uniform(size=(batch_size, problem.irreps_in1.dim)), - dtype=dtype, - device=device, - ) - in2 = torch.tensor( - rng.uniform(size=(batch_size, problem.irreps_in2.dim)), - dtype=dtype, - device=device, - ) - weights_size = ( - (problem.weight_numel,) - if problem.shared_weights - else (batch_size, problem.weight_numel) - ) - weights = torch.tensor( - rng.uniform(size=weights_size), - dtype=dtype, - device=device, + def _make_inputs(self, problem, batch_size, dtype, device, prng_seed=12345): + dtype_map = {torch.float32: np.float32, torch.float64: np.float64} + buffer_problem = problem.clone() + buffer_problem.irrep_dtype = dtype_map[dtype] + buffer_problem.weight_dtype = dtype_map[dtype] + + in1_np, in2_np, weights_np, _ = get_random_buffers_forward( + buffer_problem, batch_size=batch_size, prng_seed=prng_seed ) - return in1, in2, weights + + return [ + torch.tensor(arr, dtype=dtype, device=device) + for arr in [in1_np, in2_np, weights_np] + ] def test_submodule_dtype_conversion(self, parent_module_and_problem): """Test that calling .to() on parent module properly converts TensorProduct submodule""" parent, problem = parent_module_and_problem batch_size = 10 - rng = np.random.default_rng(12345) device = "cuda" input_dtype = self._problem_dtype(problem) - in1, in2, weights = self._make_inputs( - problem, batch_size, rng, input_dtype, device - ) + in1, in2, weights = self._make_inputs(problem, batch_size, input_dtype, device) output1 = parent(in1, in2, weights) assert output1.dtype == in1.dtype, ( @@ -418,7 +410,7 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem): parent.to(target_dtype) in1_new, in2_new, weights_new = self._make_inputs( - problem, batch_size, rng, target_dtype, device + problem, batch_size, target_dtype, device, prng_seed=23456 ) output2 = parent(in1_new, in2_new, weights_new) From aa0488d05d253062d9971cc68b2c95bbb2bf06f9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Mar 2026 22:41:52 -0700 Subject: [PATCH 26/27] Changes to get the documentation to build. --- CHANGELOG.md | 4 +++- docs/api.rst | 6 +++++- docs/conf.py | 2 ++ openequivariance/openequivariance/__init__.py | 2 ++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6aa11d05..723857d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,9 @@ This release brings `ir_mul` layout support for OpenEquivariance. Pass the parameter `layout='ir_mul'` to any `TPProblem` instance to use a transposed layout for the input and output -irreps. +irreps. To transpose input and output irreps use +`oeq.transpose_irreps` or `oeq.jax.transpose_irreps`; +see our API page for usage details. ### v0.6.4 (2026-03-05) Bugfix: added missing MLIR lowerings for diff --git a/docs/api.rst b/docs/api.rst index c21b918f..15b1aec9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -30,6 +30,8 @@ PyTorch API :undoc-members: :exclude-members: name +.. autofunction:: openequivariance.transpose_irreps + .. autofunction:: openequivariance.torch_to_oeq_dtype .. autofunction:: openequivariance.torch_ext_so_path @@ -54,7 +56,9 @@ breaking the PyTorch version of OpenEquivariance. .. autoclass:: openequivariance.jax.TensorProductConv :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn :undoc-members: - :exclude-members: + :exclude-members: + +.. autofunction:: openequivariance.jax.transpose_irreps Common API --------------------- diff --git a/docs/conf.py b/docs/conf.py index 540cf37e..70dd285e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,6 +38,8 @@ "openequivariance._torch.extlib", "openequivariance.jax.extlib", "openequivariance_extjax", + "openequivariance.jax.jvp.tp_prim", + "openequivariance.jax.jvp.conv_prim", "jinja2", "numpy", ] diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 35c4b542..1a3681f5 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -37,6 +37,7 @@ def _check_package_editable(): from openequivariance._torch.TensorProduct import TensorProduct from openequivariance._torch.TensorProductConv import TensorProductConv + from openequivariance._torch.utils import transpose_irreps from openequivariance._torch.extlib import ( torch_ext_so_path as torch_ext_so_path_internal, @@ -111,4 +112,5 @@ def TensorProductConv(*args, **kwargs): "_check_package_editable", "torch_ext_so_path", "jax", + "transpose_irreps", ] From a0e36041489ea719831a5248c1bc0eb1608f5d3d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 23 Mar 2026 22:42:29 -0700 Subject: [PATCH 27/27] Minor change to test. --- tests/batch_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/batch_test.py b/tests/batch_test.py index 2d376379..ff1cd1ce 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -317,7 +317,7 @@ def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax): tp_base_cls = jax_tp else: - from openequivariance._torch.utils import transpose_irreps + from openequivariance import transpose_irreps tp_base_cls = oeq.TensorProduct