diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml deleted file mode 100644 index 1f61e62a..00000000 --- a/.github/workflows/deploy-docs.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: Deploy docs - -on: - push: - branches: [main] - paths: - - 'docs/**' - -jobs: - deploy-docs: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ secrets.AWS_REGION }} - - - uses: pnpm/action-setup@v4 - with: - version: 9 - run_install: false - - - name: Install Node.js - uses: actions/setup-node@v4 - with: - node-version: 22 - cache: 'pnpm' - - - name: Install dependencies - run: | - pnpm i - - - name: Build docs - run: | - cd docs - pnpm build - - - name: Deploy docs - run: | - aws s3 sync ./docs/dist s3://${{ secrets.AWS_DOCS_S3_BUCKET }}/ --delete - - - name: Invalidate CloudFront - run: | - aws cloudfront create-invalidation --distribution-id ${{ secrets.AWS_DOCS_CLOUDFRONT_DISTRIBUTION_ID }} --paths '/*' diff --git a/lavender_data/client/api.py b/lavender_data/client/api.py index 6ae01eb6..f5212e2b 100644 --- a/lavender_data/client/api.py +++ b/lavender_data/client/api.py @@ -377,8 +377,7 @@ def get_next_item( no_cache=no_cache, max_retry_count=max_retry_count, ) - content = self._check_response(response).payload.read() - return deserialize_sample(content) + return self._check_response(response).payload.read() def submit_next_item( self, @@ -406,8 +405,7 @@ def get_submitted_result(self, iteration_id: str, cache_key: str): ) if response.status_code == 202: raise LavenderDataApiError(response.content.decode("utf-8")) - content = self._check_response(response).payload.read() - return deserialize_sample(content) + return self._check_response(response).payload.read() def complete_index(self, iteration_id: str, index: int): with self._get_client() as client: diff --git a/lavender_data/client/cli/api_call.py b/lavender_data/client/cli/api_call.py index e863ec11..fe069562 100644 --- a/lavender_data/client/cli/api_call.py +++ b/lavender_data/client/cli/api_call.py @@ -1,4 +1,5 @@ from lavender_data.client import LavenderDataClient +from lavender_data.serialize import deserialize_sample def _api(api_url: str, api_key: str): @@ -77,11 +78,13 @@ def get_next_item( no_cache: bool, max_retry_count: int, ): - return _api(api_url=api_url, api_key=api_key).get_next_item( - iteration_id=iteration_id, - rank=rank, - no_cache=no_cache, - max_retry_count=max_retry_count, + return deserialize_sample( + _api(api_url=api_url, api_key=api_key).get_next_item( + iteration_id=iteration_id, + rank=rank, + no_cache=no_cache, + max_retry_count=max_retry_count, + ) ) @@ -102,8 +105,10 @@ def submit_next_item( def get_submitted_result(api_url: str, api_key: str, iteration_id: str, cache_key: str): - return _api(api_url=api_url, api_key=api_key).get_submitted_result( - iteration_id=iteration_id, cache_key=cache_key + return deserialize_sample( + _api(api_url=api_url, api_key=api_key).get_submitted_result( + iteration_id=iteration_id, cache_key=cache_key + ) ) diff --git a/lavender_data/client/iteration.py b/lavender_data/client/iteration.py index 8692e1fd..d39602a2 100644 --- a/lavender_data/client/iteration.py +++ b/lavender_data/client/iteration.py @@ -1,8 +1,9 @@ import time -import json +import threading +import warnings from typing import Optional, Union, Literal -from concurrent.futures import ThreadPoolExecutor, Future, as_completed +from lavender_data.serialize import deserialize_sample, DeserializeException from lavender_data.client.api import ( get_client, LavenderDataClient, @@ -130,7 +131,8 @@ def __init__( self._iteration_id = iteration_response.id self._total = iteration_response.total - self._last_indices = None + self._last_indices = set() + self._completed_indices = set() self._no_cache = no_cache self._max_retry_count = max_retry_count self._skip_on_failure = skip_on_failure @@ -140,8 +142,16 @@ def __init__( self._api_key = api_key self._api = _api(self._api_url, self._api_key) + self._bytes = 0 + self._stopped = False + self.id = self._iteration_id + self._complete_thread = threading.Thread( + target=self._keep_complete_indices, daemon=True + ) + self._complete_thread.start() + def torch( self, pin_memory: bool = False, @@ -200,34 +210,43 @@ def pushback(self): def __len__(self): return self._total + def _keep_complete_indices(self): + while not self._stopped: + if len(self._completed_indices) == 0: + time.sleep(0.1) + continue + + index = self._completed_indices.pop() + self.complete(index) + def _complete_last_indices(self): - if self._last_indices is not None: - for index in self._last_indices: - self.complete(index) - self._last_indices = None + self._completed_indices.update(self._last_indices) + self._last_indices = set() def _set_last_indices(self, sample_or_batch): indices = sample_or_batch["_lavender_data_indices"] if isinstance(indices, list): - self._last_indices = indices + self._last_indices = set(indices) else: - self._last_indices = [indices] + self._last_indices = {indices} def _get_next_item(self): try: - sample_or_batch = self._api.get_next_item( + serialized = self._api.get_next_item( iteration_id=self._iteration_id, rank=self._rank, no_cache=self._no_cache, max_retry_count=self._max_retry_count, ) + self._bytes += len(serialized) + return deserialize_sample(serialized) except LavenderDataApiError as e: if "No more indices to pop" in str(e): raise StopIteration else: raise e - - return sample_or_batch + except DeserializeException as e: + raise ValueError(f"Failed to deserialize sample: {e}") def _submit_next_item(self) -> str: cache_key = self._api.submit_next_item( @@ -240,10 +259,12 @@ def _submit_next_item(self) -> str: def _get_submitted_result(self, cache_key: str): try: - return self._api.get_submitted_result( + serialized = self._api.get_submitted_result( iteration_id=self._iteration_id, cache_key=cache_key, ) + self._bytes += len(serialized) + return deserialize_sample(serialized) except LavenderDataApiError as e: if "Data is still being processed" in str(e): return None @@ -253,6 +274,12 @@ def _get_submitted_result(self, cache_key: str): raise StopIteration else: raise e + except DeserializeException as e: + raise ValueError(f"Failed to deserialize sample: {e}") + + def _stop(self): + self._stopped = True + self._complete_thread.join() def __next__(self): self._complete_last_indices() @@ -271,7 +298,13 @@ def __next__(self): return sample_or_batch def __iter__(self): - return self + try: + while True: + yield next(self) + except StopIteration: + pass + finally: + self._stop() def __getitem__(self, index: int): return next(self) @@ -288,113 +321,121 @@ def __init__( if prefetch_factor < 1: raise ValueError("prefetch_factor must be greater than 0") - self.dl = dl - self.prefetch_factor = prefetch_factor - self.poll_interval = poll_interval - self.in_order = in_order - self.executor: Optional[ThreadPoolExecutor] = None # to be serializable - self.futures: list[Future] = [] - self.arrived: list[tuple[int, dict]] = [] - self.current = -1 - self.stopped = False + self._dl = dl + self._prefetch_factor = prefetch_factor + self._poll_interval = poll_interval + self._in_order = in_order + self._arrived: list[tuple[int, dict]] = [] + self._current = 0 + self._stopped = False + self._fetch_threads: list[threading.Thread] = [] + self._joined_fetch_threads = 0 + self._error: Optional[Exception] = None + + def _stop(self): + self._stopped = True + for thread in self._fetch_threads: + thread.join() + self._dl._stop() def _get_submitted_result(self, cache_key: str): while True: - data = self.dl._get_submitted_result(cache_key) + data = self._dl._get_submitted_result(cache_key) if data is not None: return data else: - time.sleep(self.poll_interval) + time.sleep(self._poll_interval) - def _submit_next(self): - if self.stopped: - return + def _fetch_one(self): + try: + cache_key = self._dl._submit_next_item() + except LavenderDataApiError as e: + if "No more indices to pop" in str(e): + raise StopIteration + else: + raise e - queue_size = self.prefetch_factor - if self.executor is None: - self.executor = ThreadPoolExecutor(queue_size) + data = None + arrived_index = None + while arrived_index is None and not self._stopped: + time.sleep(self._poll_interval) - while len(self.futures) < queue_size: try: - cache_key = self.dl._submit_next_item() - except LavenderDataApiError as e: - if "No more indices to pop" in str(e): - self.stopped = True - return + data = self._dl._get_submitted_result(cache_key) + if data is not None: + arrived_index = data["_lavender_data_current"] + except LavenderDataSampleProcessingError as e: + if self._dl._skip_on_failure: + arrived_index = e.current else: raise e - future = self.executor.submit( - self._get_submitted_result, cache_key=cache_key - ) - self.futures.append(future) + + return arrived_index, data + + def _keep_fetching(self): + while not self._stopped: + try: + fetched = self._fetch_one() + arrived_index, data = fetched + self._arrived.append((arrived_index, data)) + except StopIteration: + break + except Exception as e: + self._error = e + self._stopped = True + + self._joined_fetch_threads += 1 + + def _start_fetch_threads(self): + for _ in range(self._prefetch_factor): + thread = threading.Thread(target=self._keep_fetching) + thread.start() + self._fetch_threads.append(thread) def __len__(self): - return len(self.dl) + return len(self._dl) def __next__(self): - if len(self.futures) == 0: - self._submit_next() + if len(self._fetch_threads) == 0: + self._start_fetch_threads() - self.dl._complete_last_indices() + self._dl._complete_last_indices() data = None - next_index = self.current + 1 - while True: - # check if the data has already arrived during the previous iteration - already_arrived = ( - [data for i, data in self.arrived if i == next_index] - if self.in_order - else self.arrived - ) - if len(already_arrived) > 0: - data = already_arrived[0] - self.arrived = [a for a in self.arrived if a[0] != next_index] - self.current = next_index - next_index = self.current + 1 - if data is not None: - break - else: - continue + while data is None: + if self._stopped: + raise self._error or StopIteration - # if the iteration is stopped and there are no more futures to be waited, stop iteration - if self.stopped and len(self.futures) == 0: + if self._joined_fetch_threads == len(self._fetch_threads): raise StopIteration - try: - # wait for the data to arrive - future = next(as_completed(self.futures)) - self.futures.remove(future) - data = future.result() - arrived_index = data["_lavender_data_current"] - except StopIteration: - # it means one of the workers detected that the iteration is stopped - # but it's not guaranteed that all data from the other workers has returned - self.stopped = True - continue - except LavenderDataSampleProcessingError as e: - if self.dl._skip_on_failure: - data = None - arrived_index = e.current - else: - raise e + arrived = next( + ( + (i, data) + for i, data in self._arrived + if i == self._current or not self._in_order + ), + None, + ) - self._submit_next() + if arrived is None: + continue - if not self.in_order or arrived_index == next_index: - # if arrived index is the next index, return the data - self.current = next_index - next_index = self.current + 1 - if data is not None: - break - else: - # if arrived index is not the next index, add the data to the list - self.arrived.append((arrived_index, data)) + arrived_index, data = arrived + self._arrived = [a for a in self._arrived if a[0] != arrived_index] + self._current += 1 - self.dl._set_last_indices(data) + self._dl._set_last_indices(data) return data def __iter__(self): - return self + try: + while True: + yield next(self) + except StopIteration: + pass + finally: + self._stop() def __getitem__(self, index: int): return next(self) diff --git a/lavender_data/serialize.py b/lavender_data/serialize.py index 6a95ce05..d3346f81 100644 --- a/lavender_data/serialize.py +++ b/lavender_data/serialize.py @@ -1,6 +1,7 @@ import io import numpy as np import ujson as json +import warnings try: import torch @@ -127,7 +128,11 @@ def serialize_sample(sample: dict): return header + body -def deserialize_sample(content: bytes): +class DeserializeException(Exception): + pass + + +def deserialize_sample(content: bytes, strict: bool = True): header_length, current = detach_length(content) keys = json.loads(current[:header_length].decode("utf-8")) values = [] @@ -136,9 +141,26 @@ def deserialize_sample(content: bytes): if signature != b"sa": raise ValueError(f"Unknown signature: {signature}") current = current[2:] - while current: + i = 0 + while current and i < len(keys): value_length, value = detach_length(current) current_value = value[:value_length] - values.append(deserialize_item(current_value)) + try: + values.append(deserialize_item(current_value)) + except Exception as e: + msg = ( + f"Failed to deserialize item {keys[i]}: {e}\n" + f"Remaining {len(value)} bytes, current item {len(current_value)} bytes, length {value_length}" + ) + if not strict: + warnings.warn(msg) + values.append(None) + else: + raise DeserializeException(msg) current = value[value_length:] + i += 1 + + if len(current) > 0: + warnings.warn(f"Remaining {len(current)} bytes") + return dict(zip(keys, values)) diff --git a/lavender_data/server/app.py b/lavender_data/server/app.py index f10fd1f3..2b148f4b 100644 --- a/lavender_data/server/app.py +++ b/lavender_data/server/app.py @@ -1,8 +1,8 @@ +import time import re -import logging from contextlib import asynccontextmanager -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles @@ -106,15 +106,37 @@ async def lifespan(app: FastAPI): pass -class EndpointFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - # Disable logging for polling requests - return not re.match(r".*GET.*/iterations/.*/next/.* 202.*", record.getMessage()) +app = FastAPI(lifespan=lifespan) -logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) +def log_filter(request: Request, response): + if ( + re.match(r"/iterations/.*/next/.*", request.url.path) + and response.status_code == 202 + ): + return False + return True + + +@app.middleware("http") +async def add_process_time_header(request: Request, call_next): + logger = get_logger(__name__) + start_time = time.perf_counter() + response = await call_next(request) + process_time = time.perf_counter() - start_time + + if log_filter(request, response): + path = ( + f"{request.url.path}?{request.url.query}" + if request.url.query + else request.url.path + ) + logger.info( + f"{request.client.host}:{request.client.port} - {request.method} {path} {response.status_code} {process_time:.2f}s" + ) + + return response -app = FastAPI(lifespan=lifespan) app.mount("/files", StaticFiles(directory=files_dir), name="files") app.add_middleware( diff --git a/lavender_data/server/background_worker/memory.py b/lavender_data/server/background_worker/memory.py index fab9b66c..6f0e39f3 100644 --- a/lavender_data/server/background_worker/memory.py +++ b/lavender_data/server/background_worker/memory.py @@ -5,6 +5,8 @@ import hashlib from lavender_data.logging import get_logger +EOF_SIGNATURE = b"EOF" + class SharedMemory: def __init__(self): @@ -62,7 +64,11 @@ def expire(self, name: str, ex: int): def set(self, name: str, value: Union[bytes, str], ex: Optional[int] = None): _value = self._ensure_bytes(value) - _value = len(_value).to_bytes(length=8, byteorder="big", signed=False) + _value + _value = ( + len(_value).to_bytes(length=8, byteorder="big", signed=False) + + _value + + EOF_SIGNATURE + ) try: memory = self._create_shared_memory(name, len(_value)) @@ -81,6 +87,8 @@ def get(self, name: str) -> bytes: memory = self._get_shared_memory(name) b = memory.buf.tobytes() length = int.from_bytes(b[:8], byteorder="big", signed=False) + if b[length + 8 : length + 8 + len(EOF_SIGNATURE)] != EOF_SIGNATURE: + return None return b[8 : length + 8] except FileNotFoundError: return None diff --git a/lavender_data/server/background_worker/process_pool.py b/lavender_data/server/background_worker/process_pool.py index 980be7cd..cb499d58 100644 --- a/lavender_data/server/background_worker/process_pool.py +++ b/lavender_data/server/background_worker/process_pool.py @@ -347,6 +347,7 @@ def cancel(self, *work_ids: str): def shutdown(self): self._kill_switch.set() + self._spawner_thread.join() _clear_queue(self._call_queue) for p in self._processes: @@ -357,10 +358,10 @@ def shutdown(self): self._result_queue.put(STOP_SIGN) self._result_queue.close() + self._manager_thread.join() + for p in self._processes: p.join() - self._spawner_thread.join() - self._manager_thread.join() _tasks = {} diff --git a/lavender_data/server/cli/run.py b/lavender_data/server/cli/run.py index d4cfcccb..cf8084fa 100644 --- a/lavender_data/server/cli/run.py +++ b/lavender_data/server/cli/run.py @@ -34,7 +34,7 @@ def run(env_file: str = ".env", init: bool = False): server = uvicorn.Server(config) get_logger("uvicorn", clear_handlers=True) - get_logger("uvicorn.access", clear_handlers=True) + get_logger("uvicorn.access", clear_handlers=True).disabled = True try: server.run() diff --git a/lavender_data/server/dataset/preview.py b/lavender_data/server/dataset/preview.py index 6897040c..30997ea5 100644 --- a/lavender_data/server/dataset/preview.py +++ b/lavender_data/server/dataset/preview.py @@ -127,7 +127,8 @@ def _read_dataset( uid_column_type=uid_column_type, main_shard=main_shard, feature_shards=feature_shards, - ) + ), + join="left", ) diff --git a/lavender_data/server/iteration/iteration_state/default.py b/lavender_data/server/iteration/iteration_state/default.py index 2b1e730b..2887b71d 100644 --- a/lavender_data/server/iteration/iteration_state/default.py +++ b/lavender_data/server/iteration/iteration_state/default.py @@ -22,6 +22,7 @@ ShardInfo, MainShardInfo, GlobalSampleIndex, + InnerJoinSampleInsufficient, ) from lavender_data.server.registries import ( FilterRegistry, @@ -479,13 +480,17 @@ def get_next_samples( filters = self._filters() categorizer = self._categorizer() + current = int(self.cache.incr(self._key("batch_count"), 1)) - 1 global_sample_indices = [] samples = [] while len(samples) < max(batch_size, 1): next_item = self.next_item(rank) try: - sample = reader.get_sample(next_item) + sample = reader.get_sample(next_item, join="inner") + except InnerJoinSampleInsufficient: + self.filtered(next_item.index) + continue except Exception as e: self.failed(next_item.index) msg = f"Failed to read sample {next_item.index} (sample {next_item.main_shard.sample_index} of shard {next_item.main_shard.index}): {e.__class__.__name__}({str(e)})" @@ -538,7 +543,6 @@ def get_next_samples( samples.append(sample) cache_key = self._cache_key([i.index for i in global_sample_indices]) - current = int(self.cache.incr(self._key("batch_count"), 1)) - 1 return cache_key, ProcessNextSamplesParams( current=current, global_sample_indices=global_sample_indices, diff --git a/lavender_data/server/iteration/process.py b/lavender_data/server/iteration/process.py index 8fc153c7..c7b58e5e 100644 --- a/lavender_data/server/iteration/process.py +++ b/lavender_data/server/iteration/process.py @@ -1,6 +1,7 @@ import ujson as json from typing import Optional import traceback +import time from fastapi import HTTPException from pydantic import BaseModel @@ -93,7 +94,10 @@ def _decollate(batch: dict) -> dict: _batch = {} for k, v in batch.items(): if torch is not None and isinstance(v, torch.Tensor): - _batch[k] = v.item() + try: + _batch[k] = v.item() + except RuntimeError: + _batch[k] = v[0] elif isinstance(v, list) and len(v) == 1: _batch[k] = v[0] elif isinstance(v, dict): @@ -114,7 +118,7 @@ def _process_next_samples(params: ProcessNextSamplesParams) -> dict: batch_size = params.batch_size if samples is None: - samples = [reader.get_sample(i) for i in global_sample_indices] + samples = [reader.get_sample(i, join="left") for i in global_sample_indices] batch = ( CollaterRegistry.get(collater["name"]).collate(samples) @@ -161,6 +165,21 @@ def process_next_samples( raise error +def _format_number(number: int): + if number < 1000: + return f"{number} " + elif number < 1000**2: + return f"{number/1000:.2f} K" + elif number < 1000**3: + return f"{number/1000**2:.2f} M" + else: + return f"{number/1000**3:.2f} G" + + +def _ms(seconds: float): + return f"{seconds*1000:.2f} ms" + + @pool_task() def process_next_samples_and_store( params: ProcessNextSamplesParams, @@ -170,10 +189,22 @@ def process_next_samples_and_store( *, shared_memory: SharedMemory, ): + logger = get_logger(__name__) try: + _start = time.perf_counter() batch = process_next_samples(params, max_retry_count) + _process_time = time.perf_counter() content = serialize_sample(batch) + _serialize_time = time.perf_counter() shared_memory.set(cache_key, content, ex=cache_ttl) + _store_time = time.perf_counter() + logger.debug( + f"Done processing {cache_key} in {_ms(_store_time - _start)}, " + f"process: {_ms(_process_time - _start)}, " + f"serialize: {_ms(_serialize_time - _process_time)}, " + f"store: {_ms(_store_time - _serialize_time)}, " + f"size: {_format_number(len(content))}B" + ) except ProcessNextSamplesException as e: shared_memory.set(cache_key, f"processing_error:{e.json()}", ex=cache_ttl) except Exception as e: diff --git a/lavender_data/server/reader/__init__.py b/lavender_data/server/reader/__init__.py index 7c58c7c6..94fb2833 100644 --- a/lavender_data/server/reader/__init__.py +++ b/lavender_data/server/reader/__init__.py @@ -1,6 +1,6 @@ import os import hashlib -from typing import Annotated, Optional +from typing import Annotated, Optional, Literal import numpy as np from fastapi import Depends @@ -43,10 +43,21 @@ def _default_null_type(t: str) -> str: return b"" elif t.startswith("bool"): return np.nan + elif t.startswith("list"): + return [] + elif t.startswith("dict"): + return {} else: return None +JoinMethod = Literal["inner", "left"] + + +class InnerJoinSampleInsufficient(Exception): + pass + + class ServerSideReader: reader_cache: dict[str, Reader] = {} @@ -149,7 +160,11 @@ def clear_cache(self, *shards: list[ShardInfo]): self.reader_cache[cache_key].clear() del self.reader_cache[cache_key] - def _get_sample(self, index: GlobalSampleIndex): + def _get_sample( + self, + index: GlobalSampleIndex, + join: JoinMethod, + ): reader = self.get_reader( index.main_shard, index.uid_column_name, index.uid_column_type ) @@ -175,20 +190,34 @@ def _get_sample(self, index: GlobalSampleIndex): try: sample_partial = reader.get_item_by_uid(sample_uid) except KeyError: - msg = f'Failed to read sample with uid "{sample_uid}" from shard {feature_shard.location} ({index.main_shard.sample_index} of {index.main_shard.location})' - sample_partial = { - k: _default_null_type(t) for k, t in feature_shard.columns.items() - } + if join == "inner": + raise InnerJoinSampleInsufficient( + f'Failed to read sample with uid "{sample_uid}" from shard {feature_shard.location} ({index.main_shard.sample_index} of {index.main_shard.location})' + ) + else: + sample_partial = { + k: _default_null_type(t) + for k, t in feature_shard.columns.items() + } for k, v in sample_partial.items(): if k == index.uid_column_name: continue + if v is None: + sample[k] = _default_null_type(feature_shard.columns[k]) + continue sample[k] = v return sample - def get_sample(self, index: GlobalSampleIndex): + def get_sample( + self, + index: GlobalSampleIndex, + join: JoinMethod = "inner", + ): try: - return self._get_sample(index) + return self._get_sample(index, join) + except InnerJoinSampleInsufficient: + raise except Exception as e: self.clear_cache(index.main_shard, *index.feature_shards) raise e diff --git a/lavender_data/server/registries/abc.py b/lavender_data/server/registries/abc.py index 1a36db1b..f00269bb 100644 --- a/lavender_data/server/registries/abc.py +++ b/lavender_data/server/registries/abc.py @@ -17,7 +17,11 @@ class FuncSpec(BaseModel): def _get_md5(_class: type[T]) -> str: - return hashlib.md5(inspect.getsource(_class).encode()).hexdigest() + try: + source = inspect.getsource(_class) + except Exception as e: + source = _class.__name__ + return hashlib.md5(source.encode()).hexdigest() class Registry(ABC, Generic[T]): diff --git a/tests/test_reader.py b/tests/test_reader.py index f0a53aa9..91742aa9 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -9,6 +9,7 @@ GlobalSampleIndex, MainShardInfo, ShardInfo, + InnerJoinSampleInsufficient, ) from tests.utils.shards import create_test_shard @@ -131,10 +132,13 @@ def test_get_sample(self): ), ], ) - sample = self.reader.get_sample(index) + sample = self.reader.get_sample(index, join="left") self.assertEqual(sample["id"], 0) self.assertEqual(sample["image_url"], "https://example.com/image-0.jpg") self.assertEqual(sample["caption"], "Caption for image 0") self.assertTrue(np.isnan(sample["score"])) + with self.assertRaises(InnerJoinSampleInsufficient): + self.reader.get_sample(index, join="inner") + # TODO cache size test