Skip to content

Commit 6db6e23

Browse files
committed
Pipeline updates for 3.0
1 parent 01fbe90 commit 6db6e23

File tree

4 files changed

+179
-135
lines changed

4 files changed

+179
-135
lines changed

src/spyglass/position/v1/dlc_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def __init__(self, dlc_dir, filename_prefix=""):
6262
"shuffle": int(shuffle),
6363
"snapshotindex": self.yml["snapshotindex"],
6464
"trainingsetindex": np.where(yml_frac == pkl_frac)[0][0],
65-
"training_iteration": int(self.pkl["Scorer"].split("_")[-1]),
65+
"training_iteration": int(
66+
self.pkl["Scorer"].split("_")[-1].replace("best-", "")
67+
),
6668
}
6769

6870
self.fps = self.pkl["fps"]

src/spyglass/position/v1/pipeline_dlc_inference.py

Lines changed: 136 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -109,116 +109,128 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
109109
logger.info(
110110
f"---- Step 2: Pose Estimation | {dlc_pipeline_description} ----"
111111
)
112-
pose_estimation_selection_key = {
113-
**epoch_key,
114-
**model_key, # Includes project_name implicitly
115-
}
116-
if not (DLCPoseEstimationSelection & pose_estimation_selection_key):
117-
# Returns key if successful/exists, None otherwise
118-
sel_key = DLCPoseEstimationSelection().insert_estimation_task(
119-
pose_estimation_selection_key, skip_duplicates=skip_duplicates
120-
)
121-
if not sel_key: # If insert failed (e.g. duplicate and skip=False)
122-
if skip_duplicates and (
123-
DLCPoseEstimationSelection & pose_estimation_selection_key
124-
):
125-
logger.warning(
126-
f"Pose Estimation Selection already exists for {pose_estimation_selection_key}"
127-
)
128-
else:
129-
raise dj.errors.DataJointError(
130-
f"Failed to insert Pose Estimation Selection for {pose_estimation_selection_key}"
131-
)
132-
else:
133-
logger.warning(
134-
f"Pose Estimation Selection already exists for {pose_estimation_selection_key}"
135-
)
136-
137-
# Ensure selection exists before populating
138-
if not (DLCPoseEstimationSelection & pose_estimation_selection_key):
139-
raise dj.errors.DataJointError(
140-
f"Pose Estimation Selection missing for {pose_estimation_selection_key}"
141-
)
112+
video_file_nums = (
113+
VideoFile() & {"nwb_file_name": nwb_file_name, "epoch": epoch}
114+
).fetch("video_file_num")
115+
116+
for video_file_num in video_file_nums:
117+
pose_estimation_selection_key = {
118+
**epoch_key,
119+
**model_key, # Includes project_name implicitly
120+
"video_file_num": video_file_num,
121+
}
122+
if not (DLCPoseEstimationSelection & pose_estimation_selection_key):
123+
# Returns key if successful/exists, None otherwise
124+
sel_key = DLCPoseEstimationSelection().insert_estimation_task(
125+
pose_estimation_selection_key,
126+
skip_duplicates=skip_duplicates,
127+
)
128+
if (
129+
not sel_key
130+
): # If insert failed (e.g. duplicate and skip=False)
131+
if skip_duplicates and (
132+
DLCPoseEstimationSelection
133+
& pose_estimation_selection_key
134+
):
135+
logger.warning(
136+
f"Pose Estimation Selection already exists for {pose_estimation_selection_key}"
137+
)
138+
else:
139+
raise dj.errors.DataJointError(
140+
f"Failed to insert Pose Estimation Selection for {pose_estimation_selection_key}"
141+
)
142+
else:
143+
logger.warning(
144+
f"Pose Estimation Selection already exists for {pose_estimation_selection_key}"
145+
)
142146

143-
if not (DLCPoseEstimation & pose_estimation_selection_key):
144-
logger.info("Populating DLCPoseEstimation...")
145-
DLCPoseEstimation.populate(
146-
pose_estimation_selection_key, reserve_jobs=True, **kwargs
147-
)
148-
else:
149-
logger.info("DLCPoseEstimation already populated.")
147+
# Ensure selection exists before populating
148+
if not (DLCPoseEstimationSelection & pose_estimation_selection_key):
149+
raise dj.errors.DataJointError(
150+
f"Pose Estimation Selection missing for {pose_estimation_selection_key}"
151+
)
150152

151-
if not (DLCPoseEstimation & pose_estimation_selection_key):
152-
raise dj.errors.DataJointError(
153-
f"DLCPoseEstimation population failed for {pose_estimation_selection_key}"
154-
)
155-
pose_est_key = (
156-
DLCPoseEstimation & pose_estimation_selection_key
157-
).fetch1("KEY")
153+
if not (DLCPoseEstimation & pose_estimation_selection_key):
154+
logger.info("Populating DLCPoseEstimation...")
155+
DLCPoseEstimation.populate(
156+
pose_estimation_selection_key, **kwargs
157+
)
158+
else:
159+
logger.info("DLCPoseEstimation already populated.")
158160

159-
# --- 3. Smoothing/Interpolation (per bodypart) ---
160-
processed_bodyparts_keys = {} # Store keys for subsequent steps
161-
if run_smoothing_interp:
162-
logger.info(
163-
f"---- Step 3: Smooth/Interpolate | {dlc_pipeline_description} ----"
164-
)
165-
target_bodyparts = (
166-
bodyparts_params_dict.keys()
167-
if bodyparts_params_dict
168-
else (DLCPoseEstimation.BodyPart & pose_est_key).fetch(
169-
"bodypart"
161+
if not (DLCPoseEstimation & pose_estimation_selection_key):
162+
raise dj.errors.DataJointError(
163+
f"DLCPoseEstimation population failed for {pose_estimation_selection_key}"
170164
)
171-
)
165+
pose_est_key = (
166+
DLCPoseEstimation & pose_estimation_selection_key
167+
).fetch1("KEY")
172168

173-
for bodypart in target_bodyparts:
174-
logger.info(f"Processing bodypart: {bodypart}")
175-
current_si_params_name = bodyparts_params_dict.get(
176-
bodypart, dlc_si_params_name
169+
# --- 3. Smoothing/Interpolation (per bodypart) ---
170+
processed_bodyparts_keys = {} # Store keys for subsequent steps
171+
if run_smoothing_interp:
172+
logger.info(
173+
f"---- Step 3: Smooth/Interpolate | {dlc_pipeline_description} ----"
177174
)
178-
if not (
179-
DLCSmoothInterpParams
180-
& {"dlc_si_params_name": current_si_params_name}
181-
):
182-
raise ValueError(
183-
f"DLCSmoothInterpParams not found for {bodypart}: {current_si_params_name}"
184-
)
185-
186-
si_selection_key = {
187-
**pose_est_key,
188-
"bodypart": bodypart,
189-
"dlc_si_params_name": current_si_params_name,
190-
}
191-
if not (DLCSmoothInterpSelection & si_selection_key):
192-
DLCSmoothInterpSelection.insert1(
193-
si_selection_key, skip_duplicates=skip_duplicates
194-
)
195-
else:
196-
logger.warning(
197-
f"Smooth/Interp Selection already exists for {si_selection_key}"
198-
)
199-
200-
if not (DLCSmoothInterp & si_selection_key):
201-
logger.info(f"Populating DLCSmoothInterp for {bodypart}...")
202-
DLCSmoothInterp.populate(
203-
si_selection_key, reserve_jobs=True, **kwargs
204-
)
175+
if bodyparts_params_dict:
176+
target_bodyparts = bodyparts_params_dict.keys()
205177
else:
206-
logger.info(
207-
f"DLCSmoothInterp already populated for {bodypart}."
208-
)
209-
210-
if DLCSmoothInterp & si_selection_key:
211-
processed_bodyparts_keys[bodypart] = (
212-
DLCSmoothInterp & si_selection_key
213-
).fetch1("KEY")
214-
else:
215-
raise dj.errors.DataJointError(
216-
f"DLCSmoothInterp population failed for {si_selection_key}"
217-
)
218-
else:
219-
logger.info(
220-
f"Skipping Smoothing/Interpolation for {dlc_pipeline_description}"
221-
)
178+
target_bodyparts = (
179+
DLCPoseEstimation.BodyPart & pose_est_key
180+
).fetch("bodypart")
181+
182+
for bodypart in target_bodyparts:
183+
logger.info(f"Processing bodypart: {bodypart}")
184+
if bodyparts_params_dict is not None:
185+
current_si_params_name = bodyparts_params_dict.get(
186+
bodypart, dlc_si_params_name
187+
)
188+
else:
189+
current_si_params_name = dlc_si_params_name
190+
if not (
191+
DLCSmoothInterpParams
192+
& {"dlc_si_params_name": current_si_params_name}
193+
):
194+
raise ValueError(
195+
f"DLCSmoothInterpParams not found for {bodypart}: {current_si_params_name}"
196+
)
197+
198+
si_selection_key = {
199+
**pose_est_key,
200+
"bodypart": bodypart,
201+
"dlc_si_params_name": current_si_params_name,
202+
}
203+
if not (DLCSmoothInterpSelection & si_selection_key):
204+
DLCSmoothInterpSelection.insert1(
205+
si_selection_key, skip_duplicates=skip_duplicates
206+
)
207+
else:
208+
logger.warning(
209+
f"Smooth/Interp Selection already exists for {si_selection_key}"
210+
)
211+
212+
if not (DLCSmoothInterp & si_selection_key):
213+
logger.info(
214+
f"Populating DLCSmoothInterp for {bodypart}..."
215+
)
216+
DLCSmoothInterp.populate(si_selection_key, **kwargs)
217+
else:
218+
logger.info(
219+
f"DLCSmoothInterp already populated for {bodypart}."
220+
)
221+
222+
if DLCSmoothInterp & si_selection_key:
223+
processed_bodyparts_keys[bodypart] = (
224+
DLCSmoothInterp & si_selection_key
225+
).fetch1("KEY")
226+
else:
227+
raise dj.errors.DataJointError(
228+
f"DLCSmoothInterp population failed for {si_selection_key}"
229+
)
230+
else:
231+
logger.info(
232+
f"Skipping Smoothing/Interpolation for {dlc_pipeline_description}"
233+
)
222234

223235
# --- Steps 4-7 require bodyparts_params_dict ---
224236
if not bodyparts_params_dict:
@@ -253,7 +265,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
253265
"dlc_si_cohort_selection_name": cohort_selection_name,
254266
"bodyparts_params_dict": bodyparts_params_dict,
255267
}
256-
if not (DLCSmoothInterpCohortSelection & cohort_selection_key):
268+
if not (DLCSmoothInterpCohortSelection & pose_est_key):
257269
DLCSmoothInterpCohortSelection.insert1(
258270
cohort_selection_key, skip_duplicates=skip_duplicates
259271
)
@@ -264,9 +276,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
264276

265277
if not (DLCSmoothInterpCohort & cohort_selection_key):
266278
logger.info("Populating DLCSmoothInterpCohort...")
267-
DLCSmoothInterpCohort.populate(
268-
cohort_selection_key, reserve_jobs=True, **kwargs
269-
)
279+
DLCSmoothInterpCohort.populate(cohort_selection_key, **kwargs)
270280
else:
271281
logger.info("DLCSmoothInterpCohort already populated.")
272282

@@ -299,9 +309,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
299309

300310
if not (DLCCentroid & centroid_selection_key):
301311
logger.info("Populating DLCCentroid...")
302-
DLCCentroid.populate(
303-
centroid_selection_key, reserve_jobs=True, **kwargs
304-
)
312+
DLCCentroid.populate(centroid_selection_key, **kwargs)
305313
else:
306314
logger.info("DLCCentroid already populated.")
307315

@@ -336,9 +344,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
336344

337345
if not (DLCOrientation & orientation_selection_key):
338346
logger.info("Populating DLCOrientation...")
339-
DLCOrientation.populate(
340-
orientation_selection_key, reserve_jobs=True, **kwargs
341-
)
347+
DLCOrientation.populate(orientation_selection_key, **kwargs)
342348
else:
343349
logger.info("DLCOrientation already populated.")
344350

@@ -368,12 +374,12 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
368374
"dlc_si_cohort_centroid": centroid_key[
369375
"dlc_si_cohort_selection_name"
370376
],
371-
"centroid_analysis_file_name": centroid_key[
372-
"analysis_file_name"
373-
],
374377
"dlc_model_name": centroid_key["dlc_model_name"],
375-
"epoch": centroid_key["epoch"],
376378
"nwb_file_name": centroid_key["nwb_file_name"],
379+
"epoch": centroid_key["epoch"],
380+
"video_file_num": centroid_key["video_file_num"],
381+
"project_name": centroid_key["project_name"],
382+
"dlc_model_name": centroid_key["dlc_model_name"],
377383
"dlc_model_params_name": centroid_key["dlc_model_params_name"],
378384
"dlc_centroid_params_name": centroid_key[
379385
"dlc_centroid_params_name"
@@ -382,9 +388,6 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
382388
"dlc_si_cohort_orientation": orientation_key[
383389
"dlc_si_cohort_selection_name"
384390
],
385-
"orientation_analysis_file_name": orientation_key[
386-
"analysis_file_name"
387-
],
388391
"dlc_orientation_params_name": orientation_key[
389392
"dlc_orientation_params_name"
390393
],
@@ -400,9 +403,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
400403

401404
if not (DLCPosV1 & pos_selection_key):
402405
logger.info("Populating DLCPosV1...")
403-
DLCPosV1.populate(
404-
pos_selection_key, reserve_jobs=True, **kwargs
405-
)
406+
DLCPosV1.populate(pos_selection_key, **kwargs)
406407
else:
407408
logger.info("DLCPosV1 already populated.")
408409

@@ -464,9 +465,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool:
464465

465466
if not (DLCPosVideo & video_selection_key):
466467
logger.info("Populating DLCPosVideo...")
467-
DLCPosVideo.populate(
468-
video_selection_key, reserve_jobs=True, **kwargs
469-
)
468+
DLCPosVideo.populate(video_selection_key, **kwargs)
470469
else:
471470
logger.info("DLCPosVideo already populated.")
472471
elif generate_video and not final_pos_key:
@@ -503,8 +502,8 @@ def populate_spyglass_dlc_pipeline_v1(
503502
dlc_orientation_params_name: str = "default",
504503
bodyparts_params_dict: Optional[Dict[str, str]] = None,
505504
run_smoothing_interp: bool = True,
506-
run_centroid: bool = True,
507-
run_orientation: bool = True,
505+
run_centroid: bool = False,
506+
run_orientation: bool = False,
508507
generate_video: bool = False,
509508
dlc_pos_video_params_name: str = "default",
510509
skip_duplicates: bool = True,
@@ -602,6 +601,15 @@ def populate_spyglass_dlc_pipeline_v1(
602601
f"Found {len(epochs_to_process)} epoch(s) to process: {sorted(epochs_to_process)}"
603602
)
604603

604+
if bodyparts_params_dict is None and run_centroid:
605+
raise ValueError(
606+
"bodyparts_params_dict must be provided when running centroid calculation."
607+
)
608+
if bodyparts_params_dict is None and run_orientation:
609+
raise ValueError(
610+
"bodyparts_params_dict must be provided when running orientation calculation."
611+
)
612+
605613
# --- Prepare arguments for each epoch ---
606614
process_args_list = []
607615
for epoch in epochs_to_process:

src/spyglass/position/v1/pipeline_dlc_training.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,25 @@ def run_spyglass_dlc_training_v1(
176176
f"DLCModelTraining population failed for {selection_key}"
177177
)
178178

179-
if not (DLCModelSource() & selection_key):
179+
model_source_key = {
180+
"project_name": project_key["project_name"],
181+
"dlc_model_name": (
182+
f"{project_key['project_name']}_"
183+
f"{params_key['dlc_training_params_name']}_"
184+
f"{selection_key['training_id']:02d}"
185+
),
186+
}
187+
if not (DLCModelSource() & model_source_key):
180188
raise dj.errors.DataJointError(
181-
f"DLCModelSource entry missing for {selection_key}"
189+
f"DLCModelSource entry missing for {model_source_key}"
182190
)
183191

184192
# Populate DLCModel
185193
logger.info(
186194
f"---- Step 4: Populating DLCModel for Project: {project_name} ----"
187195
)
188196
model_key = {
189-
**(DLCModelSource & selection_key).fetch1("KEY"),
197+
**(DLCModelSource & model_source_key).fetch1("KEY"),
190198
"dlc_model_params_name": dlc_model_params_name,
191199
}
192200
DLCModelSelection().insert1(

0 commit comments

Comments
 (0)