Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions bioindex/api/bio.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ async def api_query_index(index: str, q: str, req: fastapi.Request, fmt='row', l
CONFIG,
engine,
i,
qs,
[qs],
restricted=restricted,
)

Expand All @@ -364,6 +364,51 @@ async def api_query_index(index: str, q: str, req: fastapi.Request, fmt='row', l
raise fastapi.HTTPException(status_code=400, detail=str(e))


@router.post('/multiquery', response_class=fastapi.responses.ORJSONResponse)
async def api_multi_query_index(req: fastapi.Request, fmt='row'):
"""
Query the database for records matching the query parameter and
read the records from s3.
"""
global INDEXES

req_body = await req.json()
index = req_body['index']
queries = req_body['queries']
try:
qss = [_parse_query(q, required=True) for q in queries]
# All queries have to have the same arity
assert(len({len(qs) for qs in qss}) == 1)
# in the event we've added a new index
if (index, len(qss[0])) not in INDEXES:
INDEXES = _load_indexes()
i = INDEXES[(index, len(qss[0]))]

# discover what the user doesn't have access to see
restricted, auth_s = profile(restricted_keywords, portal, req) if portal else (None, 0)
# lookup the schema for this index and perform the query
reader, query_s = profile(
query.fetch,
CONFIG,
engine,
i,
qss,
restricted=restricted,
)

# with no limit, will this request exceed the limit?
if reader.bytes_total > RESPONSE_LIMIT_MAX:
raise fastapi.HTTPException(status_code=413)

# the results of the query
return _fetch_records(reader, index, qss, fmt, query_s=auth_s + query_s)

except KeyError:
raise fastapi.HTTPException(status_code=400, detail=f'Invalid index: {index}')
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=str(e))


@router.get('/schema', response_class=fastapi.responses.PlainTextResponse)
async def api_schema(req: fastapi.Request):
"""
Expand Down Expand Up @@ -436,7 +481,7 @@ async def api_test_index(index: str, q: str, req: fastapi.Request):
i = INDEXES[(index, len(qs))]

# lookup the schema for this index and perform the query
reader, query_s = profile(query.fetch, engine, CONFIG.s3_bucket, i, qs)
reader, query_s = profile(query.fetch, engine, CONFIG.s3_bucket, i, [qs])

return fastapi.Response(
headers={'Content-Length': str(reader.bytes_total)})
Expand Down
2 changes: 1 addition & 1 deletion bioindex/lib/ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ async def resolver(parent, info, **kwargs):
#q.append(build_region_str(**kwargs['locus']))

# execute the query, get the resulting reader
reader = fetch(config, engine, index, q)
reader = fetch(config, engine, index, [q])

# materialize all the records
return list(reader.records)
Expand Down
152 changes: 70 additions & 82 deletions bioindex/lib/query.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
import concurrent.futures
import re

from sqlalchemy import text

from .locus import Locus, parse_region_string
from .reader import MultiRecordReader, RecordReader, RecordSource
from .reader import RecordReader, RecordSource
from .s3 import list_objects


def fetch(config, engine, index, q, restricted=None):
"""
Use the table schema to determine the type of query to execute. Returns
a RecordReader of all the results.
"""
if len(q) != index.schema.arity:
raise ValueError(f'Arity mismatch for index schema "{index.schema}"')

# execute the query and fetch the records from s3
return _run_query(config, engine, index, q, restricted)


def fetch_multi(executor, config, engine, index, queries, restricted=None):
def fetch(config, engine, index, qss, restricted=None):
"""
Run multiple queries in parallel and chain the readers returned
into a single reader.
"""
jobs = [executor.submit(fetch, config, engine, index, q, restricted) for q in queries]

# wait for them to complete and get the readers for each
done = concurrent.futures.as_completed(jobs)
readers = [d.result() for d in done]
if len(qss[0]) != index.schema.arity:
raise ValueError(f'Arity mismatch for index schema "{index.schema}"')

# chain the records together
return MultiRecordReader(readers)
return _run_queries(config, engine, index, qss, restricted)


def fetch_all(config, index, restricted=None, key_limit=None):
Expand All @@ -48,7 +31,7 @@ def fetch_all(config, index, restricted=None, key_limit=None):
s3_objects = [o[1] for o in zip(range(key_limit), s3_objects)]

# create a RecordSource for each object
sources = [RecordSource.from_s3_object(obj) for obj in s3_objects]
sources = [RecordSource.from_s3_object(obj, record_filter=None) for obj in s3_objects]

# create the reader object, begin reading the records
return RecordReader(config, sources, index, restricted=restricted)
Expand Down Expand Up @@ -80,7 +63,7 @@ def count(config, engine, index, q):
"""
Estimate the number of records that will be returned by a query.
"""
reader = fetch_all(config, index) if len(q) == 0 else _run_query(config, engine, index, q, None)
reader = fetch_all(config, index) if len(q) == 0 else _run_queries(config, engine, index, [q], None)

# read a couple hundred records to get the total bytes read
records = list(zip(range(500), reader.records))
Expand Down Expand Up @@ -145,70 +128,75 @@ def match(config, engine, index, q):
prev_key = r[0]


def _run_query(config, engine, index, q, restricted):
def _run_queries(config, engine, index, qss, restricted):
sources = _get_sources(config, engine, index, qss)
return RecordReader(
config,
sources,
index,
restricted=restricted,
)


def _get_sources(config, engine, index, qss):
"""
Construct a SQL query to fetch S3 objects and byte offsets. Run it and
return a RecordReader to the results.
return a list of RecordSources to the results.
"""
record_filter = None
sources = []
for q in qss:
record_filter = None

# validate the index
if not index.built:
raise ValueError(f'Index "{index.name}" is not built')

# build the query
sql = (
f'SELECT `__Keys`.`key`, MIN(`start_offset`), MAX(`end_offset`) '
f'FROM `{index.table}` '
f'INNER JOIN `__Keys` '
f'ON `__Keys`.`id` = `{index.table}`.`key` '
f'WHERE {index.schema.sql_filters} '
f'GROUP BY `key` '
f'ORDER BY `key` ASC'
)

# validate the index
if not index.built:
raise ValueError(f'Index "{index.name}" is not built')
# query parameter list
query_params = q
escaped_column_names = [col.replace("|", "_") for col in index.schema.schema_columns]
# if the schema has a locus, parse the query parameter
if index.schema.has_locus:
if index.schema.locus_is_template:
chromosome, start, stop = index.schema.locus_class(q[-1]).region()
else:
chromosome, start, stop = parse_region_string(q[-1], config)

# build the query
sql = (
f'SELECT `__Keys`.`key`, MIN(`start_offset`), MAX(`end_offset`) '
f'FROM `{index.table}` '
f'INNER JOIN `__Keys` '
f'ON `__Keys`.`id` = `{index.table}`.`key` '
f'WHERE {index.schema.sql_filters} '
f'GROUP BY `key` '
f'ORDER BY `key` ASC'
)
# positions are stepped, and need to be between stepped ranges
step_start = (start // Locus.LOCUS_STEP) * Locus.LOCUS_STEP
step_stop = (stop // Locus.LOCUS_STEP) * Locus.LOCUS_STEP

# query parameter list
query_params = q
escaped_column_names = [col.replace("|", "_") for col in index.schema.schema_columns]
# if the schema has a locus, parse the query parameter
if index.schema.has_locus:
if index.schema.locus_is_template:
chromosome, start, stop = index.schema.locus_class(q[-1]).region()
else:
chromosome, start, stop = parse_region_string(q[-1], config)

# positions are stepped, and need to be between stepped ranges
step_start = (start // Locus.LOCUS_STEP) * Locus.LOCUS_STEP
step_stop = (stop // Locus.LOCUS_STEP) * Locus.LOCUS_STEP

# replace the last query parameter with the locus
query_params = dict(zip(escaped_column_names, q[:-1]))
query_params.update({"chromosome": chromosome, "start_pos": step_start, "end_pos": step_stop})

# match templated locus or overlapping loci
def overlaps(row):
if index.schema.locus_is_template:
return row[index.schema.locus_columns[0]] == q[-1]
# replace the last query parameter with the locus
query_params = dict(zip(escaped_column_names, q[:-1]))
query_params.update({"chromosome": chromosome, "start_pos": step_start, "end_pos": step_stop})

return index.schema.locus_of_row(row).overlaps(chromosome, start, stop)
# match templated locus or overlapping loci
def overlaps(row):
if index.schema.locus_is_template:
return row[index.schema.locus_columns[0]] == q[-1]

# filter records read by locus
record_filter = overlaps
return index.schema.locus_of_row(row).overlaps(chromosome, start, stop)

with engine.connect() as conn:
if isinstance(query_params, list):
query_params = dict(zip(escaped_column_names, query_params))
cursor = conn.execute(text(sql), query_params)
rows = cursor.fetchall()

# create a RecordSource for each entry in the database
sources = [RecordSource(*row) for row in rows]

# create the reader
return RecordReader(
config,
sources,
index,
record_filter=record_filter,
restricted=restricted,
)
# filter records read by locus
record_filter = overlaps

with engine.connect() as conn:
if isinstance(query_params, list):
query_params = dict(zip(escaped_column_names, query_params))
cursor = conn.execute(text(sql), query_params)
rows = cursor.fetchall()

# create a RecordSource for each entry in the database
sources += [RecordSource(*row, record_filter) for row in rows]

return sources
Loading