@@ -152,6 +152,44 @@ def parse_arduino_text(arduino_text: list, arduino_timestamps: list, logger):
152152
153153 # If we are in the middle of a beam break, update the end times until we reach the end of the beam break
154154 if current_trial :
155+
156+ # If we are in the middle of a trial that ends at a different port, end that trial first!
157+ # This only happens during ultra-short beam breaks, e.g. this arduino snippet below.
158+ # See Github issue https://github.com/calderast/jdb_to_nwb/issues/163
159+ # ...
160+ # beam break at port A; 134084 trial 7 of 61
161+ # beam break at port A; 134130 trial 7 of 61
162+ # beam break at port C; 142826 trial 7 of 61 <-- BEAM BREAK
163+ # no Reward port C; trial 7 of 61 <-- REWARD INFO FOR THIS TRIAL
164+ # beam break at port B; 148678 trial 8 of 61 <-- NEXT BEAM BREAK IS ALREADY THE NEXT TRIAL!
165+ # rwd delivered at port B; 148731
166+ # beam break at port B; 148732 trial 9 of 61
167+ # beam break at port B; 148767 trial 9 of 61
168+ # ...
169+ if port != current_trial ["end_port" ]:
170+ trial_data .append (current_trial )
171+ logger .debug (f"Short beam break! Beam break is over. Adding trial { current_trial } " )
172+ previous_trial = current_trial
173+ current_trial = {}
174+ trial_within_session += 1
175+ trial_within_block += 1
176+
177+ # Start new trial at this port
178+ current_trial = {
179+ "start_time" : float (previous_trial ["end_time" ]),
180+ "beam_break_start" : float (arduino_timestamps [i ]),
181+ "start_port" : previous_trial .get ("end_port" , "None" ),
182+ "end_port" : port ,
183+ "trial_within_block" : trial_within_block ,
184+ "trial_within_session" : trial_within_session ,
185+ "block" : current_block .get ("block" ),
186+ }
187+ current_trial ["reward" ] = (
188+ 1
189+ if re .search (rf"rwd delivered at port { port } " , arduino_text [i + 1 ])
190+ else 0 if re .search (rf"no Reward port { port } " , arduino_text [i + 1 ]) else None
191+ )
192+
155193 current_trial ["beam_break_end" ] = float (arduino_timestamps [i ])
156194 current_trial ["end_time" ] = float (arduino_timestamps [i ])
157195
@@ -326,6 +364,14 @@ def validate_trial_and_block_data(trial_data: list, block_data: list, logger):
326364 f"Got block numbers { block_numbers } "
327365 )
328366 logger .debug (f"All block numbers are unique and match the range 1 to { len (block_data )} " )
367+
368+ # The end time of each trial must be the start time of the next trial
369+ for t1 , t2 in zip (trial_data , trial_data [1 :]):
370+ assert t1 .get ("end_time" ) == t2 .get ("start_time" ), (
371+ f"Trial { t1 .get ('trial_within_session' )} end_time { t1 .get ('end_time' )} "
372+ f"does not match trial { t2 .get ('trial_within_session' )} start_time { t2 .get ('start_time' )} "
373+ )
374+ logger .debug ("The end time of each trial matches the start time of the next trial" )
329375
330376 # There must be a legitimate reward value (1 or 0) for all trials (instead of default None)
331377 assert all (trial .get ("reward" ) in {0 , 1 } for trial in trial_data ), (
0 commit comments