Skip to content
3 changes: 2 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Union

import dill
import numpy as np
import torch
from ase.calculators.calculator import Calculator, all_changes
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(

# Load models from files
self.models = [
torch.load(f=model_path, map_location=device)
torch.load(f=model_path, map_location=device, pickle_module=dill)
for model_path in model_paths
]

Expand Down
2 changes: 2 additions & 0 deletions mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

import dill
import torch
from e3nn.util import jit

Expand Down Expand Up @@ -64,6 +65,7 @@ def main():
model = torch.load(
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
pickle_module=dill,
)
if args.dtype == "float64":
model = model.double().to("cpu")
Expand Down
7 changes: 4 additions & 3 deletions mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import ase.data
import ase.io
import dill
import numpy as np
import torch

Expand Down Expand Up @@ -58,7 +59,7 @@ def parse_args() -> argparse.Namespace:
help="Model head used for evaluation",
type=str,
required=False,
default=None
default=None,
)
return parser.parse_args()

Expand All @@ -73,7 +74,7 @@ def run(args: argparse.Namespace) -> None:
device = torch_tools.init_device(args.device)

# Load model
model = torch.load(f=args.model, map_location=args.device)
model = torch.load(f=args.model, map_location=args.device, pickle_module=dill)
model = model.to(
args.device
) # shouldn't be necessary but seems to help with CUDA problems
Expand All @@ -94,7 +95,7 @@ def run(args: argparse.Namespace) -> None:
heads = model.heads
except AttributeError:
heads = None

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
Expand Down
9 changes: 5 additions & 4 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pathlib import Path
from typing import List, Optional

import dill
import torch.distributed
import torch.nn.functional
from e3nn.util import jit
Expand Down Expand Up @@ -142,7 +143,7 @@ def run(args: argparse.Namespace) -> None:
model_foundation = calc.models[0]
else:
model_foundation = torch.load(
args.foundation_model, map_location=args.device
args.foundation_model, map_location=args.device, pickle_module=dill
)
logging.info(
f"Using foundation model {args.foundation_model} as initial checkpoint."
Expand Down Expand Up @@ -731,7 +732,7 @@ def run(args: argparse.Namespace) -> None:
logging.info(f"Saving model to {model_path}")
if args.save_cpu:
model = model.to("cpu")
torch.save(model, model_path)
torch.save(model, model_path, pickle_module=dill)
extra_files = {
"commit.txt": commit.encode("utf-8") if commit is not None else b"",
"config.yaml": json.dumps(
Expand All @@ -740,7 +741,7 @@ def run(args: argparse.Namespace) -> None:
}
if swa_eval:
torch.save(
model, Path(args.model_dir) / (args.name + "_stagetwo.model")
model, Path(args.model_dir) / (args.name + "_stagetwo.model"), pickle_module=dill
)
try:
path_complied = Path(args.model_dir) / (
Expand All @@ -756,7 +757,7 @@ def run(args: argparse.Namespace) -> None:
except Exception as e: # pylint: disable=W0703
pass
else:
torch.save(model, Path(args.model_dir) / (args.name + ".model"))
torch.save(model, Path(args.model_dir) / (args.name + ".model"), pickle_module=dill)
try:
path_complied = Path(args.model_dir) / (
args.name + "_compiled.model"
Expand Down
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
KANNonLinearReadoutBlock,
KANReadoutBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
Expand Down Expand Up @@ -77,6 +79,8 @@
"ZBLBasis",
"LinearNodeEmbeddingBlock",
"LinearReadoutBlock",
"KANReadoutBlock",
"KANNonLinearReadoutBlock",
"EquivariantProductBasisBlock",
"ScaleShiftBlock",
"LinearDipoleReadoutBlock",
Expand Down
91 changes: 91 additions & 0 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from e3nn.util.jit import compile_mode

from mace.tools.compile import simplify_if_compile
from mace.tools.MultKAN_jit import MultKAN
from mace.tools.scatter import scatter_sum

from .irreps_tools import (
Expand Down Expand Up @@ -59,6 +60,96 @@ def forward(
return self.linear(x) # [n_nodes, 1]


@compile_mode("trace")
class KANReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
irrep_out: o3.Irreps = o3.Irreps("0e"),
):
super().__init__()
self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=MLP_irreps)
self.linear_2 = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out)
self.irreps_in = o3.Irreps(irreps_in)
self.hidden_irreps = MLP_irreps
assert MLP_irreps.dim >= 8, "MLP_irreps at least 8!"
dim = [MLP_irreps.dim, MLP_irreps.dim // 2, MLP_irreps.dim // 4, irrep_out.dim]
self.kan = MultKAN(
width=dim,
grid=3,
k=3,
mult_arity=2,
symbolic_enabled=False,
auto_save=False,
save_act=False,
)
# self.kan.speed(compile=True)

def forward(
self,
x: torch.Tensor,
heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x1 = self.linear(x)
return self.kan(x1) + self.linear_2(x) # [n_nodes, irrep_out.dim]

def _make_tracing_inputs(self, n: int):
return [
{"forward": (torch.randn(6, self.irreps_in.dim), None)}
for _ in range(n)
]

def __repr__(self):
return f"{self.__class__.__name__}(dim=[{self.kan.width}])"


@compile_mode("trace")
class KANNonLinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
irrep_out: o3.Irreps = o3.Irreps("0e"),
num_heads: int = 1,
):
super().__init__()
self.irreps_in = o3.Irreps(irreps_in)
self.hidden_irreps = MLP_irreps
self.num_heads = num_heads
self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
# self.linear_2 = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out)
assert MLP_irreps.dim >= 8, "MLP_irreps at least 8!"
dim = [MLP_irreps.dim, MLP_irreps.dim // 2, MLP_irreps.dim // 4, irrep_out.dim]
self.kan = MultKAN(
width=dim,
grid=3,
k=3,
mult_arity=2,
symbolic_enabled=False,
auto_save=False,
save_act=False,
)

def forward(
self, x: torch.Tensor, heads: Optional[torch.Tensor] = None
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
if hasattr(self, "num_heads"):
if self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
x1 = self.linear_1(x)
return self.kan(x1) # + self.linear_2(x) # [n_nodes, irrep_out.dim]

def _make_tracing_inputs(self, n: int):
return [
{"forward": (torch.randn(6, self.irreps_in.dim), None)}
for _ in range(n)
]

def __repr__(self):
return f"{self.__class__.__name__}(dim=[{self.kan.width}])"


@simplify_if_compile
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
Expand Down
57 changes: 43 additions & 14 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
KANNonLinearReadoutBlock,
KANReadoutBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
radial_MLP: Optional[List[int]] = None,
radial_type: Optional[str] = "bessel",
heads: Optional[List[str]] = None,
KAN_readout: bool = False,
):
super().__init__()
self.register_buffer(
Expand Down Expand Up @@ -135,9 +138,18 @@ def __init__(
self.products = torch.nn.ModuleList([prod])

self.readouts = torch.nn.ModuleList()
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
self.KAN_readout = KAN_readout

if KAN_readout:
self.readouts.append(
KANReadoutBlock(
hidden_irreps, MLP_irreps, o3.Irreps(f"{len(heads)}x0e")
)
)
else:
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)

for i in range(num_interactions - 1):
if i == num_interactions - 2:
Expand Down Expand Up @@ -166,19 +178,36 @@ def __init__(
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
gate,
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
if KAN_readout:
self.readouts.append(
KANNonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
)
)
else:
self.readouts.append(
NonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
gate,
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
)
)
)
else:
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
if KAN_readout:
self.readouts.append(
KANReadoutBlock(
hidden_irreps, MLP_irreps, o3.Irreps(f"{len(heads)}x0e")
)
)
else:
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)

def forward(
self,
Expand Down
Loading
Loading