Skip to content

Commit 82d91fa

Browse files
committed
uma 1p2p1 chkpt surgery
1 parent e054599 commit 82d91fa

1 file changed

Lines changed: 218 additions & 0 deletions

File tree

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import argparse
12+
import os
13+
14+
import numpy as np
15+
import torch
16+
17+
from fairchem.core.calculate.ase_calculator import FAIRChemCalculator
18+
19+
20+
def add_omat_rattle_support(checkpoint):
21+
"""Stage 1: Add omat_rattle support (matches notebook)."""
22+
# dataset_mapping = {
23+
# "oc20": "oc20",
24+
# "oc22": "oc22",
25+
# "oc25": "oc25",
26+
# "omol": "omol",
27+
# "omat": "omat",
28+
# "omat_rattle": "omat",
29+
# "odac": "odac",
30+
# "omc": "omc",
31+
# }
32+
33+
# del checkpoint.model_config["backbone"]["dataset_list"]
34+
# checkpoint.model_config["backbone"]["dataset_mapping"] = dataset_mapping
35+
36+
# del checkpoint.model_config["heads"]["energyandforcehead"]["dataset_names"]
37+
# checkpoint.model_config["heads"]["energyandforcehead"][
38+
# "dataset_mapping"
39+
# ] = dataset_mapping
40+
41+
checkpoint.model_state_dict[
42+
"backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight"
43+
] = checkpoint.model_state_dict[
44+
"backbone.dataset_embedding.dataset_emb_dict.omat.weight"
45+
].clone()
46+
47+
checkpoint.ema_state_dict[
48+
"module.backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight"
49+
] = checkpoint.ema_state_dict[
50+
"module.backbone.dataset_embedding.dataset_emb_dict.omat.weight"
51+
].clone()
52+
53+
checkpoint.model_config["model_id"] = "UMA-S-1.2"
54+
return checkpoint
55+
56+
57+
def remove_omat_rattle(checkpoint):
58+
"""Stage 2: Remove all omat_rattle mentions from config and weights.
59+
60+
Key insight: In stage1, omat_rattle maps to 'omat' (shares omat's expert).
61+
The unique targets in stage1 are: ['oc20', 'oc22', 'oc25', 'odac', 'omat', 'omc', 'omol']
62+
These map to expert indices 0-6 via _build_expert_mapping (sorted unique values).
63+
64+
However, the original checkpoint weights have 8 expert slots (indices 0-7),
65+
where index 7 is UNUSED/untrained. When we remove omat_rattle and have only
66+
7 datasets, the model expects 7 expert slots, so we must resize weights by
67+
removing the unused index 7.
68+
"""
69+
dataset_mapping = {
70+
"oc20": "oc20",
71+
"oc22": "oc22",
72+
"oc25": "oc25",
73+
"omol": "omol",
74+
"omat": "omat",
75+
"odac": "odac",
76+
"omc": "omc",
77+
}
78+
79+
checkpoint.model_config["backbone"]["dataset_mapping"] = dataset_mapping
80+
checkpoint.model_config["heads"]["energyandforcehead"]["dataset_mapping"] = (
81+
dataset_mapping
82+
)
83+
84+
# Remove backbone embeddings for omat_rattle
85+
del checkpoint.model_state_dict[
86+
"backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight"
87+
]
88+
del checkpoint.ema_state_dict[
89+
"module.backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight"
90+
]
91+
92+
# Resize head expert weights from 8 to 7 by removing UNUSED index 7
93+
# The original checkpoint has 8 slots, with index 7 never used during training.
94+
# Stage1 mapping: 0=oc20, 1=oc22, 2=oc25, 3=odac, 4=omat(+omat_rattle), 5=omc, 6=omol, 7=UNUSED
95+
# Stage2 mapping: same indices 0-6, just without omat_rattle entry
96+
head_keys = [
97+
"output_heads.energyandforcehead.head.energy_block.0.weights",
98+
"output_heads.energyandforcehead.head.energy_block.2.weights",
99+
"output_heads.energyandforcehead.head.energy_block.4.weights",
100+
]
101+
for key in head_keys:
102+
w = checkpoint.model_state_dict[key]
103+
checkpoint.model_state_dict[key] = w[:7] # keep indices 0-6, remove index 7
104+
105+
ema_key = "module." + key
106+
w = checkpoint.ema_state_dict[ema_key]
107+
checkpoint.ema_state_dict[ema_key] = w[:7] # keep indices 0-6, remove index 7
108+
109+
# Remove omat_rattle tasks from tasks_config
110+
checkpoint.tasks_config = [
111+
t for t in checkpoint.tasks_config if "omat_rattle" not in t.get("datasets", [])
112+
]
113+
114+
# Add single atom support
115+
checkpoint.model_config["supports_single_atoms"] = True
116+
checkpoint.model_config["model_id"] = "UMA-S-1.2"
117+
checkpoint.model_config["backbone"]["model_version"] = 1.21
118+
checkpoint.model_config["backbone"]["moe_layer_type"] = "pytorch"
119+
return checkpoint
120+
121+
122+
def create_test_systems():
123+
"""Create PBC and non-PBC H2O test systems."""
124+
from ase.build import molecule
125+
126+
# Non-PBC (aperiodic) H2O
127+
h2o_nopbc = molecule("H2O")
128+
h2o_nopbc.info["charge"] = 0
129+
h2o_nopbc.info["spin"] = 1
130+
131+
# PBC H2O
132+
h2o_pbc = molecule("H2O")
133+
h2o_pbc.set_cell([10.0, 10.0, 10.0])
134+
h2o_pbc.set_pbc(True)
135+
136+
return {"nopbc": h2o_nopbc, "pbc": h2o_pbc}
137+
138+
139+
def compare_checkpoints(stage1_path: str, stage2_path: str):
140+
"""Compare stage1 and stage2 checkpoints across all tasks."""
141+
systems = create_test_systems()
142+
143+
print(
144+
f"{'Task':<8} {'System':<8} {'E1':<14} {'E2':<14} {'E_abs':<12} {'E_rel':<12} {'F_max_abs':<12} {'F_max_rel':<12}"
145+
)
146+
print("-" * 100)
147+
148+
all_match = True
149+
for task in ["oc20", "oc22", "oc25", "omat", "odac", "omc", "omol"]:
150+
calc1 = FAIRChemCalculator.from_model_checkpoint(stage1_path, task_name=task)
151+
calc2 = FAIRChemCalculator.from_model_checkpoint(stage2_path, task_name=task)
152+
153+
for sys_name, atoms in systems.items():
154+
atoms1 = atoms.copy()
155+
atoms2 = atoms.copy()
156+
atoms1.calc = calc1
157+
atoms2.calc = calc2
158+
159+
e1 = atoms1.get_potential_energy()
160+
e2 = atoms2.get_potential_energy()
161+
f1 = atoms1.get_forces()
162+
f2 = atoms2.get_forces()
163+
164+
e_abs = abs(e1 - e2)
165+
e_rel = e_abs / abs(e1) if e1 != 0 else 0
166+
f_abs = np.abs(f1 - f2).max()
167+
f_rel = f_abs / np.abs(f1).max() if np.abs(f1).max() != 0 else 0
168+
169+
e_match = e_abs < 1e-5
170+
f_match = f_abs < 1e-5
171+
172+
print(
173+
f"{task:<8} {sys_name:<8} {e1:<14.6f} {e2:<14.6f} {e_abs:<12.2e} {e_rel:<12.2e} {f_abs:<12.2e} {f_rel:<12.2e}"
174+
)
175+
176+
if not (e_match and f_match):
177+
all_match = False
178+
179+
return all_match
180+
181+
182+
def uma_1p2_surgery(checkpoint_path: str, output_dir: str) -> tuple[str, str]:
183+
"""Perform checkpoint surgery and output stage1 and stage2 checkpoints."""
184+
os.makedirs(output_dir, exist_ok=True)
185+
186+
# Stage 1: Add omat_rattle support
187+
checkpoint = torch.load(checkpoint_path, weights_only=False)
188+
checkpoint = add_omat_rattle_support(checkpoint)
189+
stage1_path = os.path.join(output_dir, "inference_ckpt_stage1.pt")
190+
torch.save(checkpoint, stage1_path)
191+
print(f"Stage 1 saved to {stage1_path}")
192+
193+
# Stage 2: Remove omat_rattle entirely
194+
checkpoint = remove_omat_rattle(checkpoint)
195+
stage2_path = os.path.join(output_dir, "inference_ckpt_stage2.pt")
196+
torch.save(checkpoint, stage2_path)
197+
print(f"Stage 2 saved to {stage2_path}")
198+
199+
return stage1_path, stage2_path
200+
201+
202+
if __name__ == "__main__":
203+
parser = argparse.ArgumentParser(
204+
description="Perform checkpoint surgery on UMA 1.2"
205+
)
206+
parser.add_argument("--checkpoint-in", type=str, required=True)
207+
parser.add_argument("--output-dir", type=str, required=True)
208+
209+
args = parser.parse_args()
210+
211+
stage1_path, stage2_path = uma_1p2_surgery(args.checkpoint_in, args.output_dir)
212+
print("\n=== Self-test: stage1 vs stage1 ===")
213+
self_match = compare_checkpoints(stage1_path, stage1_path)
214+
print(f"Self-test match: {self_match}")
215+
216+
print("\n=== Comparing stage1 vs stage2 ===")
217+
all_match = compare_checkpoints(stage1_path, stage2_path)
218+
print(f"\nAll outputs match: {all_match}")

0 commit comments

Comments
 (0)