|
8 | 8 | To be run once (around sometime Nov 2024), likely not needed after that. See also #385, 392.
|
9 | 9 | """
|
10 | 10 | import logging
|
| 11 | +import os |
11 | 12 | import string
|
12 | 13 | from http import HTTPStatus
|
| 14 | +import time |
| 15 | +from pathlib import Path |
13 | 16 |
|
14 | 17 | from sqlalchemy import select
|
15 | 18 | from database.session import DbSession, EngineSingleton
|
|
22 | 25 | import database.setup
|
23 | 26 |
|
24 | 27 | import requests
|
| 28 | +import json |
| 29 | + |
| 30 | +import re |
| 31 | +from http import HTTPStatus |
| 32 | + |
| 33 | + |
| 34 | +def fetch_huggingface_metadata() -> list[dict]: |
| 35 | + next_url = "https://huggingface.co/api/datasets" |
| 36 | + datasets = [] |
| 37 | + while next_url: |
| 38 | + logging.info(f"Counted {len(datasets)} so far.") |
| 39 | + if token := os.environ.get("HUGGINGFACE_TOKEN"): |
| 40 | + headers = {"Authorization": f"Bearer {token}"} |
| 41 | + else: |
| 42 | + headers = {} |
| 43 | + response = requests.get( |
| 44 | + next_url, |
| 45 | + params={"limit": 1000, "full": "False"}, |
| 46 | + headers=headers, |
| 47 | + timeout=20, |
| 48 | + ) |
| 49 | + if response.status_code != HTTPStatus.OK: |
| 50 | + logging.info("Stopping iteration", response.status_code, response.json()) |
| 51 | + break |
| 52 | + |
| 53 | + datasets.extend(response.json()) |
| 54 | + |
| 55 | + next_info = response.headers.get("Link", "") |
| 56 | + if next_url_match := re.search(r"<([^>]+)>", next_info): |
| 57 | + next_url = next_url_match.group()[1:-1] |
| 58 | + else: |
| 59 | + next_url = None |
| 60 | + return datasets |
| 61 | + |
| 62 | + |
| 63 | +def load_id_map(): |
| 64 | + HF_DATA_FILE = Path(__file__).parent / "hf_metadata.json" |
| 65 | + if HF_DATA_FILE.exists(): |
| 66 | + logging.info(f"Loading HF data from {HF_DATA_FILE}.") |
| 67 | + with open(HF_DATA_FILE, "r") as fh: |
| 68 | + hf_data = json.load(fh) |
| 69 | + else: |
| 70 | + logging.info("Fetching HF data from Hugging Face.") |
| 71 | + hf_data = fetch_huggingface_metadata() |
| 72 | + with open(HF_DATA_FILE, "w") as fh: |
| 73 | + json.dump(hf_data, fh) |
| 74 | + id_map = {data["id"]: data["_id"] for data in hf_data} |
| 75 | + return id_map |
25 | 76 |
|
26 | 77 |
|
27 | 78 | def main():
|
| 79 | + logging.basicConfig(level=logging.INFO) |
28 | 80 | AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True)
|
| 81 | + id_map = load_id_map() |
| 82 | + |
29 | 83 | with DbSession() as session:
|
30 | 84 | datasets_query = select(Dataset).where(Dataset.platform == PlatformName.huggingface)
|
31 | 85 | datasets = session.scalars(datasets_query).all()
|
32 | 86 |
|
| 87 | + logging.info(f"Found {len(datasets)} huggingface datasets.") |
| 88 | + is_old_style_identifier = lambda identifier: any( |
| 89 | + char not in string.hexdigits for char in identifier |
| 90 | + ) |
| 91 | + datasets = [ |
| 92 | + dataset |
| 93 | + for dataset in datasets |
| 94 | + if is_old_style_identifier(dataset.platform_resource_identifier) |
| 95 | + ] |
| 96 | + logging.info(f"Found {len(datasets)} huggingface datasets that need an update.") |
| 97 | + |
| 98 | + with DbSession() as session: |
33 | 99 | for dataset in datasets:
|
34 |
| - if all(c in string.hexdigits for c in dataset.platform_resource_identifier): |
35 |
| - continue # entry already updated to use new-style id |
36 |
| - |
37 |
| - response = requests.get( |
38 |
| - f"https://huggingface.co/api/datasets/{dataset.name}", |
39 |
| - params={"full": "False"}, |
40 |
| - headers={}, |
41 |
| - timeout=10, |
42 |
| - ) |
43 |
| - if response.status_code != HTTPStatus.OK: |
44 |
| - logging.warning(f"Dataset {dataset.name} could not be retrieved.") |
45 |
| - continue |
46 |
| - |
47 |
| - dataset_json = response.json() |
48 |
| - if dataset.platform_resource_identifier != dataset_json["id"]: |
49 |
| - logging.info( |
50 |
| - f"Dataset {dataset.platform_resource_identifier} moved to {dataset_json['id']}" |
51 |
| - "Deleting the old entry. The new entry either already exists or" |
52 |
| - "will be added on a later synchronization invocation." |
53 |
| - ) |
| 100 | + if new_id := id_map.get(dataset.platform_resource_identifier): |
| 101 | + dataset.platform_resource_identifier = new_id |
| 102 | + session.add(dataset) |
| 103 | + else: |
54 | 104 | session.delete(dataset)
|
55 |
| - continue |
56 |
| - |
57 |
| - persistent_id = dataset_json["_id"] |
58 |
| - logging.info( |
59 |
| - f"Setting platform id of {dataset.platform_resource_identifier} to {persistent_id}" |
60 |
| - ) |
61 |
| - dataset.platform_resource_identifier = persistent_id |
62 | 105 | session.commit()
|
| 106 | + logging.info("Done updating entries.") |
63 | 107 |
|
64 | 108 |
|
65 | 109 | if __name__ == "__main__":
|
|
0 commit comments