Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1eac81f
More progress.
vbharadwaj-bk Mar 6, 2026
54f2ee8
Merge branch 'main' into ir_mul
vbharadwaj-bk Mar 21, 2026
33ed045
More diffs.
vbharadwaj-bk Mar 21, 2026
c78d48f
Avoided transposing irreps once the shared memory load is complete.
vbharadwaj-bk Mar 22, 2026
85e988f
Made more progress.
vbharadwaj-bk Mar 22, 2026
379fd28
Convolution test is failing.
vbharadwaj-bk Mar 22, 2026
7dcd887
More bugfixes.
vbharadwaj-bk Mar 22, 2026
4806748
Fixed more stuff.
vbharadwaj-bk Mar 22, 2026
44ecf0e
Compacted everything.
vbharadwaj-bk Mar 22, 2026
4947781
compaction.
vbharadwaj-bk Mar 22, 2026
5968745
File renaming.
vbharadwaj-bk Mar 22, 2026
161a1b6
Revert "File renaming."
vbharadwaj-bk Mar 22, 2026
7fd7ab6
Fixed a regression.
vbharadwaj-bk Mar 22, 2026
3c9ed29
More compaction.
vbharadwaj-bk Mar 22, 2026
95783ef
More test cleaning.
vbharadwaj-bk Mar 22, 2026
9de117e
More refactoring.
vbharadwaj-bk Mar 22, 2026
aa9bbf6
Ruff.
vbharadwaj-bk Mar 23, 2026
5a5080a
Even more refactoring.
vbharadwaj-bk Mar 23, 2026
f5b7a26
Added include for automatic string conversion.
vbharadwaj-bk Mar 23, 2026
b80d15c
Competed benchmarking.
vbharadwaj-bk Mar 23, 2026
a5d0fe5
Modified changelog before release.
vbharadwaj-bk Mar 23, 2026
b5866a2
Began adding a pair of robust transpose functions.
vbharadwaj-bk Mar 23, 2026
9544e14
Wrote a compact test for the transpose functions.
vbharadwaj-bk Mar 24, 2026
5440d55
Compacted diff further.
vbharadwaj-bk Mar 24, 2026
45f236a
Compacted tests.
vbharadwaj-bk Mar 24, 2026
26b746b
Almost there.
vbharadwaj-bk Mar 24, 2026
aa0488d
Changes to get the documentation to build.
vbharadwaj-bk Mar 24, 2026
a0e3604
Minor change to test.
vbharadwaj-bk Mar 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
## Latest Changes

### v0.6.5 (2026-03-22)
This release brings `ir_mul` layout support for
OpenEquivariance. Pass the parameter
`layout='ir_mul'` to any `TPProblem` instance to use
a transposed layout for the input and output
irreps. To transpose input and output irreps use
`oeq.transpose_irreps` or `oeq.jax.transpose_irreps`;
see our API page for usage details.

### v0.6.4 (2026-03-05)
Bugfix: added missing MLIR lowerings for
a pair of JAX primitives (thanks @teddykoker!)
Expand Down
6 changes: 5 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ PyTorch API
:undoc-members:
:exclude-members: name

.. autofunction:: openequivariance.transpose_irreps

.. autofunction:: openequivariance.torch_to_oeq_dtype

.. autofunction:: openequivariance.torch_ext_so_path
Expand All @@ -54,7 +56,9 @@ breaking the PyTorch version of OpenEquivariance.
.. autoclass:: openequivariance.jax.TensorProductConv
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
:undoc-members:
:exclude-members:
:exclude-members:

.. autofunction:: openequivariance.jax.transpose_irreps

Common API
---------------------
Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"openequivariance._torch.extlib",
"openequivariance.jax.extlib",
"openequivariance_extjax",
"openequivariance.jax.jvp.tp_prim",
"openequivariance.jax.jvp.conv_prim",
"jinja2",
"numpy",
]
Expand Down
2 changes: 2 additions & 0 deletions openequivariance/openequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _check_package_editable():

from openequivariance._torch.TensorProduct import TensorProduct
from openequivariance._torch.TensorProductConv import TensorProductConv
from openequivariance._torch.utils import transpose_irreps

from openequivariance._torch.extlib import (
torch_ext_so_path as torch_ext_so_path_internal,
Expand Down Expand Up @@ -111,4 +112,5 @@ def TensorProductConv(*args, **kwargs):
"_check_package_editable",
"torch_ext_so_path",
"jax",
"transpose_irreps",
]
56 changes: 2 additions & 54 deletions openequivariance/openequivariance/_torch/CUETensorProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from openequivariance.core.TensorProductBase import TensorProductBase
from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.benchmark.tpp_creation_utils import (
from openequivariance.core.logging import getLogger
from openequivariance.benchmark.problems import (
ChannelwiseTPP,
FullyConnectedTPProblem,
SingleInstruction,
)
from openequivariance.core.utils import count_cg_non_zero

os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"

Expand Down Expand Up @@ -235,57 +234,6 @@ def benchmark_backward(
kernel_names=self.kernel_names,
)

# Copied over from loop unroller to match arithmetic intensity on roofline plots
def calculate_flops_forward(self, batch_size: int) -> dict:
if self.is_uvw:
return super().calculate_flops_forward(batch_size)
else:
tpp = self.config
flop_count = {
"CG_decomposition": 0,
"linear_combination": 0,
"outer_products": 0,
}
for ins in tpp.instructions:
l1, l2, l3 = (
tpp.irreps_in1[ins.i_in1].ir.l,
tpp.irreps_in2[ins.i_in2].ir.l,
tpp.irreps_out[ins.i_out].ir.l,
)
flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
ins.path_shape[0] * ins.path_shape[1]
)
flop_count["linear_combination"] += (
(2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0
)

flop_count["CG_decomposition"] *= 3 * batch_size
flop_count["linear_combination"] *= (
batch_size # Weights do not require FMA here
)
flop_count["total"] = sum(flop_count.values())
return flop_count

def calculate_flops_backward(self, batch_size: int) -> dict:
if self.is_uvw:
return super().calculate_flops_backward(batch_size)
else:
tpp = self.config
flop_count = {"backward": 0}
for ins in tpp.instructions:
l1, l2, l3 = (
tpp.irreps_in1[ins.i_in1].ir.l,
tpp.irreps_in2[ins.i_in2].ir.l,
tpp.irreps_out[ins.i_out].ir.l,
)
flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * (
ins.path_shape[0] * ins.path_shape[1]
)

flop_count["backward"] *= 9 * batch_size
flop_count["total"] = sum(flop_count.values())
return flop_count

@staticmethod
def name():
return "CUETensorProduct"
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from openequivariance.core.TensorProductBase import TensorProductBase
from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.core.logging import getLogger
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin

TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")
Expand Down
2 changes: 1 addition & 1 deletion openequivariance/openequivariance/_torch/TensorProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from openequivariance._torch import extlib
import torch
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.core.logging import getLogger
from openequivariance._torch.utils import (
reorder_torch,
string_to_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
enum_to_torch_dtype,
)

from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.core.logging import getLogger
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv

logger = getLogger()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.core.logging import getLogger

oeq_root = str(Path(__file__).parent.parent.parent)

Expand Down
70 changes: 70 additions & 0 deletions openequivariance/openequivariance/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from types import MappingProxyType
from openequivariance.core.utils import DTypeEnum
from openequivariance.core.e3nn_lite import Irreps


def reorder_helper(schedule, weights_in, direction, has_batch_dim):
Expand Down Expand Up @@ -75,3 +76,72 @@ def string_to_tensor(text: str) -> torch.Tensor:
result = torch.tensor(np_bytes, device="cpu")
result.requires_grad = False
return result


def transpose_irreps(
array: torch.Tensor,
irreps: Irreps,
src_layout: str,
dst_layout: str,
) -> torch.Tensor:
r"""
Transpose irrep-packed feature tensors between ``mul_ir`` and ``ir_mul`` layouts.

The function operates on the trailing feature dimension and preserves all leading
batch dimensions. It uses only differentiable PyTorch tensor operations, so gradients
propagate through the transpose.

:param array: Input feature tensor with shape ``[..., irreps.dim]``.
:param irreps: Irreps specification describing how the trailing feature dimension
is partitioned into irrep blocks.
:param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.
:param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.


:returns: Tensor in ``dst_layout`` with the same shape, dtype, and device as ``array``.
If ``src_layout == dst_layout``, returns a clone of ``array``.


:raises TypeError: If ``array`` is not a ``torch.Tensor``.
:raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of
``"mul_ir"`` or ``"ir_mul"``.
"""
if src_layout not in ("mul_ir", "ir_mul"):
raise ValueError(f"Unsupported src_layout: {src_layout}")
if dst_layout not in ("mul_ir", "ir_mul"):
raise ValueError(f"Unsupported dst_layout: {dst_layout}")

if not isinstance(array, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(array)}")

out = torch.empty_like(array)

if src_layout == dst_layout:
out.copy_(array)
return out

slices = irreps.slices()
for ir_idx, mul_ir in enumerate(irreps):
mul = mul_ir.mul
dim = mul_ir.ir.dim
seg = slices[ir_idx]
block = array[..., seg.start : seg.stop]

if src_layout == "ir_mul" and dst_layout == "mul_ir":
out[..., seg.start : seg.stop] = (
block.reshape(*block.shape[:-1], dim, mul)
.transpose(-1, -2)
.reshape(*block.shape[:-1], mul * dim)
)
elif src_layout == "mul_ir" and dst_layout == "ir_mul":
out[..., seg.start : seg.stop] = (
block.reshape(*block.shape[:-1], mul, dim)
.transpose(-1, -2)
.reshape(*block.shape[:-1], dim * mul)
)
else:
raise ValueError(
f"Unsupported layout transpose: {src_layout} -> {dst_layout}"
)

return out
18 changes: 13 additions & 5 deletions openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import numpy as np

import openequivariance as oeq
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.benchmark.correctness import (
correctness_backward_conv,
correctness_double_backward_conv,
correctness_forward_conv,
)
from openequivariance.core.logging import getLogger
from openequivariance.core.ConvolutionBase import CoordGraph
from openequivariance.benchmark.benchmark_utils import NpEncoder

Expand Down Expand Up @@ -90,7 +95,8 @@ def run(

if direction == "forward":
if correctness:
correctness = conv.test_correctness_forward(
correctness = correctness_forward_conv(
conv,
graph,
thresh=self.correctness_threshold,
prng_seed=self.prng_seed,
Expand All @@ -105,7 +111,8 @@ def run(

if direction == "backward":
if correctness:
correctness = conv.test_correctness_backward(
correctness = correctness_backward_conv(
conv,
graph,
thresh=self.correctness_threshold,
prng_seed=self.prng_seed,
Expand All @@ -120,8 +127,9 @@ def run(

if direction == "double_backward":
if correctness:
correctness = conv.test_correctness_double_backward(
self.graph,
correctness = correctness_double_backward_conv(
conv,
graph,
thresh=self.correctness_threshold,
prng_seed=self.prng_seed,
reference_implementation=self.reference_impl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from openequivariance._torch.extlib import DeviceProp
from openequivariance.core.TensorProductBase import TensorProductBase

from openequivariance.benchmark.logging_utils import getLogger, bcolors
from openequivariance.core.logging import getLogger, bcolors
from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.benchmark.correctness_utils import (
from openequivariance.benchmark.correctness import (
correctness_forward,
correctness_backward,
correctness_double_backward,
Expand Down
Loading
Loading