88from  .expression  import  QueryExpression , AndList 
99from  .errors  import  DataJointError , LostConnectionError 
1010import  signal 
11+ import  multiprocessing  as  mp 
1112
1213# noinspection PyExceptionInherit,PyCallingNonCallable 
1314
1415logger  =  logging .getLogger (__name__ )
1516
1617
18+ # --- helper functions for multiprocessing -- 
19+ 
20+ def  _initialize_populate (table , jobs , populate_kwargs ):
21+     """ 
22+     Initialize the process for mulitprocessing. 
23+     Saves the unpickled copy of the table to the current process and reconnects. 
24+     """ 
25+     process  =  mp .current_process ()
26+     process .table  =  table 
27+     process .jobs  =  jobs 
28+     process .populate_kwargs  =  populate_kwargs 
29+     table .connection .connect ()  # reconnect 
30+ 
31+ 
32+ def  _call_populate1 (key ):
33+     """ 
34+     Call current process' table._populate1() 
35+     :key - a dict specifying job to compute 
36+     :return: key, error if error, otherwise None 
37+     """ 
38+     process  =  mp .current_process ()
39+     return  process .table ._populate1 (key , process .jobs , ** process .populate_kwargs )
40+ 
41+ 
1742class  AutoPopulate :
1843    """ 
1944    AutoPopulate is a mixin class that adds the method populate() to a Relation class. 
@@ -26,18 +51,20 @@ class AutoPopulate:
2651    @property  
2752    def  key_source (self ):
2853        """ 
29-         :return: the relation whose primary key values are passed, sequentially, to the 
30-                 ``make`` method when populate() is called. 
31-                 The default value is the join of the parent relations. 
32-                 Users may override to change the granularity or the scope of populate() calls. 
54+         :return: the query expression that yields primary key values to be passed, 
55+         sequentially, to the ``make`` method when populate() is called. 
56+         The default value is the join of the parent tables references from the primary key. 
57+         Subclasses may override they key_source to change the scope or the granularity 
58+         of the make calls. 
3359        """ 
3460        def  _rename_attributes (table , props ):
3561            return  (table .proj (
3662                ** {attr : ref  for  attr , ref  in  props ['attr_map' ].items () if  attr  !=  ref })
37-                 if  props ['aliased' ] else  table )
63+                 if  props ['aliased' ] else  table . proj () )
3864
3965        if  self ._key_source  is  None :
40-             parents  =  self .target .parents (primary = True , as_objects = True , foreign_key_info = True )
66+             parents  =  self .target .parents (
67+                 primary = True , as_objects = True , foreign_key_info = True )
4168            if  not  parents :
4269                raise  DataJointError ('A table must have dependencies ' 
4370                                     'from its primary key for auto-populate to work' )
@@ -48,17 +75,19 @@ def _rename_attributes(table, props):
4875
4976    def  make (self , key ):
5077        """ 
51-         Derived classes must implement method `make` that fetches data from tables that are  
52-         above them in the dependency hierarchy, restricting by the given key, computes dependent  
53-         attributes, and inserts the new tuples into self. 
78+         Derived classes must implement method `make` that fetches data from tables 
79+         above them in the dependency hierarchy, restricting by the given key, 
80+         computes secondary  attributes, and inserts the new tuples into self. 
5481        """ 
55-         raise  NotImplementedError ('Subclasses of AutoPopulate must implement the method `make`' )
82+         raise  NotImplementedError (
83+             'Subclasses of AutoPopulate must implement the method `make`' )
5684
5785    @property  
5886    def  target (self ):
5987        """ 
6088        :return: table to be populated. 
61-         In the typical case, dj.AutoPopulate is mixed into a dj.Table class by inheritance and the target is self. 
89+         In the typical case, dj.AutoPopulate is mixed into a dj.Table class by 
90+         inheritance and the target is self. 
6291        """ 
6392        return  self 
6493
@@ -85,41 +114,50 @@ def _jobs_to_do(self, restrictions):
85114
86115        if  not  isinstance (todo , QueryExpression ):
87116            raise  DataJointError ('Invalid key_source value' )
88-          # check if target lacks any attributes from the primary key of key_source 
117+ 
89118        try :
119+             # check if target lacks any attributes from the primary key of key_source 
90120            raise  DataJointError (
91-                 'The populate target lacks attribute %s from the primary key of key_source'  %  next (
92-                     name  for  name  in  todo .heading .primary_key  if  name  not  in   self .target .heading ))
121+                 'The populate target lacks attribute %s ' 
122+                 'from the primary key of key_source'  %  next (
123+                     name  for  name  in  todo .heading .primary_key 
124+                     if  name  not  in   self .target .heading ))
93125        except  StopIteration :
94126            pass 
95127        return  (todo  &  AndList (restrictions )).proj ()
96128
97129    def  populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
98130                 reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
99-                  display_progress = False , make_kwargs = None ):
131+                  display_progress = False , processes = 1 ,  make_kwargs = None ):
100132        """ 
101-         rel.populate() calls rel.make(key) for every primary key in self.key_source 
102-         for which there is not already a tuple in rel. 
103-         :param restrictions: a list of restrictions each restrict (rel.key_source - target.proj()) 
133+         ``table.populate()`` calls ``table.make(key)`` for every primary key in  
134+         ``self.key_source`` for which there is not already a tuple in table. 
135+          
136+         :param restrictions: a list of restrictions each restrict  
137+             (table.key_source - target.proj()) 
104138        :param suppress_errors: if True, do not terminate execution. 
105139        :param return_exception_objects: return error objects instead of just error messages 
106-         :param reserve_jobs: if true, reserves job  to populate in asynchronous fashion 
140+         :param reserve_jobs: if True, reserve jobs  to populate in asynchronous fashion 
107141        :param order: "original"|"reverse"|"random"  - the order of execution 
142+         :param limit: if not None, check at most this many keys 
143+         :param max_calls: if not None, populate at most this many keys 
108144        :param display_progress: if True, report progress_bar 
109-         :param limit: if not None, checks at most that many keys 
110-         :param max_calls: if not None, populates at max that many keys 
111-         :param make_kwargs: optional dict containing keyword arguments that will be passed down to each make() call 
145+         :param processes: number of processes to use. When set to a large number, then 
146+             uses as many as CPU cores 
147+         :param make_kwargs: Keyword arguments which do not affect the result of computation  
148+             to be passed down to each ``make()`` call. Computation arguments should be  
149+             specified within the pipeline e.g. using a `dj.Lookup` table. 
150+         :type make_kwargs: dict, optional 
112151        """ 
113152        if  self .connection .in_transaction :
114153            raise  DataJointError ('Populate cannot be called during a transaction.' )
115154
116155        valid_order  =  ['original' , 'reverse' , 'random' ]
117156        if  order  not  in   valid_order :
118157            raise  DataJointError ('The order argument must be one of %s'  %  str (valid_order ))
119-         error_list  =  [] if  suppress_errors  else  None 
120158        jobs  =  self .connection .schemas [self .target .database ].jobs  if  reserve_jobs  else  None 
121159
122-         # define and setup  signal handler for SIGTERM 
160+         # define and set up  signal handler for SIGTERM:  
123161        if  reserve_jobs :
124162            def  handler (signum , frame ):
125163                logger .info ('Populate terminated by SIGTERM' )
@@ -132,60 +170,100 @@ def handler(signum, frame):
132170        elif  order  ==  "random" :
133171            random .shuffle (keys )
134172
135-         call_count  =  0 
136173        logger .info ('Found %d keys to populate'  %  len (keys ))
137174
138-         make  =  self ._make_tuples  if  hasattr (self , '_make_tuples' ) else  self .make 
175+         keys  =  keys [:max_calls ]
176+         nkeys  =  len (keys )
139177
140-         for  key  in  (tqdm (keys , desc = self .__class__ .__name__ ) if  display_progress  else  keys ):
141-             if  max_calls  is  not   None  and  call_count  >=  max_calls :
142-                 break 
143-             if  not  reserve_jobs  or  jobs .reserve (self .target .table_name , self ._job_key (key )):
144-                 self .connection .start_transaction ()
145-                 if  key  in  self .target :  # already populated 
146-                     self .connection .cancel_transaction ()
147-                     if  reserve_jobs :
148-                         jobs .complete (self .target .table_name , self ._job_key (key ))
178+         if  processes  >  1 :
179+             processes  =  min (processes , nkeys , mp .cpu_count ())
180+ 
181+         error_list  =  []
182+         populate_kwargs  =  dict (
183+             suppress_errors = suppress_errors ,
184+             return_exception_objects = return_exception_objects ,
185+             make_kwargs = make_kwargs )
186+ 
187+         if  processes  ==  1 :
188+             for  key  in  tqdm (keys , desc = self .__class__ .__name__ ) if  display_progress  else  keys :
189+                 error  =  self ._populate1 (key , jobs , ** populate_kwargs )
190+                 if  error  is  not   None :
191+                     error_list .append (error )
192+         else :
193+             # spawn multiple processes 
194+             self .connection .close ()  # disconnect parent process from MySQL server 
195+             del  self .connection ._conn .ctx   # SSLContext is not pickleable 
196+             with  mp .Pool (processes , _initialize_populate , (self , populate_kwargs )) as  pool :
197+                 if  display_progress :
198+                     with  tqdm (desc = "Processes: " , total = nkeys ) as  pbar :
199+                         for  error  in  pool .imap (_call_populate1 , keys , chunksize = 1 ):
200+                             if  error  is  not   None :
201+                                 error_list .append (error )
202+                             pbar .update ()
149203                else :
150-                     logger .info ('Populating: '  +  str (key ))
151-                     call_count  +=  1 
152-                     self .__class__ ._allow_insert  =  True 
153-                     try :
154-                         make (dict (key ), ** (make_kwargs  or  {}))
155-                     except  (KeyboardInterrupt , SystemExit , Exception ) as  error :
156-                         try :
157-                             self .connection .cancel_transaction ()
158-                         except  LostConnectionError :
159-                             pass 
160-                         error_message  =  '{exception}{msg}' .format (
161-                             exception = error .__class__ .__name__ ,
162-                             msg = ': '  +  str (error ) if  str (error ) else  '' )
163-                         if  reserve_jobs :
164-                             # show error name and error message (if any) 
165-                             jobs .error (
166-                                 self .target .table_name , self ._job_key (key ),
167-                                 error_message = error_message , error_stack = traceback .format_exc ())
168-                         if  not  suppress_errors  or  isinstance (error , SystemExit ):
169-                             raise 
170-                         else :
171-                             logger .error (error )
172-                             error_list .append ((key , error  if  return_exception_objects  else  error_message ))
173-                     else :
174-                         self .connection .commit_transaction ()
175-                         if  reserve_jobs :
176-                             jobs .complete (self .target .table_name , self ._job_key (key ))
177-                     finally :
178-                         self .__class__ ._allow_insert  =  False 
204+                     for  error  in  pool .imap (_call_populate1 , keys ):
205+                         if  error  is  not   None :
206+                             error_list .append (error )
207+             self .connection .connect ()  # reconnect parent process to MySQL server 
179208
180-         # place back the  original signal handler 
209+         # restore  original signal handler:  
181210        if  reserve_jobs :
182211            signal .signal (signal .SIGTERM , old_handler )
183-         return  error_list 
212+ 
213+         if  suppress_errors :
214+             return  error_list 
215+ 
216+     def  _populate1 (self , key , jobs , suppress_errors , return_exception_objects , make_kwargs = None ):
217+         """ 
218+         populates table for one source key, calling self.make inside a transaction. 
219+         :param jobs: the jobs table or None if not reserve_jobs 
220+         :param key: dict specifying job to populate 
221+         :param suppress_errors: bool if errors should be suppressed and returned 
222+         :param return_exception_objects: if True, errors must be returned as objects 
223+         :return: (key, error) when suppress_errors=True, otherwise None 
224+         """ 
225+         make  =  self ._make_tuples  if  hasattr (self , '_make_tuples' ) else  self .make 
226+ 
227+         if  jobs  is  None  or  jobs .reserve (self .target .table_name , self ._job_key (key )):
228+             self .connection .start_transaction ()
229+             if  key  in  self .target :  # already populated 
230+                 self .connection .cancel_transaction ()
231+                 if  jobs  is  not   None :
232+                     jobs .complete (self .target .table_name , self ._job_key (key ))
233+             else :
234+                 logger .info ('Populating: '  +  str (key ))
235+                 self .__class__ ._allow_insert  =  True 
236+                 try :
237+                     make (dict (key ), ** (make_kwargs  or  {}))
238+                 except  (KeyboardInterrupt , SystemExit , Exception ) as  error :
239+                     try :
240+                         self .connection .cancel_transaction ()
241+                     except  LostConnectionError :
242+                         pass 
243+                     error_message  =  '{exception}{msg}' .format (
244+                         exception = error .__class__ .__name__ ,
245+                         msg = ': '  +  str (error ) if  str (error ) else  '' )
246+                     if  jobs  is  not   None :
247+                         # show error name and error message (if any) 
248+                         jobs .error (
249+                             self .target .table_name , self ._job_key (key ),
250+                             error_message = error_message , error_stack = traceback .format_exc ())
251+                     if  not  suppress_errors  or  isinstance (error , SystemExit ):
252+                         raise 
253+                     else :
254+                         logger .error (error )
255+                         return  key , error  if  return_exception_objects  else  error_message 
256+                 else :
257+                     self .connection .commit_transaction ()
258+                     if  jobs  is  not   None :
259+                         jobs .complete (self .target .table_name , self ._job_key (key ))
260+                 finally :
261+                     self .__class__ ._allow_insert  =  False 
184262
185263    def  progress (self , * restrictions , display = True ):
186264        """ 
187-         report  progress of populating the table 
188-         :return: remaining, total -- tuples to be populated 
265+         Report the  progress of populating the table.  
266+         :return: ( remaining, total)  -- numbers of  tuples to be populated 
189267        """ 
190268        todo  =  self ._jobs_to_do (restrictions )
191269        total  =  len (todo )
@@ -194,5 +272,6 @@ def progress(self, *restrictions, display=True):
194272            print ('%-20s'  %  self .__class__ .__name__ ,
195273                  'Completed %d of %d (%2.1f%%)   %s'  %  (
196274                      total  -  remaining , total , 100  -  100  *  remaining  /  (total + 1e-12 ),
197-                       datetime .datetime .strftime (datetime .datetime .now (), '%Y-%m-%d %H:%M:%S' )), flush = True )
275+                       datetime .datetime .strftime (datetime .datetime .now (),
276+                                                  '%Y-%m-%d %H:%M:%S' )), flush = True )
198277        return  remaining , total 
0 commit comments