77import logging
88import warnings
99from pymysql import OperationalError , InternalError , IntegrityError
10- from . import config , DataJointError
10+ from . import config
1111from .declare import declare
1212from .relational_operand import RelationalOperand
1313from .blob import pack
1414from .utils import user_choice
1515from .heading import Heading
16- from .settings import server_error_codes
16+ from .errors import server_error_codes , DataJointError , DuplicateError
1717from . import __version__ as version
1818
1919logger = logging .getLogger (__name__ )
@@ -42,7 +42,12 @@ def heading(self):
4242 if self ._heading is None :
4343 self ._heading = Heading () # instance-level heading
4444 if not self ._heading : # lazy loading of heading
45- self ._heading .init_from_database (self .connection , self .database , self .table_name )
45+ if self .connection is None :
46+ raise DataJointError (
47+ 'DataJoint class is missing a database connection. '
48+ 'Missing schema decorator on the class? (e.g. @schema)' )
49+ else :
50+ self ._heading .init_from_database (self .connection , self .database , self .table_name )
4651 return self ._heading
4752
4853 @property
@@ -172,7 +177,8 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
172177 fields = '`' + '`,`' .join (fields ) + '`' ,
173178 table = self .full_table_name ,
174179 select = rows .make_sql (select_fields = fields ),
175- duplicate = (' ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`' .format (pk = self .primary_key [0 ])
180+ duplicate = (' ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`' .format (
181+ table = self .full_table_name , pk = self .primary_key [0 ])
176182 if skip_duplicates else '' ))
177183 self .connection .query (query )
178184 return
@@ -282,10 +288,12 @@ def check_fields(fields):
282288 elif err .args [0 ] == server_error_codes ['unknown column' ]:
283289 # args[1] -> Unknown column 'extra' in 'field list'
284290 raise DataJointError (
285- '{} : To ignore extra fields, set ignore_extra_fields=True in insert.' .format (err .args [1 ])) from None
291+ '{} : To ignore extra fields, set ignore_extra_fields=True in insert.' .format (err .args [1 ])
292+ ) from None
286293 elif err .args [0 ] == server_error_codes ['duplicate entry' ]:
287- raise DataJointError (
288- '{} : To ignore duplicate entries, set skip_duplicates=True in insert.' .format (err .args [1 ])) from None
294+ raise DuplicateError (
295+ '{} : To ignore duplicate entries, set skip_duplicates=True in insert.' .format (err .args [1 ])
296+ ) from None
289297 else :
290298 raise
291299
@@ -434,11 +442,15 @@ def show_definition(self):
434442 logger .warning ('show_definition is deprecated. Use describe instead.' )
435443 return self .describe ()
436444
437- def describe (self , printout = True ):
445+ def describe (self , context = None , printout = True ):
438446 """
439447 :return: the definition string for the relation using DataJoint DDL.
440448 This does not yet work for aliased foreign keys.
441449 """
450+ if context is None :
451+ frame = inspect .currentframe ().f_back
452+ context = dict (frame .f_globals , ** frame .f_locals )
453+ del frame
442454 if self .full_table_name not in self .connection .dependencies :
443455 self .connection .dependencies .load ()
444456 parents = self .parents ()
@@ -460,14 +472,14 @@ def describe(self, printout=True):
460472 parents .pop (parent_name )
461473 if not parent_name .isdigit ():
462474 definition += '-> {class_name}\n ' .format (
463- class_name = lookup_class_name (parent_name , self . context ) or parent_name )
475+ class_name = lookup_class_name (parent_name , context ) or parent_name )
464476 else :
465477 # aliased foreign key
466478 parent_name = list (self .connection .dependencies .in_edges (parent_name ))[0 ][0 ]
467479 lst = [(attr , ref ) for attr , ref in fk_props ['attr_map' ].items () if ref != attr ]
468480 definition += '({attr_list}) -> {class_name}{ref_list}\n ' .format (
469481 attr_list = ',' .join (r [0 ] for r in lst ),
470- class_name = lookup_class_name (parent_name , self . context ) or parent_name ,
482+ class_name = lookup_class_name (parent_name , context ) or parent_name ,
471483 ref_list = ('' if len (attributes_thus_far ) - len (attributes_declared ) == 1
472484 else '(%s)' % ',' .join (r [1 ] for r in lst )))
473485 attributes_declared .update (fk_props ['attr_map' ])
@@ -540,25 +552,26 @@ def lookup_class_name(name, context, depth=3):
540552 while nodes :
541553 node = nodes .pop (0 )
542554 for member_name , member in node ['context' ].items ():
543- if inspect .isclass (member ) and issubclass (member , BaseRelation ):
544- if member .full_table_name == name : # found it!
545- return '.' .join ([node ['context_name' ], member_name ]).lstrip ('.' )
546- try : # look for part tables
547- parts = member ._ordered_class_members
548- except AttributeError :
549- pass # not a UserRelation -- cannot have part tables.
550- else :
551- for part in (getattr (member , p ) for p in parts if p [0 ].isupper () and hasattr (member , p )):
552- if inspect .isclass (part ) and issubclass (part , BaseRelation ) and part .full_table_name == name :
553- return '.' .join ([node ['context_name' ], member_name , part .__name__ ]).lstrip ('.' )
554- elif node ['depth' ] > 0 and inspect .ismodule (member ) and member .__name__ != 'datajoint' :
555- try :
556- nodes .append (
557- dict (context = dict (inspect .getmembers (member )),
558- context_name = node ['context_name' ] + '.' + member_name ,
559- depth = node ['depth' ]- 1 ))
560- except ImportError :
561- pass # could not import, so do not attempt
555+ if not member_name .startswith ('_' ): # skip IPython's implicit variables
556+ if inspect .isclass (member ) and issubclass (member , BaseRelation ):
557+ if member .full_table_name == name : # found it!
558+ return '.' .join ([node ['context_name' ], member_name ]).lstrip ('.' )
559+ try : # look for part tables
560+ parts = member ._ordered_class_members
561+ except AttributeError :
562+ pass # not a UserRelation -- cannot have part tables.
563+ else :
564+ for part in (getattr (member , p ) for p in parts if p [0 ].isupper () and hasattr (member , p )):
565+ if inspect .isclass (part ) and issubclass (part , BaseRelation ) and part .full_table_name == name :
566+ return '.' .join ([node ['context_name' ], member_name , part .__name__ ]).lstrip ('.' )
567+ elif node ['depth' ] > 0 and inspect .ismodule (member ) and member .__name__ != 'datajoint' :
568+ try :
569+ nodes .append (
570+ dict (context = dict (inspect .getmembers (member )),
571+ context_name = node ['context_name' ] + '.' + member_name ,
572+ depth = node ['depth' ]- 1 ))
573+ except ImportError :
574+ pass # could not import, so do not attempt
562575 return None
563576
564577
0 commit comments