Skip to content

Commit b17c225

Browse files
authored
Explicit epoch format handling #1443 (#1459)
* Explicit epoch format handling #1443 * Update changelog * PR feedback * Refactor TaskEpoch inserts
1 parent af64081 commit b17c225

File tree

4 files changed

+335
-73
lines changed

4 files changed

+335
-73
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ DecodingParameters().alter()
6262
- Improve error transparency on duplicate `Electrode` ids #1454
6363
- Remove pre-existing `Units` from created analysis nwb files #1453
6464
- Allow multiple VideoFile entries during ingestion #1462
65+
- Handle epoch formats with varying zero-padding #1459
6566
- Decoding
6667
- Ensure results directory is created if it doesn't exist #1362
6768
- Change BLOB fields to LONGBLOB in DecodingParameters #1463

src/spyglass/common/common_task.py

Lines changed: 159 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,79 @@ class TaskEpoch(SpyglassMixin, dj.Imported):
131131
camera_names : blob # list of keys corresponding to entry in CameraDevice
132132
"""
133133

134+
@classmethod
135+
def _get_valid_camera_names(cls, camera_ids, camera_names, context=""):
136+
"""Get valid camera names for given camera IDs.
137+
138+
Parameters
139+
----------
140+
camera_ids : list
141+
List of camera IDs to validate
142+
camera_names : dict
143+
Mapping of camera ID to camera name
144+
context : str, optional
145+
Context string for warning message
146+
147+
Returns
148+
-------
149+
list or None
150+
List of camera name dicts, or None if no valid cameras found
151+
"""
152+
valid_camera_ids = [
153+
camera_id
154+
for camera_id in camera_ids
155+
if camera_id in camera_names.keys()
156+
]
157+
if valid_camera_ids:
158+
return [
159+
{"camera_name": camera_names[camera_id]}
160+
for camera_id in valid_camera_ids
161+
]
162+
if camera_ids: # Only warn if camera_ids were specified
163+
logger.warning(
164+
f"No camera device found with ID {camera_ids}{context}\n"
165+
)
166+
return None
167+
168+
@classmethod
169+
def _process_task_epochs(
170+
cls, base_key, task_epochs, nwb_file_name, session_intervals
171+
):
172+
"""Process task epochs and create TaskEpoch insert entries.
173+
174+
Parameters
175+
----------
176+
base_key : dict
177+
Base key dict with task_name, camera_names, etc.
178+
task_epochs : list
179+
List of epoch numbers/identifiers
180+
nwb_file_name : str
181+
Name of the NWB file
182+
session_intervals : list
183+
Available interval names from IntervalList
184+
185+
Returns
186+
-------
187+
list
188+
List of dicts ready for TaskEpoch insertion
189+
"""
190+
inserts = []
191+
for epoch in task_epochs:
192+
epoch_key = base_key.copy()
193+
epoch_key["epoch"] = epoch
194+
target_interval = cls.get_epoch_interval_name(
195+
epoch, session_intervals
196+
)
197+
if target_interval is None:
198+
continue
199+
epoch_key["interval_list_name"] = target_interval
200+
inserts.append(epoch_key)
201+
return inserts
202+
134203
def make(self, key):
135204
"""Populate TaskEpoch from the processing module in the NWB file."""
136205
nwb_file_name = key["nwb_file_name"]
206+
nwb_dict = dict(nwb_file_name=nwb_file_name)
137207
nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
138208
nwbf = get_nwb_file(nwb_file_abspath)
139209
config = get_config(nwb_file_abspath, calling_table=self.camel_name)
@@ -148,6 +218,7 @@ def make(self, key):
148218
# get the camera ID
149219
camera_id = int(str.split(device.name)[1])
150220
camera_names[camera_id] = device.camera_name
221+
151222
if device_list := config.get("CameraDevice"):
152223
for device in device_list:
153224
camera_names.update(
@@ -171,6 +242,8 @@ def make(self, key):
171242
)
172243
return
173244

245+
sess_intervals = (IntervalList & nwb_dict).fetch("interval_list_name")
246+
174247
task_inserts = [] # inserts for Task table
175248
task_epoch_inserts = [] # inserts for TaskEpoch table
176249
for task_table in tasks_mod.data_interfaces.values():
@@ -181,79 +254,50 @@ def make(self, key):
181254
for task in task_df.itertuples(index=False):
182255
key["task_name"] = task.task_name
183256

184-
# get the CameraDevice used for this task (primary key is
185-
# camera name so we need to map from ID to name)
186-
187-
camera_ids = task.camera_id
188-
valid_camera_ids = [
189-
camera_id
190-
for camera_id in camera_ids
191-
if camera_id in camera_names.keys()
192-
]
193-
if valid_camera_ids:
194-
key["camera_names"] = [
195-
{"camera_name": camera_names[camera_id]}
196-
for camera_id in valid_camera_ids
197-
]
198-
else:
199-
logger.warning(
200-
f"No camera device found with ID {camera_ids} in NWB "
201-
+ f"file {nwbf}\n"
202-
)
203-
# Add task environment
257+
# Get valid camera names for this task
258+
camera_names_list = self._get_valid_camera_names(
259+
task.camera_id,
260+
camera_names,
261+
context=f" in NWB file {nwbf}",
262+
)
263+
if camera_names_list:
264+
key["camera_names"] = camera_names_list
265+
266+
# Add task environment if present
204267
if hasattr(task, "task_environment"):
205268
key["task_environment"] = task.task_environment
206269

207-
# get the interval list for this task, which corresponds to the
208-
# matching epoch for the raw data. Users should define more
209-
# restrictive intervals as required for analyses
210-
211-
session_intervals = (
212-
IntervalList() & {"nwb_file_name": nwb_file_name}
213-
).fetch("interval_list_name")
214-
for epoch in task.task_epochs:
215-
key["epoch"] = epoch
216-
target_interval = self.get_epoch_interval_name(
217-
epoch, session_intervals
270+
# Process all epochs for this task
271+
task_epoch_inserts.extend(
272+
self._process_task_epochs(
273+
key, task.task_epochs, nwb_file_name, sess_intervals
218274
)
219-
if target_interval is None:
220-
logger.warning("Skipping epoch.")
221-
continue
222-
key["interval_list_name"] = target_interval
223-
task_epoch_inserts.append(key.copy())
275+
)
224276

225277
# Add tasks from config
226278
for task in config_tasks:
227-
new_key = {
279+
task_key = {
228280
**key,
229281
"task_name": task.get("task_name"),
230282
"task_environment": task.get("task_environment", None),
231283
}
232-
# add cameras
233-
camera_ids = task.get("camera_id", [])
234-
valid_camera_ids = [
235-
camera_id
236-
for camera_id in camera_ids
237-
if camera_id in camera_names.keys()
238-
]
239-
if valid_camera_ids:
240-
new_key["camera_names"] = [
241-
{"camera_name": camera_names[camera_id]}
242-
for camera_id in valid_camera_ids
243-
]
244-
session_intervals = (
245-
IntervalList() & {"nwb_file_name": nwb_file_name}
246-
).fetch("interval_list_name")
247-
for epoch in task.get("task_epochs", []):
248-
new_key["epoch"] = epoch
249-
target_interval = self.get_epoch_interval_name(
250-
epoch, session_intervals
284+
285+
# Add cameras if specified
286+
camera_names_list = self._get_valid_camera_names(
287+
task.get("camera_id", []), camera_names
288+
)
289+
if camera_names_list:
290+
task_key["camera_names"] = camera_names_list
291+
292+
# Process all epochs for this task
293+
task_epoch_inserts.extend(
294+
self._process_task_epochs(
295+
task_key,
296+
task.get("task_epochs", []),
297+
nwb_file_name,
298+
sess_intervals,
251299
)
252-
if target_interval is None:
253-
logger.warning("Skipping epoch.")
254-
continue
255-
new_key["interval_list_name"] = target_interval
256-
task_epoch_inserts.append(key.copy())
300+
)
257301

258302
# check if the task entries are in the Task table and if not, add it
259303
[
@@ -264,23 +308,66 @@ def make(self, key):
264308

265309
@classmethod
266310
def get_epoch_interval_name(cls, epoch, session_intervals):
267-
"""Get the interval name for a given epoch based on matching number"""
268-
target_interval = str(epoch).zfill(2)
311+
"""Get the interval name for a given epoch based on matching number.
312+
313+
This method implements flexible matching to handle various epoch tag
314+
formats. It tries multiple formats to find a match:
315+
1. Exact match (e.g., "1")
316+
2. Two-digit zero-padded (e.g., "01")
317+
3. Three-digit zero-padded (e.g., "001")
318+
319+
Parameters
320+
----------
321+
epoch : int or str
322+
The epoch number to search for
323+
session_intervals : list of str
324+
List of interval names from IntervalList
325+
326+
Returns
327+
-------
328+
str or None
329+
The matching interval name, or None if no unique match is found
330+
331+
Examples
332+
--------
333+
>>> session_intervals = ["1", "02", "003"]
334+
>>> TaskEpoch.get_epoch_interval_name(1, session_intervals)
335+
'1'
336+
>>> TaskEpoch.get_epoch_interval_name(2, session_intervals)
337+
'02'
338+
>>> TaskEpoch.get_epoch_interval_name(3, session_intervals)
339+
'003'
340+
"""
341+
if epoch in session_intervals:
342+
return epoch
343+
344+
# Try multiple formats:
345+
possible_formats = [
346+
str(epoch), # Try exact match first (e.g., "1")
347+
str(epoch).zfill(2), # Try 2-digit zero-pad (e.g., "01")
348+
str(epoch).zfill(3), # Try 3-digit zero-pad (e.g., "001")
349+
]
350+
unique_formats = list(dict.fromkeys(possible_formats))
351+
352+
# Find matches for any format, remove duplicates preserving order
269353
possible_targets = [
270354
interval
271355
for interval in session_intervals
272-
if target_interval in interval
356+
for target in unique_formats
357+
if target in interval
273358
]
274-
if not possible_targets:
275-
logger.warning(f"Interval not found for epoch {epoch}.")
276-
elif len(possible_targets) > 1:
277-
logger.warning(
278-
f"Multiple intervals found for epoch {epoch}. "
279-
+ f"matches are {possible_targets}."
280-
)
281-
else:
359+
360+
if len(set(possible_targets)) == 1:
282361
return possible_targets[0]
283362

363+
warn = "Multiple" if len(possible_targets) > 1 else "No"
364+
365+
logger.warning(
366+
f"{warn} interval(s) found for epoch {epoch}. "
367+
f"Available intervals: {session_intervals}"
368+
)
369+
return None
370+
284371
@classmethod
285372
def update_entries(cls, restrict=True):
286373
"""Update entries in the TaskEpoch table based on a restriction."""

0 commit comments

Comments
 (0)