Skip to content

Commit 67ca666

Browse files
authored
Fixed multi-row task table bug (#1433)
* fixed multi-row task table bug * updated changelog
1 parent 75f6282 commit 67ca666

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import all foreign key references.
2626
- Fix error from unlinked object in `AnalysisNwbfile.create` #1396
2727
- Sort `UserEnvironment` dict objects by key for consistency #1380
2828
- Fix typo in VideoFile.make #1427
29+
- Fix bug in TaskEpoch.make so that it correctly handles multi-row task tables from NWB #1433
2930

3031
### Infrastructure
3132

src/spyglass/common/common_task.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,14 @@ def make(self, key):
177177
if not self.is_nwb_task_epoch(task_table):
178178
continue
179179
task_inserts.append(task_table)
180-
for task in task_table:
181-
key["task_name"] = task.task_name[0]
180+
task_df = task_table.to_dataframe()
181+
for task in task_df.itertuples(index=False):
182+
key["task_name"] = task.task_name
182183

183184
# get the CameraDevice used for this task (primary key is
184185
# camera name so we need to map from ID to name)
185186

186-
camera_ids = task.camera_id[0]
187+
camera_ids = task.camera_id
187188
valid_camera_ids = [
188189
camera_id
189190
for camera_id in camera_ids
@@ -201,7 +202,7 @@ def make(self, key):
201202
)
202203
# Add task environment
203204
if hasattr(task, "task_environment"):
204-
key["task_environment"] = task.task_environment[0]
205+
key["task_environment"] = task.task_environment
205206

206207
# get the interval list for this task, which corresponds to the
207208
# matching epoch for the raw data. Users should define more
@@ -210,7 +211,7 @@ def make(self, key):
210211
session_intervals = (
211212
IntervalList() & {"nwb_file_name": nwb_file_name}
212213
).fetch("interval_list_name")
213-
for epoch in task.task_epochs[0]:
214+
for epoch in task.task_epochs:
214215
key["epoch"] = epoch
215216
target_interval = self.get_epoch_interval_name(
216217
epoch, session_intervals

0 commit comments

Comments
 (0)