|
| 1 | +""" |
| 2 | +Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +
|
| 4 | +This source code is licensed under the MIT license found in the |
| 5 | +LICENSE file in the root directory of this source tree. |
| 6 | +
|
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import argparse |
| 12 | +import os |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +import torch |
| 16 | + |
| 17 | +from fairchem.core.calculate.ase_calculator import FAIRChemCalculator |
| 18 | + |
| 19 | + |
| 20 | +def add_omat_rattle_support(checkpoint): |
| 21 | + """Stage 1: Add omat_rattle support (matches notebook).""" |
| 22 | + # dataset_mapping = { |
| 23 | + # "oc20": "oc20", |
| 24 | + # "oc22": "oc22", |
| 25 | + # "oc25": "oc25", |
| 26 | + # "omol": "omol", |
| 27 | + # "omat": "omat", |
| 28 | + # "omat_rattle": "omat", |
| 29 | + # "odac": "odac", |
| 30 | + # "omc": "omc", |
| 31 | + # } |
| 32 | + |
| 33 | + # del checkpoint.model_config["backbone"]["dataset_list"] |
| 34 | + # checkpoint.model_config["backbone"]["dataset_mapping"] = dataset_mapping |
| 35 | + |
| 36 | + # del checkpoint.model_config["heads"]["energyandforcehead"]["dataset_names"] |
| 37 | + # checkpoint.model_config["heads"]["energyandforcehead"][ |
| 38 | + # "dataset_mapping" |
| 39 | + # ] = dataset_mapping |
| 40 | + |
| 41 | + checkpoint.model_state_dict[ |
| 42 | + "backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight" |
| 43 | + ] = checkpoint.model_state_dict[ |
| 44 | + "backbone.dataset_embedding.dataset_emb_dict.omat.weight" |
| 45 | + ].clone() |
| 46 | + |
| 47 | + checkpoint.ema_state_dict[ |
| 48 | + "module.backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight" |
| 49 | + ] = checkpoint.ema_state_dict[ |
| 50 | + "module.backbone.dataset_embedding.dataset_emb_dict.omat.weight" |
| 51 | + ].clone() |
| 52 | + |
| 53 | + checkpoint.model_config["model_id"] = "UMA-S-1.2" |
| 54 | + return checkpoint |
| 55 | + |
| 56 | + |
| 57 | +def remove_omat_rattle(checkpoint): |
| 58 | + """Stage 2: Remove all omat_rattle mentions from config and weights. |
| 59 | +
|
| 60 | + Key insight: In stage1, omat_rattle maps to 'omat' (shares omat's expert). |
| 61 | + The unique targets in stage1 are: ['oc20', 'oc22', 'oc25', 'odac', 'omat', 'omc', 'omol'] |
| 62 | + These map to expert indices 0-6 via _build_expert_mapping (sorted unique values). |
| 63 | +
|
| 64 | + However, the original checkpoint weights have 8 expert slots (indices 0-7), |
| 65 | + where index 7 is UNUSED/untrained. When we remove omat_rattle and have only |
| 66 | + 7 datasets, the model expects 7 expert slots, so we must resize weights by |
| 67 | + removing the unused index 7. |
| 68 | + """ |
| 69 | + dataset_mapping = { |
| 70 | + "oc20": "oc20", |
| 71 | + "oc22": "oc22", |
| 72 | + "oc25": "oc25", |
| 73 | + "omol": "omol", |
| 74 | + "omat": "omat", |
| 75 | + "odac": "odac", |
| 76 | + "omc": "omc", |
| 77 | + } |
| 78 | + |
| 79 | + checkpoint.model_config["backbone"]["dataset_mapping"] = dataset_mapping |
| 80 | + checkpoint.model_config["heads"]["energyandforcehead"]["dataset_mapping"] = ( |
| 81 | + dataset_mapping |
| 82 | + ) |
| 83 | + |
| 84 | + # Remove backbone embeddings for omat_rattle |
| 85 | + del checkpoint.model_state_dict[ |
| 86 | + "backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight" |
| 87 | + ] |
| 88 | + del checkpoint.ema_state_dict[ |
| 89 | + "module.backbone.dataset_embedding.dataset_emb_dict.omat_rattle.weight" |
| 90 | + ] |
| 91 | + |
| 92 | + # Resize head expert weights from 8 to 7 by removing UNUSED index 7 |
| 93 | + # The original checkpoint has 8 slots, with index 7 never used during training. |
| 94 | + # Stage1 mapping: 0=oc20, 1=oc22, 2=oc25, 3=odac, 4=omat(+omat_rattle), 5=omc, 6=omol, 7=UNUSED |
| 95 | + # Stage2 mapping: same indices 0-6, just without omat_rattle entry |
| 96 | + head_keys = [ |
| 97 | + "output_heads.energyandforcehead.head.energy_block.0.weights", |
| 98 | + "output_heads.energyandforcehead.head.energy_block.2.weights", |
| 99 | + "output_heads.energyandforcehead.head.energy_block.4.weights", |
| 100 | + ] |
| 101 | + for key in head_keys: |
| 102 | + w = checkpoint.model_state_dict[key] |
| 103 | + checkpoint.model_state_dict[key] = w[:7] # keep indices 0-6, remove index 7 |
| 104 | + |
| 105 | + ema_key = "module." + key |
| 106 | + w = checkpoint.ema_state_dict[ema_key] |
| 107 | + checkpoint.ema_state_dict[ema_key] = w[:7] # keep indices 0-6, remove index 7 |
| 108 | + |
| 109 | + # Remove omat_rattle tasks from tasks_config |
| 110 | + checkpoint.tasks_config = [ |
| 111 | + t for t in checkpoint.tasks_config if "omat_rattle" not in t.get("datasets", []) |
| 112 | + ] |
| 113 | + |
| 114 | + # Add single atom support |
| 115 | + checkpoint.model_config["supports_single_atoms"] = True |
| 116 | + checkpoint.model_config["model_id"] = "UMA-S-1.2" |
| 117 | + checkpoint.model_config["backbone"]["model_version"] = 1.21 |
| 118 | + checkpoint.model_config["backbone"]["moe_layer_type"] = "pytorch" |
| 119 | + return checkpoint |
| 120 | + |
| 121 | + |
| 122 | +def create_test_systems(): |
| 123 | + """Create PBC and non-PBC H2O test systems.""" |
| 124 | + from ase.build import molecule |
| 125 | + |
| 126 | + # Non-PBC (aperiodic) H2O |
| 127 | + h2o_nopbc = molecule("H2O") |
| 128 | + h2o_nopbc.info["charge"] = 0 |
| 129 | + h2o_nopbc.info["spin"] = 1 |
| 130 | + |
| 131 | + # PBC H2O |
| 132 | + h2o_pbc = molecule("H2O") |
| 133 | + h2o_pbc.set_cell([10.0, 10.0, 10.0]) |
| 134 | + h2o_pbc.set_pbc(True) |
| 135 | + |
| 136 | + return {"nopbc": h2o_nopbc, "pbc": h2o_pbc} |
| 137 | + |
| 138 | + |
| 139 | +def compare_checkpoints(stage1_path: str, stage2_path: str): |
| 140 | + """Compare stage1 and stage2 checkpoints across all tasks.""" |
| 141 | + systems = create_test_systems() |
| 142 | + |
| 143 | + print( |
| 144 | + f"{'Task':<8} {'System':<8} {'E1':<14} {'E2':<14} {'E_abs':<12} {'E_rel':<12} {'F_max_abs':<12} {'F_max_rel':<12}" |
| 145 | + ) |
| 146 | + print("-" * 100) |
| 147 | + |
| 148 | + all_match = True |
| 149 | + for task in ["oc20", "oc22", "oc25", "omat", "odac", "omc", "omol"]: |
| 150 | + calc1 = FAIRChemCalculator.from_model_checkpoint(stage1_path, task_name=task) |
| 151 | + calc2 = FAIRChemCalculator.from_model_checkpoint(stage2_path, task_name=task) |
| 152 | + |
| 153 | + for sys_name, atoms in systems.items(): |
| 154 | + atoms1 = atoms.copy() |
| 155 | + atoms2 = atoms.copy() |
| 156 | + atoms1.calc = calc1 |
| 157 | + atoms2.calc = calc2 |
| 158 | + |
| 159 | + e1 = atoms1.get_potential_energy() |
| 160 | + e2 = atoms2.get_potential_energy() |
| 161 | + f1 = atoms1.get_forces() |
| 162 | + f2 = atoms2.get_forces() |
| 163 | + |
| 164 | + e_abs = abs(e1 - e2) |
| 165 | + e_rel = e_abs / abs(e1) if e1 != 0 else 0 |
| 166 | + f_abs = np.abs(f1 - f2).max() |
| 167 | + f_rel = f_abs / np.abs(f1).max() if np.abs(f1).max() != 0 else 0 |
| 168 | + |
| 169 | + e_match = e_abs < 1e-5 |
| 170 | + f_match = f_abs < 1e-5 |
| 171 | + |
| 172 | + print( |
| 173 | + f"{task:<8} {sys_name:<8} {e1:<14.6f} {e2:<14.6f} {e_abs:<12.2e} {e_rel:<12.2e} {f_abs:<12.2e} {f_rel:<12.2e}" |
| 174 | + ) |
| 175 | + |
| 176 | + if not (e_match and f_match): |
| 177 | + all_match = False |
| 178 | + |
| 179 | + return all_match |
| 180 | + |
| 181 | + |
| 182 | +def uma_1p2_surgery(checkpoint_path: str, output_dir: str) -> tuple[str, str]: |
| 183 | + """Perform checkpoint surgery and output stage1 and stage2 checkpoints.""" |
| 184 | + os.makedirs(output_dir, exist_ok=True) |
| 185 | + |
| 186 | + # Stage 1: Add omat_rattle support |
| 187 | + checkpoint = torch.load(checkpoint_path, weights_only=False) |
| 188 | + checkpoint = add_omat_rattle_support(checkpoint) |
| 189 | + stage1_path = os.path.join(output_dir, "inference_ckpt_stage1.pt") |
| 190 | + torch.save(checkpoint, stage1_path) |
| 191 | + print(f"Stage 1 saved to {stage1_path}") |
| 192 | + |
| 193 | + # Stage 2: Remove omat_rattle entirely |
| 194 | + checkpoint = remove_omat_rattle(checkpoint) |
| 195 | + stage2_path = os.path.join(output_dir, "inference_ckpt_stage2.pt") |
| 196 | + torch.save(checkpoint, stage2_path) |
| 197 | + print(f"Stage 2 saved to {stage2_path}") |
| 198 | + |
| 199 | + return stage1_path, stage2_path |
| 200 | + |
| 201 | + |
| 202 | +if __name__ == "__main__": |
| 203 | + parser = argparse.ArgumentParser( |
| 204 | + description="Perform checkpoint surgery on UMA 1.2" |
| 205 | + ) |
| 206 | + parser.add_argument("--checkpoint-in", type=str, required=True) |
| 207 | + parser.add_argument("--output-dir", type=str, required=True) |
| 208 | + |
| 209 | + args = parser.parse_args() |
| 210 | + |
| 211 | + stage1_path, stage2_path = uma_1p2_surgery(args.checkpoint_in, args.output_dir) |
| 212 | + print("\n=== Self-test: stage1 vs stage1 ===") |
| 213 | + self_match = compare_checkpoints(stage1_path, stage1_path) |
| 214 | + print(f"Self-test match: {self_match}") |
| 215 | + |
| 216 | + print("\n=== Comparing stage1 vs stage2 ===") |
| 217 | + all_match = compare_checkpoints(stage1_path, stage2_path) |
| 218 | + print(f"\nAll outputs match: {all_match}") |
0 commit comments