Skip to content

Commit 7250cf2

Browse files
committed
Updated database access code
1 parent a82a043 commit 7250cf2

File tree

16 files changed

+89
-104
lines changed

16 files changed

+89
-104
lines changed

.idea/inspectionProfiles/Project_Default.xml

Lines changed: 0 additions & 38 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

services/column-footprint-editor/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
from dotenv import load_dotenv
3+
24

35
disable_loggers = ["macrostrat.database.utils"]
46

@@ -8,3 +10,5 @@ def pytest_configure():
810
for logger_name in disable_loggers:
911
logger = logging.getLogger(logger_name)
1012
logger.disabled = True
13+
14+
load_dotenv()

services/column-footprint-editor/macrostrat/column_footprint_editor/api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .home import HomePage
99
from .project import ProjectsAPI
1010
from .voronoi import VoronoiTesselator
11+
from ..settings import DATABASE
1112

1213
middleware = [
1314
Middleware(
@@ -36,6 +37,5 @@
3637
async def startup_event():
3738
from ..database import Database
3839

39-
# TODO: don't create tables on startup
40-
db = Database()
40+
db = Database(DATABASE)
4141
db.create_project_table()

services/column-footprint-editor/macrostrat/column_footprint_editor/api/column_groups.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from starlette.endpoints import HTTPEndpoint
55
from starlette.responses import JSONResponse
66

7-
from ..database import Database
87
from ..project import Project
8+
from ..settings import DATABASE
99

1010

1111
class ColumnGroups(HTTPEndpoint):
@@ -20,7 +20,7 @@ async def get(self, request):
2020
"""
2121

2222
project_id = request.path_params["project_id"]
23-
project = Project(project_id)
23+
project = Project(DATABASE, project_id)
2424

2525
if "id" in request.query_params:
2626
id_ = request.query_params["id"]
@@ -29,7 +29,7 @@ async def get(self, request):
2929
)
3030

3131
try:
32-
df = Database().exec_query(sql)
32+
df = project.db.exec_query(sql)
3333
col_groups = df.to_dict(orient="records")
3434

3535
return JSONResponse({"status": "success", "data": col_groups})
@@ -53,7 +53,7 @@ async def post(self, request):
5353
:col_group,:col_group_name,:color
5454
) """
5555

56-
project = Project(request.path_params["project_id"])
56+
project = Project(DATABASE, request.path_params["project_id"])
5757

5858
res = await request.json()
5959

@@ -82,7 +82,7 @@ async def put(self, request):
8282
color = :color
8383
WHERE cg.id = :col_group_id
8484
"""
85-
project = Project(request.path_params["project_id"])
85+
project = Project(DATABASE, request.path_params["project_id"])
8686

8787
res = await request.json()
8888

services/column-footprint-editor/macrostrat/column_footprint_editor/api/geometries.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import json
2-
import simplejson
32
from pathlib import Path
3+
4+
import simplejson
45
from starlette.endpoints import HTTPEndpoint
56
from starlette.responses import JSONResponse, PlainTextResponse
67

78
from .utils import clean_change_set
89
from ..database import Database
910
from ..project import Project
1011
from ..project.importer import ProjectImporter
12+
from ..settings import DATABASE
1113

1214
here = Path(__file__).parent / ".."
1315
procedures = here / "database" / "procedures"
@@ -18,8 +20,7 @@
1820
class Lines(HTTPEndpoint):
1921
async def get(self, request):
2022
project_id = request.path_params["project_id"]
21-
project = Project(project_id)
22-
db = Database(project)
23+
project = Project(DATABASE, project_id)
2324

2425
if "id" in request.query_params:
2526
id_ = request.query_params["id"]
@@ -41,7 +42,7 @@ async def get(self, request):
4142
q = queries / "get-linework.sql"
4243
sql = open(q).read()
4344

44-
df = db.exec_query(sql)
45+
df = project.db.exec_query(sql)
4546

4647
lines = []
4748
for i in range(0, len(df["lines"])):
@@ -72,7 +73,7 @@ async def put(self, request):
7273
data = await request.json()
7374

7475
project_id = data["project_id"]
75-
project = Project(project_id)
76+
project = Project(DATABASE, project_id)
7677
db = Database(project)
7778

7879
new_change_set = clean_change_set(data["change_set"])
@@ -107,13 +108,12 @@ async def put(self, request):
107108
class Points(HTTPEndpoint):
108109
async def get(self, request):
109110
project_id = request.path_params["project_id"]
110-
project = Project(project_id)
111-
db = Database(project)
111+
project = Project(DATABASE, project_id)
112112

113113
q = queries / "get-points.sql"
114114
sql = open(q).read()
115115

116-
df = db.exec_query(sql)
116+
df = project.db.exec_query(sql)
117117
cols = df.to_dict(orient="records")
118118
cols = json.loads(simplejson.dumps(cols, ignore_nan=True))
119119

@@ -159,7 +159,7 @@ async def get_line(request):
159159

160160
sql = f"SELECT ST_AsGeoJSON(((ST_Dump(ST_Boundary({location_parser}))).geom))"
161161

162-
db = Database()
162+
db = Database(DATABASE)
163163
df = db.exec_query(sql, params={"location": data})
164164
location = json.loads(df.to_dict(orient="records")[0]["st_asgeojson"])
165165

@@ -169,13 +169,12 @@ async def get_line(request):
169169
async def geometries(request):
170170

171171
project_id = request.path_params["project_id"]
172-
project = Project(project_id)
173-
db = Database(project)
172+
project = Project(DATABASE, project_id)
174173

175174
q = queries / "get-topology-columns.sql"
176175
sql = open(q).read()
177176

178-
df = db.exec_query(sql)
177+
df = project.db.exec_query(sql)
179178
cols = df.to_dict(orient="records")
180179
cols = json.loads(simplejson.dumps(cols, ignore_nan=True))
181180

services/column-footprint-editor/macrostrat/column_footprint_editor/api/project.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from os import error
2-
31
import json
2+
from os import error
43
from pathlib import Path
4+
55
from starlette.endpoints import HTTPEndpoint
66
from starlette.responses import JSONResponse
77

88
from ..database import Database
99
from ..project import Project
10+
from ..settings import DATABASE
1011

1112
here = Path(__file__).parent / ".."
1213
procedures = here / "database" / "procedures"
@@ -18,7 +19,7 @@ class ProjectsAPI(HTTPEndpoint):
1819

1920
async def get(self, request):
2021
"""endpoint to get availble projects"""
21-
db = Database()
22+
db = Database(DATABASE)
2223

2324
project_data = db.get_project_info()
2425

@@ -58,7 +59,7 @@ async def put(self, request):
5859
async def post(self, request):
5960
res = await request.json()
6061

61-
db = Database()
62+
db = Database(DATABASE)
6263
next_id = db.get_next_project_id()
6364

6465
params = res["data"]

services/column-footprint-editor/macrostrat/column_footprint_editor/api/voronoi.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from collections import defaultdict
2-
31
import json
2+
from collections import defaultdict
43
from pathlib import Path
4+
55
from starlette.endpoints import HTTPEndpoint
66
from starlette.responses import JSONResponse
77

88
from ..database import Database
99
from ..project import Project
10+
from ..settings import DATABASE
1011

1112
here = Path(__file__).parent / ".."
1213
procedures = here / "database" / "procedures"
@@ -112,30 +113,28 @@ def tesselate(self, db: Database, points, radius, quad_segs):
112113

113114
async def put(self, request):
114115
project_id = request.path_params["project_id"]
115-
project = Project(project_id)
116-
db = Database(project)
116+
project = Project(DATABASE, project_id)
117117

118118
data = await request.json()
119119
points = data["points"]
120120
radius = data["radius"]
121121
quad_segs = data["quad_segs"]
122122

123-
polygons = self.tesselate(db, points, radius, quad_segs)
123+
polygons = self.tesselate(project.db, points, radius, quad_segs)
124124

125125
return JSONResponse({"Status": "Success", "polygons": polygons})
126126

127127
async def post(self, request):
128128
project_id = request.path_params["project_id"]
129-
project = Project(project_id)
130-
db = Database(project)
129+
project = Project(DATABASE, project_id)
131130

132131
data = await request.json()
133132
points = data["points"]
134133

135134
radius = data["radius"]
136135
quad_segs = data["quad_segs"]
137136

138-
polygons = self.tesselate(db, points, radius, quad_segs)
137+
polygons = self.tesselate(project.db, points, radius, quad_segs)
139138
## dump each polygon to multilinestring and insert!!
140139
sql = open(self.dump_voronoi_to_lines).read()
141140

services/column-footprint-editor/macrostrat/column_footprint_editor/database/__init__.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from macrostrat.database import Database as BaseDatabase
22
from macrostrat.database.utils import get_sql_text
33
from pathlib import Path
4-
from sqlalchemy import create_engine
54
from sqlalchemy.orm import sessionmaker
65

76
from .sql_formatter import SqlFormatter
8-
from ..settings import DATABASE
97
from ..utils import run_topology_command, delete_config
108

119
here = Path(__file__).parent
@@ -30,11 +28,10 @@ class Database(BaseDatabase):
3028
Database class with built in SQL Formatter
3129
"""
3230

33-
def __init__(self, project=None):
31+
def __init__(self, url, project=None):
3432
self.project_id = getattr(project, "id", None)
35-
super().__init__(DATABASE, echo_sql=True)
33+
super().__init__(url)
3634

37-
self.engine = create_engine(DATABASE, echo=True)
3835
self.Session = sessionmaker(bind=self.engine)
3936
# self.config = config_check(project)
4037
self.formatter = SqlFormatter(self.project_id)
@@ -110,7 +107,7 @@ def redump_linework_from_edge(self):
110107
self.run_sql_file(redump_linework_sql)
111108

112109
def remove_project(self, params={}):
113-
run_topology_command(self.project_id, "delete") # delete topology
110+
run_topology_command(self, self.project_id, "delete") # delete topology
114111
self.run_sql_file(remove_project_schema, params={"project_id": self.project_id})
115112
delete_config(self.project_id) # remove config file
116113

services/column-footprint-editor/macrostrat/column_footprint_editor/project/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import requests
22
from pathlib import Path
3+
from psycopg2.sql import Identifier
34

45
from mapboard.topology_manager.database import _get_instance_params
56
from ..database import Database
@@ -12,24 +13,27 @@
1213
class Project:
1314
"""Helper class to pass around project attributes"""
1415

15-
def __init__(self, id_: int = None, name: str = "", description: str = "") -> None:
16+
def __init__(
17+
self, db_url: str, id_: int = None, *, name: str = "", description: str = ""
18+
) -> None:
1619
self.id = id_
1720
self.name = name
1821
self.description = description
19-
self.db = Database(self)
22+
self.db = Database(db_url, project=self)
2023

2124
params = _get_instance_params(
2225
data_schema=f"project_{self.id}_data",
2326
topo_schema=f"project_{self.id}_topology",
2427
tolerance=0.0001,
2528
)
26-
params["project_schema"] = f"project_{self.id}"
29+
params["project_schema"] = Identifier(f"project_{self.id}")
2730

2831
self.db.instance_params = params
2932

3033
self.base_url = IMPORTER_API
3134

3235
def create_new_project(self):
36+
self.db.create_project_table()
3337
if not self.project_in_db():
3438
self.id = self.db.get_next_project_id()
3539
self.insert_project_info()

services/column-footprint-editor/macrostrat/column_footprint_editor/project/importer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import json
2+
23
import requests
34

45
from . import Project
5-
from ..settings import IMPORTER_API
6+
from ..settings import IMPORTER_API, DATABASE
67

78

89
class ProjectImporter:
@@ -22,7 +23,7 @@ class ProjectImporter:
2223
"""
2324

2425
def __init__(self, project_id: int, name: str, description: str):
25-
self.project = Project(project_id, name, description)
26+
self.project = Project(DATABASE, project_id, name=name, description=description)
2627
self.db = self.project.db
2728
self.base_url = IMPORTER_API
2829

0 commit comments

Comments
 (0)