|
9 | 9 | # ] |
10 | 10 | # /// |
11 | 11 |
|
12 | | -import subprocess |
| 12 | +import asyncio |
| 13 | +from dataclasses import dataclass |
13 | 14 | from functools import cache |
14 | 15 | from operator import itemgetter |
15 | | -from typing import Iterator |
| 16 | +from typing import NamedTuple |
16 | 17 |
|
17 | 18 | import httpx |
18 | 19 | import stamina |
19 | 20 | import structlog |
| 21 | +import gql |
20 | 22 | from gql import Client, gql as gql_query |
21 | | -from gql.transport.httpx import HTTPXTransport |
| 23 | +from gql.transport.httpx import HTTPXAsyncTransport |
22 | 24 | from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn |
23 | 25 |
|
24 | 26 | logger = structlog.get_logger() |
@@ -52,76 +54,185 @@ query DatasetsWithLatestSnapshots($count: Int, $after: String) { |
52 | 54 | """) |
53 | 55 |
|
54 | 56 |
|
55 | | -@cache |
56 | | -def get_client(url: str) -> Client: |
57 | | - return Client(transport=HTTPXTransport(url=url)) |
| 57 | +@dataclass |
| 58 | +class Dataset: |
| 59 | + id: str |
| 60 | + tag: str |
| 61 | + created: str |
| 62 | + hexsha: str |
58 | 63 |
|
59 | 64 |
|
60 | 65 | @stamina.retry(on=httpx.HTTPError) |
61 | | -def get_page(url: str, count: int, after: str | None) -> dict: |
62 | | - return get_client(url).execute( |
| 66 | +async def get_page(client: Client, count: int, after: str | None) -> dict: |
| 67 | + return await client.execute_async( |
63 | 68 | QUERY, variable_values={"count": count, "after": after} |
64 | 69 | ) |
65 | 70 |
|
66 | 71 |
|
67 | | -def get_dataset_count(url: str) -> int: |
68 | | - response = get_page(url, 0, None) |
| 72 | +async def get_dataset_count(client: Client) -> int: |
| 73 | + response = await get_page(client, 0, None) |
69 | 74 | return response["datasets"]["pageInfo"]["count"] |
70 | 75 |
|
71 | 76 |
|
72 | | -def dataset_iterator(url: str) -> Iterator[tuple[str, str, str, str]]: |
| 77 | +async def dataset_producer( |
| 78 | + client: Client, queue: asyncio.Queue[Dataset | None], progress: Progress, task_id |
| 79 | +) -> None: |
| 80 | + """Producer that fetches datasets from GraphQL API and puts them in the queue.""" |
73 | 81 | page_info = {"hasNextPage": True, "endCursor": None} |
74 | 82 |
|
75 | | - while page_info["hasNextPage"]: |
76 | | - result = get_page(url, 100, page_info["endCursor"]) |
77 | | - |
78 | | - edges, page_info = itemgetter("edges", "pageInfo")(result["datasets"]) |
79 | | - |
80 | | - for edge in edges: |
81 | | - dataset_id, latest_snapshot = itemgetter("id", "latestSnapshot")( |
82 | | - edge["node"] |
83 | | - ) |
84 | | - yield (dataset_id, *itemgetter("tag", "created", "hexsha")(latest_snapshot)) |
| 83 | + try: |
| 84 | + while page_info["hasNextPage"]: |
| 85 | + try: |
| 86 | + result = await get_page(client, 100, page_info["endCursor"]) |
| 87 | + except gql.transport.exceptions.TransportQueryError as e: |
| 88 | + logger.error("GraphQL query error") |
| 89 | + if e.data is not None: |
| 90 | + result = e.data |
| 91 | + |
| 92 | + edges, page_info = itemgetter("edges", "pageInfo")(result["datasets"]) |
| 93 | + |
| 94 | + for edge in edges: |
| 95 | + if edge is None: |
| 96 | + continue |
| 97 | + dataset_id, latest_snapshot = itemgetter("id", "latestSnapshot")( |
| 98 | + edge["node"] |
| 99 | + ) |
| 100 | + dataset = Dataset( |
| 101 | + id=dataset_id, |
| 102 | + tag=latest_snapshot["tag"], |
| 103 | + created=latest_snapshot["created"], |
| 104 | + hexsha=latest_snapshot["hexsha"], |
| 105 | + ) |
| 106 | + await queue.put(dataset) |
| 107 | + progress.update(task_id, advance=1, dataset=dataset_id) |
| 108 | + |
| 109 | + finally: |
| 110 | + # Signal that we're done producing |
| 111 | + await queue.put(None) |
| 112 | + |
| 113 | + |
| 114 | +async def check_remote(dataset: Dataset) -> bool | None: |
| 115 | + """Check if the git remote has the expected tag and commit hash.""" |
| 116 | + log = logger.bind(dataset=dataset.id, tag=dataset.tag) |
| 117 | + repo = f"https://github.com/OpenNeuroDatasets/{dataset.id}.git" |
| 118 | + |
| 119 | + proc = await asyncio.create_subprocess_exec( |
| 120 | + "git", |
| 121 | + "ls-remote", |
| 122 | + "--exit-code", |
| 123 | + repo, |
| 124 | + dataset.tag, |
| 125 | + stdout=asyncio.subprocess.PIPE, |
| 126 | + stderr=asyncio.subprocess.PIPE, |
| 127 | + ) |
85 | 128 |
|
| 129 | + stdout, stderr = await proc.communicate() |
86 | 130 |
|
87 | | -def check_remote(dataset_id: str, tag: str, hexsha: str) -> bool | None: |
88 | | - log = logger.bind(dataset=dataset_id, tag=tag) |
89 | | - repo = f"https://github.com/OpenNeuroDatasets/{dataset_id}.git" |
90 | | - result = subprocess.run( |
91 | | - ["git", "ls-remote", "--exit-code", repo, tag], |
92 | | - capture_output=True, |
93 | | - ) |
94 | | - if result.returncode: |
95 | | - if "Repository not found" in result.stderr.decode(): |
| 131 | + if proc.returncode: |
| 132 | + stderr_text = stderr.decode() |
| 133 | + if "Repository not found" in stderr_text: |
96 | 134 | log.error("Missing repository") |
97 | 135 | return None |
98 | 136 | log.error("Missing latest tag") |
99 | 137 | return False |
100 | 138 |
|
101 | | - shasum, ref = result.stdout.decode("utf-8").strip().split() |
| 139 | + stdout_text = stdout.decode("utf-8").strip() |
| 140 | + if not stdout_text: |
| 141 | + log.error("Empty response from git ls-remote") |
| 142 | + return False |
| 143 | + |
| 144 | + shasum, ref = stdout_text.split() |
102 | 145 |
|
103 | | - if shasum != hexsha: |
104 | | - log.warning(f"mismatch: {shasum[:7]}({ref[10:]}) != {hexsha[:7]}") |
| 146 | + if shasum != dataset.hexsha: |
| 147 | + log.warning(f"mismatch: {shasum[:7]}({ref[10:]}) != {dataset.hexsha[:7]}") |
105 | 148 | return False |
106 | 149 |
|
107 | | - return ref == f"refs/tags/{tag}" |
| 150 | + return ref == f"refs/tags/{dataset.tag}" |
108 | 151 |
|
109 | 152 |
|
110 | | -if __name__ == "__main__": |
111 | | - count = get_dataset_count(ENDPOINT) |
| 153 | +async def dataset_consumer( |
| 154 | + queue: asyncio.Queue[Dataset | None], |
| 155 | + progress: Progress, |
| 156 | + task_id, |
| 157 | + semaphore: asyncio.Semaphore, |
| 158 | + results: list[bool | None], |
| 159 | +) -> None: |
| 160 | + """Consumer that checks git remotes for datasets from the queue.""" |
112 | 161 |
|
113 | | - retcode = 0 |
| 162 | + async def check_single_dataset(dataset: Dataset) -> bool | None: |
| 163 | + async with semaphore: |
| 164 | + result = await check_remote(dataset) |
| 165 | + progress.update(task_id, advance=1, dataset=dataset.id) |
| 166 | + return result |
| 167 | + |
| 168 | + tasks = [] |
| 169 | + |
| 170 | + while True: |
| 171 | + dataset = await queue.get() |
| 172 | + if dataset is None: |
| 173 | + # Producer is done, wait for remaining tasks to complete |
| 174 | + break |
| 175 | + |
| 176 | + # Create task for checking this dataset |
| 177 | + task = asyncio.create_task(check_single_dataset(dataset)) |
| 178 | + tasks.append(task) |
| 179 | + |
| 180 | + # Wait for all remaining tasks to complete |
| 181 | + if tasks: |
| 182 | + completed_results = await asyncio.gather(*tasks, return_exceptions=True) |
| 183 | + for result in completed_results: |
| 184 | + if isinstance(result, Exception): |
| 185 | + logger.error("Error checking dataset", exc_info=result) |
| 186 | + results.append(False) |
| 187 | + else: |
| 188 | + results.append(result) |
| 189 | + |
| 190 | + |
| 191 | +async def main() -> int: |
| 192 | + client = Client(transport=HTTPXAsyncTransport(url=ENDPOINT)) |
| 193 | + count = await get_dataset_count(client) |
| 194 | + |
| 195 | + # Queue to pass datasets from producer to consumer |
| 196 | + queue: asyncio.Queue[Dataset | None] = asyncio.Queue(maxsize=200) |
| 197 | + |
| 198 | + # Semaphore to limit concurrent git operations |
| 199 | + git_semaphore = asyncio.Semaphore(20) # Adjust based on your needs |
| 200 | + |
| 201 | + # Results collection |
| 202 | + results: list[bool | None] = [] |
114 | 203 |
|
115 | 204 | with Progress( |
116 | | - TextColumn("[progress.description]{task.description} {task.fields[dataset]:8s}"), |
| 205 | + TextColumn( |
| 206 | + "[progress.description]{task.description} {task.fields[dataset]:8s}" |
| 207 | + ), |
117 | 208 | BarColumn(), |
118 | 209 | MofNCompleteColumn(), |
119 | 210 | ) as progress: |
120 | | - task = progress.add_task("Checking", total=count, dataset="...") |
| 211 | + # Create progress tasks |
| 212 | + fetch_task = progress.add_task("Fetching", total=count, dataset="...") |
| 213 | + check_task = progress.add_task("Checking", total=count, dataset="...") |
| 214 | + |
| 215 | + # Start producer and consumer |
| 216 | + producer_task = asyncio.create_task( |
| 217 | + dataset_producer(client, queue, progress, fetch_task) |
| 218 | + ) |
| 219 | + consumer_task = asyncio.create_task( |
| 220 | + dataset_consumer(queue, progress, check_task, git_semaphore, results) |
| 221 | + ) |
| 222 | + |
| 223 | + # Wait for both to complete |
| 224 | + await asyncio.gather(producer_task, consumer_task) |
| 225 | + |
| 226 | + # Calculate return code |
| 227 | + retcode = 0 |
| 228 | + for result in results: |
| 229 | + if result is False: # Only False indicates failure, None is ignored |
| 230 | + retcode = 1 |
| 231 | + break |
121 | 232 |
|
122 | | - for dataset_id, tag, created, hexsha in dataset_iterator(ENDPOINT): |
123 | | - progress.update(task, advance=1, dataset=dataset_id) |
| 233 | + return retcode |
124 | 234 |
|
125 | | - retcode |= not check_remote(dataset_id, tag, hexsha) |
126 | 235 |
|
| 236 | +if __name__ == "__main__": |
| 237 | + retcode = asyncio.run(main()) |
127 | 238 | raise SystemExit(retcode) |
0 commit comments