Skip to content

Commit 6720dca

Browse files
committed
fix(trajectoryplots)
1 parent 7c627e2 commit 6720dca

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

element_moseq/moseq_report.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -155,28 +155,35 @@ def make(self, key):
155155

156156
start_time = datetime.now(timezone.utc)
157157

158-
results_file = (moseq_infer.Inference & key).fetch1(
159-
"syllable_segmentation_file"
160-
)
158+
kpms_processed = moseq_train.get_kpms_processed_data_dir()
159+
kpms_root = moseq_train.get_kpms_root_data_dir()
160+
161+
# From trained model
161162
model_dir = (moseq_infer.Model & key).fetch1("model_dir")
162-
inference_output_dir = (moseq_infer.InferenceTask & key).fetch1(
163-
"inference_output_dir"
164-
)
165163
model_key = (moseq_infer.Model * moseq_train.SelectedFullFit & key).fetch1(
166164
"KEY"
167165
)
168-
coordinates = (moseq_train.PreProcessing & model_key).fetch1("coordinates")
169-
kpms_dj_config_file = (moseq_train.FullFit.ConfigFile & model_key).fetch1(
166+
use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts")
167+
kpms_dj_config_path = (moseq_train.FullFit.ConfigFile & model_key).fetch1(
170168
"config_file"
171169
)
172170
kpms_dj_config_dict = kpms_reader.load_kpms_dj_config(
173-
config_path=kpms_dj_config_file
171+
config_path=kpms_dj_config_path
174172
)
175173

176-
# Get use_bodyparts from the BodyParts table
177-
use_bodyparts = (moseq_train.BodyParts & model_key).fetch1("use_bodyparts")
178-
179-
fps = (moseq_train.PreProcessing & model_key).fetch1("average_frame_rate")
174+
# From new recordings
175+
inference_output_dir = (moseq_infer.InferenceTask & key).fetch1(
176+
"inference_output_dir"
177+
)
178+
coordinates = (moseq_infer.Inference & key).fetch1("coordinates")
179+
results_file = (moseq_infer.Inference & key).fetch1(
180+
"syllable_segmentation_file"
181+
)
182+
results = h5py.File(results_file, "r")
183+
fps = (moseq_infer.Inference & key).fetch1("average_frame_rate")
184+
kpset_dir = (moseq_infer.InferenceTask & key).fetch1("keypointset_dir")
185+
kpset_dir = find_full_path(kpms_root, kpset_dir)
186+
kpset_dir = Path(kpset_dir).as_posix()
180187

181188
# Construct output directory
182189
kpms_processed = moseq_train.get_kpms_processed_data_dir()
@@ -189,9 +196,6 @@ def make(self, key):
189196
trajectory_dir.mkdir(parents=True, exist_ok=True)
190197
grid_movies_dir.mkdir(parents=True, exist_ok=True)
191198

192-
# Load results
193-
results = h5py.File(results_file, "r")
194-
195199
logger.info(f"Generating trajectory plots for {key}")
196200
# Generate trajectory plots
197201
generate_trajectory_plots(
@@ -206,8 +210,8 @@ def make(self, key):
206210
logger.info(f"Generating grid movies for {key}")
207211
# Generate grid movies
208212
generate_grid_movies(
209-
coordinates=coordinates,
210213
results=results,
214+
video_path=kpset_dir,
211215
output_dir=grid_movies_dir.as_posix(),
212216
use_bodyparts=use_bodyparts,
213217
fps=fps,

0 commit comments

Comments
 (0)