Skip to content

Commit 831cafd

Browse files
committed
Add model loader convenience function.
1 parent f576734 commit 831cafd

1 file changed

Lines changed: 71 additions & 0 deletions

File tree

olmoearth_pretrain/model_loader.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Load the OlmoEarth models from Hugging Face.
2+
3+
The weights are converted to pth file from distributed checkpoint like this:
4+
5+
import json
6+
from pathlib import Path
7+
8+
import torch
9+
10+
from olmo_core.config import Config
11+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
12+
13+
checkpoint_path = Path("/weka/dfive-default/helios/checkpoints/joer/nano_lr0.001_wd0.002/step370000")
14+
with (checkpoint_path / "config.json").open() as f:
15+
config_dict = json.load(f)
16+
model_config = Config.from_dict(config_dict["model"])
17+
18+
model = model_config.build()
19+
20+
train_module_dir = checkpoint_path / "model_and_optim"
21+
load_model_and_optim_state(str(train_module_dir), model)
22+
torch.save(model.state_dict(), "OlmoEarth-v1-Nano.pth")
23+
"""
24+
25+
import json
26+
from enum import StrEnum
27+
28+
import torch
29+
from huggingface_hub import hf_hub_download
30+
from olmo_core.config import Config
31+
32+
33+
class ModelID(StrEnum):
34+
"""OlmoEarth pre-trained model ID."""
35+
36+
OLMOEARTH_V1_NANO = "OlmoEarth-v1-Nano"
37+
OLMOEARTH_V1_TINY = "OlmoEarth-v1-Tiny"
38+
OLMOEARTH_V1_BASE = "OlmoEarth-v1-Base"
39+
40+
41+
def load_model(model_id: ModelID, load_weights: bool = True) -> torch.nn.Module:
42+
"""Initialize and load the weights for the specified model ID.
43+
44+
The config and weights will be downloaded from Hugging Face.
45+
46+
Args:
47+
model_id: the model ID to load.
48+
load_weights: whether to load the weights. Set false to skip downloading the
49+
weights from Hugging Face and leave them randomly initialized. Note that
50+
the config.json will still be downloaded from Hugging Face.
51+
"""
52+
# We ignore bandit warnings here since we are just downloading config and weights,
53+
# not any code.
54+
config_fname = hf_hub_download(
55+
repo_id="allenai/olmoearth_pretrain", filename=f"{model_id.value}-config.json"
56+
) # nosec
57+
pth_fname = hf_hub_download(
58+
repo_id="allenai/olmoearth_pretrain", filename=f"{model_id.value}.pth"
59+
) # nosec
60+
with open(config_fname) as f:
61+
config_dict = json.load(f)
62+
model_config = Config.from_dict(config_dict["model"])
63+
64+
model: torch.nn.Module = model_config.build()
65+
66+
if not load_weights:
67+
return model
68+
69+
state_dict = torch.load(pth_fname)
70+
model.load_state_dict(state_dict)
71+
return model

0 commit comments

Comments
 (0)