1010from openequivariance .implementations .TensorProduct import TensorProduct
1111from openequivariance import TPProblem
1212
13+
1314class TensorProductConv (torch .nn .Module , LoopUnrollConv ):
1415 """
15- Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
16- :math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes
16+ Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
17+ :math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes
1718
1819 .. math::
1920 z_i = \sum_{(i, j, e) \in \mathcal{N}(i)} W_e (x_j \otimes_{\\ textrm{CG}} y_e)
20- where :math:`(i, j, e) \in \mathcal{N}(i)` indicates that node :math:`i` is connected to node :math:`j`
21- via the edge indexed :math:`e`.
21+ where :math:`(i, j, e) \in \mathcal{N}(i)` indicates that node :math:`i` is connected to node :math:`j`
22+ via the edge indexed :math:`e`.
2223
2324 This class offers multiple options to perform the summation: an atomic algorithm and a deterministic algorithm
2425 that relies on a sorted adjacency matrix input. If you use the determinstic algorithm, you must also supply
25- a permutation to transpose the adjacency matrix.
26+ a permutation to transpose the adjacency matrix.
2627
2728 :param problem: Specification of the tensor product.
2829 :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
2930 fixup-based algorithm. `Default`: ``False``.
3031 :param kahan: if ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
31- the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
32-
32+ the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
33+
3334 """
3435
3536 def __init__ (
3637 self ,
3738 problem : TPProblem ,
38- deterministic : bool = False ,
39- kahan : bool = False ,
40- torch_op = True
39+ deterministic : bool = False ,
40+ kahan : bool = False ,
41+ torch_op = True ,
4142 ):
4243 torch .nn .Module .__init__ (self )
4344 LoopUnrollConv .__init__ (
@@ -64,26 +65,26 @@ def forward(
6465 cols : torch .Tensor ,
6566 sender_perm : Optional [torch .Tensor ] = None ,
6667 ) -> torch .Tensor :
67- '''
68- Computes the fused CG tensor product + convolution.
68+ """
69+ Computes the fused CG tensor product + convolution.
6970
7071 :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
7172 :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
72- :param W: Tensor of datatype ``problem.weight_dtype`` and shape
73+ :param W: Tensor of datatype ``problem.weight_dtype`` and shape
7374
7475 * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
7576 * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
76-
77- :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
77+
78+ :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
7879 datatype ``torch.int64``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
7980 :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
80- datatype ``torch.int64``.
81- :param sender_perm: Tensor of shape ``[|E|]`` and ``torch.int64`` datatype containing a
81+ datatype ``torch.int64``.
82+ :param sender_perm: Tensor of shape ``[|E|]`` and ``torch.int64`` datatype containing a
8283 permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
8384 Must be provided when ``deterministic=True``.
84-
85+
8586 :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
86- '''
87+ """
8788 if sender_perm is None :
8889 return torch .ops .libtorch_tp_jit .jit_conv_forward (
8990 self .internal ,
0 commit comments