@@ -186,29 +186,6 @@ def infer_from_loader(
186186 cv2 .destroyAllWindows ()
187187
188188
189- def save_renders (
190- renders : dict [tuple [str , str ], list [np .ndarray ]],
191- save_dir : Path ,
192- counter : Counter ,
193- img_paths : list [PathType ] | None = None ,
194- ) -> None :
195- """Persist a rendered batch to disk."""
196- batch_size = len (next (iter (renders .values ())))
197-
198- for i in range (batch_size ):
199- if img_paths is not None :
200- idx = counter ()
201- img_path = Path (img_paths [idx ])
202- for (node_name , viz_name ), visualizations in renders .items ():
203- viz = visualizations [i ]
204- if img_paths is not None :
205- name = f"{ img_path .stem } _{ node_name } _{ viz_name } "
206- else :
207- name = f"{ node_name } _{ viz_name } _{ counter ()} "
208- name = name .replace ("/" , "-" )
209- cv2 .imwrite (str (save_dir / f"{ name } .png" ), viz )
210-
211-
212189class InferenceSaveWriter (BasePredictionWriter ):
213190 """Writes rendered inference batches as soon as they are
214191 predicted."""
@@ -239,7 +216,7 @@ def write_on_batch_end(
239216 if not renders :
240217 return
241218
242- save_renders ( renders , self .save_dir , self . counter , self . img_paths )
219+ self ._save_renders ( renders )
243220
244221 def write_on_epoch_end (
245222 self ,
@@ -250,6 +227,25 @@ def write_on_epoch_end(
250227 ) -> None :
251228 del trainer , pl_module , predictions , batch_indices
252229
230+ def _save_renders (
231+ self , renders : dict [tuple [str , str ], list [np .ndarray ]]
232+ ) -> None :
233+ """Persist a rendered batch to disk."""
234+ batch_size = len (next (iter (renders .values ())))
235+
236+ for i in range (batch_size ):
237+ if self .img_paths is not None :
238+ idx = self .counter ()
239+ img_path = Path (self .img_paths [idx ])
240+ for (node_name , viz_name ), visualizations in renders .items ():
241+ viz = visualizations [i ]
242+ if self .img_paths is not None :
243+ name = f"{ img_path .stem } _{ node_name } _{ viz_name } "
244+ else :
245+ name = f"{ node_name } _{ viz_name } _{ self .counter ()} "
246+ name = name .replace ("/" , "-" )
247+ cv2 .imwrite (str (self .save_dir / f"{ name } .png" ), viz )
248+
253249
254250def create_loader_from_directory (
255251 img_paths : Iterable [PathType ],
0 commit comments