@@ -131,9 +131,79 @@ class TaskEpoch(SpyglassMixin, dj.Imported):
131131 camera_names : blob # list of keys corresponding to entry in CameraDevice
132132 """
133133
134+ @classmethod
135+ def _get_valid_camera_names (cls , camera_ids , camera_names , context = "" ):
136+ """Get valid camera names for given camera IDs.
137+
138+ Parameters
139+ ----------
140+ camera_ids : list
141+ List of camera IDs to validate
142+ camera_names : dict
143+ Mapping of camera ID to camera name
144+ context : str, optional
145+ Context string for warning message
146+
147+ Returns
148+ -------
149+ list or None
150+ List of camera name dicts, or None if no valid cameras found
151+ """
152+ valid_camera_ids = [
153+ camera_id
154+ for camera_id in camera_ids
155+ if camera_id in camera_names .keys ()
156+ ]
157+ if valid_camera_ids :
158+ return [
159+ {"camera_name" : camera_names [camera_id ]}
160+ for camera_id in valid_camera_ids
161+ ]
162+ if camera_ids : # Only warn if camera_ids were specified
163+ logger .warning (
164+ f"No camera device found with ID { camera_ids } { context } \n "
165+ )
166+ return None
167+
168+ @classmethod
169+ def _process_task_epochs (
170+ cls , base_key , task_epochs , nwb_file_name , session_intervals
171+ ):
172+ """Process task epochs and create TaskEpoch insert entries.
173+
174+ Parameters
175+ ----------
176+ base_key : dict
177+ Base key dict with task_name, camera_names, etc.
178+ task_epochs : list
179+ List of epoch numbers/identifiers
180+ nwb_file_name : str
181+ Name of the NWB file
182+ session_intervals : list
183+ Available interval names from IntervalList
184+
185+ Returns
186+ -------
187+ list
188+ List of dicts ready for TaskEpoch insertion
189+ """
190+ inserts = []
191+ for epoch in task_epochs :
192+ epoch_key = base_key .copy ()
193+ epoch_key ["epoch" ] = epoch
194+ target_interval = cls .get_epoch_interval_name (
195+ epoch , session_intervals
196+ )
197+ if target_interval is None :
198+ continue
199+ epoch_key ["interval_list_name" ] = target_interval
200+ inserts .append (epoch_key )
201+ return inserts
202+
134203 def make (self , key ):
135204 """Populate TaskEpoch from the processing module in the NWB file."""
136205 nwb_file_name = key ["nwb_file_name" ]
206+ nwb_dict = dict (nwb_file_name = nwb_file_name )
137207 nwb_file_abspath = Nwbfile ().get_abs_path (nwb_file_name )
138208 nwbf = get_nwb_file (nwb_file_abspath )
139209 config = get_config (nwb_file_abspath , calling_table = self .camel_name )
@@ -148,6 +218,7 @@ def make(self, key):
148218 # get the camera ID
149219 camera_id = int (str .split (device .name )[1 ])
150220 camera_names [camera_id ] = device .camera_name
221+
151222 if device_list := config .get ("CameraDevice" ):
152223 for device in device_list :
153224 camera_names .update (
@@ -171,6 +242,8 @@ def make(self, key):
171242 )
172243 return
173244
245+ sess_intervals = (IntervalList & nwb_dict ).fetch ("interval_list_name" )
246+
174247 task_inserts = [] # inserts for Task table
175248 task_epoch_inserts = [] # inserts for TaskEpoch table
176249 for task_table in tasks_mod .data_interfaces .values ():
@@ -181,79 +254,50 @@ def make(self, key):
181254 for task in task_df .itertuples (index = False ):
182255 key ["task_name" ] = task .task_name
183256
184- # get the CameraDevice used for this task (primary key is
185- # camera name so we need to map from ID to name)
186-
187- camera_ids = task .camera_id
188- valid_camera_ids = [
189- camera_id
190- for camera_id in camera_ids
191- if camera_id in camera_names .keys ()
192- ]
193- if valid_camera_ids :
194- key ["camera_names" ] = [
195- {"camera_name" : camera_names [camera_id ]}
196- for camera_id in valid_camera_ids
197- ]
198- else :
199- logger .warning (
200- f"No camera device found with ID { camera_ids } in NWB "
201- + f"file { nwbf } \n "
202- )
203- # Add task environment
257+ # Get valid camera names for this task
258+ camera_names_list = self ._get_valid_camera_names (
259+ task .camera_id ,
260+ camera_names ,
261+ context = f" in NWB file { nwbf } " ,
262+ )
263+ if camera_names_list :
264+ key ["camera_names" ] = camera_names_list
265+
266+ # Add task environment if present
204267 if hasattr (task , "task_environment" ):
205268 key ["task_environment" ] = task .task_environment
206269
207- # get the interval list for this task, which corresponds to the
208- # matching epoch for the raw data. Users should define more
209- # restrictive intervals as required for analyses
210-
211- session_intervals = (
212- IntervalList () & {"nwb_file_name" : nwb_file_name }
213- ).fetch ("interval_list_name" )
214- for epoch in task .task_epochs :
215- key ["epoch" ] = epoch
216- target_interval = self .get_epoch_interval_name (
217- epoch , session_intervals
270+ # Process all epochs for this task
271+ task_epoch_inserts .extend (
272+ self ._process_task_epochs (
273+ key , task .task_epochs , nwb_file_name , sess_intervals
218274 )
219- if target_interval is None :
220- logger .warning ("Skipping epoch." )
221- continue
222- key ["interval_list_name" ] = target_interval
223- task_epoch_inserts .append (key .copy ())
275+ )
224276
225277 # Add tasks from config
226278 for task in config_tasks :
227- new_key = {
279+ task_key = {
228280 ** key ,
229281 "task_name" : task .get ("task_name" ),
230282 "task_environment" : task .get ("task_environment" , None ),
231283 }
232- # add cameras
233- camera_ids = task .get ("camera_id" , [])
234- valid_camera_ids = [
235- camera_id
236- for camera_id in camera_ids
237- if camera_id in camera_names .keys ()
238- ]
239- if valid_camera_ids :
240- new_key ["camera_names" ] = [
241- {"camera_name" : camera_names [camera_id ]}
242- for camera_id in valid_camera_ids
243- ]
244- session_intervals = (
245- IntervalList () & {"nwb_file_name" : nwb_file_name }
246- ).fetch ("interval_list_name" )
247- for epoch in task .get ("task_epochs" , []):
248- new_key ["epoch" ] = epoch
249- target_interval = self .get_epoch_interval_name (
250- epoch , session_intervals
284+
285+ # Add cameras if specified
286+ camera_names_list = self ._get_valid_camera_names (
287+ task .get ("camera_id" , []), camera_names
288+ )
289+ if camera_names_list :
290+ task_key ["camera_names" ] = camera_names_list
291+
292+ # Process all epochs for this task
293+ task_epoch_inserts .extend (
294+ self ._process_task_epochs (
295+ task_key ,
296+ task .get ("task_epochs" , []),
297+ nwb_file_name ,
298+ sess_intervals ,
251299 )
252- if target_interval is None :
253- logger .warning ("Skipping epoch." )
254- continue
255- new_key ["interval_list_name" ] = target_interval
256- task_epoch_inserts .append (key .copy ())
300+ )
257301
258302 # check if the task entries are in the Task table and if not, add it
259303 [
@@ -264,23 +308,66 @@ def make(self, key):
264308
265309 @classmethod
266310 def get_epoch_interval_name (cls , epoch , session_intervals ):
267- """Get the interval name for a given epoch based on matching number"""
268- target_interval = str (epoch ).zfill (2 )
311+ """Get the interval name for a given epoch based on matching number.
312+
313+ This method implements flexible matching to handle various epoch tag
314+ formats. It tries multiple formats to find a match:
315+ 1. Exact match (e.g., "1")
316+ 2. Two-digit zero-padded (e.g., "01")
317+ 3. Three-digit zero-padded (e.g., "001")
318+
319+ Parameters
320+ ----------
321+ epoch : int or str
322+ The epoch number to search for
323+ session_intervals : list of str
324+ List of interval names from IntervalList
325+
326+ Returns
327+ -------
328+ str or None
329+ The matching interval name, or None if no unique match is found
330+
331+ Examples
332+ --------
333+ >>> session_intervals = ["1", "02", "003"]
334+ >>> TaskEpoch.get_epoch_interval_name(1, session_intervals)
335+ '1'
336+ >>> TaskEpoch.get_epoch_interval_name(2, session_intervals)
337+ '02'
338+ >>> TaskEpoch.get_epoch_interval_name(3, session_intervals)
339+ '003'
340+ """
341+ if epoch in session_intervals :
342+ return epoch
343+
344+ # Try multiple formats:
345+ possible_formats = [
346+ str (epoch ), # Try exact match first (e.g., "1")
347+ str (epoch ).zfill (2 ), # Try 2-digit zero-pad (e.g., "01")
348+ str (epoch ).zfill (3 ), # Try 3-digit zero-pad (e.g., "001")
349+ ]
350+ unique_formats = list (dict .fromkeys (possible_formats ))
351+
352+ # Find matches for any format, remove duplicates preserving order
269353 possible_targets = [
270354 interval
271355 for interval in session_intervals
272- if target_interval in interval
356+ for target in unique_formats
357+ if target in interval
273358 ]
274- if not possible_targets :
275- logger .warning (f"Interval not found for epoch { epoch } ." )
276- elif len (possible_targets ) > 1 :
277- logger .warning (
278- f"Multiple intervals found for epoch { epoch } . "
279- + f"matches are { possible_targets } ."
280- )
281- else :
359+
360+ if len (set (possible_targets )) == 1 :
282361 return possible_targets [0 ]
283362
363+ warn = "Multiple" if len (possible_targets ) > 1 else "No"
364+
365+ logger .warning (
366+ f"{ warn } interval(s) found for epoch { epoch } . "
367+ f"Available intervals: { session_intervals } "
368+ )
369+ return None
370+
284371 @classmethod
285372 def update_entries (cls , restrict = True ):
286373 """Update entries in the TaskEpoch table based on a restriction."""
0 commit comments