Skip to content

Commit 1bc917a

Browse files
authored
Improve hf migration (#398)
* Update the migration script to be more robust against failure * Update HF script to instead use bulk fetches and updates * Add a timeout so pre-commit passes * Remove hard-coded token
1 parent adf4e0c commit 1bc917a

File tree

1 file changed

+71
-27
lines changed

1 file changed

+71
-27
lines changed

scripts/migrate_hf.py

+71-27
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
To be run once (around sometime Nov 2024), likely not needed after that. See also #385, 392.
99
"""
1010
import logging
11+
import os
1112
import string
1213
from http import HTTPStatus
14+
import time
15+
from pathlib import Path
1316

1417
from sqlalchemy import select
1518
from database.session import DbSession, EngineSingleton
@@ -22,44 +25,85 @@
2225
import database.setup
2326

2427
import requests
28+
import json
29+
30+
import re
31+
from http import HTTPStatus
32+
33+
34+
def fetch_huggingface_metadata() -> list[dict]:
35+
next_url = "https://huggingface.co/api/datasets"
36+
datasets = []
37+
while next_url:
38+
logging.info(f"Counted {len(datasets)} so far.")
39+
if token := os.environ.get("HUGGINGFACE_TOKEN"):
40+
headers = {"Authorization": f"Bearer {token}"}
41+
else:
42+
headers = {}
43+
response = requests.get(
44+
next_url,
45+
params={"limit": 1000, "full": "False"},
46+
headers=headers,
47+
timeout=20,
48+
)
49+
if response.status_code != HTTPStatus.OK:
50+
logging.info("Stopping iteration", response.status_code, response.json())
51+
break
52+
53+
datasets.extend(response.json())
54+
55+
next_info = response.headers.get("Link", "")
56+
if next_url_match := re.search(r"<([^>]+)>", next_info):
57+
next_url = next_url_match.group()[1:-1]
58+
else:
59+
next_url = None
60+
return datasets
61+
62+
63+
def load_id_map():
64+
HF_DATA_FILE = Path(__file__).parent / "hf_metadata.json"
65+
if HF_DATA_FILE.exists():
66+
logging.info(f"Loading HF data from {HF_DATA_FILE}.")
67+
with open(HF_DATA_FILE, "r") as fh:
68+
hf_data = json.load(fh)
69+
else:
70+
logging.info("Fetching HF data from Hugging Face.")
71+
hf_data = fetch_huggingface_metadata()
72+
with open(HF_DATA_FILE, "w") as fh:
73+
json.dump(hf_data, fh)
74+
id_map = {data["id"]: data["_id"] for data in hf_data}
75+
return id_map
2576

2677

2778
def main():
79+
logging.basicConfig(level=logging.INFO)
2880
AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True)
81+
id_map = load_id_map()
82+
2983
with DbSession() as session:
3084
datasets_query = select(Dataset).where(Dataset.platform == PlatformName.huggingface)
3185
datasets = session.scalars(datasets_query).all()
3286

87+
logging.info(f"Found {len(datasets)} huggingface datasets.")
88+
is_old_style_identifier = lambda identifier: any(
89+
char not in string.hexdigits for char in identifier
90+
)
91+
datasets = [
92+
dataset
93+
for dataset in datasets
94+
if is_old_style_identifier(dataset.platform_resource_identifier)
95+
]
96+
logging.info(f"Found {len(datasets)} huggingface datasets that need an update.")
97+
98+
with DbSession() as session:
3399
for dataset in datasets:
34-
if all(c in string.hexdigits for c in dataset.platform_resource_identifier):
35-
continue # entry already updated to use new-style id
36-
37-
response = requests.get(
38-
f"https://huggingface.co/api/datasets/{dataset.name}",
39-
params={"full": "False"},
40-
headers={},
41-
timeout=10,
42-
)
43-
if response.status_code != HTTPStatus.OK:
44-
logging.warning(f"Dataset {dataset.name} could not be retrieved.")
45-
continue
46-
47-
dataset_json = response.json()
48-
if dataset.platform_resource_identifier != dataset_json["id"]:
49-
logging.info(
50-
f"Dataset {dataset.platform_resource_identifier} moved to {dataset_json['id']}"
51-
"Deleting the old entry. The new entry either already exists or"
52-
"will be added on a later synchronization invocation."
53-
)
100+
if new_id := id_map.get(dataset.platform_resource_identifier):
101+
dataset.platform_resource_identifier = new_id
102+
session.add(dataset)
103+
else:
54104
session.delete(dataset)
55-
continue
56-
57-
persistent_id = dataset_json["_id"]
58-
logging.info(
59-
f"Setting platform id of {dataset.platform_resource_identifier} to {persistent_id}"
60-
)
61-
dataset.platform_resource_identifier = persistent_id
62105
session.commit()
106+
logging.info("Done updating entries.")
63107

64108

65109
if __name__ == "__main__":

0 commit comments

Comments
 (0)