Skip to content

Commit efc2149

Browse files
committed
Reformatted.
1 parent 89eb51b commit efc2149

6 files changed

Lines changed: 52 additions & 49 deletions

File tree

docs/conf.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,30 @@
99
# -- Project information -----------------------------------------------------
1010
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
1111

12-
project = 'OpenEquivariance'
13-
copyright = '2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory.'
14-
author = 'Vivek Bharadwaj, Austin Glover, Aydin Buluc, James Demmel'
12+
project = "OpenEquivariance"
13+
copyright = "2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory."
14+
author = "Vivek Bharadwaj, Austin Glover, Aydin Buluc, James Demmel"
1515

1616
# -- General configuration ---------------------------------------------------
1717
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
1818

1919
extensions = []
2020

21-
templates_path = ['_templates']
22-
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
23-
21+
templates_path = ["_templates"]
22+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
2423

2524

2625
# -- Options for HTML output -------------------------------------------------
2726
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
2827

29-
html_theme = 'furo'
30-
html_static_path = ['_static']
28+
html_theme = "furo"
29+
html_static_path = ["_static"]
3130

3231
extensions = [
33-
'sphinx.ext.autodoc',
32+
"sphinx.ext.autodoc",
3433
]
3534

36-
sys.path.insert(0, str(Path('..').resolve()))
35+
sys.path.insert(0, str(Path("..").resolve()))
3736

38-
autodoc_mock_imports = ['torch', 'openequivariance.extlib', 'jinja2']
37+
autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2"]
3938
autodoc_typehints = "description"

openequivariance/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
)
1212
from openequivariance.implementations.utils import torch_to_oeq_dtype
1313

14-
__version__ = None
14+
__version__ = None
1515
try:
1616
__version__ = version("openequivariance")
1717
except Exception as e:
1818
print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr)
1919

20+
2021
def _check_package_editable():
2122
import json
2223
from importlib.metadata import Distribution
@@ -29,10 +30,10 @@ def _check_package_editable():
2930

3031

3132
def torch_ext_so_path():
32-
'''
33+
"""
3334
:returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
34-
from the PyTorch C++ Interface.
35-
'''
35+
from the PyTorch C++ Interface.
36+
"""
3637
return openequivariance.extlib.torch_module.__file__
3738

3839

openequivariance/implementations/TensorProduct.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ class TensorProduct(torch.nn.Module, LoopUnrollTP):
99
backward, and double-backward passes using JIT-compiled kernels. Initialization
1010
fails if:
1111
12-
* There are no visible GPUs.
13-
* The provided tensor product specification is unsupported.
14-
12+
* There are no visible GPUs.
13+
* The provided tensor product specification is unsupported.
14+
1515
:param problem: Specification of the tensor product.
1616
"""
17+
1718
def __init__(self, problem: TPProblem, torch_op=True):
1819
torch.nn.Module.__init__(self)
1920
LoopUnrollTP.__init__(self, problem, torch_op)
@@ -26,19 +27,19 @@ def name():
2627
def forward(
2728
self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor
2829
) -> torch.Tensor:
29-
'''
30+
"""
3031
Computes :math:`W (x \otimes_{\\textrm{CG}} y)`, identical to
31-
``o3.TensorProduct.forward``.
32+
``o3.TensorProduct.forward``.
3233
3334
:param x: Tensor of shape ``[batch_size, problem.irreps_in1.dim()]``, datatype
3435
``problem.irrep_dtype``.
3536
:param y: Tensor of shape ``[batch_size, problem.irreps_in2.dim()]``, datatype
3637
``problem.irrep_dtype``.
37-
:param W: Tensor of datatype ``problem.weight_dtype`` and shape
38+
:param W: Tensor of datatype ``problem.weight_dtype`` and shape
3839
3940
* ``[batch_size, problem.weight_numel]`` if ``problem.shared_weights=False``
4041
* ``[problem.weight_numel]`` if ``problem.shared_weights=True``
4142
4243
:return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
43-
'''
44+
"""
4445
return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W)

openequivariance/implementations/convolution/TensorProductConv.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,35 @@
1010
from openequivariance.implementations.TensorProduct import TensorProduct
1111
from openequivariance import TPProblem
1212

13+
1314
class TensorProductConv(torch.nn.Module, LoopUnrollConv):
1415
"""
15-
Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
16-
:math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes
16+
Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
17+
:math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes
1718
1819
.. math::
1920
z_i = \sum_{(i, j, e) \in \mathcal{N}(i)} W_e (x_j \otimes_{\\textrm{CG}} y_e)
20-
where :math:`(i, j, e) \in \mathcal{N}(i)` indicates that node :math:`i` is connected to node :math:`j`
21-
via the edge indexed :math:`e`.
21+
where :math:`(i, j, e) \in \mathcal{N}(i)` indicates that node :math:`i` is connected to node :math:`j`
22+
via the edge indexed :math:`e`.
2223
2324
This class offers multiple options to perform the summation: an atomic algorithm and a deterministic algorithm
2425
that relies on a sorted adjacency matrix input. If you use the determinstic algorithm, you must also supply
25-
a permutation to transpose the adjacency matrix.
26+
a permutation to transpose the adjacency matrix.
2627
2728
:param problem: Specification of the tensor product.
2829
:param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
2930
fixup-based algorithm. `Default`: ``False``.
3031
:param kahan: if ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
31-
the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
32-
32+
the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
33+
3334
"""
3435

3536
def __init__(
3637
self,
3738
problem: TPProblem,
38-
deterministic: bool=False,
39-
kahan: bool =False,
40-
torch_op=True
39+
deterministic: bool = False,
40+
kahan: bool = False,
41+
torch_op=True,
4142
):
4243
torch.nn.Module.__init__(self)
4344
LoopUnrollConv.__init__(
@@ -64,26 +65,26 @@ def forward(
6465
cols: torch.Tensor,
6566
sender_perm: Optional[torch.Tensor] = None,
6667
) -> torch.Tensor:
67-
'''
68-
Computes the fused CG tensor product + convolution.
68+
"""
69+
Computes the fused CG tensor product + convolution.
6970
7071
:param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
7172
:param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
72-
:param W: Tensor of datatype ``problem.weight_dtype`` and shape
73+
:param W: Tensor of datatype ``problem.weight_dtype`` and shape
7374
7475
* ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
7576
* ``[problem.weight_numel]`` if ``problem.shared_weights=True``
76-
77-
:param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
77+
78+
:param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
7879
datatype ``torch.int64``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
7980
:param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
80-
datatype ``torch.int64``.
81-
:param sender_perm: Tensor of shape ``[|E|]`` and ``torch.int64`` datatype containing a
81+
datatype ``torch.int64``.
82+
:param sender_perm: Tensor of shape ``[|E|]`` and ``torch.int64`` datatype containing a
8283
permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
8384
Must be provided when ``deterministic=True``.
84-
85+
8586
:return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
86-
'''
87+
"""
8788
if sender_perm is None:
8889
return torch.ops.libtorch_tp_jit.jit_conv_forward(
8990
self.internal,

openequivariance/implementations/e3nn_lite.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -359,23 +359,23 @@ class Instruction(NamedTuple):
359359

360360

361361
class TPProblem:
362-
'''
363-
Specification for a CG tensor product. All parameters from
362+
"""
363+
Specification for a CG tensor product. All parameters from
364364
e3nn's ``o3.TensorProduct`` are available, along with additional
365365
parameters for the types of weights and irreps.
366366
367367
:param irreps_in1: Irreps for the first CG argument
368368
:param irreps_in2: Irreps for the second CG argument
369-
:param irreps_out: Irreps for the output
370-
:param instructions: A list of 5-tuples, each of
369+
:param irreps_out: Irreps for the output
370+
:param instructions: A list of 5-tuples, each of
371371
the form ``(i_in1, i_in2, i_out, has_weight, path_weight)``.
372372
``i_in1``, ``i_in2``, and ``i_out`` each index
373373
an Irrep from ``irreps_in1``, ``irreps_in2``, and
374374
``irreps_in3``, respectively. ``has_weight`` (True / False)
375375
controls whether trainable weights are included for the
376-
instruction, and ``path_weight`` controls output normalization.
376+
instruction, and ``path_weight`` controls output normalization.
377377
:param irrep_dtype: Datatype of irrep inputs; one of ``np.float32`` or ``np.float64``.
378-
*Default*: ``np.float32``.
378+
*Default*: ``np.float32``.
379379
:param weight_dtype: Datatype of weights; one of ``np.float32`` or ``np.float64``.
380380
*Default*: ``np.float32``.
381381
:param label: A name for this problem specification (useful for testing / benchmarking).
@@ -384,7 +384,8 @@ class TPProblem:
384384
:param internal_weights: Must be False; OpenEquivariance does not support internal weights. *Default*: False.
385385
:param irrep_normalization: One of ``["component", "norm", "none"]``. *Default*: "component".
386386
:param path_normalization: One of ``["element", "path", "none"]``. *Default*: "element".
387-
'''
387+
"""
388+
388389
instructions: List[Any]
389390
shared_weights: bool
390391
internal_weights: bool

openequivariance/implementations/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def filter_and_analyze_problem(problem):
7979
f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}"
8080
)
8181

82-
assert not problem.internal_weights, (
82+
assert not problem.internal_weights, (
8383
f"Openequivariance does not support internal weights, got {problem.internal_weights}"
8484
)
8585

0 commit comments

Comments
 (0)