Skip to content

Commit 865dcbc

Browse files
committed
refactor(moseq_infer)
1 parent 89dfe66 commit 865dcbc

File tree

1 file changed

+34
-36
lines changed

1 file changed

+34
-36
lines changed

element_moseq/moseq_infer.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Model(dj.Manual):
8585
---
8686
model_name : varchar(1000) # User-friendly model name
8787
model_dir : varchar(1000) # Model directory relative to root data directory
88-
model_file : filepath@moseq-infer-processed # Model pkl file containing states, parameters, hyperparameters, noise prior, and random seed
88+
model_file : filepath@moseq-infer-processed # Checkpoint file (h5 format)
8989
model_desc='' : varchar(1000) # Optional. User-defined description of the model
9090
-> [nullable] moseq_train.SelectedFullFit # Optional. FullFit key.
9191
"""
@@ -225,7 +225,7 @@ def make_fetch(self, key):
225225
model_key = (Model * moseq_train.SelectedFullFit & key).fetch1("KEY")
226226
checkpoint_file_path = (
227227
moseq_train.FullFit.File & model_key & 'file_name="checkpoint.h5"'
228-
).fetch1("file")
228+
).fetch1("file_path")
229229
kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1(
230230
"config_file"
231231
)
@@ -318,48 +318,46 @@ def make_compute(
318318
# Construct the full path to the inference output directory
319319
inference_output_dir = kpms_processed / model_dir_rel / inference_output_dir
320320

321-
inference_output_dir.mkdir(parents=True, exist_ok=True)
322-
model_dir = find_full_path(kpms_processed, model_dir_rel)
323-
keypointset_dir = find_full_path(kpms_root, keypointset_dir)
321+
if task_mode == "trigger":
322+
if not inference_output_dir.exists():
323+
inference_output_dir.mkdir(parents=True, exist_ok=True)
324324

325-
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
326-
config_path=kpms_dj_config_file
327-
)
325+
keypointset_dir = find_full_path(kpms_root, keypointset_dir)
328326

329-
metadata = pickle.load(open(metadata_file_path, "rb"))
330-
data = pickle.load(open(data_file_path, "rb"))
331-
model_data = pickle.load(open(model_file, "rb"))
332327
if task_mode == "trigger":
333-
results = apply_model(
334-
model_name=inference_output_dir.name,
335-
model=model_data,
336-
data=data,
337-
metadata=metadata,
338-
pca=pca_file_path,
339-
project_dir=inference_output_dir.parent,
340-
results_path=(inference_output_dir / "results.h5"),
341-
return_model=False,
342-
num_iters=num_iterations or DEFAULT_NUM_ITERS,
343-
overwrite=True,
344-
save_results=True,
345-
**kpms_dj_config_dict,
328+
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
329+
config_path=kpms_dj_config_file
346330
)
347331

348-
# Create results directory and save CSV files
349-
save_results_as_csv(
350-
results=results,
351-
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
352-
)
332+
metadata = pickle.load(open(metadata_file_path, "rb"))
333+
data = pickle.load(open(data_file_path, "rb"))
334+
model_data = pickle.load(open(model_file, "rb"))
335+
if task_mode == "trigger":
336+
results = apply_model(
337+
model_name=inference_output_dir.name,
338+
model=model_data,
339+
data=data,
340+
metadata=metadata,
341+
pca=pca_file_path,
342+
project_dir=inference_output_dir.parent,
343+
results_path=(inference_output_dir / "results.h5"),
344+
return_model=False,
345+
num_iters=num_iterations or DEFAULT_NUM_ITERS,
346+
overwrite=True,
347+
save_results=True,
348+
**kpms_dj_config_dict,
349+
)
353350

354-
end_time = datetime.now(timezone.utc)
355-
duration_seconds = (end_time - start_time).total_seconds()
351+
# Create results directory and save CSV files
352+
save_results_as_csv(
353+
results=results,
354+
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
355+
)
356+
357+
end_time = datetime.now(timezone.utc)
358+
duration_seconds = (end_time - start_time).total_seconds()
356359

357360
else:
358-
# For load mode
359-
results = load_results(
360-
project_dir=model_dir,
361-
model_name=model_dir.name,
362-
)
363361
duration_seconds = None
364362

365363
results_filepath = (inference_output_dir / "results.h5").as_posix()

0 commit comments

Comments
 (0)