Skip to content

Commit 7b26bfe

Browse files
authored
feat: server side prefetch (#41)
* feat: serversice prefetch (wip) * fix: pass test * fix: prefetcher thread cleanup * feat: use 202 * fix: log
1 parent 291da02 commit 7b26bfe

22 files changed

+487
-1592
lines changed

Makefile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ all: openapi-python-client ui
55
openapi-python-client:
66
openapi-python-client generate --url http://localhost:8000/openapi.json --config openapi.config.yaml --overwrite --meta poetry
77
sed -i -e 's/response_200 = File(payload=BytesIO(response.json()))/response_200 = File(payload=BytesIO(response.content))/g' openapi-lavender-data-rest/openapi_lavender_data_rest/api/iterations/get_next_iterations_iteration_id_next_get.py
8-
sed -i -e 's/response_200 = File(payload=BytesIO(response.json()))/response_200 = File(payload=BytesIO(response.content))/g' openapi-lavender-data-rest/openapi_lavender_data_rest/api/iterations/get_submitted_result_iterations_iteration_id_next_cache_key_get.py
98
rm openapi-lavender-data-rest/openapi_lavender_data_rest/api/iterations/get_next_iterations_iteration_id_next_get.py-e 2> /dev/null
10-
rm openapi-lavender-data-rest/openapi_lavender_data_rest/api/iterations/get_submitted_result_iterations_iteration_id_next_cache_key_get.py-e 2> /dev/null
119

1210
ui:
1311
cd ./ui && pnpm build && cd ../

lavender_data/client/api.py

Lines changed: 23 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from openapi_lavender_data_rest.api.iterations import (
2626
create_iteration_iterations_post,
2727
get_next_iterations_iteration_id_next_get,
28-
submit_next_iterations_iteration_id_next_post,
29-
get_submitted_result_iterations_iteration_id_next_cache_key_get,
3028
get_iteration_iterations_iteration_id_get,
3129
get_iterations_iterations_get,
3230
complete_index_iterations_iteration_id_complete_index_post,
@@ -304,10 +302,15 @@ def create_iteration(
304302
categorizer: Optional[IterationCategorizer] = None,
305303
collater: Optional[IterationCollater] = None,
306304
preprocessors: Optional[list[IterationPreprocessor]] = None,
305+
max_retry_count: int = 0,
307306
rank: int = 0,
308307
world_size: Optional[int] = None,
309308
wait_participant_threshold: Optional[float] = None,
310-
cluster_sync: bool = False,
309+
no_cache: Optional[bool] = None,
310+
num_workers: Optional[int] = None,
311+
prefetch_factor: Optional[int] = None,
312+
in_order: Optional[bool] = None,
313+
cluster_sync: Optional[bool] = None,
311314
):
312315
with self._get_client() as client:
313316
response = create_iteration_iterations_post.sync_detailed(
@@ -324,9 +327,14 @@ def create_iteration(
324327
collater=collater,
325328
preprocessors=preprocessors,
326329
replication_pg=replication_pg,
330+
max_retry_count=max_retry_count,
327331
rank=rank,
328332
world_size=world_size,
329333
wait_participant_threshold=wait_participant_threshold,
334+
no_cache=no_cache,
335+
num_workers=num_workers,
336+
prefetch_factor=prefetch_factor,
337+
in_order=in_order,
330338
cluster_sync=cluster_sync,
331339
),
332340
)
@@ -366,61 +374,22 @@ def get_next_item(
366374
self,
367375
iteration_id: str,
368376
rank: int = 0,
369-
no_cache: bool = False,
370-
max_retry_count: int = 0,
371377
client: Optional[Client] = None,
372378
):
373379
with self._get_client() if client is None else nullcontext() as _client:
374380
response = get_next_iterations_iteration_id_next_get.sync_detailed(
375381
client=client or _client,
376382
iteration_id=iteration_id,
377383
rank=rank,
378-
no_cache=no_cache,
379-
max_retry_count=max_retry_count,
380384
)
381-
382385
try:
383386
current = int(response.headers.get("X-Lavender-Data-Sample-Current"))
384387
except TypeError:
385388
current = None
386-
return self._check_response(response).payload.read(), current
387-
388-
def submit_next_item(
389-
self,
390-
iteration_id: str,
391-
rank: int = 0,
392-
no_cache: bool = False,
393-
max_retry_count: int = 0,
394-
client: Optional[Client] = None,
395-
):
396-
with self._get_client() if client is None else nullcontext() as _client:
397-
response = submit_next_iterations_iteration_id_next_post.sync_detailed(
398-
client=client or _client,
399-
iteration_id=iteration_id,
400-
rank=rank,
401-
no_cache=no_cache,
402-
max_retry_count=max_retry_count,
403-
)
404-
return self._check_response(response)
405389

406-
def get_submitted_result(
407-
self,
408-
iteration_id: str,
409-
cache_key: str,
410-
client: Optional[Client] = None,
411-
):
412-
with self._get_client() if client is None else nullcontext() as _client:
413-
response = get_submitted_result_iterations_iteration_id_next_cache_key_get.sync_detailed(
414-
client=client or _client,
415-
iteration_id=iteration_id,
416-
cache_key=cache_key,
417-
)
418390
if response.status_code == 202:
419391
raise LavenderDataApiError(response.content.decode("utf-8"))
420-
try:
421-
current = int(response.headers.get("X-Lavender-Data-Sample-Current"))
422-
except TypeError:
423-
current = None
392+
424393
return self._check_response(response).payload.read(), current
425394

426395
def complete_index(self, iteration_id: str, index: int):
@@ -602,10 +571,15 @@ def create_iteration(
602571
categorizer: Optional[IterationCategorizer] = None,
603572
collater: Optional[IterationCollater] = None,
604573
preprocessors: Optional[list[IterationPreprocessor]] = None,
574+
max_retry_count: int = 0,
605575
rank: int = 0,
606576
world_size: Optional[int] = None,
607577
wait_participant_threshold: Optional[float] = None,
608-
cluster_sync: bool = False,
578+
no_cache: Optional[bool] = None,
579+
num_workers: Optional[int] = None,
580+
prefetch_factor: Optional[int] = None,
581+
in_order: Optional[bool] = None,
582+
cluster_sync: Optional[bool] = None,
609583
):
610584
return _client_instance.create_iteration(
611585
dataset_id=dataset_id,
@@ -619,9 +593,14 @@ def create_iteration(
619593
categorizer=categorizer,
620594
collater=collater,
621595
preprocessors=preprocessors,
596+
max_retry_count=max_retry_count,
622597
rank=rank,
623598
world_size=world_size,
624599
wait_participant_threshold=wait_participant_threshold,
600+
no_cache=no_cache,
601+
num_workers=num_workers,
602+
prefetch_factor=prefetch_factor,
603+
in_order=in_order,
625604
cluster_sync=cluster_sync,
626605
)
627606

@@ -644,36 +623,10 @@ def get_iteration(iteration_id: str):
644623
def get_next_item(
645624
iteration_id: str,
646625
rank: int = 0,
647-
no_cache: bool = False,
648-
max_retry_count: int = 0,
649626
):
650627
return _client_instance.get_next_item(
651628
iteration_id=iteration_id,
652629
rank=rank,
653-
no_cache=no_cache,
654-
max_retry_count=max_retry_count,
655-
)
656-
657-
658-
@ensure_client()
659-
def submit_next_item(
660-
iteration_id: str,
661-
rank: int = 0,
662-
no_cache: bool = False,
663-
max_retry_count: int = 0,
664-
):
665-
return _client_instance.submit_next_item(
666-
iteration_id=iteration_id,
667-
rank=rank,
668-
no_cache=no_cache,
669-
max_retry_count=max_retry_count,
670-
)
671-
672-
673-
@ensure_client()
674-
def get_submitted_result(iteration_id: str, cache_key: str):
675-
return _client_instance.get_submitted_result(
676-
iteration_id=iteration_id, cache_key=cache_key
677630
)
678631

679632

lavender_data/client/cli/__init__.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
get_iterations,
1616
get_iteration,
1717
get_next_item,
18-
submit_next_item,
19-
get_submitted_result,
2018
complete_index,
2119
pushback,
2220
get_progress,
@@ -92,22 +90,6 @@ def __init__(self, parent_parser: Optional[argparse.ArgumentParser] = None):
9290
self.iterations_next.add_argument("--no-cache", action="store_true")
9391
self.iterations_next.add_argument("--max-retry-count", type=int, default=0)
9492

95-
self.iterations_submit_next_item = self.iterations_command_parser.add_parser(
96-
"async-next"
97-
)
98-
self.iterations_submit_next_item.add_argument("id", type=str)
99-
self.iterations_submit_next_item.add_argument("--rank", type=int, default=0)
100-
self.iterations_submit_next_item.add_argument("--no-cache", action="store_true")
101-
self.iterations_submit_next_item.add_argument(
102-
"--max-retry-count", type=int, default=0
103-
)
104-
105-
self.iterations_async_result = self.iterations_command_parser.add_parser(
106-
"async-result"
107-
)
108-
self.iterations_async_result.add_argument("id", type=str)
109-
self.iterations_async_result.add_argument("key", type=str)
110-
11193
self.iterations_complete_index = self.iterations_command_parser.add_parser(
11294
"complete-index"
11395
)
@@ -194,21 +176,6 @@ def main(self, args: Optional[argparse.Namespace] = None):
194176
args.api_key,
195177
args.id,
196178
args.rank,
197-
args.no_cache,
198-
args.max_retry_count,
199-
)
200-
elif args.command == "async-next":
201-
result = submit_next_item(
202-
args.api_url,
203-
args.api_key,
204-
args.id,
205-
args.rank,
206-
args.no_cache,
207-
args.max_retry_count,
208-
)
209-
elif args.command == "async-result":
210-
result = get_submitted_result(
211-
args.api_url, args.api_key, args.id, args.key
212179
)
213180
elif args.command == "complete-index":
214181
result = complete_index(args.api_url, args.api_key, args.id, args.index)

lavender_data/client/cli/api_call.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -75,39 +75,11 @@ def get_next_item(
7575
api_key: str,
7676
iteration_id: str,
7777
rank: int,
78-
no_cache: bool,
79-
max_retry_count: int,
8078
):
8179
return deserialize_sample(
8280
_api(api_url=api_url, api_key=api_key).get_next_item(
8381
iteration_id=iteration_id,
8482
rank=rank,
85-
no_cache=no_cache,
86-
max_retry_count=max_retry_count,
87-
)[0]
88-
)
89-
90-
91-
def submit_next_item(
92-
api_url: str,
93-
api_key: str,
94-
iteration_id: str,
95-
rank: int,
96-
no_cache: bool,
97-
max_retry_count: int,
98-
):
99-
return _api(api_url=api_url, api_key=api_key).submit_next_item(
100-
iteration_id=iteration_id,
101-
rank=rank,
102-
no_cache=no_cache,
103-
max_retry_count=max_retry_count,
104-
)
105-
106-
107-
def get_submitted_result(api_url: str, api_key: str, iteration_id: str, cache_key: str):
108-
return deserialize_sample(
109-
_api(api_url=api_url, api_key=api_key).get_submitted_result(
110-
iteration_id=iteration_id, cache_key=cache_key
11183
)[0]
11284
)
11385

0 commit comments

Comments
 (0)