Skip to content

Commit da985ce

Browse files
authored
Fix bug where ultra-short beam breaks caused the next trial to be ignored (calderast#164)
1 parent 4c07ce1 commit da985ce

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

src/jdb_to_nwb/convert_behavior.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,44 @@ def parse_arduino_text(arduino_text: list, arduino_timestamps: list, logger):
152152

153153
# If we are in the middle of a beam break, update the end times until we reach the end of the beam break
154154
if current_trial:
155+
156+
# If we are in the middle of a trial that ends at a different port, end that trial first!
157+
# This only happens during ultra-short beam breaks, e.g. this arduino snippet below.
158+
# See Github issue https://github.com/calderast/jdb_to_nwb/issues/163
159+
# ...
160+
# beam break at port A; 134084 trial 7 of 61
161+
# beam break at port A; 134130 trial 7 of 61
162+
# beam break at port C; 142826 trial 7 of 61 <-- BEAM BREAK
163+
# no Reward port C; trial 7 of 61 <-- REWARD INFO FOR THIS TRIAL
164+
# beam break at port B; 148678 trial 8 of 61 <-- NEXT BEAM BREAK IS ALREADY THE NEXT TRIAL!
165+
# rwd delivered at port B; 148731
166+
# beam break at port B; 148732 trial 9 of 61
167+
# beam break at port B; 148767 trial 9 of 61
168+
# ...
169+
if port != current_trial["end_port"]:
170+
trial_data.append(current_trial)
171+
logger.debug(f"Short beam break! Beam break is over. Adding trial {current_trial}")
172+
previous_trial = current_trial
173+
current_trial = {}
174+
trial_within_session += 1
175+
trial_within_block += 1
176+
177+
# Start new trial at this port
178+
current_trial = {
179+
"start_time": float(previous_trial["end_time"]),
180+
"beam_break_start": float(arduino_timestamps[i]),
181+
"start_port": previous_trial.get("end_port", "None"),
182+
"end_port": port,
183+
"trial_within_block": trial_within_block,
184+
"trial_within_session": trial_within_session,
185+
"block": current_block.get("block"),
186+
}
187+
current_trial["reward"] = (
188+
1
189+
if re.search(rf"rwd delivered at port {port}", arduino_text[i + 1])
190+
else 0 if re.search(rf"no Reward port {port}", arduino_text[i + 1]) else None
191+
)
192+
155193
current_trial["beam_break_end"] = float(arduino_timestamps[i])
156194
current_trial["end_time"] = float(arduino_timestamps[i])
157195

@@ -326,6 +364,14 @@ def validate_trial_and_block_data(trial_data: list, block_data: list, logger):
326364
f"Got block numbers {block_numbers}"
327365
)
328366
logger.debug(f"All block numbers are unique and match the range 1 to {len(block_data)}")
367+
368+
# The end time of each trial must be the start time of the next trial
369+
for t1, t2 in zip(trial_data, trial_data[1:]):
370+
assert t1.get("end_time") == t2.get("start_time"), (
371+
f"Trial {t1.get('trial_within_session')} end_time {t1.get('end_time')} "
372+
f"does not match trial {t2.get('trial_within_session')} start_time {t2.get('start_time')}"
373+
)
374+
logger.debug("The end time of each trial matches the start time of the next trial")
329375

330376
# There must be a legitimate reward value (1 or 0) for all trials (instead of default None)
331377
assert all(trial.get("reward") in {0, 1} for trial in trial_data), (

0 commit comments

Comments
 (0)