Skip to content

Commit d53cbc3

Browse files
committed
update repo organization
1 parent a403be8 commit d53cbc3

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

olmoearth_pretrain/model_loader.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ def load_model(model_id: ModelID, load_weights: bool = True) -> torch.nn.Module:
5151
"""
5252
# We ignore bandit warnings here since we are just downloading config and weights,
5353
# not any code.
54-
config_fname = hf_hub_download(
55-
repo_id="allenai/olmoearth_pretrain", filename=f"{model_id.value}-config.json"
56-
) # nosec
54+
repo_id = f"allenai/{model_id.value}"
55+
config_fname = hf_hub_download(repo_id=repo_id, filename="config.json") # nosec
5756
with open(config_fname) as f:
5857
config_dict = json.load(f)
5958
model_config = Config.from_dict(config_dict["model"])
@@ -63,9 +62,7 @@ def load_model(model_id: ModelID, load_weights: bool = True) -> torch.nn.Module:
6362
if not load_weights:
6463
return model
6564

66-
pth_fname = hf_hub_download(
67-
repo_id="allenai/olmoearth_pretrain", filename=f"{model_id.value}.pth"
68-
) # nosec
69-
state_dict = torch.load(pth_fname)
65+
pth_fname = hf_hub_download(repo_id=repo_id, filename="weights.pth") # nosec
66+
state_dict = torch.load(pth_fname, map_location="cpu")
7067
model.load_state_dict(state_dict)
7168
return model

0 commit comments

Comments
 (0)