diff --git a/LNX-docker-compose.yaml b/LNX-docker-compose.yaml index 8fb6b98..eb48b61 100644 --- a/LNX-docker-compose.yaml +++ b/LNX-docker-compose.yaml @@ -29,6 +29,7 @@ services: volumes: - .:/main - ./test_requirements.txt:/tmp/pip_requirements.txt + - ./dj_gui_api_server:/opt/conda/lib/python3.8/site-packages/dj_gui_api_server user: ${HOST_UID}:anaconda working_dir: /main command: @@ -42,7 +43,7 @@ services: echo "------ UNIT TESTS ------" pytest -sv --cov-report term-missing --cov=dj_gui_api_server /main/tests echo "------ STYLE TESTS ------" - flake8 dj_gui_api_server --count --max-complexity=10 --max-line-length=95 \ + flake8 dj_gui_api_server --count --max-complexity=20 --max-line-length=95 \ --statistics else echo "=== Running ===" diff --git a/dj_gui_api_server/DJConnector.py b/dj_gui_api_server/DJConnector.py index e630292..8c76ad1 100644 --- a/dj_gui_api_server/DJConnector.py +++ b/dj_gui_api_server/DJConnector.py @@ -2,6 +2,7 @@ import datajoint as dj import datetime import numpy as np +from functools import reduce from datajoint.errors import AccessError import re from .errors import InvalidDeleteRequest, InvalidRestriction, UnsupportedTableType @@ -102,7 +103,9 @@ def list_tables(jwt_payload: dict, schema_name: str): return tables_dict_list @staticmethod - def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str): + def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str, + restriction: list = [], limit: int = 1000, page: int = 1, + order=['KEY ASC']) -> tuple: """ Get records as tuples from table :param jwt_payload: Dictionary containing databaseAddress, username and password @@ -112,9 +115,37 @@ def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str): :type schema_name: str :param table_name: Table name under the given schema; must be in camel case :type table_name: str - :return: List of tuples in dict form - :rtype: list + :param restriction: Sequence of filter cards with attribute_name, operation, value + defined, defaults to [] + :type restriction: list, optional + :param limit: Max number of records to return, defaults to 1000 + :type limit: int, optional + :param page: Page number to return, defaults to 1 + :type page: int, optional + :param order: Sequence to order records, defaults to ['KEY ASC']. + See :class:`datajoint.fetch.Fetch` for more info. + :type order: list, optional + :return: Records in dict form and the total number of records that can be paged + :rtype: tuple """ + def filter_to_restriction(attribute_filter: dict) -> str: + if attribute_filter['operation'] in ('>', '<', '>=', '<='): + operation = attribute_filter['operation'] + elif attribute_filter['value'] is None: + operation = (' IS ' if attribute_filter['operation'] == '=' + else ' IS NOT ') + else: + operation = attribute_filter['operation'] + + if (isinstance(attribute_filter['value'], str) and + not attribute_filter['value'].isnumeric()): + value = f"'{attribute_filter['value']}'" + else: + value = ('NULL' if attribute_filter['value'] is None + else attribute_filter['value']) + + return f"{attribute_filter['attributeName']}{operation}{value}" + DJConnector.set_datajoint_config(jwt_payload) schema_virtual_module = dj.create_virtual_module(schema_name, schema_name) @@ -124,8 +155,10 @@ def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str): # Fetch tuples without blobs as dict to be used to create a # list of tuples for returning - non_blobs_rows = table.fetch(*table.heading.non_blobs, as_dict=True, - limit=DEFAULT_FETCH_LIMIT) + query = reduce(lambda q1, q2: q1 & q2, [table()] + [filter_to_restriction(f) + for f in restriction]) + non_blobs_rows = query.fetch(*table.heading.non_blobs, as_dict=True, limit=limit, + offset=(page-1)*limit, order_by=order) # Buffer list to be return rows = [] @@ -168,7 +201,7 @@ def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str): # Add the row list to tuples rows.append(row) - return rows + return rows, len(query) @staticmethod def get_table_attributes(jwt_payload: dict, schema_name: str, table_name: str): diff --git a/dj_gui_api_server/DJGUIAPIServer.py b/dj_gui_api_server/DJGUIAPIServer.py index 0bff967..49b23b1 100644 --- a/dj_gui_api_server/DJGUIAPIServer.py +++ b/dj_gui_api_server/DJGUIAPIServer.py @@ -133,6 +133,8 @@ def fetch_tuples(jwt_payload: dict): """ Route to fetch all records for a given table. Expects: (html:GET:Authorization): Must include in format of: bearer + (html:query_params): {"limit": , "page": , "order": , + "restriction": } (html:POST:JSON): {"schemaName": , "tableName": } NOTE: Table name must be in CamalCase :param jwt_payload: Dictionary containing databaseAddress, username and password @@ -142,10 +144,16 @@ def fetch_tuples(jwt_payload: dict): :rtype: dict """ try: - table_tuples = DJConnector.fetch_tuples(jwt_payload, - request.json["schemaName"], - request.json["tableName"]) - return dict(tuples=table_tuples) + table_tuples, total_count = DJConnector.fetch_tuples( + jwt_payload=jwt_payload, + schema_name=request.json["schemaName"], + table_name=request.json["tableName"], + **{k: (int(v) if k in ('limit', 'page') + else (v.split(',') if k == 'order' else loads( + b64decode(v.encode('utf-8')).decode('utf-8')))) + for k, v in request.args.items()}, + ) + return dict(tuples=table_tuples, total_count=total_count) except Exception as e: return str(e), 500 diff --git a/test_requirements.txt b/test_requirements.txt index f604537..bec8baf 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,3 +1,4 @@ pytest pytest-cov flake8 +Faker \ No newline at end of file diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 0000000..028048f --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,112 @@ +from os import getenv +import pytest +from dj_gui_api_server.DJGUIAPIServer import app +import datajoint as dj +from json import dumps +from base64 import b64encode +from urllib.parse import urlencode +from datetime import date, datetime +from random import randint, choice, seed, getrandbits +from faker import Faker +seed('lock') # Pin down randomizer between runs +faker = Faker() +Faker.seed(0) # Pin down randomizer between runs + + +@pytest.fixture +def client(): + with app.test_client() as client: + yield client + + +@pytest.fixture +def token(client): + yield client.post('/api/login', json=dict(databaseAddress=getenv('TEST_DB_SERVER'), + username=getenv('TEST_DB_USER'), + password=getenv('TEST_DB_PASS'))).json['jwt'] + + +@pytest.fixture +def virtual_module(): + dj.config['safemode'] = False + connection = dj.conn(host=getenv('TEST_DB_SERVER'), + user=getenv('TEST_DB_USER'), + password=getenv('TEST_DB_PASS'), reset=True) + schema = dj.Schema('filter') + + @schema + class Student(dj.Lookup): + definition = """ + student_id: int + --- + student_name: varchar(50) + student_ssn: varchar(20) + student_enroll_date: datetime + student_balance: float + student_parking_lot=null : varchar(20) + student_out_of_state: bool + """ + contents = [(i, faker.name(), faker.ssn(), faker.date_between_dates( + date_start=date(2021, 1, 1), date_end=date(2021, 1, 31)), + round(randint(1000, 3000), 2), + choice([None, 'LotA', 'LotB', 'LotC']), + bool(getrandbits(1))) for i in range(100)] + + yield dj.VirtualModule('filter', 'filter') + schema.drop() + connection.close() + dj.config['safemode'] = True + + +def test_filters(token, client, virtual_module): + # 'between' dates + restriction = [dict(attributeName='student_enroll_date', operation='>', + value='2021-01-07'), + dict(attributeName='student_enroll_date', operation='<', + value='2021-01-17')] + encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8') + q = dict(limit=10, page=1, order='student_enroll_date DESC', + restriction=encoded_restriction) + REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}'), + json=dict(schemaName='filter', + tableName='Student')).json['tuples'] + assert len(REST_records) == 10 + assert REST_records[0][3] == datetime(2021, 1, 16).timestamp() + # 'equal' null + restriction = [dict(attributeName='student_parking_lot', operation='=', value=None)] + encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8') + q = dict(limit=10, page=2, order='student_id ASC', + restriction=encoded_restriction) + REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}'), + json=dict(schemaName='filter', + tableName='Student')).json['tuples'] + assert len(REST_records) == 10 + assert all([r[5] is None for r in REST_records]) + assert REST_records[0][0] == 34 + # not equal int + restriction = [dict(attributeName='student_id', operation='!=', value='2')] + encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8') + q = dict(limit=10, page=1, order='student_id ASC', + restriction=encoded_restriction) + REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}'), + json=dict(schemaName='filter', + tableName='Student')).json['tuples'] + assert len(REST_records) == 10 + assert all([r[0] != 2 for r in REST_records]) + assert REST_records[-1][0] == 10 + # equal 'Norma Fisher' and in_state student (bool) + restriction = [dict(attributeName='student_name', operation='=', value='Norma Fisher'), + dict(attributeName='student_out_of_state', operation='=', value='0')] + encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8') + q = dict(limit=10, page=1, order='student_id ASC', + restriction=encoded_restriction) + REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}'), + json=dict(schemaName='filter', + tableName='Student')).json['tuples'] + assert len(REST_records) == 1 + assert REST_records[0][1] == 'Norma Fisher' + assert REST_records[0][6] == 0