Skip to content

Commit aefba0c

Browse files
authored
Merge pull request #407 from allenai/favyen/20251023-model-loader
Add convenience function to load models from Hugging Face.
2 parents 29c305d + d53cbc3 commit aefba0c

1 file changed

Lines changed: 68 additions & 0 deletions

File tree

olmoearth_pretrain/model_loader.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Load the OlmoEarth models from Hugging Face.
2+
3+
The weights are converted to pth file from distributed checkpoint like this:
4+
5+
import json
6+
from pathlib import Path
7+
8+
import torch
9+
10+
from olmo_core.config import Config
11+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
12+
13+
checkpoint_path = Path("/weka/dfive-default/helios/checkpoints/joer/nano_lr0.001_wd0.002/step370000")
14+
with (checkpoint_path / "config.json").open() as f:
15+
config_dict = json.load(f)
16+
model_config = Config.from_dict(config_dict["model"])
17+
18+
model = model_config.build()
19+
20+
train_module_dir = checkpoint_path / "model_and_optim"
21+
load_model_and_optim_state(str(train_module_dir), model)
22+
torch.save(model.state_dict(), "OlmoEarth-v1-Nano.pth")
23+
"""
24+
25+
import json
26+
from enum import StrEnum
27+
28+
import torch
29+
from huggingface_hub import hf_hub_download
30+
from olmo_core.config import Config
31+
32+
33+
class ModelID(StrEnum):
34+
"""OlmoEarth pre-trained model ID."""
35+
36+
OLMOEARTH_V1_NANO = "OlmoEarth-v1-Nano"
37+
OLMOEARTH_V1_TINY = "OlmoEarth-v1-Tiny"
38+
OLMOEARTH_V1_BASE = "OlmoEarth-v1-Base"
39+
40+
41+
def load_model(model_id: ModelID, load_weights: bool = True) -> torch.nn.Module:
42+
"""Initialize and load the weights for the specified model ID.
43+
44+
The config and weights will be downloaded from Hugging Face.
45+
46+
Args:
47+
model_id: the model ID to load.
48+
load_weights: whether to load the weights. Set false to skip downloading the
49+
weights from Hugging Face and leave them randomly initialized. Note that
50+
the config.json will still be downloaded from Hugging Face.
51+
"""
52+
# We ignore bandit warnings here since we are just downloading config and weights,
53+
# not any code.
54+
repo_id = f"allenai/{model_id.value}"
55+
config_fname = hf_hub_download(repo_id=repo_id, filename="config.json") # nosec
56+
with open(config_fname) as f:
57+
config_dict = json.load(f)
58+
model_config = Config.from_dict(config_dict["model"])
59+
60+
model: torch.nn.Module = model_config.build()
61+
62+
if not load_weights:
63+
return model
64+
65+
pth_fname = hf_hub_download(repo_id=repo_id, filename="weights.pth") # nosec
66+
state_dict = torch.load(pth_fname, map_location="cpu")
67+
model.load_state_dict(state_dict)
68+
return model

0 commit comments

Comments
 (0)