Skip to content

Commit 594bc98

Browse files
authored
Debug VideoMaker - fix failing test (#1418)
* Debug VideoMaker * Update changelog * Address comments
1 parent be9314f commit 594bc98

File tree

5 files changed

+36
-7
lines changed

5 files changed

+36
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mountainsort4_output/
44
.idea/
55
mysql_config
66
memray*
7+
root_*.y*ml
78

89
# Notebooks
910
*.ipynb

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import all foreign key references.
4343
- Ensure video files are properly added to `DLCProject` # 1367
4444
- DLC parameter handling improvements and default value corrections #1379
4545
- Fix ingestion nwb files with position objects but no spatial series #1405
46+
- Ignore `percent_frames` when using `limit` in `DLCPosVideo` #1418
4647
- Spikesorting
4748
- Implement short-transaction `SpikeSortingRecording.make` for v0 #1338
4849

src/spyglass/position/v1/dlc_utils_makevid.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __init__(
7474
# key_hash supports resume from previous run
7575
self.temp_dir = Path(temp_dir) / f"dlc_vid_{key_hash}"
7676
self.temp_dir.mkdir(parents=True, exist_ok=True)
77+
if not self.temp_dir.exists(): # pragma: no cover
78+
raise FileNotFoundError(f"Could not create {self.temp_dir}")
7779
logger.debug(f"Temporary directory: {self.temp_dir}")
7880

7981
if not Path(video_filename).exists():
@@ -109,6 +111,7 @@ def __init__(
109111
self.timeout = 30 if test_mode else 300
110112

111113
self.ffmpeg_log_args = ["-hide_banner", "-loglevel", "error"]
114+
112115
self.ffmpeg_fmt_args = ["-c:v", "libx264", "-pix_fmt", "yuv420p"]
113116

114117
prev_backend = matplotlib.get_backend()
@@ -123,7 +126,12 @@ def __init__(
123126
)
124127
self.process_frames()
125128
plt.close(self.fig)
126-
logger.info(f"Finished video: {self.output_video_filename}")
129+
130+
if Path(self.output_video_filename).exists():
131+
logger.info(f"Finished video: {self.output_video_filename}")
132+
else:
133+
logger.error(f"Failed to create: {self.output_video_filename}")
134+
127135
logger.debug(f"Dropped frames: {self.dropped_frames}")
128136

129137
if not debug:
@@ -187,6 +195,9 @@ def _set_frame_info(self):
187195
if self.debug: # If debugging, limit frames to available data
188196
self.n_frames = min(len(self.position_mean), self.n_frames)
189197

198+
if self.n_frames == 0: # pragma: no cover
199+
raise ValueError("No frames to process!") # pragma: no cover
200+
190201
self.pad_len = len(str(self.n_frames))
191202

192203
def _set_plot_bases(self):
@@ -490,11 +501,18 @@ def ffmpeg_extract(self, start_frame, end_frame):
490501
]
491502
ret = subprocess.run(ffmpeg_cmd, stderr=subprocess.PIPE)
492503

504+
if ret.returncode != 0: # pragma: no cover
505+
logger.error(f"Error extracting frames: {ret.stderr}")
506+
if not self.temp_dir.glob("orig_*.png"): # pragma: no cover
507+
raise FileNotFoundError("No frames were extracted!")
508+
493509
extracted = len(list(self.temp_dir.glob("orig_*.png")))
494510
logger.debug(f"Extracted frames: {start_frame}, len: {extracted}")
495-
if extracted < self.batch_size - 1:
511+
512+
frame_diff = end_frame - start_frame + 1 # may be less than batch size
513+
if extracted < frame_diff:
496514
logger.warning(
497-
f"Could not extract frames: {extracted} / {self.batch_size-1}"
515+
f"Could not extract frames: {extracted} / {frame_diff}"
498516
)
499517
one_err = "\n".join(str(ret.stderr).split("\\")[-3:-1])
500518
logger.debug(f"\nExtract Error: {one_err}")
@@ -536,9 +554,16 @@ def ffmpeg_stitch_partial(self, start_frame, output_partial_video):
536554
except subprocess.CalledProcessError as e: # pragma: no cover
537555
logger.error(f"Err stitching video: {e.stderr}") # pragma: no cover
538556

557+
if not Path(output_partial_video).exists(): # pragma: no cover
558+
logger.error(f"Partial video not created: {output_partial_video}")
559+
539560
def concat_partial_videos(self):
540561
"""Concatenate all the partial videos into one final video."""
541562
partial_vids = sorted(self.temp_dir.glob("partial_*.mp4"))
563+
564+
if not partial_vids: # pragma: no cover
565+
raise FileNotFoundError("No partial videos to concatenate!")
566+
542567
logger.debug(f"Concat part vids: {len(partial_vids)}")
543568
concat_list_path = self.temp_dir / "concat_list.txt"
544569
with open(concat_list_path, "w") as f:
@@ -568,6 +593,7 @@ def concat_partial_videos(self):
568593
)
569594
except subprocess.CalledProcessError as e:
570595
logger.error(f"Error stitching partial video: {e.stderr}")
596+
raise # pragma: no cover
571597

572598

573599
def make_video(**kwargs):

src/spyglass/position/v1/position_dlc_selection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,12 @@ def make(self, key):
487487
else None
488488
)
489489
frames = params.get("frames", None)
490+
percent_frames = params.get("percent_frames", None)
490491

491492
if limit := params.get("limit", None): # new int param for debugging
492493
output_video_filename = Path(".") / f"TEST_VID_{limit}.mp4"
493494
video_frame_inds = video_frame_inds[:limit]
495+
percent_frames = 1
494496
pos_info_df = pos_info_df.head(limit)
495497

496498
video_maker = make_video(
@@ -505,7 +507,7 @@ def make(self, key):
505507
position_time=np.asarray(pos_info_df.index),
506508
processor=params.get("processor", "matplotlib"),
507509
frames=np.arange(frames[0], frames[1]) if frames else None,
508-
percent_frames=params.get("percent_frames", None),
510+
percent_frames=percent_frames,
509511
output_video_filename=output_video_filename,
510512
cm_to_pixels=meters_per_pixel * M_TO_CM,
511513
crop=pose_estimation_params.get("cropping"),

tests/position/v1/test_dlc_position.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,8 @@ def test_si_interpolate(sgp, si_params_tbl, si_key, pose_estimation_key):
9999

100100
@pytest.fixture(scope="session")
101101
def si_df(sgp, si_key, populate_si, bodyparts):
102-
yield (
103-
sgp.v1.DLCSmoothInterp() & {**si_key, "bodypart": bodyparts[0]}
104-
).fetch1_dataframe()
102+
_ = si_key, populate_si, bodyparts
103+
yield (sgp.v1.DLCSmoothInterp() & dj.Top()).fetch1_dataframe()
105104

106105

107106
def test_cohort_fetch1_dataframe(si_df):

0 commit comments

Comments
 (0)