Skip to content

Commit c6af1f0

Browse files
Merge pull request #109 from jverswijver/update_dynamic_api
Update dynamic api to support attribute route, and paging sorting filtering ect.
2 parents ef173c4 + 7875944 commit c6af1f0

File tree

8 files changed

+148
-80
lines changed

8 files changed

+148
-80
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ build
99
docs/_build
1010
*.tar.gz
1111
pharus/dynamic_api.py
12-
pharus/dynamic_api_spec.yaml
12+
/specs

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
44

5+
## [0.2.0] - 2021-11-02
6+
### Added
7+
- Dynamic api generation from spec sheet.(#103, #104, #105, #107, #108, #110) PR #106, #109
8+
- `dynamic_api_gen.py` Python script that generates `dynamic_api.py`.
9+
- Add Tests for the new dynamic api.
10+
- `server.py` now loads the routes generated dynamically from `dynamic_api.py` when it is present.
11+
512
## [0.1.0] - 2021-03-31
613
### Added
714
- Local database instance pre-populated with sample data for `dev` Docker Compose environment. PR #99
@@ -69,6 +76,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
6976
- Support for DataJoint attribute types: `varchar`, `int`, `float`, `datetime`, `date`, `time`, `decimal`, `uuid`.
7077
- Check dependency utility to determine child table references.
7178

79+
[0.2.0]: https://github.com/datajoint/pharus/compare/0.1.0...0.2.0
7280
[0.1.0]: https://github.com/datajoint/pharus/compare/0.1.0b2...0.1.0
7381
[0.1.0b2]: https://github.com/datajoint/pharus/compare/0.1.0b0...0.1.0b2
7482
[0.1.0b0]: https://github.com/datajoint/pharus/compare/0.1.0a5...0.1.0b0

pharus/dynamic_api_gen.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,12 @@
55

66

77
def populate_api():
8-
header_template = """
9-
# Auto-generated rest api
8+
header_template = """# Auto-generated rest api
109
from .server import app, protected_route
1110
from .interface import _DJConnector, dj
12-
import json
13-
import numpy as np
14-
15-
16-
class NumpyEncoder(json.JSONEncoder):
17-
def default(self, obj):
18-
if isinstance(obj, np.ndarray):
19-
return obj.tolist()
20-
return json.JSONEncoder.default(self, obj)
11+
from flask import request
12+
from json import loads
13+
from base64 import b64decode
2114
"""
2215
route_template = """
2316
@@ -26,10 +19,44 @@ def default(self, obj):
2619
def {method_name}(jwt_payload: dict) -> dict:
2720
2821
{query}
29-
djconn = _DJConnector._set_datajoint_config(jwt_payload)
30-
vm_dict = {{s: dj.VirtualModule(s, s, connection=djconn) for s in dj.list_schemas()}}
31-
query, fetch_args = dj_query(vm_dict)
32-
return json.dumps(query.fetch(**fetch_args), cls=NumpyEncoder)
22+
{restriction}
23+
if request.method in {{'GET'}}:
24+
try:
25+
djconn = _DJConnector._set_datajoint_config(jwt_payload)
26+
vm_dict = {{s: dj.VirtualModule(s, s, connection=djconn)
27+
for s in dj.list_schemas()}}
28+
query, fetch_args = dj_query(vm_dict)
29+
query = query & restriction()
30+
record_header, table_tuples, total_count = _DJConnector._fetch_records(
31+
query=query,
32+
**{{k: (int(v) if k in ('limit', 'page')
33+
else (v.split(',') if k == 'order'
34+
else loads(b64decode(v.encode('utf-8')).decode('utf-8'))))
35+
for k, v in request.args.items()}},
36+
)
37+
return dict(recordHeader=record_header, records=table_tuples,
38+
totalCount=total_count)
39+
except Exception as e:
40+
return str(e), 500
41+
42+
43+
@app.route('{route}/attributes', methods=['GET'])
44+
@protected_route
45+
def {method_name}_attributes(jwt_payload: dict) -> dict:
46+
47+
{query}
48+
if request.method in {{'GET'}}:
49+
try:
50+
djconn = _DJConnector._set_datajoint_config(jwt_payload)
51+
vm_dict = {{s: dj.VirtualModule(s, s, connection=djconn)
52+
for s in dj.list_schemas()}}
53+
query, fetch_args = dj_query(vm_dict)
54+
attributes_meta = _DJConnector._get_attributes(query)
55+
56+
return dict(attributeHeaders=attributes_meta['attribute_headers'],
57+
attributes=attributes_meta['attributes'])
58+
except Exception as e:
59+
return str(e), 500
3360
"""
3461

3562
spec_path = os.environ.get('API_SPEC_PATH')
@@ -45,4 +72,5 @@ def {method_name}(jwt_payload: dict) -> dict:
4572
for comp in grid['components'].values():
4673
f.write(route_template.format(route=comp['route'],
4774
method_name=comp['route'].replace('/', ''),
48-
query=indent(comp['dj_query'], ' ')))
75+
query=indent(comp['dj_query'], ' '),
76+
restriction=indent(comp['restriction'], ' ')))

pharus/interface.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,14 @@ def _list_tables(jwt_payload: dict, schema_name: str) -> dict:
9898
return tables_dict_list
9999

100100
@staticmethod
101-
def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str,
101+
def _fetch_records(query,
102102
restriction: list = [], limit: int = 1000, page: int = 1,
103103
order=['KEY ASC']) -> tuple:
104104
"""
105-
Get records from table.
105+
Get records from query.
106106
107-
:param jwt_payload: Dictionary containing databaseAddress, username, and password
108-
strings
109-
:type jwt_payload: dict
110-
:param schema_name: Name of schema
111-
:type schema_name: str
112-
:param table_name: Table name under the given schema; must be in camel case
113-
:type table_name: str
107+
:param query: any datajoint object related to QueryExpression
108+
:type query: datajoint ``QueryExpression`` or related object
114109
:param restriction: Sequence of filters as ``dict`` with ``attributeName``,
115110
``operation``, ``value`` keys defined, defaults to ``[]``
116111
:type restriction: list, optional
@@ -125,20 +120,17 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str,
125120
can be paged
126121
:rtype: tuple
127122
"""
128-
_DJConnector._set_datajoint_config(jwt_payload)
129-
130-
schema_virtual_module = dj.create_virtual_module(schema_name, schema_name)
131123

132124
# Get table object from name
133-
table = _DJConnector._get_table_object(schema_virtual_module, table_name)
134-
attributes = table.heading.attributes
125+
attributes = query.heading.attributes
135126
# Fetch tuples without blobs as dict to be used to create a
136127
# list of tuples for returning
137-
query = table & dj.AndList([
128+
query_restricted = query & dj.AndList([
138129
_DJConnector._filter_to_restriction(f, attributes[f['attributeName']].type)
139130
for f in restriction])
140-
non_blobs_rows = query.fetch(*table.heading.non_blobs, as_dict=True, limit=limit,
141-
offset=(page-1)*limit, order_by=order)
131+
non_blobs_rows = query_restricted.fetch(*query.heading.non_blobs, as_dict=True,
132+
limit=limit, offset=(page-1)*limit,
133+
order_by=order)
142134

143135
# Buffer list to be return
144136
rows = []
@@ -181,45 +173,34 @@ def _fetch_records(jwt_payload: dict, schema_name: str, table_name: str,
181173

182174
# Add the row list to tuples
183175
rows.append(row)
184-
return list(attributes.keys()), rows, len(query)
176+
return list(attributes.keys()), rows, len(query_restricted)
185177

186178
@staticmethod
187-
def _get_table_attributes(jwt_payload: dict, schema_name: str, table_name: str) -> dict:
179+
def _get_attributes(query) -> dict:
188180
"""
189-
Method to get primary and secondary attributes of a table.
181+
Method to get primary and secondary attributes of a query.
190182
191-
:param jwt_payload: Dictionary containing databaseAddress, username, and password
192-
strings
193-
:type jwt_payload: dict
194-
:param schema_name: Name of schema to list all tables from
195-
:type schema_name: str
196-
:param table_name: Table name under the given schema; must be in camel case
197-
:type table_name: str
183+
:param query: any datajoint object related to QueryExpression
184+
:type query: datajoint ``QueryExpression`` or related object
198185
:return: Dict with keys ``attribute_headers`` and ``attributes`` containing
199186
``primary``, ``secondary`` which each contain a
200187
``list`` of ``tuples`` specifying: ``attribute_name``, ``type``, ``nullable``,
201188
``default``, ``autoincrement``.
202189
:rtype: dict
203190
"""
204-
_DJConnector._set_datajoint_config(jwt_payload)
205-
local_values = locals()
206-
local_values[schema_name] = dj.VirtualModule(schema_name, schema_name)
207-
208-
# Get table object from name
209-
table = _DJConnector._get_table_object(local_values[schema_name], table_name)
210191

211-
table_attributes = dict(primary=[], secondary=[])
212-
for attribute_name, attribute_info in table.heading.attributes.items():
192+
query_attributes = dict(primary=[], secondary=[])
193+
for attribute_name, attribute_info in query.heading.attributes.items():
213194
if attribute_info.in_key:
214-
table_attributes['primary'].append((
195+
query_attributes['primary'].append((
215196
attribute_name,
216197
attribute_info.type,
217198
attribute_info.nullable,
218199
attribute_info.default,
219200
attribute_info.autoincrement
220201
))
221202
else:
222-
table_attributes['secondary'].append((
203+
query_attributes['secondary'].append((
223204
attribute_name,
224205
attribute_info.type,
225206
attribute_info.nullable,
@@ -229,7 +210,7 @@ def _get_table_attributes(jwt_payload: dict, schema_name: str, table_name: str)
229210

230211
return dict(attribute_headers=['name', 'type', 'nullable',
231212
'default', 'autoincrement'],
232-
attributes=table_attributes)
213+
attributes=query_attributes)
233214

234215
@staticmethod
235216
def _get_table_definition(jwt_payload: dict, schema_name: str, table_name: str) -> str:
@@ -270,7 +251,7 @@ def _insert_tuple(jwt_payload: dict, schema_name: str, table_name: str,
270251
"""
271252
_DJConnector._set_datajoint_config(jwt_payload)
272253

273-
schema_virtual_module = dj.create_virtual_module(schema_name, schema_name)
254+
schema_virtual_module = dj.VirtualModule(schema_name, schema_name)
274255
getattr(schema_virtual_module, table_name).insert(tuple_to_insert)
275256

276257
@staticmethod
@@ -326,7 +307,7 @@ def _update_tuple(jwt_payload: dict, schema_name: str, table_name: str,
326307
"""
327308
conn = _DJConnector._set_datajoint_config(jwt_payload)
328309

329-
schema_virtual_module = dj.create_virtual_module(schema_name, schema_name)
310+
schema_virtual_module = dj.VirtualModule(schema_name, schema_name)
330311
with conn.transaction:
331312
[getattr(schema_virtual_module, table_name).update1(t) for t in tuple_to_update]
332313

@@ -351,7 +332,7 @@ def _delete_records(jwt_payload: dict, schema_name: str, table_name: str,
351332
"""
352333
_DJConnector._set_datajoint_config(jwt_payload)
353334

354-
schema_virtual_module = dj.create_virtual_module(schema_name, schema_name)
335+
schema_virtual_module = dj.VirtualModule(schema_name, schema_name)
355336

356337
# Get table object from name
357338
table = _DJConnector._get_table_object(schema_virtual_module, table_name)

pharus/server.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Exposed REST API."""
22
from os import environ
3-
from .interface import _DJConnector
3+
from .interface import _DJConnector, dj
44
from . import __version__ as version
55
from typing import Callable
66
from functools import wraps
@@ -605,10 +605,15 @@ def record(jwt_payload: dict, schema_name: str, table_name: str) -> Union[dict,
605605
""")
606606
if request.method in {'GET', 'HEAD'}:
607607
try:
608+
_DJConnector._set_datajoint_config(jwt_payload)
609+
610+
schema_virtual_module = dj.VirtualModule(schema_name, schema_name)
611+
612+
# Get table object from name
613+
dj_table = _DJConnector._get_table_object(schema_virtual_module, table_name)
614+
608615
record_header, table_tuples, total_count = _DJConnector._fetch_records(
609-
jwt_payload=jwt_payload,
610-
schema_name=schema_name,
611-
table_name=table_name,
616+
query=dj_table,
612617
**{k: (int(v) if k in ('limit', 'page')
613618
else (v.split(',') if k == 'order' else loads(
614619
b64decode(v.encode('utf-8')).decode('utf-8'))))
@@ -895,8 +900,14 @@ def attribute(jwt_payload: dict, schema_name: str, table_name: str) -> dict:
895900
"""
896901
if request.method in {'GET', 'HEAD'}:
897902
try:
898-
attributes_meta = _DJConnector._get_table_attributes(jwt_payload, schema_name,
899-
table_name)
903+
_DJConnector._set_datajoint_config(jwt_payload)
904+
local_values = locals()
905+
local_values[schema_name] = dj.VirtualModule(schema_name, schema_name)
906+
907+
# Get table object from name
908+
dj_table = _DJConnector._get_table_object(local_values[schema_name], table_name)
909+
910+
attributes_meta = _DJConnector._get_attributes(dj_table)
900911
return dict(attributeHeaders=attributes_meta['attribute_headers'],
901912
attributes=attributes_meta['attributes'])
902913
except Exception as e:

pharus/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Package metadata."""
2-
__version__ = '0.1.0'
2+
__version__ = '0.2.0'

tests/init/test_dynamic_api_spec.yaml

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,30 @@ SciViz: # top level tab
1212
- session_number
1313
hidden: true
1414
grids:
15-
grid2:
15+
grid1:
1616
components:
1717
component1:
1818
route: /query1
1919
row_span: 0
2020
column_span: 0
2121
type: plot:png
2222
restriction: >
23-
def restriction(**args):
24-
return dict(**args)
23+
def restriction(**kwargs):
24+
return dict(**kwargs)
2525
dj_query: >
2626
def dj_query(vms):
2727
TableA, TableB = (vms['test_group1_simple'].TableA, vms['test_group1_simple'].TableB)
2828
return TableA * TableB, dict(order_by='b_number')
29-
grid1:
29+
grid2:
3030
components:
3131
component1:
3232
route: /query2
3333
row_span: 0
3434
column_span: 0
3535
type: plot:png
3636
restriction: >
37-
def restriction(**args):
38-
return dict(**args)
37+
def restriction(**kwargs):
38+
return dict(**kwargs)
3939
dj_query: >
4040
def dj_query(vms):
4141
TableA, TableB = (vms['test_group1_simple'].TableA, vms['test_group1_simple'].TableB)
@@ -60,8 +60,8 @@ SciViz: # top level tab
6060
column_span: 0
6161
type: plot:png
6262
restriction: >
63-
def restriction(**args):
64-
return dict(**args)
63+
def restriction(**kwargs):
64+
return dict(**kwargs)
6565
dj_query: >
6666
def dj_query(vms):
6767
TableA, TableB = (vms['test_group1_simple'].TableA, vms['test_group1_simple'].TableB)
@@ -72,12 +72,28 @@ SciViz: # top level tab
7272
column_span: 1
7373
type: plot:plotly1
7474
restriction: >
75-
def restriction(**args):
76-
return dict(**args)
75+
def restriction(**kwargs):
76+
return dict(**kwargs)
77+
dj_query: >
78+
def dj_query(vms):
79+
TableA, TableB = (vms['test_group1_simple'].TableA, vms['test_group1_simple'].TableB)
80+
return TableA * TableB, dict(order_by='b_number')
81+
diff_checker: >
82+
def diff_checker(**args):
83+
return TrainingStatsPlotly.proj(hash='trial_mean_hash')
84+
component3:
85+
route: /query5
86+
row_span: 0
87+
column_span: 1
88+
type: plot:plotly1
89+
restriction: >
90+
def restriction(**kwargs):
91+
return dict(a_id=0, **kwargs)
7792
dj_query: >
7893
def dj_query(vms):
7994
TableA, TableB = (vms['test_group1_simple'].TableA, vms['test_group1_simple'].TableB)
8095
return TableA * TableB, dict(order_by='b_number')
8196
diff_checker: >
8297
def diff_checker(**args):
83-
return TrainingStatsPlotly.proj(hash='trial_mean_hash')
98+
return TrainingStatsPlotly.proj(hash='trial_mean_hash')
99+

0 commit comments

Comments
 (0)