Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
### Removed
- `InvalidDeleteRequest` exception is no longer available as it is now allowed to delete more than 1 record at a time. PR #99

### Fixed
- `uuid` types not properly restricted on `GET /record`, `DELETE /record`, and `GET /dependency`. PR #102

## [0.1.0b2] - 2021-03-12

### Fixed
Expand Down
28 changes: 18 additions & 10 deletions pharus/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str,

# Get table object from name
table = _DJConnector._get_table_object(schema_virtual_module, table_name)

attributes = table.heading.attributes
# Fetch tuples without blobs as dict to be used to create a
# list of tuples for returning
query = table & dj.AndList([_DJConnector._filter_to_restriction(f)
for f in restriction])
query = table & dj.AndList([
_DJConnector._filter_to_restriction(f, attributes[f['attributeName']].type)
for f in restriction])
non_blobs_rows = query.fetch(*table.heading.non_blobs, as_dict=True, limit=limit,
offset=(page-1)*limit, order_by=order)

Expand All @@ -149,7 +150,7 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str,
row = []
# Loop through each attributes, append to the tuple_to_return with specific
# modification based on data type
for attribute_name, attribute_info in table.heading.attributes.items():
for attribute_name, attribute_info in attributes.items():
if not attribute_info.is_blob:
if non_blobs_row[attribute_name] is None:
# If it is none then just append None
Expand Down Expand Up @@ -180,7 +181,7 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str,

# Add the row list to tuples
rows.append(row)
return list(table.heading.attributes.keys()), rows, len(query)
return list(attributes.keys()), rows, len(query)

@staticmethod
def _get_table_attributes(jwt_payload: dict, schema_name: str, table_name: str) -> dict:
Expand Down Expand Up @@ -295,12 +296,14 @@ def _record_dependency(jwt_payload: dict, schema_name: str, table_name: str,
_DJConnector._set_datajoint_config(jwt_payload)
virtual_module = dj.VirtualModule(schema_name, schema_name)
table = getattr(virtual_module, table_name)
attributes = table.heading.attributes
# Retrieve dependencies of related to retricted
dependencies = [dict(schema=descendant.database, table=descendant.table_name,
accessible=True, count=len(
(table if descendant.full_table_name == table.full_table_name
else descendant * table) & dj.AndList([
_DJConnector._filter_to_restriction(f)
_DJConnector._filter_to_restriction(
f, attributes[f['attributeName']].type)
for f in restriction])))
for descendant in table().descendants(as_objects=True)]
return dependencies
Expand Down Expand Up @@ -352,8 +355,10 @@ def _delete_records(jwt_payload: dict, schema_name: str, table_name: str,

# Get table object from name
table = _DJConnector._get_table_object(schema_virtual_module, table_name)

restrictions = [_DJConnector._filter_to_restriction(f) for f in restriction]
attributes = table.heading.attributes
restrictions = [
_DJConnector._filter_to_restriction(f, attributes[f['attributeName']].type)
for f in restriction]

# Compute restriction
query = table & dj.AndList(restrictions)
Expand Down Expand Up @@ -385,13 +390,15 @@ def _get_table_object(schema_virtual_module: VirtualModule, table_name: str) ->
return getattr(schema_virtual_module, table_name_parts[0])

@staticmethod
def _filter_to_restriction(attribute_filter: dict) -> str:
def _filter_to_restriction(attribute_filter: dict, attribute_type: str) -> str:
"""
Convert attribute filter to a restriction.

:param attribute_filter: A filter as ``dict`` with ``attributeName``, ``operation``,
``value`` keys defined, defaults to ``[]``
:type attribute_filter: dict
:param attribute_type: Attribute type
:type attribute_type: str
:return: DataJoint-compatible restriction
:rtype: str
"""
Expand All @@ -405,7 +412,8 @@ def _filter_to_restriction(attribute_filter: dict) -> str:

if (isinstance(attribute_filter['value'], str) and
not attribute_filter['value'].isnumeric()):
value = f"'{attribute_filter['value']}'"
value = (f"X'{attribute_filter['value'].replace('-', '')}'"
if attribute_type == 'uuid' else f"'{attribute_filter['value']}'")
else:
value = ('NULL' if attribute_filter['value'] is None
else attribute_filter['value'])
Expand Down
17 changes: 17 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from pharus.server import app
from uuid import UUID
from os import getenv
import datajoint as dj
from datetime import date
Expand Down Expand Up @@ -136,6 +137,22 @@ class Student(dj.Lookup):
Student.drop()


@pytest.fixture
def Computer(schema_main):
"""Computer table for testing."""
@schema_main
class Computer(dj.Lookup):
definition = """
computer_id: uuid
---
computer_brand: enum('HP', 'DELL')
"""
contents = [(UUID('ffffffff-86d5-4af7-a013-89bde75528bd'), 'HP'),
(UUID('aaaaaaaa-86d5-4af7-a013-89bde75528bd'), 'DELL')]
yield Computer
Computer.drop()


@pytest.fixture
def Int(schema_main):
"""Integer basic table for testing."""
Expand Down
18 changes: 17 additions & 1 deletion tests/test_delete.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from . import SCHEMA_PREFIX, token, client, connection, schemas_simple
from . import SCHEMA_PREFIX, token, client, connection, schemas_simple, schema_main, Computer
import datajoint as dj
from json import dumps
from base64 import b64encode
from urllib.parse import urlencode
from uuid import UUID


def test_delete_dependent_with_cascade(token, client, connection, schemas_simple):
Expand Down Expand Up @@ -72,3 +73,18 @@ def test_delete_invalid(token, client, connection, schemas_simple):
assert REST_response.status_code == 500
assert b'Nothing to delete' in REST_response.data
assert len(getattr(vm, table_name)()) == 3


def test_delete_uuid_primary(token, client, Computer):
"""Verify can delete if restricting by UUID."""
uuid_val = 'aaaaaaaa-86d5-4af7-a013-89bde75528bd'
restriction = [dict(attributeName='computer_id', operation='=',
value=uuid_val)]
encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8')
q = dict(limit=10, page=1, order='computer_id DESC',
restriction=encoded_restriction)
REST_response = client.delete(
f'/schema/{Computer.database}/table/{"Computer"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}'))
assert REST_response.status_code == 200
assert len(Computer() & dict(computer_id=UUID(uuid_val))) == 0
32 changes: 23 additions & 9 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from base64 import b64encode
from urllib.parse import urlencode
from datetime import date, datetime
from . import token, client, connection, schema_main, Student
from . import token, client, connection, schema_main, Student, Computer


def test_filters(token, client, Student):
Expand All @@ -15,8 +15,8 @@ def test_filters(token, client, Student):
q = dict(limit=10, page=1, order='student_enroll_date DESC',
restriction=encoded_restriction)
REST_records = client.get(
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
assert len(REST_records) == 10
assert REST_records[0][3] == datetime(2021, 1, 16).timestamp()
# 'equal' null
Expand All @@ -25,8 +25,8 @@ def test_filters(token, client, Student):
q = dict(limit=10, page=2, order='student_id ASC',
restriction=encoded_restriction)
REST_records = client.get(
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
assert len(REST_records) == 10
assert all([r[5] is None for r in REST_records])
assert REST_records[0][0] == 34
Expand All @@ -36,8 +36,8 @@ def test_filters(token, client, Student):
q = dict(limit=10, page=1, order='student_id ASC',
restriction=encoded_restriction)
REST_records = client.get(
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
assert len(REST_records) == 10
assert all([r[0] != 2 for r in REST_records])
assert REST_records[-1][0] == 10
Expand All @@ -48,8 +48,22 @@ def test_filters(token, client, Student):
q = dict(limit=10, page=1, order='student_id ASC',
restriction=encoded_restriction)
REST_records = client.get(
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
f'/schema/{Student.database}/table/{"Student"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
assert len(REST_records) == 1
assert REST_records[0][1] == 'Norma Fisher'
assert REST_records[0][6] == 0


def test_uuid_filter(token, client, Computer):
"""Verify UUID can be properly restricted."""
restriction = [dict(attributeName='computer_id', operation='=',
value='aaaaaaaa-86d5-4af7-a013-89bde75528bd')]
encoded_restriction = b64encode(dumps(restriction).encode('utf-8')).decode('utf-8')
q = dict(limit=10, page=1, order='computer_id DESC',
restriction=encoded_restriction)
REST_records = client.get(
f'/schema/{Computer.database}/table/{"Computer"}/record?{urlencode(q)}',
headers=dict(Authorization=f'Bearer {token}')).json['records']
assert len(REST_records) == 1
assert REST_records[0][1] == 'DELL'