Skip to content

Commit c17ee2c

Browse files
authored
Merge pull request #126 from A-Baji/multi-table-insert
Multi table insert
2 parents c3f5853 + 4745fe6 commit c17ee2c

File tree

6 files changed

+685
-57
lines changed

6 files changed

+685
-57
lines changed

docker-compose-dev.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ services:
3232
command: pharus
3333
fakeservices.datajoint.io:
3434
<<: *net
35-
image: datajoint/nginx:v0.0.18
35+
image: datajoint/nginx:v0.2.2
3636
environment:
3737
- ADD_pharus_TYPE=REST
3838
- ADD_pharus_ENDPOINT=pharus:5000

pharus/component_interface.py

Lines changed: 165 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import types
1313
import io
1414
import numpy as np
15+
from functools import reduce
1516

1617

1718
class NumpyEncoder(json.JSONEncoder):
@@ -44,9 +45,7 @@ def dumps(cls, obj):
4445
return json.dumps(obj, cls=cls)
4546

4647

47-
class QueryComponent:
48-
attributes_route_format = None
49-
48+
class FetchComponent:
5049
def __init__(self, name, component_config, static_config, jwt_payload: dict):
5150
lcls = locals()
5251
self.name = name
@@ -64,10 +63,6 @@ def __init__(self, name, component_config, static_config, jwt_payload: dict):
6463
self.route = component_config["route"]
6564
exec(component_config["dj_query"], globals(), lcls)
6665
self.dj_query = lcls["dj_query"]
67-
if self.attributes_route_format:
68-
self.attribute_route = self.attributes_route_format.format(
69-
route=component_config["route"]
70-
)
7166
if "restriction" in component_config:
7267
exec(component_config["restriction"], globals(), lcls)
7368
self.dj_restriction = lcls["restriction"]
@@ -114,8 +109,164 @@ def restriction(self):
114109
]
115110
)
116111

112+
def dj_query_route(self):
113+
fetch_metadata = self.fetch_metadata
114+
record_header, table_records, total_count = _DJConnector._fetch_records(
115+
query=fetch_metadata["query"] & self.restriction,
116+
fetch_args=fetch_metadata["fetch_args"],
117+
)
118+
return dict(
119+
recordHeader=record_header, records=table_records, totalCount=total_count
120+
)
121+
122+
123+
class InsertComponent:
124+
fields_route_format = "{route}/fields"
125+
126+
def __init__(
127+
self, name, component_config, static_config, payload, jwt_payload: dict
128+
):
129+
self.name = name
130+
self.payload = payload
131+
if static_config:
132+
self.static_variables = types.MappingProxyType(static_config)
133+
if not all(k in component_config for k in ("x", "y", "height", "width")):
134+
self.mode = "dynamic"
135+
else:
136+
self.mode = "fixed"
137+
self.x = component_config["x"]
138+
self.y = component_config["y"]
139+
self.height = component_config["height"]
140+
self.width = component_config["width"]
141+
self.type = component_config["type"]
142+
self.route = component_config["route"]
143+
self.connection = dj.conn(
144+
host=jwt_payload["databaseAddress"],
145+
user=jwt_payload["username"],
146+
password=jwt_payload["password"],
147+
reset=True,
148+
)
149+
self.fields_map = component_config.get("map")
150+
self.tables = [
151+
getattr(
152+
dj.VirtualModule(
153+
s,
154+
s,
155+
connection=self.connection,
156+
),
157+
t,
158+
)
159+
for s, t in (_.split(".") for _ in component_config["tables"])
160+
]
161+
self.parents = sorted(
162+
set(
163+
[
164+
p
165+
for t in self.tables
166+
for p in t.parents(as_objects=True)
167+
if p.full_table_name not in (t.full_table_name for t in self.tables)
168+
]
169+
),
170+
key=lambda p: p.full_table_name,
171+
)
172+
173+
def dj_query_route(self):
174+
with self.connection.transaction:
175+
destination_lookup = reduce(
176+
lambda m0, m1: dict(
177+
m0,
178+
**(
179+
{
180+
m_t["input"]
181+
if "input" in m_t
182+
else m_t["destination"]: m_t["destination"]
183+
for m_t in m1["map"]
184+
}
185+
if m1["type"] == "table"
186+
else {
187+
m1["input"]
188+
if "input" in m1
189+
else m1["destination"]: m1["destination"]
190+
}
191+
),
192+
),
193+
self.fields_map or [],
194+
{},
195+
)
196+
for t in self.tables:
197+
t.insert(
198+
[
199+
{
200+
a: v
201+
for k, v in r.items()
202+
if (a := destination_lookup.get(k, k))
203+
in t.heading.attributes
204+
}
205+
for r in self.payload["submissions"]
206+
]
207+
)
208+
return "Insert successful"
209+
210+
def fields_route(self):
211+
parent_attributes = sorted(set(sum([p.primary_key for p in self.parents], [])))
212+
source_fields = {
213+
**{
214+
(p_name := f"{p.database}.{dj.utils.to_camel_case(p.table_name)}"): {
215+
"values": p.fetch("KEY"),
216+
"type": "table",
217+
"name": p_name,
218+
}
219+
for p in self.parents
220+
},
221+
**{
222+
a: {"datatype": v.type, "type": "attribute", "name": v.name}
223+
for t in self.tables
224+
for a, v in t.heading.attributes.items()
225+
if a not in parent_attributes
226+
},
227+
}
228+
229+
if not self.fields_map:
230+
return dict(fields=list(source_fields.values()))
231+
return dict(
232+
fields=[
233+
dict(
234+
(field := source_fields.pop(m["destination"])),
235+
name=m["input" if "input" in m else "destination"],
236+
**(
237+
{
238+
"values": field["values"]
239+
if "map" not in m
240+
else [
241+
{
242+
input_lookup[k]: v
243+
for k, v in r.items()
244+
if k
245+
in (
246+
input_lookup := {
247+
table_m["destination"]: table_m[
248+
"input"
249+
if "input" in table_m
250+
else "destination"
251+
]
252+
for table_m in m["map"]
253+
}
254+
)
255+
}
256+
for r in field["values"]
257+
]
258+
}
259+
if m["type"] == "table"
260+
else {}
261+
),
262+
)
263+
for m in self.fields_map
264+
]
265+
+ list(source_fields.values())
266+
)
267+
117268

118-
class TableComponent(QueryComponent):
269+
class TableComponent(FetchComponent):
119270
attributes_route_format = "{route}/attributes"
120271

121272
def __init__(self, *args, **kwargs):
@@ -235,7 +386,7 @@ def dj_query_route(self):
235386
)
236387

237388

238-
class PlotPlotlyStoredjsonComponent(QueryComponent):
389+
class PlotPlotlyStoredjsonComponent(FetchComponent):
239390
def __init__(self, *args, **kwargs):
240391
super().__init__(*args, **kwargs)
241392
self.frontend_map = {
@@ -264,39 +415,7 @@ def dj_query_route(self):
264415
)
265416

266417

267-
class BasicQuery(QueryComponent):
268-
def __init__(self, *args, **kwargs):
269-
super().__init__(*args, **kwargs)
270-
self.frontend_map = {
271-
"source": "sci-viz/src/Components/Plots/FullPlotly.tsx",
272-
"target": "FullPlotly",
273-
}
274-
self.response_examples = {
275-
"dj_query_route": {
276-
"recordHeader": ["subject_uuid", "session_start_time", "session_uuid"],
277-
"records": [
278-
[
279-
"00778394-c956-408d-8a6c-ca3b05a611d5",
280-
1565436299.0,
281-
"fb9bdf18-76be-452b-ac4e-21d5de3a6f9f",
282-
]
283-
],
284-
"totalCount": 1,
285-
},
286-
}
287-
288-
def dj_query_route(self):
289-
fetch_metadata = self.fetch_metadata
290-
record_header, table_records, total_count = _DJConnector._fetch_records(
291-
query=fetch_metadata["query"] & self.restriction,
292-
fetch_args=fetch_metadata["fetch_args"],
293-
)
294-
return dict(
295-
recordHeader=record_header, records=table_records, totalCount=total_count
296-
)
297-
298-
299-
class FileImageAttachComponent(QueryComponent):
418+
class FileImageAttachComponent(FetchComponent):
300419
def __init__(self, *args, **kwargs):
301420
super().__init__(*args, **kwargs)
302421
self.frontend_map = {
@@ -319,11 +438,12 @@ def dj_query_route(self):
319438

320439

321440
type_map = {
322-
"basicquery": BasicQuery,
441+
"basicquery": FetchComponent,
323442
"plot:plotly:stored_json": PlotPlotlyStoredjsonComponent,
324443
"table": TableComponent,
325444
"metadata": MetadataComponent,
326445
"file:image:attach": FileImageAttachComponent,
327-
"slider": BasicQuery,
328-
"dropdown-query": BasicQuery,
446+
"slider": FetchComponent,
447+
"dropdown-query": FetchComponent,
448+
"form": InsertComponent,
329449
}

pharus/dynamic_api_gen.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import json
66
import re
77

8+
from pharus.component_interface import InsertComponent, TableComponent
9+
810

911
def populate_api():
1012
header_template = """# Auto-generated rest api
@@ -24,25 +26,26 @@ def populate_api():
2426
"""
2527
route_template = """
2628
27-
@app.route('{route}', methods=['GET'])
29+
@app.route('{route}', methods=['{rest_verb}'])
2830
@protected_route
2931
def {method_name}(jwt_payload: dict) -> dict:
3032
31-
if request.method in {{'GET'}}:
33+
if request.method in ['{rest_verb}']:
3234
try:
3335
component_instance = type_map['{component_type}'](name='{component_name}',
3436
component_config={component},
3537
static_config={static_config},
36-
jwt_payload=jwt_payload)
38+
jwt_payload=jwt_payload,
39+
{payload})
3740
return component_instance.{method_name_type}()
3841
except Exception as e:
3942
return traceback.format_exc(), 500
4043
"""
4144
route_template_nologin = """
4245
43-
@app.route('{route}', methods=['GET'])
46+
@app.route('{route}', methods=['{rest_verb}'])
4447
def {method_name}() -> dict:
45-
if request.method in {{'GET'}}:
48+
if request.method in ['{rest_verb}']:
4649
jwt_payload = dict(
4750
databaseAddress=os.environ["PHARUS_HOST"],
4851
username=os.environ["PHARUS_USER"],
@@ -52,7 +55,8 @@ def {method_name}() -> dict:
5255
component_instance = type_map['{component_type}'](name='{component_name}',
5356
component_config={component},
5457
static_config={static_config},
55-
jwt_payload=jwt_payload)
58+
jwt_payload=jwt_payload,
59+
{payload})
5660
return component_instance.{method_name_type}()
5761
except Exception as e:
5862
return traceback.format_exc(), 500
@@ -99,11 +103,13 @@ def {method_name}() -> dict:
99103
f.write(
100104
(active_route_template).format(
101105
route=grid["route"],
106+
rest_verb="GET",
102107
method_name=grid["route"].replace("/", ""),
103108
component_type="basicquery",
104109
component_name="dynamicgrid",
105110
component=json.dumps(grid),
106111
static_config=static_config,
112+
payload="",
107113
method_name_type="dj_query_route",
108114
)
109115
)
@@ -114,32 +120,55 @@ def {method_name}() -> dict:
114120
else grid["components"]
115121
).items():
116122
if re.match(
117-
r"^(table|metadata|plot|file|slider|dropdown-query).*$",
123+
r"^(table|metadata|plot|file|slider|dropdown-query|form).*$",
118124
comp["type"],
119125
):
120126
f.write(
121127
(active_route_template).format(
122128
route=comp["route"],
129+
rest_verb="POST" if comp["type"] == "form" else "GET",
123130
method_name=comp["route"].replace("/", ""),
124131
component_type=comp["type"],
125132
component_name=comp_name,
126133
component=json.dumps(comp),
127134
static_config=static_config,
135+
payload="payload=request.get_json()"
136+
if comp["type"] == "form"
137+
else "",
128138
method_name_type="dj_query_route",
129139
)
130140
)
131-
if type_map[comp["type"]].attributes_route_format:
141+
if issubclass(type_map[comp["type"]], InsertComponent):
142+
fields_route = type_map[
143+
comp["type"]
144+
].fields_route_format.format(route=comp["route"])
145+
f.write(
146+
(active_route_template).format(
147+
route=fields_route,
148+
rest_verb="GET",
149+
method_name=fields_route.replace("/", ""),
150+
component_type=comp["type"],
151+
component_name=comp_name,
152+
component=json.dumps(comp),
153+
static_config=static_config,
154+
payload="payload=None",
155+
method_name_type="fields_route",
156+
)
157+
)
158+
elif issubclass(type_map[comp["type"]], TableComponent):
132159
attributes_route = type_map[
133160
comp["type"]
134161
].attributes_route_format.format(route=comp["route"])
135162
f.write(
136163
(active_route_template).format(
137164
route=attributes_route,
165+
rest_verb="GET",
138166
method_name=attributes_route.replace("/", ""),
139167
component_type=comp["type"],
140168
component_name=comp_name,
141169
component=json.dumps(comp),
142170
static_config=static_config,
171+
payload="",
143172
method_name_type="attributes_route",
144173
)
145174
)

0 commit comments

Comments
 (0)