diff --git a/CHANGELOG.md b/CHANGELOG.md index dd42bd9..f824f68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and ### Removed - `InvalidDeleteRequest` exception is no longer available as it is now allowed to delete more than 1 record at a time. PR #99 +### Fixed +- `uuid` types not properly restricted on `GET /record`, `DELETE /record`, and `GET /dependency`. PR #102 + ## [0.1.0b2] - 2021-03-12 ### Fixed diff --git a/pharus/interface.py b/pharus/interface.py index ea9f263..979ebd5 100644 --- a/pharus/interface.py +++ b/pharus/interface.py @@ -131,11 +131,12 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str, # Get table object from name table = _DJConnector._get_table_object(schema_virtual_module, table_name) - + attributes = table.heading.attributes # Fetch tuples without blobs as dict to be used to create a # list of tuples for returning - query = table & dj.AndList([_DJConnector._filter_to_restriction(f) - for f in restriction]) + query = table & dj.AndList([ + _DJConnector._filter_to_restriction(f, attributes[f['attributeName']].type) + for f in restriction]) non_blobs_rows = query.fetch(*table.heading.non_blobs, as_dict=True, limit=limit, offset=(page-1)*limit, order_by=order) @@ -149,7 +150,7 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str, row = [] # Loop through each attributes, append to the tuple_to_return with specific # modification based on data type - for attribute_name, attribute_info in table.heading.attributes.items(): + for attribute_name, attribute_info in attributes.items(): if not attribute_info.is_blob: if non_blobs_row[attribute_name] is None: # If it is none then just append None @@ -180,7 +181,7 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str, # Add the row list to tuples rows.append(row) - return list(table.heading.attributes.keys()), rows, len(query) + return list(attributes.keys()), rows, len(query) @staticmethod def _get_table_attributes(jwt_payload: dict, schema_name: str, table_name: str) -> dict: @@ -295,12 +296,14 @@ def _record_dependency(jwt_payload: dict, schema_name: str, table_name: str, _DJConnector._set_datajoint_config(jwt_payload) virtual_module = dj.VirtualModule(schema_name, schema_name) table = getattr(virtual_module, table_name) + attributes = table.heading.attributes # Retrieve dependencies of related to retricted dependencies = [dict(schema=descendant.database, table=descendant.table_name, accessible=True, count=len( (table if descendant.full_table_name == table.full_table_name else descendant * table) & dj.AndList([ - _DJConnector._filter_to_restriction(f) + _DJConnector._filter_to_restriction( + f, attributes[f['attributeName']].type) for f in restriction]))) for descendant in table().descendants(as_objects=True)] return dependencies @@ -352,8 +355,10 @@ def _delete_records(jwt_payload: dict, schema_name: str, table_name: str, # Get table object from name table = _DJConnector._get_table_object(schema_virtual_module, table_name) - - restrictions = [_DJConnector._filter_to_restriction(f) for f in restriction] + attributes = table.heading.attributes + restrictions = [ + _DJConnector._filter_to_restriction(f, attributes[f['attributeName']].type) + for f in restriction] # Compute restriction query = table & dj.AndList(restrictions) @@ -385,13 +390,15 @@ def _get_table_object(schema_virtual_module: VirtualModule, table_name: str) -> return getattr(schema_virtual_module, table_name_parts[0]) @staticmethod - def _filter_to_restriction(attribute_filter: dict) -> str: + def _filter_to_restriction(attribute_filter: dict, attribute_type: str) -> str: """ Convert attribute filter to a restriction. :param attribute_filter: A filter as ``dict`` with ``attributeName``, ``operation``, ``value`` keys defined, defaults to ``[]`` :type attribute_filter: dict + :param attribute_type: Attribute type + :type attribute_type: str :return: DataJoint-compatible restriction :rtype: str """ @@ -405,7 +412,8 @@ def _filter_to_restriction(attribute_filter: dict) -> str: if (isinstance(attribute_filter['value'], str) and not attribute_filter['value'].isnumeric()): - value = f"'{attribute_filter['value']}'" + value = (f"X'{attribute_filter['value'].replace('-', '')}'" + if attribute_type == 'uuid' else f"'{attribute_filter['value']}'") else: value = ('NULL' if attribute_filter['value'] is None else attribute_filter['value']) diff --git a/tests/__init__.py b/tests/__init__.py index ccbe004..67b2afe 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,6 @@ import pytest from pharus.server import app +from uuid import UUID from os import getenv import datajoint as dj from datetime import date @@ -136,6 +137,22 @@ class Student(dj.Lookup): Student.drop() +@pytest.fixture +def Computer(schema_main): + """Computer table for testing.""" + @schema_main + class Computer(dj.Lookup): + definition = """ + computer_id: uuid + --- + computer_brand: enum('HP', 'DELL') + """ + contents = [(UUID('ffffffff-86d5-4af7-a013-89bde75528bd'), 'HP'), + (UUID('aaaaaaaa-86d5-4af7-a013-89bde75528bd'), 'DELL')] + yield Computer + Computer.drop() + + @pytest.fixture def Int(schema_main): """Integer basic table for testing.""" diff --git a/tests/test_delete.py b/tests/test_delete.py index 8ba76d2..f68c6d7 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -1,8 +1,9 @@ -from . import SCHEMA_PREFIX, token, client, connection, schemas_simple +from . import SCHEMA_PREFIX, token, client, connection, schemas_simple, schema_main, Computer import datajoint as dj from json import dumps from base64 import b64encode from urllib.parse import urlencode +from uuid import UUID def test_delete_dependent_with_cascade(token, client, connection, schemas_simple): @@ -72,3 +73,18 @@ def test_delete_invalid(token, client, connection, schemas_simple): assert REST_response.status_code == 500 assert b'Nothing to delete' in REST_response.data assert len(getattr(vm, table_name)()) == 3 + + +def test_delete_uuid_primary(token, client, Computer): + """Verify can delete if restricting by UUID.""" + uuid_val = 'aaaaaaaa-86d5-4af7-a013-89bde75528bd' + restriction = [dict(attributeName='computer_id', operation='=', + value=uuid_val)] + encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8') + q = dict(limit=10, page=1, order='computer_id DESC', + restriction=encoded_restriction) + REST_response = client.delete( + f'/schema/{Computer.database}/table/{"Computer"}/record?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}')) + assert REST_response.status_code == 200 + assert len(Computer() & dict(computer_id=UUID(uuid_val))) == 0 diff --git a/tests/test_filter.py b/tests/test_filter.py index 3279c08..7e31d33 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -2,7 +2,7 @@ from base64 import b64encode from urllib.parse import urlencode from datetime import date, datetime -from . import token, client, connection, schema_main, Student +from . import token, client, connection, schema_main, Student, Computer def test_filters(token, client, Student): @@ -15,8 +15,8 @@ def test_filters(token, client, Student): q = dict(limit=10, page=1, order='student_enroll_date DESC', restriction=encoded_restriction) REST_records = client.get( - f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', - headers=dict(Authorization=f'Bearer {token}')).json['records'] + f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}')).json['records'] assert len(REST_records) == 10 assert REST_records[0][3] == datetime(2021, 1, 16).timestamp() # 'equal' null @@ -25,8 +25,8 @@ def test_filters(token, client, Student): q = dict(limit=10, page=2, order='student_id ASC', restriction=encoded_restriction) REST_records = client.get( - f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', - headers=dict(Authorization=f'Bearer {token}')).json['records'] + f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}')).json['records'] assert len(REST_records) == 10 assert all([r[5] is None for r in REST_records]) assert REST_records[0][0] == 34 @@ -36,8 +36,8 @@ def test_filters(token, client, Student): q = dict(limit=10, page=1, order='student_id ASC', restriction=encoded_restriction) REST_records = client.get( - f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', - headers=dict(Authorization=f'Bearer {token}')).json['records'] + f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}')).json['records'] assert len(REST_records) == 10 assert all([r[0] != 2 for r in REST_records]) assert REST_records[-1][0] == 10 @@ -48,8 +48,22 @@ def test_filters(token, client, Student): q = dict(limit=10, page=1, order='student_id ASC', restriction=encoded_restriction) REST_records = client.get( - f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', - headers=dict(Authorization=f'Bearer {token}')).json['records'] + f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}')).json['records'] assert len(REST_records) == 1 assert REST_records[0][1] == 'Norma Fisher' assert REST_records[0][6] == 0 + + +def test_uuid_filter(token, client, Computer): + """Verify UUID can be properly restricted.""" + restriction = [dict(attributeName='computer_id', operation='=', + value='aaaaaaaa-86d5-4af7-a013-89bde75528bd')] + encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8') + q = dict(limit=10, page=1, order='computer_id DESC', + restriction=encoded_restriction) + REST_records = client.get( + f'/schema/{Computer.database}/table/{"Computer"}/record?{urlencode(q)}', + headers=dict(Authorization=f'Bearer {token}')).json['records'] + assert len(REST_records) == 1 + assert REST_records[0][1] == 'DELL'