|
1 | 1 | """Library for interfaces into DataJoint pipelines.""" |
2 | 2 | import datajoint as dj |
| 3 | +from datajoint import DataJointError |
3 | 4 | from datajoint.utils import to_camel_case |
4 | 5 | from datajoint.user_tables import UserTable |
5 | 6 | from datajoint import VirtualModule |
6 | 7 | import datetime |
7 | 8 | import numpy as np |
8 | 9 | import re |
9 | | -from .error import InvalidRestriction, UnsupportedTableType |
| 10 | +from .error import ( |
| 11 | + InvalidRestriction, |
| 12 | + UnsupportedTableType, |
| 13 | + SchemaNotFound, |
| 14 | + TableNotFound, |
| 15 | +) |
10 | 16 |
|
11 | 17 | DAY = 24 * 60 * 60 |
12 | 18 | DEFAULT_FETCH_LIMIT = 1000 # Stop gap measure to deal with super large tables |
@@ -60,9 +66,11 @@ def _list_tables( |
60 | 66 | """ |
61 | 67 |
|
62 | 68 | # Get list of tables names |
63 | | - tables_name = dj.Schema( |
64 | | - schema_name, create_schema=False, connection=connection |
65 | | - ).list_tables() |
| 69 | + try: |
| 70 | + schema = dj.Schema(schema_name, create_schema=False, connection=connection) |
| 71 | + except DataJointError: |
| 72 | + raise SchemaNotFound("Schema does not exist") |
| 73 | + tables_name = schema.list_tables() |
66 | 74 | # Dict to store list of table name for each type |
67 | 75 | tables_dict_list = dict(manual=[], lookup=[], computed=[], imported=[], part=[]) |
68 | 76 | # Loop through each table name to figure out what type it is and add them to |
@@ -452,12 +460,16 @@ def _get_table_object( |
452 | 460 |
|
453 | 461 | # Split the table name by '.' for dealing with part tables |
454 | 462 | table_name_parts = table_name.split(".") |
455 | | - if len(table_name_parts) == 2: |
456 | | - return getattr( |
457 | | - getattr(schema_virtual_module, table_name_parts[0]), table_name_parts[1] |
458 | | - ) |
459 | | - else: |
460 | | - return getattr(schema_virtual_module, table_name_parts[0]) |
| 463 | + try: |
| 464 | + if len(table_name_parts) == 2: |
| 465 | + return getattr( |
| 466 | + getattr(schema_virtual_module, table_name_parts[0]), |
| 467 | + table_name_parts[1], |
| 468 | + ) |
| 469 | + else: |
| 470 | + return getattr(schema_virtual_module, table_name_parts[0]) |
| 471 | + except AttributeError: |
| 472 | + raise TableNotFound("Table does not exist") |
461 | 473 |
|
462 | 474 | @staticmethod |
463 | 475 | def _filter_to_restriction(attribute_filter: dict, attribute_type: str) -> str: |
|
0 commit comments