2020    from  ecephys_spike_sorting .scripts .create_input_json  import  createInputJson 
2121    from  ecephys_spike_sorting .scripts .helpers  import  SpikeGLX_utils 
2222except  Exception  as  e :
23-     print (f'Error in  loading "ecephys_spike_sorting" package - { str (e )}  ' )
23+     print (f'Warning: Failed  loading "ecephys_spike_sorting" package - { str (e )}  ' )
2424
2525# import pykilosort package 
2626try :
2727    import  pykilosort 
2828except  Exception  as  e :
29-     print (f'Error in  loading "pykilosort" package - { str (e )}  ' )
29+     print (f'Warning: Failed  loading "pykilosort" package - { str (e )}  ' )
3030
3131
3232class  SGLXKilosortPipeline :
@@ -67,7 +67,6 @@ def __init__(
6767        ni_present = False ,
6868        ni_extract_string = None ,
6969    ):
70- 
7170        self ._npx_input_dir  =  pathlib .Path (npx_input_dir )
7271
7372        self ._ks_output_dir  =  pathlib .Path (ks_output_dir )
@@ -85,6 +84,13 @@ def __init__(
8584        self ._json_directory  =  self ._ks_output_dir  /  "json_configs" 
8685        self ._json_directory .mkdir (parents = True , exist_ok = True )
8786
87+         self ._module_input_json  =  (
88+             self ._json_directory  /  f"{ self ._npx_input_dir .name }  -input.json" 
89+         )
90+         self ._module_logfile  =  (
91+             self ._json_directory  /  f"{ self ._npx_input_dir .name }  -run_modules-log.txt" 
92+         )
93+ 
8894        self ._CatGT_finished  =  False 
8995        self .ks_input_params  =  None 
9096        self ._modules_input_hash  =  None 
@@ -223,20 +229,20 @@ def generate_modules_input_json(self):
223229            ** params ,
224230        )
225231
226-         self ._modules_input_hash  =  dict_to_uuid (self .ks_input_params )
232+         self ._modules_input_hash  =  dict_to_uuid (dict ( self ._params ,  KS2ver = self . _KS2ver ) )
227233
228-     def  run_modules (self ):
234+     def  run_modules (self ,  modules_to_run = None ):
229235        if  self ._run_CatGT  and  not  self ._CatGT_finished :
230236            self .run_CatGT ()
231237
232238        print ("---- Running Modules ----" )
233239        self .generate_modules_input_json ()
234240        module_input_json  =  self ._module_input_json .as_posix ()
235-         module_logfile  =  module_input_json . replace ( 
236-              "-input.json" ,  "-run_modules-log.txt" 
237-         ) 
241+         module_logfile  =  self . _module_logfile . as_posix () 
242+ 
243+         modules   =   modules_to_run   or   self . _modules 
238244
239-         for  module  in  self . _modules :
245+         for  module  in  modules :
240246            module_status  =  self ._get_module_status (module )
241247            if  module_status ["completion_time" ] is  not   None :
242248                continue 
@@ -312,13 +318,11 @@ def _update_module_status(self, updated_module_status={}):
312318        else :
313319            # handle cases of processing rerun on different parameters (the hash changes) 
314320            # delete outdated files 
315-             outdated_files   =   [
316-                 f 
321+             [
322+                 f . unlink () 
317323                for  f  in  self ._json_directory .glob ("*" )
318324                if  f .is_file () and  f .name  !=  self ._module_input_json .name 
319325            ]
320-             for  f  in  outdated_files :
321-                 f .unlink ()
322326
323327            modules_status  =  {
324328                module : {"start_time" : None , "completion_time" : None , "duration" : None }
@@ -371,14 +375,26 @@ def _update_total_duration(self):
371375            for  k , v  in  modules_status .items ()
372376            if  k  not  in   ("cumulative_execution_duration" , "total_duration" )
373377        )
378+ 
379+         for  m  in  self ._modules :
380+             first_start_time  =  modules_status [m ]["start_time" ]
381+             if  first_start_time  is  not   None :
382+                 break 
383+ 
384+         for  m  in  self ._modules [::- 1 ]:
385+             last_completion_time  =  modules_status [m ]["completion_time" ]
386+             if  last_completion_time  is  not   None :
387+                 break 
388+ 
389+         if  first_start_time  is  None  or  last_completion_time  is  None :
390+             return 
391+ 
374392        total_duration  =  (
375393            datetime .strptime (
376-                 modules_status [ self . _modules [ - 1 ]][ "completion_time" ] ,
394+                 last_completion_time ,
377395                "%Y-%m-%d %H:%M:%S.%f" ,
378396            )
379-             -  datetime .strptime (
380-                 modules_status [self ._modules [0 ]]["start_time" ], "%Y-%m-%d %H:%M:%S.%f" 
381-             )
397+             -  datetime .strptime (first_start_time , "%Y-%m-%d %H:%M:%S.%f" )
382398        ).total_seconds ()
383399        self ._update_module_status (
384400            {
@@ -414,7 +430,6 @@ class OpenEphysKilosortPipeline:
414430    def  __init__ (
415431        self , npx_input_dir : str , ks_output_dir : str , params : dict , KS2ver : str 
416432    ):
417- 
418433        self ._npx_input_dir  =  pathlib .Path (npx_input_dir )
419434
420435        self ._ks_output_dir  =  pathlib .Path (ks_output_dir )
@@ -426,7 +441,13 @@ def __init__(
426441        self ._json_directory  =  self ._ks_output_dir  /  "json_configs" 
427442        self ._json_directory .mkdir (parents = True , exist_ok = True )
428443
429-         self ._median_subtraction_status  =  {}
444+         self ._module_input_json  =  (
445+             self ._json_directory  /  f"{ self ._npx_input_dir .name }  -input.json" 
446+         )
447+         self ._module_logfile  =  (
448+             self ._json_directory  /  f"{ self ._npx_input_dir .name }  -run_modules-log.txt" 
449+         )
450+ 
430451        self .ks_input_params  =  None 
431452        self ._modules_input_hash  =  None 
432453        self ._modules_input_hash_fp  =  None 
@@ -451,9 +472,6 @@ def make_chanmap_file(self):
451472
452473    def  generate_modules_input_json (self ):
453474        self .make_chanmap_file ()
454-         self ._module_input_json  =  (
455-             self ._json_directory  /  f"{ self ._npx_input_dir .name }  -input.json" 
456-         )
457475
458476        continuous_file  =  self ._get_raw_data_filepaths ()
459477
@@ -497,35 +515,37 @@ def generate_modules_input_json(self):
497515            ** params ,
498516        )
499517
500-         self ._modules_input_hash  =  dict_to_uuid (self .ks_input_params )
518+         self ._modules_input_hash  =  dict_to_uuid (dict ( self ._params ,  KS2ver = self . _KS2ver ) )
501519
502-     def  run_modules (self ):
520+     def  run_modules (self ,  modules_to_run = None ):
503521        print ("---- Running Modules ----" )
504522        self .generate_modules_input_json ()
505523        module_input_json  =  self ._module_input_json .as_posix ()
506-         module_logfile  =  module_input_json .replace (
507-             "-input.json" , "-run_modules-log.txt" 
508-         )
524+         module_logfile  =  self ._module_logfile .as_posix ()
509525
510-         for  module  in  self ._modules :
526+         modules  =  modules_to_run  or  self ._modules 
527+ 
528+         for  module  in  modules :
511529            module_status  =  self ._get_module_status (module )
512530            if  module_status ["completion_time" ] is  not   None :
513531                continue 
514532
515-             if  module  ==  "median_subtraction"  and  self ._median_subtraction_status :
516-                 median_subtraction_status  =  self ._get_module_status (
517-                     "median_subtraction" 
518-                 )
519-                 median_subtraction_status ["duration" ] =  self ._median_subtraction_status [
520-                     "duration" 
521-                 ]
522-                 median_subtraction_status ["completion_time" ] =  datetime .strptime (
523-                     median_subtraction_status ["start_time" ], "%Y-%m-%d %H:%M:%S.%f" 
524-                 ) +  timedelta (seconds = median_subtraction_status ["duration" ])
525-                 self ._update_module_status (
526-                     {"median_subtraction" : median_subtraction_status }
533+             if  module  ==  "median_subtraction" :
534+                 median_subtraction_duration  =  (
535+                     self ._get_median_subtraction_duration_from_log ()
527536                )
528-                 continue 
537+                 if  median_subtraction_duration  is  not   None :
538+                     median_subtraction_status  =  self ._get_module_status (
539+                         "median_subtraction" 
540+                     )
541+                     median_subtraction_status ["duration" ] =  median_subtraction_duration 
542+                     median_subtraction_status ["completion_time" ] =  datetime .strptime (
543+                         median_subtraction_status ["start_time" ], "%Y-%m-%d %H:%M:%S.%f" 
544+                     ) +  timedelta (seconds = median_subtraction_status ["duration" ])
545+                     self ._update_module_status (
546+                         {"median_subtraction" : median_subtraction_status }
547+                     )
548+                     continue 
529549
530550            module_output_json  =  self ._get_module_output_json_filename (module )
531551            command  =  [
@@ -576,26 +596,11 @@ def _get_raw_data_filepaths(self):
576596        assert  "depth_estimation"  in  self ._modules 
577597        continuous_file  =  self ._ks_output_dir  /  "continuous.dat" 
578598        if  continuous_file .exists ():
579-             if  raw_ap_fp .stat ().st_mtime  <  continuous_file .stat ().st_mtime :
580-                 # if the copied continuous.dat was actually modified, 
581-                 # median_subtraction may have been completed - let's check 
582-                 module_input_json  =  self ._module_input_json .as_posix ()
583-                 module_logfile  =  module_input_json .replace (
584-                     "-input.json" , "-run_modules-log.txt" 
585-                 )
586-                 with  open (module_logfile , "r" ) as  f :
587-                     previous_line  =  "" 
588-                     for  line  in  f .readlines ():
589-                         if  line .startswith (
590-                             "ecephys spike sorting: median subtraction module" 
591-                         ) and  previous_line .startswith ("Total processing time:" ):
592-                             # regex to search for the processing duration - a float value 
593-                             duration  =  int (
594-                                 re .search ("\d+\.?\d+" , previous_line ).group ()
595-                             )
596-                             self ._median_subtraction_status ["duration" ] =  duration 
597-                             return  continuous_file 
598-                         previous_line  =  line 
599+             if  raw_ap_fp .stat ().st_mtime  ==  continuous_file .stat ().st_mtime :
600+                 return  continuous_file 
601+             else :
602+                 if  self ._module_logfile .exists ():
603+                     return  continuous_file 
599604
600605        shutil .copy2 (raw_ap_fp , continuous_file )
601606        return  continuous_file 
@@ -614,13 +619,11 @@ def _update_module_status(self, updated_module_status={}):
614619        else :
615620            # handle cases of processing rerun on different parameters (the hash changes) 
616621            # delete outdated files 
617-             outdated_files   =   [
618-                 f 
622+             [
623+                 f . unlink () 
619624                for  f  in  self ._json_directory .glob ("*" )
620625                if  f .is_file () and  f .name  !=  self ._module_input_json .name 
621626            ]
622-             for  f  in  outdated_files :
623-                 f .unlink ()
624627
625628            modules_status  =  {
626629                module : {"start_time" : None , "completion_time" : None , "duration" : None }
@@ -673,14 +676,26 @@ def _update_total_duration(self):
673676            for  k , v  in  modules_status .items ()
674677            if  k  not  in   ("cumulative_execution_duration" , "total_duration" )
675678        )
679+ 
680+         for  m  in  self ._modules :
681+             first_start_time  =  modules_status [m ]["start_time" ]
682+             if  first_start_time  is  not   None :
683+                 break 
684+ 
685+         for  m  in  self ._modules [::- 1 ]:
686+             last_completion_time  =  modules_status [m ]["completion_time" ]
687+             if  last_completion_time  is  not   None :
688+                 break 
689+ 
690+         if  first_start_time  is  None  or  last_completion_time  is  None :
691+             return 
692+ 
676693        total_duration  =  (
677694            datetime .strptime (
678-                 modules_status [ self . _modules [ - 1 ]][ "completion_time" ] ,
695+                 last_completion_time ,
679696                "%Y-%m-%d %H:%M:%S.%f" ,
680697            )
681-             -  datetime .strptime (
682-                 modules_status [self ._modules [0 ]]["start_time" ], "%Y-%m-%d %H:%M:%S.%f" 
683-             )
698+             -  datetime .strptime (first_start_time , "%Y-%m-%d %H:%M:%S.%f" )
684699        ).total_seconds ()
685700        self ._update_module_status (
686701            {
@@ -689,6 +704,26 @@ def _update_total_duration(self):
689704            }
690705        )
691706
707+     def  _get_median_subtraction_duration_from_log (self ):
708+         raw_ap_fp  =  self ._npx_input_dir  /  "continuous.dat" 
709+         continuous_file  =  self ._ks_output_dir  /  "continuous.dat" 
710+         if  raw_ap_fp .stat ().st_mtime  <  continuous_file .stat ().st_mtime :
711+             # if the copied continuous.dat was actually modified, 
712+             # median_subtraction may have been completed - let's check 
713+             if  self ._module_logfile .exists ():
714+                 with  open (self ._module_logfile , "r" ) as  f :
715+                     previous_line  =  "" 
716+                     for  line  in  f .readlines ():
717+                         if  line .startswith (
718+                             "ecephys spike sorting: median subtraction module" 
719+                         ) and  previous_line .startswith ("Total processing time:" ):
720+                             # regex to search for the processing duration - a float value 
721+                             duration  =  int (
722+                                 re .search ("\d+\.?\d+" , previous_line ).group ()
723+                             )
724+                             return  duration 
725+                         previous_line  =  line 
726+ 
692727
693728def  run_pykilosort (
694729    continuous_file ,
0 commit comments