@@ -102,13 +102,15 @@ def make(self, key):
102102 "KEY"
103103 )
104104 coordinates = (moseq_train .PreProcessing & model_key ).fetch1 ("coordinates" )
105- config_file = (moseq_train .FullFit .ConfigFile & model_key ).fetch1 ("config_file" )
106- kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (config_path = config_file )
105+ use_bodyparts = (moseq_train .BodyParts & model_key ).fetch1 ("use_bodyparts" )
106+ fps = (moseq_train .PreProcessing & model_key ).fetch1 ("average_frame_rate" )
107+
107108 plot_similarity_dendrogram (
108109 coordinates = coordinates ,
109110 results = results ,
110111 save_path = (inference_output_dir / "similarity_dendrogram" ).as_posix (),
111- ** kpms_dj_config_dict ,
112+ use_bodyparts = use_bodyparts ,
113+ fps = fps ,
112114 )
113115
114116 # Insert the record
@@ -163,14 +165,19 @@ def make(self, key):
163165 model_key = (moseq_infer .Model * moseq_train .SelectedFullFit & key ).fetch1 (
164166 "KEY"
165167 )
166- coordinates_dict = (moseq_train .PreProcessing & model_key ).fetch1 ("coordinates" )
168+ coordinates = (moseq_train .PreProcessing & model_key ).fetch1 ("coordinates" )
167169 kpms_dj_config_file = (moseq_train .FullFit .ConfigFile & model_key ).fetch1 (
168170 "config_file"
169171 )
170172 kpms_dj_config_dict = kpms_reader .load_kpms_dj_config (
171173 config_path = kpms_dj_config_file
172174 )
173175
176+ # Get use_bodyparts from the BodyParts table
177+ use_bodyparts = (moseq_train .BodyParts & model_key ).fetch1 ("use_bodyparts" )
178+
179+ fps = (moseq_train .PreProcessing & model_key ).fetch1 ("average_frame_rate" )
180+
174181 # Construct output directory
175182 kpms_processed = moseq_train .get_kpms_processed_data_dir ()
176183 output_dir = Path (model_dir ) / inference_output_dir
@@ -185,20 +192,27 @@ def make(self, key):
185192 # Load results
186193 results = h5py .File (results_file , "r" )
187194
195+ logger .info (f"Generating trajectory plots for { key } " )
188196 # Generate trajectory plots
189197 generate_trajectory_plots (
190- coordinates = coordinates_dict ,
198+ coordinates = coordinates ,
191199 results = results ,
192200 output_dir = trajectory_dir .as_posix (),
193- ** kpms_dj_config_dict ,
201+ use_bodyparts = use_bodyparts ,
202+ fps = fps ,
203+ skeleton = kpms_dj_config_dict .get ("skeleton" , []),
194204 )
195205
206+ logger .info (f"Generating grid movies for { key } " )
196207 # Generate grid movies
197208 generate_grid_movies (
198- coordinates = coordinates_dict ,
209+ coordinates = coordinates ,
199210 results = results ,
200211 output_dir = grid_movies_dir .as_posix (),
201- ** kpms_dj_config_dict ,
212+ use_bodyparts = use_bodyparts ,
213+ fps = fps ,
214+ overlay_keypoints = True ,
215+ skeleton = kpms_dj_config_dict .get ("skeleton" , []),
202216 )
203217
204218 # Calculate duration
0 commit comments