Skip to content

Commit a897809

Browse files
committed
throw appropriate errors for 404
1 parent bc855cd commit a897809

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

pharus/error.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@ class InvalidRestriction(Exception):
1111
"""Exception raised when restrictions result in no records when expected at least one."""
1212

1313
pass
14+
15+
16+
class SchemaNotFound(Exception):
17+
"""Exception raised when a given schema is not found to exist"""
18+
19+
pass
20+
21+
22+
class TableNotFound(Exception):
23+
"""Exception raised when a given table is not found to exist"""
24+
25+
pass

pharus/interface.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
"""Library for interfaces into DataJoint pipelines."""
22
import datajoint as dj
3+
from datajoint import DataJointError
34
from datajoint.utils import to_camel_case
45
from datajoint.user_tables import UserTable
56
from datajoint import VirtualModule
67
import datetime
78
import numpy as np
89
import re
9-
from .error import InvalidRestriction, UnsupportedTableType
10+
from .error import (
11+
InvalidRestriction,
12+
UnsupportedTableType,
13+
SchemaNotFound,
14+
TableNotFound,
15+
)
1016

1117
DAY = 24 * 60 * 60
1218
DEFAULT_FETCH_LIMIT = 1000 # Stop gap measure to deal with super large tables
@@ -60,9 +66,11 @@ def _list_tables(
6066
"""
6167

6268
# Get list of tables names
63-
tables_name = dj.Schema(
64-
schema_name, create_schema=False, connection=connection
65-
).list_tables()
69+
try:
70+
schema = dj.Schema(schema_name, create_schema=False, connection=connection)
71+
except DataJointError:
72+
raise SchemaNotFound("Schema does not exist")
73+
tables_name = schema.list_tables()
6674
# Dict to store list of table name for each type
6775
tables_dict_list = dict(manual=[], lookup=[], computed=[], imported=[], part=[])
6876
# Loop through each table name to figure out what type it is and add them to
@@ -452,12 +460,16 @@ def _get_table_object(
452460

453461
# Split the table name by '.' for dealing with part tables
454462
table_name_parts = table_name.split(".")
455-
if len(table_name_parts) == 2:
456-
return getattr(
457-
getattr(schema_virtual_module, table_name_parts[0]), table_name_parts[1]
458-
)
459-
else:
460-
return getattr(schema_virtual_module, table_name_parts[0])
463+
try:
464+
if len(table_name_parts) == 2:
465+
return getattr(
466+
getattr(schema_virtual_module, table_name_parts[0]),
467+
table_name_parts[1],
468+
)
469+
else:
470+
return getattr(schema_virtual_module, table_name_parts[0])
471+
except AttributeError:
472+
raise TableNotFound("Table does not exist")
461473

462474
@staticmethod
463475
def _filter_to_restriction(attribute_filter: dict, attribute_type: str) -> str:

0 commit comments

Comments
 (0)