Skip to content

Commit 2208df5

Browse files
committed
refactor(moseq_report)
1 parent 1510047 commit 2208df5

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

element_moseq/moseq_report.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,15 @@ def make(self, key):
102102
"KEY"
103103
)
104104
coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates")
105-
config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1("config_file")
106-
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(config_path=config_file)
105+
use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts")
106+
fps = (moseq_train.PreProcessing & model_key).fetch1("average_frame_rate")
107+
107108
plot_similarity_dendrogram(
108109
coordinates=coordinates,
109110
results=results,
110111
save_path=(inference_output_dir / "similarity_dendrogram").as_posix(),
111-
**kpms_dj_config_dict,
112+
use_bodyparts=use_bodyparts,
113+
fps=fps,
112114
)
113115

114116
# Insert the record
@@ -163,14 +165,19 @@ def make(self, key):
163165
model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1(
164166
"KEY"
165167
)
166-
coordinates_dict = (moseq_train.PreProcessing & model_key).fetch1("coordinates")
168+
coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates")
167169
kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1(
168170
"config_file"
169171
)
170172
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
171173
config_path=kpms_dj_config_file
172174
)
173175

176+
# Get use_bodyparts from the BodyParts table
177+
use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts")
178+
179+
fps = (moseq_train.PreProcessing & model_key).fetch1("average_frame_rate")
180+
174181
# Construct output directory
175182
kpms_processed = moseq_train.get_kpms_processed_data_dir()
176183
output_dir = Path(model_dir) / inference_output_dir
@@ -185,20 +192,27 @@ def make(self, key):
185192
# Load results
186193
results = h5py.File(results_file, "r")
187194

195+
logger.info(f"Generating trajectory plots for {key}")
188196
# Generate trajectory plots
189197
generate_trajectory_plots(
190-
coordinates=coordinates_dict,
198+
coordinates=coordinates,
191199
results=results,
192200
output_dir=trajectory_dir.as_posix(),
193-
**kpms_dj_config_dict,
201+
use_bodyparts=use_bodyparts,
202+
fps=fps,
203+
skeleton=kpms_dj_config_dict.get("skeleton", []),
194204
)
195205

206+
logger.info(f"Generating grid movies for {key}")
196207
# Generate grid movies
197208
generate_grid_movies(
198-
coordinates=coordinates_dict,
209+
coordinates=coordinates,
199210
results=results,
200211
output_dir=grid_movies_dir.as_posix(),
201-
**kpms_dj_config_dict,
212+
use_bodyparts=use_bodyparts,
213+
fps=fps,
214+
overlay_keypoints=True,
215+
skeleton=kpms_dj_config_dict.get("skeleton", []),
202216
)
203217

204218
# Calculate duration

0 commit comments

Comments
 (0)