Releases: e3nn/e3nn-jax
Releases · e3nn/e3nn-jax
2022-02-01
import jax.numpy as jnp
import e3nn_jax as e3nn
coeffs = e3nn.IrrepsArray("0e + 1o", jnp.array([1, 2, 0, 0.0]))
signal = e3nn.to_s2grid(coeffs, 50, 69, quadrature="gausslegendre")
import plotly.graph_objects as go
go.Figure([go.Surface(signal.plotly_surface())])Added
e3nn.SphericalSignalclass to represent signals on the sphereSignal on the Spheresection in the documentatione3nn.Irreps.D_from_log_coordinatesrotation_angle_from_*functionse3nn.to_s2pointfunction
Changed
- Wigner D matrices are computed from the log coordinates which makes 1 instead of 3 calls to
expm. - [BREAKING]
e3nn.util.assert_output_dtyperenamed toe3nn.util.assert_output_dtype_matches_input_dtype - [BREAKING] Update
experimental.point_convolutionto use the last changes. - [BREAKING] changed the
e3nn.to_s2gridande3nn.from_s2gridsignature and default normalization.
Removed
- [BREAKING] All the
haikumodules from the main module. They are now in thee3nn.haikusubmodule. - [BREAKING]
e3nn.wigner_Din favor ofe3nn.Irrep.D_from_*
Fixed
- Removed
jax.jitdecorator toIrreps.D_from_*that was causing a bug.
2022-01-20
e3nn.reduced_symmetric_tensor_product_basis("0e + 1o + 2e + 3o + 4e + 5o", 3, keep_ir="0e")Only takes 10 seconds.
Added
e3nn.s2grid_vectorsande3nn.pad_to_plot_on_s2gridto help plotting signals on the spheree3nn.util.assert_output_dtypeto check the output dtype of a functione3nn.s2_irrepsis a function to create the irreps of the coefficients of a signal on the spheree3nn.reduced_antisymmetric_tensor_product_basisto compute the basis of the reduced antisymmetric tensor productIrrepsArray * scalaris supported if the number of scalars matches the number of irreps
Changed
- Optimize the
reduced_symmetric_tensor_product. It is now up to 100x faster than the previous implementation. e3nn.from_s2gridande3nn.to_s2gridare now more flexible with input and output irreps, you can skip some l's and have them in any order- [BREAKING]
e3nn.from_s2gridrequires andirrepsargument instead of almaxargument
Fixed
- Increase robusteness of
e3nn.spherical_harmonicstowardsnanwhennormalize=True
Full Changelog: 0.14.0...0.15.0
2022-12-16
Added
IrrepsArray.astypeto cast the underlying arraye3nn.flax.MultiLayerPerceptronande3nn.haiku.MultiLayerPerceptrone3nn.IrrepsArray.from_list(..., dtype)- Add sparse tensor product as an option in
e3nn.tensor_productand related functions. It sparsify the clebsch gordan coefficients. It has more inpact whenfused=True. It is disabled by default because no improvement was observed in the benchmarks. - Add
log_coordinatesalong the other parameterizations of SO(3).e3nn.log_coordinates_to_matrix,e3nn.rand_log_coordinates, etc.
Fixed
- set dtype for all
jnp.zeros(..., dtype)calls in the codebase - set dtype for all
jnp.ones(..., dtype)calls in the codebase
Removed
- [BREAKING]
e3nn.full_tensor_productin favor ofe3nn.tensor_product - [BREAKING]
e3nn.FunctionalTensorSquarein favor ofe3nn.tensor_square - [BREAKING]
e3nn.TensorSquarein favor ofe3nn.tensor_square - [BREAKING]
e3nn.IrrepsArray.catin favor ofe3nn.concatenate - [BREAKING]
e3nn.IrrepsArray.randnin favor ofe3nn.normal - [BREAKING]
e3nn.Irreps.randnin favor ofe3nn.normal - [BREAKING]
e3nn.Irreps.transform_by_*in favor ofe3nn.IrrepsArray.transform_by_*
Changed
- moves
BatchNormandDropouttoe3nn.haikusubmodule, will remove them from the main module in the future. - move
e3nn.haiku.FullyConnectedTensorProductinhaikusubmodule. Undeprecate it because it's faster thane3nn.tensor_productfollowed bye3nn.Linear. This is becauseopteinsumoptimizes the contraction of the two operations.
2022-12-14
Introduce flax and haiku submodules.
- port
Linear - port all modules (Dropout, Batchnorm, ...) to the submodules
Example with Linear in flax
input = e3nn.normal("2x0e + 3x1e")
linear = e3nn.flax.Linear("3x0e + 1e")
w = linear.init(jax.random.PRNGKey(0), input)
linear.apply(w, input)Added
e3nn.scatter_sumto replacee3nn.index_add.e3nn.index_addis deprecated.- add
flaxandhaikusubmodules. Plan to migrate all modules toflaxandhaikuin the future. - Implement
e3nn.flax.Linearand movee3nn.Linearine3nn.haiku.Linear.
2022-12-07
import e3nn_jax as e3nn
irreps = e3nn.Irreps("0e + 1o")
print(3 * irreps) # prints 3x0e+3x1oChanged
- [BREAKING]
3 * e3nn.Irreps("0e + 1o")now returns3x0e + 3x1oinstead of1x0e + 1x1o + 1x0e + 1x1o + 1x0e + 1x1o - [BREAKING] in Linear, renamed
num_weightstonum_indexed_weightsbecause it was confusing.
Added
e3nn.Irreps("3x0e + 6x1o") // 3returns1x0e + 2x1o
Fixed
s2gridis now jitable
2022-11-16
New method regroup aim to replace .sort and .simplify most of the time.
Irreps("1e + 0e + 1e + 0x2e").regroup() # 1x0e+2x1eNew default behavior of tensor_product
e3nn.tensor_product("0e + 1o + 0e", "1o + 1o") # version<0.12.0
# 1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x1e+1x1e+1x2e+1x2e
e3nn.tensor_product("0e + 1o + 0e", "1o + 1o") # version==0.12.0
# 2x0e+4x1o+2x1e+2x2e
e3nn.tensor_product("0e + 1o + 0e", "1o + 1o", regroup_output=False) # version==0.12.0
# 1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x1e+1x1e+1x2e+1x2eAdded
e3nn.Irreps.regroupande3nn.IrrepsArray.regroupto regroup irreps. Equivalent tosortfollowed bysimplify.- add
regroup_outputparameter toe3nn.tensor_productande3nn.tensor_squareto regroup the output irreps.
Changed
e3nn.IrrepsArray.convertis now private (e3nn.IrrepsArray._convert) because it's recommended to other methods instead.- breaking change use
input.regroup()ine3nn.Linearwhich can change the structure of the parameters dictionary. - breaking change
regroup_outputisTrueby default ine3nn.tensor_productande3nn.tensor_square. - To facilitate debugging, if not
keyis provided toe3nn.normalit will use the hash of the irreps. - breaking change changed normalization of
e3nn.tensor_squarein the case ofnormalized_input=True
Removed
- Deprecate
e3nn.TensorSquare
2022-11-13
e3nn.Linear can create different weights that are then selected by an index.
z = jnp.array([0, 0, 1, 3]) # [num_nodes]
x = e3nn.IrrepsArray("8x0e + 8x1o", _) # [num_nodes, irreps]
e3nn.Linear("16x0e + 16x1o", num_weights=4)(z, x)Added
e3nn.Linearnow supports integer "weights" inputs.e3nn.Linearnow supportsnameargument.- Add
.dtypetoIrrepsArrayto get the dtype of the underlying array.
Changed
e3nn.MultiLayerPerceptronnames its layerslinear_0,linear_1, etc.
2022-11-08
This release allow e3nn.Linear to get some scalars as input that are blent with the parameters.
This allow a clean and strait forward implementation of depth-wise message passing (uvu convolution).
def message_passing_convolution(
node_feats: e3nn.IrrepsArray, # [n_nodes, irreps]
edge_attrs: e3nn.IrrepsArray, # [n_edges, irreps]
edge_feats: e3nn.IrrepsArray, # [n_edges, irreps]
...
target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
messages = e3nn.Linear(target_irreps)(
e3nn.MultiLayerPerceptron(3 * [64], activation)(edge_feats), # [n_edges, 64]
e3nn.tensor_product(node_feats[senders], edge_attrs), # [n_edges, irreps]
) # [n_edges, irreps]
zeros = e3nn.IrrepsArray.zeros(messages.irreps, (node_feats.shape[0],))
node_feats = zeros.at[receivers].add(messages) / jnp.sqrt(
avg_num_neighbors
) # [n_nodes, irreps]
return node_feats
def depthwise_convolution(
node_feats: e3nn.IrrepsArray, # [n_nodes, channel_in, irreps]
...
channel_out: int,
target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
node_feats = e3nn.Linear(node_feats.irreps, channel_out)(node_feats)
node_feats = hk.vmap(
lambda x: message_passing_convolution(
x,
edge_attrs,
edge_feats,
senders,
receivers,
avg_num_neighbors,
target_irreps,
activation,
),
in_axes=1,
out_axes=1,
split_rng=False,
)(node_feats)
node_feats = e3nn.Linear(target_irreps, channel_out)(node_feats)
return node_featsAdded
- s2grid:
e3nn.from_s2gridande3nn.to_s2gridthanks to @songk42 for the contribution - argument
max_order: intto functionreduced_tensor_product_basisto be able to limit the polynomial order of the basis MultiLayerPerceptronacceptsIrrepsArrayas input and outpute3nn.Linearaccepts optional weights as arguments that will be internally mixed with the free parameters. Very usefyul to implement the depthwise convolution
Changed
- breaking change
e3nn.normalhas a new argument to get normalized vectors. - breaking change
e3nn.tensor_squarenow distinguishes betweennormalization=normandnormalized_input=True.
2022-10-24
Added
e3nn.SymmetricTensorProductoperation: a parameterized version ofx + x^2 + x^3 + ....e3nn.soft_envelopea smoothC^infenvelope radial function.e3nn.tensor_square
2022-10-05
Added
Irrep.generatorsandIrreps.generatorsfunctions to get the generators of the representations.e3nn.besselfunctionslice_by_mul,slice_by_dimandslice_by_chunkfunctions toIrrepsandIrrepsArray
Changed
- breaking change
e3nn.soft_one_hot_linspacedoes not supportbesselanymore. Usee3nn.besselinstead. e3nn.gateis now more flexible of the input format, see examples in the docstring.
Removed
- breaking change
IrrepsArray.split
