Skip to content

Commit 30c5499

Browse files
committed
rf: Update check-github-sync to be async, skip bad results
1 parent ae223fd commit 30c5499

File tree

1 file changed

+153
-42
lines changed

1 file changed

+153
-42
lines changed

scripts/check-github-sync

Lines changed: 153 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
# ]
1010
# ///
1111

12-
import subprocess
12+
import asyncio
13+
from dataclasses import dataclass
1314
from functools import cache
1415
from operator import itemgetter
15-
from typing import Iterator
16+
from typing import NamedTuple
1617

1718
import httpx
1819
import stamina
1920
import structlog
21+
import gql
2022
from gql import Client, gql as gql_query
21-
from gql.transport.httpx import HTTPXTransport
23+
from gql.transport.httpx import HTTPXAsyncTransport
2224
from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn
2325

2426
logger = structlog.get_logger()
@@ -52,76 +54,185 @@ query DatasetsWithLatestSnapshots($count: Int, $after: String) {
5254
""")
5355

5456

55-
@cache
56-
def get_client(url: str) -> Client:
57-
return Client(transport=HTTPXTransport(url=url))
57+
@dataclass
58+
class Dataset:
59+
id: str
60+
tag: str
61+
created: str
62+
hexsha: str
5863

5964

6065
@stamina.retry(on=httpx.HTTPError)
61-
def get_page(url: str, count: int, after: str | None) -> dict:
62-
return get_client(url).execute(
66+
async def get_page(client: Client, count: int, after: str | None) -> dict:
67+
return await client.execute_async(
6368
QUERY, variable_values={"count": count, "after": after}
6469
)
6570

6671

67-
def get_dataset_count(url: str) -> int:
68-
response = get_page(url, 0, None)
72+
async def get_dataset_count(client: Client) -> int:
73+
response = await get_page(client, 0, None)
6974
return response["datasets"]["pageInfo"]["count"]
7075

7176

72-
def dataset_iterator(url: str) -> Iterator[tuple[str, str, str, str]]:
77+
async def dataset_producer(
78+
client: Client, queue: asyncio.Queue[Dataset | None], progress: Progress, task_id
79+
) -> None:
80+
"""Producer that fetches datasets from GraphQL API and puts them in the queue."""
7381
page_info = {"hasNextPage": True, "endCursor": None}
7482

75-
while page_info["hasNextPage"]:
76-
result = get_page(url, 100, page_info["endCursor"])
77-
78-
edges, page_info = itemgetter("edges", "pageInfo")(result["datasets"])
79-
80-
for edge in edges:
81-
dataset_id, latest_snapshot = itemgetter("id", "latestSnapshot")(
82-
edge["node"]
83-
)
84-
yield (dataset_id, *itemgetter("tag", "created", "hexsha")(latest_snapshot))
83+
try:
84+
while page_info["hasNextPage"]:
85+
try:
86+
result = await get_page(client, 100, page_info["endCursor"])
87+
except gql.transport.exceptions.TransportQueryError as e:
88+
logger.error("GraphQL query error")
89+
if e.data is not None:
90+
result = e.data
91+
92+
edges, page_info = itemgetter("edges", "pageInfo")(result["datasets"])
93+
94+
for edge in edges:
95+
if edge is None:
96+
continue
97+
dataset_id, latest_snapshot = itemgetter("id", "latestSnapshot")(
98+
edge["node"]
99+
)
100+
dataset = Dataset(
101+
id=dataset_id,
102+
tag=latest_snapshot["tag"],
103+
created=latest_snapshot["created"],
104+
hexsha=latest_snapshot["hexsha"],
105+
)
106+
await queue.put(dataset)
107+
progress.update(task_id, advance=1, dataset=dataset_id)
108+
109+
finally:
110+
# Signal that we're done producing
111+
await queue.put(None)
112+
113+
114+
async def check_remote(dataset: Dataset) -> bool | None:
115+
"""Check if the git remote has the expected tag and commit hash."""
116+
log = logger.bind(dataset=dataset.id, tag=dataset.tag)
117+
repo = f"https://github.com/OpenNeuroDatasets/{dataset.id}.git"
118+
119+
proc = await asyncio.create_subprocess_exec(
120+
"git",
121+
"ls-remote",
122+
"--exit-code",
123+
repo,
124+
dataset.tag,
125+
stdout=asyncio.subprocess.PIPE,
126+
stderr=asyncio.subprocess.PIPE,
127+
)
85128

129+
stdout, stderr = await proc.communicate()
86130

87-
def check_remote(dataset_id: str, tag: str, hexsha: str) -> bool | None:
88-
log = logger.bind(dataset=dataset_id, tag=tag)
89-
repo = f"https://github.com/OpenNeuroDatasets/{dataset_id}.git"
90-
result = subprocess.run(
91-
["git", "ls-remote", "--exit-code", repo, tag],
92-
capture_output=True,
93-
)
94-
if result.returncode:
95-
if "Repository not found" in result.stderr.decode():
131+
if proc.returncode:
132+
stderr_text = stderr.decode()
133+
if "Repository not found" in stderr_text:
96134
log.error("Missing repository")
97135
return None
98136
log.error("Missing latest tag")
99137
return False
100138

101-
shasum, ref = result.stdout.decode("utf-8").strip().split()
139+
stdout_text = stdout.decode("utf-8").strip()
140+
if not stdout_text:
141+
log.error("Empty response from git ls-remote")
142+
return False
143+
144+
shasum, ref = stdout_text.split()
102145

103-
if shasum != hexsha:
104-
log.warning(f"mismatch: {shasum[:7]}({ref[10:]}) != {hexsha[:7]}")
146+
if shasum != dataset.hexsha:
147+
log.warning(f"mismatch: {shasum[:7]}({ref[10:]}) != {dataset.hexsha[:7]}")
105148
return False
106149

107-
return ref == f"refs/tags/{tag}"
150+
return ref == f"refs/tags/{dataset.tag}"
108151

109152

110-
if __name__ == "__main__":
111-
count = get_dataset_count(ENDPOINT)
153+
async def dataset_consumer(
154+
queue: asyncio.Queue[Dataset | None],
155+
progress: Progress,
156+
task_id,
157+
semaphore: asyncio.Semaphore,
158+
results: list[bool | None],
159+
) -> None:
160+
"""Consumer that checks git remotes for datasets from the queue."""
112161

113-
retcode = 0
162+
async def check_single_dataset(dataset: Dataset) -> bool | None:
163+
async with semaphore:
164+
result = await check_remote(dataset)
165+
progress.update(task_id, advance=1, dataset=dataset.id)
166+
return result
167+
168+
tasks = []
169+
170+
while True:
171+
dataset = await queue.get()
172+
if dataset is None:
173+
# Producer is done, wait for remaining tasks to complete
174+
break
175+
176+
# Create task for checking this dataset
177+
task = asyncio.create_task(check_single_dataset(dataset))
178+
tasks.append(task)
179+
180+
# Wait for all remaining tasks to complete
181+
if tasks:
182+
completed_results = await asyncio.gather(*tasks, return_exceptions=True)
183+
for result in completed_results:
184+
if isinstance(result, Exception):
185+
logger.error("Error checking dataset", exc_info=result)
186+
results.append(False)
187+
else:
188+
results.append(result)
189+
190+
191+
async def main() -> int:
192+
client = Client(transport=HTTPXAsyncTransport(url=ENDPOINT))
193+
count = await get_dataset_count(client)
194+
195+
# Queue to pass datasets from producer to consumer
196+
queue: asyncio.Queue[Dataset | None] = asyncio.Queue(maxsize=200)
197+
198+
# Semaphore to limit concurrent git operations
199+
git_semaphore = asyncio.Semaphore(20) # Adjust based on your needs
200+
201+
# Results collection
202+
results: list[bool | None] = []
114203

115204
with Progress(
116-
TextColumn("[progress.description]{task.description} {task.fields[dataset]:8s}"),
205+
TextColumn(
206+
"[progress.description]{task.description} {task.fields[dataset]:8s}"
207+
),
117208
BarColumn(),
118209
MofNCompleteColumn(),
119210
) as progress:
120-
task = progress.add_task("Checking", total=count, dataset="...")
211+
# Create progress tasks
212+
fetch_task = progress.add_task("Fetching", total=count, dataset="...")
213+
check_task = progress.add_task("Checking", total=count, dataset="...")
214+
215+
# Start producer and consumer
216+
producer_task = asyncio.create_task(
217+
dataset_producer(client, queue, progress, fetch_task)
218+
)
219+
consumer_task = asyncio.create_task(
220+
dataset_consumer(queue, progress, check_task, git_semaphore, results)
221+
)
222+
223+
# Wait for both to complete
224+
await asyncio.gather(producer_task, consumer_task)
225+
226+
# Calculate return code
227+
retcode = 0
228+
for result in results:
229+
if result is False: # Only False indicates failure, None is ignored
230+
retcode = 1
231+
break
121232

122-
for dataset_id, tag, created, hexsha in dataset_iterator(ENDPOINT):
123-
progress.update(task, advance=1, dataset=dataset_id)
233+
return retcode
124234

125-
retcode |= not check_remote(dataset_id, tag, hexsha)
126235

236+
if __name__ == "__main__":
237+
retcode = asyncio.run(main())
127238
raise SystemExit(retcode)

0 commit comments

Comments
 (0)