Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion LNX-docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ services:
volumes:
- .:/main
- ./test_requirements.txt:/tmp/pip_requirements.txt
- ./dj_gui_api_server:/opt/conda/lib/python3.8/site-packages/dj_gui_api_server
user: ${HOST_UID}:anaconda
working_dir: /main
command:
Expand All @@ -42,7 +43,7 @@ services:
echo "------ UNIT TESTS ------"
pytest -sv --cov-report term-missing --cov=dj_gui_api_server /main/tests
echo "------ STYLE TESTS ------"
flake8 dj_gui_api_server --count --max-complexity=10 --max-line-length=95 \
flake8 dj_gui_api_server --count --max-complexity=20 --max-line-length=95 \
--statistics
else
echo "=== Running ==="
Expand Down
45 changes: 39 additions & 6 deletions dj_gui_api_server/DJConnector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datajoint as dj
import datetime
import numpy as np
from functools import reduce
from datajoint.errors import AccessError
import re
from .errors import InvalidDeleteRequest, InvalidRestriction, UnsupportedTableType
Expand Down Expand Up @@ -102,7 +103,9 @@ def list_tables(jwt_payload: dict, schema_name: str):
return tables_dict_list

@staticmethod
def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str):
def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str,
restriction: list = [], limit: int = 1000, page: int = 1,
order=['KEY ASC']) -> tuple:
"""
Get records as tuples from table
:param jwt_payload: Dictionary containing databaseAddress, username and password
Expand All @@ -112,9 +115,37 @@ def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str):
:type schema_name: str
:param table_name: Table name under the given schema; must be in camel case
:type table_name: str
:return: List of tuples in dict form
:rtype: list
:param restriction: Sequence of filter cards with attribute_name, operation, value
defined, defaults to []
:type restriction: list, optional
:param limit: Max number of records to return, defaults to 1000
:type limit: int, optional
:param page: Page number to return, defaults to 1
:type page: int, optional
:param order: Sequence to order records, defaults to ['KEY ASC'].
See :class:`datajoint.fetch.Fetch` for more info.
:type order: list, optional
:return: Records in dict form and the total number of records that can be paged
:rtype: tuple
"""
def filter_to_restriction(attribute_filter: dict) -> str:
if attribute_filter['operation'] in ('>', '<', '>=', '<='):
operation = attribute_filter['operation']
elif attribute_filter['value'] is None:
operation = (' IS ' if attribute_filter['operation'] == '='
else ' IS NOT ')
else:
operation = attribute_filter['operation']

if (isinstance(attribute_filter['value'], str) and
not attribute_filter['value'].isnumeric()):
value = f"'{attribute_filter['value']}'"
else:
value = ('NULL' if attribute_filter['value'] is None
else attribute_filter['value'])

return f"{attribute_filter['attributeName']}{operation}{value}"

DJConnector.set_datajoint_config(jwt_payload)

schema_virtual_module = dj.create_virtual_module(schema_name, schema_name)
Expand All @@ -124,8 +155,10 @@ def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str):

# Fetch tuples without blobs as dict to be used to create a
# list of tuples for returning
non_blobs_rows = table.fetch(*table.heading.non_blobs, as_dict=True,
limit=DEFAULT_FETCH_LIMIT)
query = reduce(lambda q1, q2: q1 & q2, [table()] + [filter_to_restriction(f)
for f in restriction])
non_blobs_rows = query.fetch(*table.heading.non_blobs, as_dict=True, limit=limit,
offset=(page-1)*limit, order_by=order)

# Buffer list to be return
rows = []
Expand Down Expand Up @@ -168,7 +201,7 @@ def fetch_tuples(jwt_payload: dict, schema_name: str, table_name: str):

# Add the row list to tuples
rows.append(row)
return rows
return rows, len(query)

@staticmethod
def get_table_attributes(jwt_payload: dict, schema_name: str, table_name: str):
Expand Down
16 changes: 12 additions & 4 deletions dj_gui_api_server/DJGUIAPIServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def fetch_tuples(jwt_payload: dict):
"""
Route to fetch all records for a given table. Expects:
(html:GET:Authorization): Must include in format of: bearer <JWT-Token>
(html:query_params): {"limit": <limit>, "page": <page>, "order": <order>,
"restriction": <Base64 encoded restriction as JSONArray>}
(html:POST:JSON): {"schemaName": <schema_name>, "tableName": <table_name>}
NOTE: Table name must be in CamalCase
:param jwt_payload: Dictionary containing databaseAddress, username and password
Expand All @@ -142,10 +144,16 @@ def fetch_tuples(jwt_payload: dict):
:rtype: dict
"""
try:
table_tuples = DJConnector.fetch_tuples(jwt_payload,
request.json["schemaName"],
request.json["tableName"])
return dict(tuples=table_tuples)
table_tuples, total_count = DJConnector.fetch_tuples(
jwt_payload=jwt_payload,
schema_name=request.json["schemaName"],
table_name=request.json["tableName"],
**{k: (int(v) if k in ('limit', 'page')
else (v.split(',') if k == 'order' else loads(
b64decode(v.encode('utf-8')).decode('utf-8'))))
for k, v in request.args.items()},
)
return dict(tuples=table_tuples, total_count=total_count)
except Exception as e:
return str(e), 500

Expand Down
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
pytest-cov
flake8
Faker
112 changes: 112 additions & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from os import getenv
import pytest
from dj_gui_api_server.DJGUIAPIServer import app
import datajoint as dj
from json import dumps
from base64 import b64encode
from urllib.parse import urlencode
from datetime import date, datetime
from random import randint, choice, seed, getrandbits
from faker import Faker
seed('lock') # Pin down randomizer between runs
faker = Faker()
Faker.seed(0) # Pin down randomizer between runs


@pytest.fixture
def client():
with app.test_client() as client:
yield client


@pytest.fixture
def token(client):
yield client.post('/api/login', json=dict(databaseAddress=getenv('TEST_DB_SERVER'),
username=getenv('TEST_DB_USER'),
password=getenv('TEST_DB_PASS'))).json['jwt']


@pytest.fixture
def virtual_module():
dj.config['safemode'] = False
connection = dj.conn(host=getenv('TEST_DB_SERVER'),
user=getenv('TEST_DB_USER'),
password=getenv('TEST_DB_PASS'), reset=True)
schema = dj.Schema('filter')

@schema
class Student(dj.Lookup):
definition = """
student_id: int
---
student_name: varchar(50)
student_ssn: varchar(20)
student_enroll_date: datetime
student_balance: float
student_parking_lot=null : varchar(20)
student_out_of_state: bool
"""
contents = [(i, faker.name(), faker.ssn(), faker.date_between_dates(
date_start=date(2021, 1, 1), date_end=date(2021, 1, 31)),
round(randint(1000, 3000), 2),
choice([None, 'LotA', 'LotB', 'LotC']),
bool(getrandbits(1))) for i in range(100)]

yield dj.VirtualModule('filter', 'filter')
schema.drop()
connection.close()
dj.config['safemode'] = True


def test_filters(token, client, virtual_module):
# 'between' dates
restriction = [dict(attributeName='student_enroll_date', operation='>',
value='2021-01-07'),
dict(attributeName='student_enroll_date', operation='<',
value='2021-01-17')]
encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8')
q = dict(limit=10, page=1, order='student_enroll_date DESC',
restriction=encoded_restriction)
REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}'),
json=dict(schemaName='filter',
tableName='Student')).json['tuples']
assert len(REST_records) == 10
assert REST_records[0][3] == datetime(2021, 1, 16).timestamp()
# 'equal' null
restriction = [dict(attributeName='student_parking_lot', operation='=', value=None)]
encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8')
q = dict(limit=10, page=2, order='student_id ASC',
restriction=encoded_restriction)
REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}'),
json=dict(schemaName='filter',
tableName='Student')).json['tuples']
assert len(REST_records) == 10
assert all([r[5] is None for r in REST_records])
assert REST_records[0][0] == 34
# not equal int
restriction = [dict(attributeName='student_id', operation='!=', value='2')]
encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8')
q = dict(limit=10, page=1, order='student_id ASC',
restriction=encoded_restriction)
REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}'),
json=dict(schemaName='filter',
tableName='Student')).json['tuples']
assert len(REST_records) == 10
assert all([r[0] != 2 for r in REST_records])
assert REST_records[-1][0] == 10
# equal 'Norma Fisher' and in_state student (bool)
restriction = [dict(attributeName='student_name', operation='=', value='Norma Fisher'),
dict(attributeName='student_out_of_state', operation='=', value='0')]
encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8')
q = dict(limit=10, page=1, order='student_id ASC',
restriction=encoded_restriction)
REST_records = client.post(f'/api/fetch_tuples?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}'),
json=dict(schemaName='filter',
tableName='Student')).json['tuples']
assert len(REST_records) == 1
assert REST_records[0][1] == 'Norma Fisher'
assert REST_records[0][6] == 0