diff --git a/bioindex/api/bio.py b/bioindex/api/bio.py index c09474f..bab3909 100644 --- a/bioindex/api/bio.py +++ b/bioindex/api/bio.py @@ -344,7 +344,7 @@ async def api_query_index(index: str, q: str, req: fastapi.Request, fmt='row', l CONFIG, engine, i, - qs, + [qs], restricted=restricted, ) @@ -364,6 +364,51 @@ async def api_query_index(index: str, q: str, req: fastapi.Request, fmt='row', l raise fastapi.HTTPException(status_code=400, detail=str(e)) +@router.post('/multiquery', response_class=fastapi.responses.ORJSONResponse) +async def api_multi_query_index(req: fastapi.Request, fmt='row'): + """ + Query the database for records matching the query parameter and + read the records from s3. + """ + global INDEXES + + req_body = await req.json() + index = req_body['index'] + queries = req_body['queries'] + try: + qss = [_parse_query(q, required=True) for q in queries] + # All queries have to have the same arity + assert(len({len(qs) for qs in qss}) == 1) + # in the event we've added a new index + if (index, len(qss[0])) not in INDEXES: + INDEXES = _load_indexes() + i = INDEXES[(index, len(qss[0]))] + + # discover what the user doesn't have access to see + restricted, auth_s = profile(restricted_keywords, portal, req) if portal else (None, 0) + # lookup the schema for this index and perform the query + reader, query_s = profile( + query.fetch, + CONFIG, + engine, + i, + qss, + restricted=restricted, + ) + + # with no limit, will this request exceed the limit? + if reader.bytes_total > RESPONSE_LIMIT_MAX: + raise fastapi.HTTPException(status_code=413) + + # the results of the query + return _fetch_records(reader, index, qss, fmt, query_s=auth_s + query_s) + + except KeyError: + raise fastapi.HTTPException(status_code=400, detail=f'Invalid index: {index}') + except ValueError as e: + raise fastapi.HTTPException(status_code=400, detail=str(e)) + + @router.get('/schema', response_class=fastapi.responses.PlainTextResponse) async def api_schema(req: fastapi.Request): """ @@ -436,7 +481,7 @@ async def api_test_index(index: str, q: str, req: fastapi.Request): i = INDEXES[(index, len(qs))] # lookup the schema for this index and perform the query - reader, query_s = profile(query.fetch, engine, CONFIG.s3_bucket, i, qs) + reader, query_s = profile(query.fetch, engine, CONFIG.s3_bucket, i, [qs]) return fastapi.Response( headers={'Content-Length': str(reader.bytes_total)}) diff --git a/bioindex/lib/ql.py b/bioindex/lib/ql.py index 7250ec5..4c9c22f 100644 --- a/bioindex/lib/ql.py +++ b/bioindex/lib/ql.py @@ -201,7 +201,7 @@ async def resolver(parent, info, **kwargs): #q.append(build_region_str(**kwargs['locus'])) # execute the query, get the resulting reader - reader = fetch(config, engine, index, q) + reader = fetch(config, engine, index, [q]) # materialize all the records return list(reader.records) diff --git a/bioindex/lib/query.py b/bioindex/lib/query.py index 864693b..6af6322 100644 --- a/bioindex/lib/query.py +++ b/bioindex/lib/query.py @@ -1,38 +1,21 @@ -import concurrent.futures import re from sqlalchemy import text from .locus import Locus, parse_region_string -from .reader import MultiRecordReader, RecordReader, RecordSource +from .reader import RecordReader, RecordSource from .s3 import list_objects -def fetch(config, engine, index, q, restricted=None): - """ - Use the table schema to determine the type of query to execute. Returns - a RecordReader of all the results. - """ - if len(q) != index.schema.arity: - raise ValueError(f'Arity mismatch for index schema "{index.schema}"') - - # execute the query and fetch the records from s3 - return _run_query(config, engine, index, q, restricted) - - -def fetch_multi(executor, config, engine, index, queries, restricted=None): +def fetch(config, engine, index, qss, restricted=None): """ Run multiple queries in parallel and chain the readers returned into a single reader. """ - jobs = [executor.submit(fetch, config, engine, index, q, restricted) for q in queries] - - # wait for them to complete and get the readers for each - done = concurrent.futures.as_completed(jobs) - readers = [d.result() for d in done] + if len(qss[0]) != index.schema.arity: + raise ValueError(f'Arity mismatch for index schema "{index.schema}"') - # chain the records together - return MultiRecordReader(readers) + return _run_queries(config, engine, index, qss, restricted) def fetch_all(config, index, restricted=None, key_limit=None): @@ -48,7 +31,7 @@ def fetch_all(config, index, restricted=None, key_limit=None): s3_objects = [o[1] for o in zip(range(key_limit), s3_objects)] # create a RecordSource for each object - sources = [RecordSource.from_s3_object(obj) for obj in s3_objects] + sources = [RecordSource.from_s3_object(obj, record_filter=None) for obj in s3_objects] # create the reader object, begin reading the records return RecordReader(config, sources, index, restricted=restricted) @@ -80,7 +63,7 @@ def count(config, engine, index, q): """ Estimate the number of records that will be returned by a query. """ - reader = fetch_all(config, index) if len(q) == 0 else _run_query(config, engine, index, q, None) + reader = fetch_all(config, index) if len(q) == 0 else _run_queries(config, engine, index, [q], None) # read a couple hundred records to get the total bytes read records = list(zip(range(500), reader.records)) @@ -145,70 +128,75 @@ def match(config, engine, index, q): prev_key = r[0] -def _run_query(config, engine, index, q, restricted): +def _run_queries(config, engine, index, qss, restricted): + sources = _get_sources(config, engine, index, qss) + return RecordReader( + config, + sources, + index, + restricted=restricted, + ) + + +def _get_sources(config, engine, index, qss): """ Construct a SQL query to fetch S3 objects and byte offsets. Run it and - return a RecordReader to the results. + return a list of RecordSources to the results. """ - record_filter = None + sources = [] + for q in qss: + record_filter = None + + # validate the index + if not index.built: + raise ValueError(f'Index "{index.name}" is not built') + + # build the query + sql = ( + f'SELECT `__Keys`.`key`, MIN(`start_offset`), MAX(`end_offset`) ' + f'FROM `{index.table}` ' + f'INNER JOIN `__Keys` ' + f'ON `__Keys`.`id` = `{index.table}`.`key` ' + f'WHERE {index.schema.sql_filters} ' + f'GROUP BY `key` ' + f'ORDER BY `key` ASC' + ) - # validate the index - if not index.built: - raise ValueError(f'Index "{index.name}" is not built') + # query parameter list + query_params = q + escaped_column_names = [col.replace("|", "_") for col in index.schema.schema_columns] + # if the schema has a locus, parse the query parameter + if index.schema.has_locus: + if index.schema.locus_is_template: + chromosome, start, stop = index.schema.locus_class(q[-1]).region() + else: + chromosome, start, stop = parse_region_string(q[-1], config) - # build the query - sql = ( - f'SELECT `__Keys`.`key`, MIN(`start_offset`), MAX(`end_offset`) ' - f'FROM `{index.table}` ' - f'INNER JOIN `__Keys` ' - f'ON `__Keys`.`id` = `{index.table}`.`key` ' - f'WHERE {index.schema.sql_filters} ' - f'GROUP BY `key` ' - f'ORDER BY `key` ASC' - ) + # positions are stepped, and need to be between stepped ranges + step_start = (start // Locus.LOCUS_STEP) * Locus.LOCUS_STEP + step_stop = (stop // Locus.LOCUS_STEP) * Locus.LOCUS_STEP - # query parameter list - query_params = q - escaped_column_names = [col.replace("|", "_") for col in index.schema.schema_columns] - # if the schema has a locus, parse the query parameter - if index.schema.has_locus: - if index.schema.locus_is_template: - chromosome, start, stop = index.schema.locus_class(q[-1]).region() - else: - chromosome, start, stop = parse_region_string(q[-1], config) - - # positions are stepped, and need to be between stepped ranges - step_start = (start // Locus.LOCUS_STEP) * Locus.LOCUS_STEP - step_stop = (stop // Locus.LOCUS_STEP) * Locus.LOCUS_STEP - - # replace the last query parameter with the locus - query_params = dict(zip(escaped_column_names, q[:-1])) - query_params.update({"chromosome": chromosome, "start_pos": step_start, "end_pos": step_stop}) - - # match templated locus or overlapping loci - def overlaps(row): - if index.schema.locus_is_template: - return row[index.schema.locus_columns[0]] == q[-1] + # replace the last query parameter with the locus + query_params = dict(zip(escaped_column_names, q[:-1])) + query_params.update({"chromosome": chromosome, "start_pos": step_start, "end_pos": step_stop}) - return index.schema.locus_of_row(row).overlaps(chromosome, start, stop) + # match templated locus or overlapping loci + def overlaps(row): + if index.schema.locus_is_template: + return row[index.schema.locus_columns[0]] == q[-1] - # filter records read by locus - record_filter = overlaps + return index.schema.locus_of_row(row).overlaps(chromosome, start, stop) - with engine.connect() as conn: - if isinstance(query_params, list): - query_params = dict(zip(escaped_column_names, query_params)) - cursor = conn.execute(text(sql), query_params) - rows = cursor.fetchall() - - # create a RecordSource for each entry in the database - sources = [RecordSource(*row) for row in rows] - - # create the reader - return RecordReader( - config, - sources, - index, - record_filter=record_filter, - restricted=restricted, - ) + # filter records read by locus + record_filter = overlaps + + with engine.connect() as conn: + if isinstance(query_params, list): + query_params = dict(zip(escaped_column_names, query_params)) + cursor = conn.execute(text(sql), query_params) + rows = cursor.fetchall() + + # create a RecordSource for each entry in the database + sources += [RecordSource(*row, record_filter) for row in rows] + + return sources diff --git a/bioindex/lib/reader.py b/bioindex/lib/reader.py index 44768d3..ce8f12a 100644 --- a/bioindex/lib/reader.py +++ b/bioindex/lib/reader.py @@ -1,30 +1,25 @@ import subprocess import botocore.exceptions -import dataclasses +import concurrent.futures import itertools import logging import orjson from .auth import verify_record from .s3 import read_lined_object -# from . import config -# CONFIG = config.Config() - -@dataclasses.dataclass(frozen=True) class RecordSource: - """ - A RecordSource is a portion of an S3 object that contains JSON- - lines records. - """ - key: str - start: int - end: int + def __init__(self, key, start, end, record_filter): + self.key = key + self.start = start + self.end = end + self.record_filter = record_filter + @staticmethod - def from_s3_object(s3_obj): + def from_s3_object(s3_obj, record_filter): """ Create a RecordSource from an S3 object listing. """ @@ -32,6 +27,7 @@ def from_s3_object(s3_obj): key=s3_obj['Key'], start=0, end=s3_obj['Size'], + record_filter=record_filter ) @property @@ -48,7 +44,7 @@ class RecordReader: from a list of RecordSource objects for a given S3 bucket. """ - def __init__(self, config, sources, index, record_filter=None, restricted=None): + def __init__(self, config, sources, index, restricted=None): """ Initialize the RecordReader with a list of RecordSource objects. """ @@ -66,29 +62,52 @@ def __init__(self, config, sources, index, record_filter=None, restricted=None): for source in sources: self.bytes_total += source.length - # start reading the records on-demand - self.record_filter = record_filter - self.records = self._readall() + if self.bytes_total <= config.response_limit: + # read parallel if small enough + self.records = self._readparallel() + else: + # start reading the records on-demand + self.records = self._readall() - # if there's a filter, apply it now - if record_filter is not None: - self.records = filter(record_filter, self.records) def _readall(self): """ A generator that reads each of the records from S3 for the sources. """ for source in self.sources: + yield from self._readsource(source) + + def _readparallel(self): + """ + A generator that reads each of the records from S3 for the sources. + """ + record_map = {} + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as pool: + jobs = [pool.submit(self._readfull, source) for source in self.sources] + + for job in concurrent.futures.as_completed(jobs): + if job.exception() is not None: + raise job.exception() + + # get the key and the record iterator returned + source, records = job.result() + record_map[source] = records + for source in self.sources: + for record in record_map[source]: + yield record - # This is here to handle a particularly bad condition: when the - # byte offsets are mucked up and this would cause the reader to - # read everything from the source file (potentially GB of data) - # which will have time and bandwidth costs. + def _readfull(self, source): + return source, list(self._readsource(source)) - if source.end <= source.start: - logging.warning('Bad index record: end offset <= start; skipping...') - continue + def _readsource(self, source): + # This is here to handle a particularly bad condition: when the + # byte offsets are mucked up and this would cause the reader to + # read everything from the source file (potentially GB of data) + # which will have time and bandwidth costs. + if source.end <= source.start: + logging.warning('Bad index record: end offset <= start; skipping...') + else: try: compression_on = self.index.compressed if compression_on: @@ -106,7 +125,7 @@ def _readall(self): self.restricted_count += 1 continue - if self.record_filter is None or self.record_filter(record): + if source.record_filter is None or source.record_filter(record): self.count += 1 yield record @@ -135,7 +154,7 @@ def _readall(self): continue # optionally filter; and tally filtered records - if self.record_filter is None or self.record_filter(record): + if source.record_filter is None or source.record_filter(record): self.count += 1 yield record @@ -163,74 +182,3 @@ def set_limit(self, limit): # update the iterator so it stops once the limit is reached self.records = itertools.takewhile(lambda _: self.count <= self.limit, self.records) - - -class MultiRecordReader: - """ - A RecordReader that's the aggregate of several readers chained - together into a single reader. - """ - - def __init__(self, readers): - """ - Initialize with the several readers. - """ - self.readers = readers - self.records = itertools.chain(*(r.records for r in readers)) - self.limit = None - - @property - def buckets(self): - """ - All buckets. - """ - return [r.bucket for r in self.readers] - - @property - def sources(self): - """ - All sources. - """ - return [s for s in r.sources for r in self.readers] - - @property - def bytes_total(self): - """ - Total bytes to read. - """ - return sum(r.bytes_total for r in self.readers) - - @property - def bytes_read(self): - """ - Total bytes read. - """ - return sum(r.bytes_read for r in self.readers) - - @property - def count(self): - """ - Total number of records read. - """ - return sum(r.count for r in self.readers) - - @property - def restricted_count(self): - """ - Total number of restricted records read. - """ - return sum(r.restricted_count for r in self.readers) - - @property - def at_end(self): - """ - True if all records have been read. - """ - return all(r.at_end for r in self.readers) - - def set_limit(self, limit): - """ - Apply a limit to the number of records that will be read. - """ - for r in self.readers: - r.set_limit(limit) diff --git a/bioindex/lib/source.py b/bioindex/lib/source.py index 5c5dd54..49ae4ca 100644 --- a/bioindex/lib/source.py +++ b/bioindex/lib/source.py @@ -19,7 +19,7 @@ def query(self, q, table): self.engine, self.config.s3_bucket, self.indexes[table], - q.split(','), + [q.split(',')], restricted=self.restricted, ) diff --git a/bioindex/main.py b/bioindex/main.py index 8faf17f..50c961d 100644 --- a/bioindex/main.py +++ b/bioindex/main.py @@ -407,7 +407,7 @@ def cli_query(cfg, index_name, q): i = index.Index.lookup(engine, index_name, len(q)) # query the index - reader = query.fetch(cfg, engine, i, q) + reader = query.fetch(cfg, engine, i, [q]) # dump all the records for record in reader.records: