Skip to content

Commit 20a63f2

Browse files
Merge pull request #115 from jverswijver/update_dynamic_api
Update dynamic api to support plots
2 parents c0ce77c + 15a8be5 commit 20a63f2

File tree

7 files changed

+113
-18
lines changed

7 files changed

+113
-18
lines changed

CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# Changelog
22

33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
4+
## [0.2.3] - 2021-11-18
5+
### Added
6+
- Support for plot component PR #155
7+
- Fetch argument specification in `dj_query` PR #155
48

59
## [0.2.2] - 2021-11-10
610
### Fixed
7-
- Optimize dynamic api virtual modules.
11+
- Optimize dynamic api virtual modules. PR #113
812

913
## [0.2.1] - 2021-11-08
1014
### Fixed
11-
- Error with retrieving the module's installation root path.
15+
- Error with retrieving the module's installation root path. PR #112
1216

1317
## [0.2.0] - 2021-11-02
1418
### Added
@@ -84,6 +88,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
8488
- Support for DataJoint attribute types: `varchar`, `int`, `float`, `datetime`, `date`, `time`, `decimal`, `uuid`.
8589
- Check dependency utility to determine child table references.
8690

91+
[0.2.3]: https://github.com/datajoint/pharus/compare/0.2.2...0.2.3
8792
[0.2.2]: https://github.com/datajoint/pharus/compare/0.2.1...0.2.2
8893
[0.2.1]: https://github.com/datajoint/pharus/compare/0.2.0...0.2.1
8994
[0.2.0]: https://github.com/datajoint/pharus/compare/0.1.0...0.2.0

pharus/dynamic_api_gen.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ def {method_name}(jwt_payload: dict) -> dict:
2727
djconn = _DJConnector._set_datajoint_config(jwt_payload)
2828
vm_list = [dj.VirtualModule(s, s, connection=djconn)
2929
for s in inspect.getfullargspec(dj_query).args]
30-
query, fetch_args = dj_query(*vm_list)
31-
query = query & restriction()
30+
djdict = dj_query(*vm_list)
31+
djdict['query'] = djdict['query'] & restriction()
3232
record_header, table_tuples, total_count = _DJConnector._fetch_records(
33-
query=query,
33+
query=djdict['query'], fetch_args=djdict['fetch_args'],
3434
**{{k: (int(v) if k in ('limit', 'page')
3535
else (v.split(',') if k == 'order'
3636
else loads(b64decode(v.encode('utf-8')).decode('utf-8'))))
@@ -52,15 +52,37 @@ def {method_name}_attributes(jwt_payload: dict) -> dict:
5252
djconn = _DJConnector._set_datajoint_config(jwt_payload)
5353
vm_list = [dj.VirtualModule(s, s, connection=djconn)
5454
for s in inspect.getfullargspec(dj_query).args]
55-
query, fetch_args = dj_query(*vm_list)
56-
attributes_meta = _DJConnector._get_attributes(query)
55+
djdict = dj_query(*vm_list)
56+
attributes_meta = _DJConnector._get_attributes(djdict['query'])
5757
5858
return dict(attributeHeaders=attributes_meta['attribute_headers'],
5959
attributes=attributes_meta['attributes'])
6060
except Exception as e:
6161
return str(e), 500
6262
"""
6363

64+
plot_route_template = '''
65+
66+
@app.route('{route}', methods=['GET'])
67+
@protected_route
68+
def {method_name}(jwt_payload: dict) -> dict:
69+
70+
{query}
71+
{restriction}
72+
if request.method in {{'GET'}}:
73+
try:
74+
djconn = _DJConnector._set_datajoint_config(jwt_payload)
75+
vm_list = [dj.VirtualModule(s, s, connection=djconn)
76+
for s in inspect.getfullargspec(dj_query).args]
77+
djdict = dj_query(*vm_list)
78+
djdict['query'] = djdict['query'] & restriction()
79+
record_header, table_tuples, total_count = _DJConnector._fetch_records(
80+
fetch_args=djdict['fetch_args'], query=djdict['query'], fetch_blobs=True)
81+
return dict(table_tuples[0][0])
82+
except Exception as e:
83+
return str(e), 500
84+
'''
85+
6486
pharus_root = f"{pkg_resources.get_distribution('pharus').module_path}/pharus"
6587
api_path = f'{pharus_root}/dynamic_api.py'
6688
spec_path = os.environ.get('API_SPEC_PATH')
@@ -79,3 +101,8 @@ def {method_name}_attributes(jwt_payload: dict) -> dict:
79101
method_name=comp['route'].replace('/', ''),
80102
query=indent(comp['dj_query'], ' '),
81103
restriction=indent(comp['restriction'], ' ')))
104+
if comp['type'] == 'plot:plotly:stored_json':
105+
f.write(plot_route_template.format(route=comp['route'],
106+
method_name=comp['route'].replace('/', ''),
107+
query=indent(comp['dj_query'], ' '),
108+
restriction=indent(comp['restriction'], ' ')))

pharus/interface.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _list_tables(jwt_payload: dict, schema_name: str) -> dict:
100100
@staticmethod
101101
def _fetch_records(query,
102102
restriction: list = [], limit: int = 1000, page: int = 1,
103-
order=['KEY ASC']) -> tuple:
103+
order=['KEY ASC'], fetch_blobs=False, fetch_args=[]) -> tuple:
104104
"""
105105
Get records from query.
106106
@@ -128,7 +128,15 @@ def _fetch_records(query,
128128
query_restricted = query & dj.AndList([
129129
_DJConnector._filter_to_restriction(f, attributes[f['attributeName']].type)
130130
for f in restriction])
131-
non_blobs_rows = query_restricted.fetch(*query.heading.non_blobs, as_dict=True,
131+
132+
if fetch_blobs and not fetch_args:
133+
fetch_args = [*query.heading.attributes]
134+
elif not fetch_args:
135+
fetch_args = query.heading.non_blobs
136+
else:
137+
attributes = {k: v for k, v in attributes.items() if k in fetch_args}
138+
139+
non_blobs_rows = query_restricted.fetch(*fetch_args, as_dict=True,
132140
limit=limit, offset=(page-1)*limit,
133141
order_by=order)
134142

@@ -169,8 +177,8 @@ def _fetch_records(query,
169177
row.append(non_blobs_row[attribute_name])
170178
else:
171179
# Attribute is blob type thus fill it in string instead
172-
row.append('=BLOB=')
173-
180+
(row.append(non_blobs_row[attribute_name])
181+
if fetch_blobs else row.append('=BLOB='))
174182
# Add the row list to tuples
175183
rows.append(row)
176184
return list(attributes.keys()), rows, len(query_restricted)

pharus/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Package metadata."""
2-
__version__ = '0.2.2'
2+
__version__ = '0.2.3'

tests/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,23 @@ class TableC(dj.Lookup):
107107
"""
108108
contents = [(0, 10, 100, -8), (0, 11, 200, -9,), (0, 11, 300, -7,)]
109109

110+
@group1_simple
111+
class PlotlyTable(dj.Lookup):
112+
definition = """
113+
p_id: int
114+
---
115+
plot: longblob
116+
"""
117+
contents = [(2, dict(data=[dict(x=[1, 2, 3],
118+
y=[2, 6, 3],
119+
type='scatter',
120+
mode='lines+markers',
121+
marker=dict(color='red')),
122+
dict(type='bar',
123+
x=[1, 2, 3],
124+
y=[2, 5, 3])],
125+
layout=dict(title='A Fancy Plot')))]
126+
110127
yield group1_simple, group2_simple
111128

112129
group2_simple.drop()

tests/init/test_dynamic_api_spec.yaml

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ SciViz: # top level tab
2525
dj_query: >
2626
def dj_query(test_group1_simple):
2727
TableA, TableB = (test_group1_simple.TableA, test_group1_simple.TableB)
28-
return TableA * TableB, dict(order_by='b_number')
28+
q = TableA * TableB
29+
f = []
30+
return dict(query=q, fetch_args=f)
2931
grid2:
3032
components:
3133
component1:
@@ -39,7 +41,9 @@ SciViz: # top level tab
3941
dj_query: >
4042
def dj_query(test_group1_simple):
4143
TableA, TableB = (test_group1_simple.TableA, test_group1_simple.TableB)
42-
return TableA * TableB, dict(order_by='b_number')
44+
q = TableA * TableB
45+
f = []
46+
return dict(query=q, fetch_args=f)
4347
4448
page1:
4549
route: /session2 # dev, be careful of name collisions
@@ -65,7 +69,9 @@ SciViz: # top level tab
6569
dj_query: >
6670
def dj_query(test_group1_simple):
6771
TableA, TableB = (test_group1_simple.TableA, test_group1_simple.TableB)
68-
return TableA * TableB, dict(order_by='b_number')
72+
q = TableA * TableB
73+
f = []
74+
return dict(query=q, fetch_args=f)
6975
component2:
7076
route: /query4
7177
row_span: 0
@@ -77,7 +83,9 @@ SciViz: # top level tab
7783
dj_query: >
7884
def dj_query(test_group1_simple):
7985
TableA, TableB = (test_group1_simple.TableA, test_group1_simple.TableB)
80-
return TableA * TableB, dict(order_by='b_number')
86+
q = TableA * TableB
87+
f = []
88+
return dict(query=q, fetch_args=f)
8189
diff_checker: >
8290
def diff_checker(**args):
8391
return TrainingStatsPlotly.proj(hash='trial_mean_hash')
@@ -92,8 +100,24 @@ SciViz: # top level tab
92100
dj_query: >
93101
def dj_query(test_group1_simple):
94102
TableA, TableB = (test_group1_simple.TableA, test_group1_simple.TableB)
95-
return TableA * TableB, dict(order_by='b_number')
103+
q = TableA * TableB
104+
f = []
105+
return dict(query=q, fetch_args=f)
96106
diff_checker: >
97107
def diff_checker(**args):
98108
return TrainingStatsPlotly.proj(hash='trial_mean_hash')
99-
109+
plot_test:
110+
route: /plot1
111+
type: plot:plotly:stored_json
112+
x: 0
113+
y: 0
114+
height: 1
115+
width: 1
116+
restriction: >
117+
def restriction(**kwargs):
118+
return dict(**kwargs, p_id=2)
119+
dj_query: >
120+
def dj_query(test_group1_simple):
121+
PlotlyTable = test_group1_simple.PlotlyTable
122+
return dict(query=PlotlyTable(), fetch_args=['plot'])
123+

tests/test_api_gen.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@ def test_auto_generated_route(token, client, schemas_simple):
2222
assert expected_json == json.dumps(REST_response4.get_json(force=True), sort_keys=True)
2323

2424

25+
def test_get_full_plot(token, client, schemas_simple):
26+
REST_response1 = client.get('/plot1', headers=dict(Authorization=f'Bearer {token}'))
27+
expected_json = json.dumps(dict(data=[dict(x=[1, 2, 3],
28+
y=[2, 6, 3],
29+
type='scatter',
30+
mode='lines+markers',
31+
marker=dict(color='red')),
32+
dict(type='bar',
33+
x=[1, 2, 3],
34+
y=[2, 5, 3])],
35+
layout=dict(title='A Fancy Plot')), sort_keys=True)
36+
assert expected_json == json.dumps(REST_response1.get_json(force=True), sort_keys=True)
37+
38+
2539
def test_get_attributes(token, client, schemas_simple):
2640
REST_response = client.get('/query1/attributes',
2741
headers=dict(Authorization=f'Bearer {token}'))

0 commit comments

Comments
 (0)