Skip to content

Commit 9bb07cd

Browse files
authored
Merge pull request #164 from microsoft/add_mp_20_checkpoint
Add mp 20 checkpoint
2 parents ec029d1 + 9e1d34c commit 9bb07cd

File tree

5 files changed

+192
-1
lines changed

5 files changed

+192
-1
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ git lfs install
6161
6262
## Get started with a pre-trained model
6363
We provide checkpoints of an unconditional base version of MatterGen as well as fine-tuned models for these properties:
64-
* `mattergen_base`: unconditional base model
64+
* `mattergen_base`: unconditional base model trained on Alex-MP-20
65+
* `mp_20_base`: unconditional base model trained on MP-20
6566
* `chemical_system`: fine-tuned model conditioned on chemical system
6667
* `space_group`: fine-tuned model conditioned on space group
6768
* `dft_mag_density`: fine-tuned model conditioned on magnetic density from DFT

checkpoints/.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ ml_bulk_modulus/checkpoints/last.ckpt filter=lfs diff=lfs merge=lfs -text
55
space_group/checkpoints/last.ckpt filter=lfs diff=lfs merge=lfs -text
66
chemical_system_energy_above_hull/checkpoints/last.ckpt filter=lfs diff=lfs merge=lfs -text
77
dft_mag_density_hhi_score/checkpoints/last.ckpt filter=lfs diff=lfs merge=lfs -text
8+
mp_20_base/checkpoints/last.ckpt filter=lfs diff=lfs merge=lfs -text
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:ffb80e4425a6f99f479a67b8cd111885d45117234e8947ff77eb3a55df420b9a
3+
size 461369442

checkpoints/mp_20_base/config.yaml

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
auto_resume: true
2+
checkpoint_path: null
3+
data_module:
4+
_recursive_: true
5+
_target_: mattergen.common.data.datamodule.CrystDataModule
6+
average_density: 0.05771451654022283
7+
batch_size:
8+
test: 8
9+
train: 64
10+
val: 8
11+
dataset_transforms:
12+
- _partial_: true
13+
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
14+
max_epochs: 900
15+
num_workers:
16+
test: 0
17+
train: 0
18+
val: 0
19+
properties: []
20+
root_dir: /mnt/data_cache/mattergen-release-cache/mp_20
21+
test_dataset:
22+
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
23+
cache_path: /mnt/data_cache/mattergen-release-cache/mp_20/test
24+
dataset_transforms:
25+
- _partial_: true
26+
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
27+
properties: []
28+
transforms:
29+
- _partial_: true
30+
_target_: mattergen.common.data.transform.symmetrize_lattice
31+
- _partial_: true
32+
_target_: mattergen.common.data.transform.set_chemical_system_string
33+
train_dataset:
34+
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
35+
cache_path: /mnt/data_cache/mattergen-release-cache/mp_20/train
36+
dataset_transforms:
37+
- _partial_: true
38+
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
39+
properties: []
40+
transforms:
41+
- _partial_: true
42+
_target_: mattergen.common.data.transform.symmetrize_lattice
43+
- _partial_: true
44+
_target_: mattergen.common.data.transform.set_chemical_system_string
45+
transforms:
46+
- _partial_: true
47+
_target_: mattergen.common.data.transform.symmetrize_lattice
48+
- _partial_: true
49+
_target_: mattergen.common.data.transform.set_chemical_system_string
50+
val_dataset:
51+
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
52+
cache_path: /mnt/data_cache/mattergen-release-cache/mp_20/val
53+
dataset_transforms:
54+
- _partial_: true
55+
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
56+
properties: []
57+
transforms:
58+
- _partial_: true
59+
_target_: mattergen.common.data.transform.symmetrize_lattice
60+
- _partial_: true
61+
_target_: mattergen.common.data.transform.set_chemical_system_string
62+
lightning_module:
63+
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
64+
diffusion_module:
65+
_target_: mattergen.diffusion.diffusion_module.DiffusionModule
66+
corruption:
67+
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
68+
discrete_corruptions:
69+
atomic_numbers:
70+
_target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
71+
d3pm:
72+
_target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
73+
dim: 101
74+
schedule:
75+
_target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
76+
kind: standard
77+
num_steps: 1000
78+
offset: 1
79+
sdes:
80+
cell:
81+
_target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
82+
vpsde_config:
83+
beta_max: 20
84+
beta_min: 0.1
85+
limit_density: 0.05771451654022283
86+
limit_var_scaling_constant: 0.25
87+
pos:
88+
_target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
89+
limit_info_key: num_atoms
90+
sigma_max: 5.0
91+
wrapping_boundary: 1.0
92+
loss_fn:
93+
_target_: mattergen.common.loss.MaterialsLoss
94+
d3pm_hybrid_lambda: 0.01
95+
include_atomic_numbers: true
96+
include_cell: true
97+
include_pos: true
98+
reduce: sum
99+
weights:
100+
atomic_numbers: 1.0
101+
cell: 1.0
102+
pos: 0.1
103+
model:
104+
_target_: mattergen.denoiser.GemNetTDenoiser
105+
atom_type_diffusion: mask
106+
denoise_atom_types: true
107+
gemnet:
108+
_target_: mattergen.common.gemnet.gemnet.GemNetT
109+
atom_embedding:
110+
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
111+
emb_size: 512
112+
with_mask_type: true
113+
cutoff: 7.0
114+
emb_size_atom: 512
115+
emb_size_edge: 512
116+
latent_dim: 512
117+
max_cell_images_per_dim: 5
118+
max_neighbors: 50
119+
num_blocks: 4
120+
num_targets: 1
121+
otf_graph: true
122+
regress_stress: true
123+
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
124+
hidden_dim: 512
125+
property_embeddings: {}
126+
property_embeddings_adapt: {}
127+
pre_corruption_fn:
128+
_target_: mattergen.property_embeddings.SetEmbeddingType
129+
dropout_fields_iid: false
130+
p_unconditional: 0.2
131+
optimizer_partial:
132+
_partial_: true
133+
_target_: torch.optim.Adam
134+
lr: 0.0001
135+
scheduler_partials:
136+
- frequency: 1
137+
interval: epoch
138+
monitor: loss_train
139+
scheduler:
140+
_partial_: true
141+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
142+
factor: 0.6
143+
min_lr: 1.0e-06
144+
patience: 100
145+
verbose: true
146+
strict: true
147+
load_original: false
148+
params: {}
149+
trainer:
150+
_target_: pytorch_lightning.Trainer
151+
accelerator: gpu
152+
accumulate_grad_batches: 1
153+
callbacks:
154+
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
155+
log_momentum: false
156+
logging_interval: step
157+
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
158+
every_n_epochs: 1
159+
filename: '{epoch}-{loss_val:.2f}'
160+
mode: min
161+
monitor: loss_val
162+
save_last: true
163+
save_top_k: 1
164+
verbose: false
165+
- _target_: pytorch_lightning.callbacks.TQDMProgressBar
166+
refresh_rate: 50
167+
- _target_: mattergen.common.data.callback.SetPropertyScalers
168+
check_val_every_n_epoch: 5
169+
devices: 8
170+
gradient_clip_algorithm: value
171+
gradient_clip_val: 0.5
172+
logger:
173+
_target_: pytorch_lightning.loggers.WandbLogger
174+
job_type: train
175+
project: crystal-generation
176+
settings:
177+
_save_requirements: false
178+
_target_: wandb.Settings
179+
start_method: fork
180+
max_epochs: 900
181+
num_nodes: 1
182+
precision: 32
183+
strategy:
184+
_target_: pytorch_lightning.strategies.ddp.DDPStrategy
185+
find_unused_parameters: true

mattergen/common/utils/data_classes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"ml_bulk_modulus",
2323
"dft_mag_density_hhi_score",
2424
"chemical_system_energy_above_hull",
25+
"mp_20_base",
2526
]
2627

2728

0 commit comments

Comments
 (0)