Skip to content

Commit 1510047

Browse files
committed
fix(inference)
1 parent bbf691f commit 1510047

File tree

1 file changed

+56
-81
lines changed

1 file changed

+56
-81
lines changed

element_moseq/moseq_infer.py

Lines changed: 56 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -223,39 +223,26 @@ def make_fetch(self, key):
223223
) # model dir relative to processed data directory
224224

225225
model_key = (Model * moseq_train.SelectedFullFit & key).fetch1("KEY")
226-
checkpoint_file_path = (
226+
fullfit_checkpoint_path = (
227227
moseq_train.FullFit.File & model_key & 'file_name="checkpoint.h5"'
228228
).fetch1("file_path")
229-
kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1(
230-
"config_file"
231-
)
232-
pca_file_path = (
229+
fullfit_kpms_dj_config_file = (
230+
moseq_train.FullFit.ConfigFile & model_key
231+
).fetch1("config_file")
232+
fullfit_pca_file_path = (
233233
moseq_train.PCAFit.File & model_key & 'file_name="pca.p"'
234234
).fetch1("file_path")
235235

236-
data_file_path = (
237-
moseq_train.PCAFit.File & model_key & 'file_name="data.pkl"'
238-
).fetch1("file_path")
239-
metadata_file_path = (
240-
moseq_train.PCAFit.File & model_key & 'file_name="metadata.pkl"'
241-
).fetch1("file_path")
242-
coordinates, confidences = (moseq_train.PreProcessing & model_key).fetch(
243-
"coordinates", "confidences"
244-
)
245236
return (
246237
keypointset_dir,
247238
inference_output_dir,
248239
num_iterations,
249240
task_mode,
250241
model_dir_rel,
251242
model_file,
252-
checkpoint_file_path,
253-
kpms_dj_config_file,
254-
pca_file_path,
255-
data_file_path,
256-
metadata_file_path,
257-
coordinates,
258-
confidences,
243+
fullfit_checkpoint_path,
244+
fullfit_kpms_dj_config_file,
245+
fullfit_pca_file_path,
259246
)
260247

261248
def make_compute(
@@ -267,34 +254,12 @@ def make_compute(
267254
task_mode,
268255
model_dir_rel,
269256
model_file,
270-
checkpoint_file_path,
271-
kpms_dj_config_file,
272-
pca_file_path,
273-
data_file_path,
274-
metadata_file_path,
275-
coordinates,
276-
confidences,
257+
fullfit_checkpoint_path,
258+
fullfit_kpms_dj_config_file,
259+
fullfit_pca_file_path,
277260
):
278261
"""
279262
Compute model inference results.
280-
281-
Args:
282-
key (dict): `InferenceTask` primary key.
283-
keypointset_dir (str): Directory containing keypoint data.
284-
inference_output_dir (str): Output directory for inference results.
285-
num_iterations (int): Number of iterations for model fitting.
286-
model_id (int): Model ID.
287-
pose_estimation_method (str): Pose estimation method.
288-
task_mode (str): Task mode ('trigger' or 'load').
289-
290-
Raises:
291-
FileNotFoundError: If no pca model (`pca.p`) found in the parent model directory.
292-
FileNotFoundError: If no model (`checkpoint.h5`) found in the model directory.
293-
NotImplementedError: If the format method is not `deeplabcut`.
294-
FileNotFoundError: If no valid `kpms_dj_config` found in the parent model directory.
295-
296-
Returns:
297-
tuple: Inference results including duration, results data, and sampled instances.
298263
"""
299264
from keypoint_moseq import (
300265
apply_model,
@@ -303,6 +268,7 @@ def make_compute(
303268
load_keypoints,
304269
load_pca,
305270
load_results,
271+
outlier_removal,
306272
save_results_as_csv,
307273
)
308274

@@ -311,51 +277,60 @@ def make_compute(
311277

312278
start_time = datetime.now(timezone.utc)
313279

314-
# Get directories first
280+
# Get directories for new recordings
315281
kpms_root = moseq_train.get_kpms_root_data_dir()
316282
kpms_processed = moseq_train.get_kpms_processed_data_dir()
317283

318284
# Construct the full path to the inference output directory
319285
inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir
320-
321-
if task_mode == "trigger":
322-
if not inference_output_dir.exists():
323-
inference_output_dir.mkdir(parents=True, exist_ok=True)
324-
325286
keypointset_dir = find_full_path(kpms_root, keypointset_dir)
326287

327288
if task_mode == "trigger":
328-
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
329-
config_path=kpms_dj_config_file
289+
# load saved model data
290+
fullfit_kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
291+
config_path=fullfit_kpms_dj_config_file
330292
)
293+
fullfit_model, _, _, _ = load_checkpoint(path=fullfit_checkpoint_path)
294+
# fullfit_model_pca = pickle.load(open(fullfit_pca_file_path, "rb"))
331295

332-
metadata = pickle.load(open(metadata_file_path, "rb"))
333-
data = pickle.load(open(data_file_path, "rb"))
334-
model_data = pickle.load(open(model_file, "rb"))
335-
if task_mode == "trigger":
336-
results = apply_model(
337-
model_name=inference_output_dir.name,
338-
model=model_data,
339-
data=data,
340-
metadata=metadata,
341-
pca=pca_file_path,
342-
project_dir=inference_output_dir.parent,
343-
results_path=(inference_output_dir / "results.h5"),
344-
return_model=False,
345-
num_iters=num_iterations or DEFAULT_NUM_ITERS,
346-
overwrite=True,
347-
save_results=True,
348-
**kpms_dj_config_dict,
349-
)
296+
# Load new data
297+
coordinates, confidences, bodyparts = load_keypoints(
298+
filepath_pattern=keypointset_dir, format="deeplabcut"
299+
)
300+
coordinates, confidences = outlier_removal(
301+
coordinates,
302+
confidences,
303+
inference_output_dir,
304+
overwrite=False,
305+
**fullfit_kpms_dj_config_dict,
306+
)
307+
data, metadata = format_data(
308+
coordinates, confidences, **fullfit_kpms_dj_config_dict
309+
)
350310

351-
# Create results directory and save CSV files
352-
save_results_as_csv(
353-
results=results,
354-
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
355-
)
311+
# # apply saved model to new data
312+
results = apply_model(
313+
model=fullfit_model,
314+
data=data,
315+
metadata=metadata,
316+
project_dir=inference_output_dir,
317+
model_name=inference_output_dir.name,
318+
results_path=(inference_output_dir / "results.h5"),
319+
return_model=False,
320+
num_iters=num_iterations or DEFAULT_NUM_ITERS,
321+
overwrite=True,
322+
save_results=True,
323+
**fullfit_kpms_dj_config_dict,
324+
)
325+
326+
# Create results directory and save CSV files
327+
save_results_as_csv(
328+
results=results,
329+
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
330+
)
356331

357-
end_time = datetime.now(timezone.utc)
358-
duration_seconds = (end_time - start_time).total_seconds()
332+
end_time = datetime.now(timezone.utc)
333+
duration_seconds = (end_time - start_time).total_seconds()
359334

360335
else:
361336
duration_seconds = None
@@ -406,7 +381,7 @@ class VideoSequence(dj.Part):
406381
latent_states : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model
407382
centroids : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model
408383
headings : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model
409-
file_csv : filepath@moseq-infer-processed # File path of the temporal sequence of motion data (CSV format)
384+
file : filepath@moseq-infer-processed # File path of the temporal sequence of motion data (CSV format)
410385
"""
411386

412387
class SampledInstance(dj.Part):
@@ -509,7 +484,7 @@ def make(self, key):
509484
"latent_states": latent_states[vid],
510485
"centroids": filtered_centroids[vid],
511486
"headings": filtered_headings[vid],
512-
"file_csv": (
487+
"file": (
513488
inference_output_dir / "results_as_csv" / f"{vid}.csv"
514489
).as_posix(),
515490
}

0 commit comments

Comments
 (0)