|
| 1 | +import argparse |
| 2 | +from os import listdir, makedirs |
| 3 | +from os.path import join |
| 4 | +import tifffile as tiff |
| 5 | +import cv2 |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +from ctc_metrics.utils.filesystem import read_tracking_file |
| 9 | + |
| 10 | + |
| 11 | +SHOW_BORDER = True |
| 12 | +BORDER_WIDTH = { |
| 13 | + "BF-C2DL-HSC": 25, |
| 14 | + "BF-C2DL-MuSC": 25, |
| 15 | + "Fluo-N2DL-HeLa": 25, |
| 16 | + "PhC-C2DL-PSC": 25, |
| 17 | + "Fluo-N2DH-SIM+": 0, |
| 18 | + "DIC-C2DH-HeLa": 50, |
| 19 | + "Fluo-C2DL-Huh7": 50, |
| 20 | + "Fluo-C2DL-MSC": 50, |
| 21 | + "Fluo-N2DH-GOWT1": 50, |
| 22 | + "PhC-C2DH-U373": 50, |
| 23 | +} |
| 24 | + |
| 25 | +np.random.seed(0) |
| 26 | +PALETTE = np.random.randint(0, 256, (100000, 3)) |
| 27 | + |
| 28 | + |
| 29 | +def get_palette_color(i): |
| 30 | + i = i % PALETTE.shape[0] |
| 31 | + return PALETTE[i] |
| 32 | + |
| 33 | + |
| 34 | +def visualize( |
| 35 | + img_dir: str, |
| 36 | + res_dir: str, |
| 37 | + viz_dir: str = None, |
| 38 | + video_name: str = None, |
| 39 | + border_width=None, |
| 40 | + show_labels: bool=True, |
| 41 | + show_parents: bool = True, |
| 42 | + ids_to_show: list = None, |
| 43 | + start_frame: int = 0, |
| 44 | + framerate: int = 30, |
| 45 | + opacity: float = 0.5, |
| 46 | +): # pylint: disable=too-many-arguments |
| 47 | + """ |
| 48 | + Visualizes the tracking results. |
| 49 | +
|
| 50 | + Args: |
| 51 | + img_dir: str |
| 52 | + The path to the images. |
| 53 | + res_dir: str |
| 54 | + The path to the results. |
| 55 | + viz_dir: str |
| 56 | + The path to save the visualizations. |
| 57 | + video_name: str |
| 58 | + The path to the video if a video should be created. Note that no |
| 59 | + visualization is available during video creation. |
| 60 | + border_width: str or int |
| 61 | + The width of the border. Either an integer or a string that |
| 62 | + describes the challenge name. |
| 63 | + show_labels: bool |
| 64 | + Print instance labels to the output. |
| 65 | + show_parents: bool |
| 66 | + Print parent labels to the output. |
| 67 | + ids_to_show: list |
| 68 | + The IDs of the instances to show. All others will be ignored. |
| 69 | + start_frame: int |
| 70 | + The frame to start the visualization. |
| 71 | + framerate: int |
| 72 | + The framerate of the video. |
| 73 | + opacity: float |
| 74 | + The opacity of the instance colors. |
| 75 | +
|
| 76 | + Returns: |
| 77 | + None |
| 78 | + """ |
| 79 | + # Define initial video parameters |
| 80 | + wait_time = max(1, round(1000 / framerate)) |
| 81 | + if border_width is None: |
| 82 | + border_width = 0 |
| 83 | + elif isinstance(border_width, str): |
| 84 | + try: |
| 85 | + border_width = int(border_width) |
| 86 | + except ValueError as exc: |
| 87 | + if border_width in BORDER_WIDTH: |
| 88 | + border_width = BORDER_WIDTH[border_width] |
| 89 | + else: |
| 90 | + raise ValueError( |
| 91 | + f"Border width '{border_width}' not recognized. " |
| 92 | + f"Existing datasets: {BORDER_WIDTH.keys()}" |
| 93 | + ) from exc |
| 94 | + |
| 95 | + # Load image and tracking data |
| 96 | + images = [x for x in sorted(listdir(img_dir)) if x.endswith(".tif")] |
| 97 | + results = [x for x in sorted(listdir(res_dir)) if x.endswith(".tif")] |
| 98 | + parents = { |
| 99 | + l[0]: l[3] for l in read_tracking_file(join(res_dir, "res_track.txt")) |
| 100 | + } |
| 101 | + |
| 102 | + # Create visualization directory |
| 103 | + if viz_dir: |
| 104 | + makedirs(viz_dir, exist_ok=True) |
| 105 | + |
| 106 | + video_writer = None |
| 107 | + |
| 108 | + # Loop through all images |
| 109 | + while start_frame < len(images): |
| 110 | + # Read image file |
| 111 | + img_name, res_name = images[start_frame], results[start_frame] |
| 112 | + img_path, res_path, = join(img_dir, img_name), join(res_dir, res_name) |
| 113 | + print(f"\rFrame {img_name} (of {len(images)})", end="") |
| 114 | + |
| 115 | + # Visualize the image |
| 116 | + viz = create_colored_image( |
| 117 | + cv2.imread(img_path), |
| 118 | + tiff.imread(res_path), |
| 119 | + labels=show_labels, |
| 120 | + frame=start_frame, |
| 121 | + parents=parents if show_parents else None, |
| 122 | + ids_to_show=ids_to_show, |
| 123 | + opacity=opacity, |
| 124 | + ) |
| 125 | + if border_width > 0: |
| 126 | + viz = cv2.rectangle( |
| 127 | + viz, |
| 128 | + (border_width, border_width), |
| 129 | + (viz.shape[1] - border_width, viz.shape[0] - border_width), |
| 130 | + (0, 0, 255), 1 |
| 131 | + ) |
| 132 | + |
| 133 | + # Save the visualization |
| 134 | + if video_name is not None: |
| 135 | + if video_writer is None: |
| 136 | + video_path = join( |
| 137 | + viz_dir, f"{video_name.replace('.mp4', '')}.mp4") |
| 138 | + video_writer = cv2.VideoWriter( |
| 139 | + video_path, |
| 140 | + cv2.VideoWriter_fourcc(*"mp4v"), |
| 141 | + framerate, |
| 142 | + (viz.shape[1], viz.shape[0]) |
| 143 | + ) |
| 144 | + video_writer.write(viz) |
| 145 | + start_frame += 1 |
| 146 | + continue |
| 147 | + |
| 148 | + # Show the video |
| 149 | + cv2.imshow("VIZ", viz) |
| 150 | + key = cv2.waitKey(wait_time) |
| 151 | + if key == ord("q"): |
| 152 | + # Quit the visualization |
| 153 | + break |
| 154 | + if key == ord("w"): |
| 155 | + # Start or stop the auto visualization |
| 156 | + if wait_time == 0: |
| 157 | + wait_time = max(1, round(1000 / framerate)) |
| 158 | + else: |
| 159 | + wait_time = 0 |
| 160 | + elif key == ord("d"): |
| 161 | + # Move to the next frame |
| 162 | + start_frame += 1 |
| 163 | + wait_time = 0 |
| 164 | + elif key == ord("a"): |
| 165 | + # Move to the previous frame |
| 166 | + start_frame -= 1 |
| 167 | + wait_time = 0 |
| 168 | + elif key == ord("l"): |
| 169 | + # Toggle the show labels option |
| 170 | + show_labels = not show_labels |
| 171 | + elif key == ord("p"): |
| 172 | + # Toggle the show parents option |
| 173 | + show_parents = not show_parents |
| 174 | + elif key == ord("s"): |
| 175 | + # Save the visualization |
| 176 | + if viz_dir is None: |
| 177 | + print("Please define the '--viz' argument to save the " |
| 178 | + "visualizations.") |
| 179 | + continue |
| 180 | + viz_path = join(viz_dir, img_name) + ".jpg" |
| 181 | + cv2.imwrite(viz_path, viz) |
| 182 | + else: |
| 183 | + # Move to the next frame |
| 184 | + start_frame += 1 |
| 185 | + |
| 186 | + |
| 187 | +def create_colored_image( |
| 188 | + img: np.ndarray, |
| 189 | + res: np.ndarray, |
| 190 | + labels: bool = False, |
| 191 | + opacity: float = 0.5, |
| 192 | + ids_to_show = None, |
| 193 | + frame: int = None, |
| 194 | + parents: dict = None, |
| 195 | +): |
| 196 | + """ |
| 197 | + Creates a colored image from the input image and the results. |
| 198 | +
|
| 199 | + Args: |
| 200 | + img: np.ndarray |
| 201 | + The input image. |
| 202 | + res: np.ndarray |
| 203 | + The results. |
| 204 | + labels: bool |
| 205 | + Print instance labels to the output. |
| 206 | + opacity: float |
| 207 | + The opacity of the instance colors. |
| 208 | + ids_to_show: list |
| 209 | + The IDs of the instances to show. All others will be ignored. |
| 210 | + frame: int |
| 211 | + The frame number. |
| 212 | + parents: dict |
| 213 | + The parent dictionary. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + The colored image. |
| 217 | + """ |
| 218 | + img = np.clip(img, 0, 255).astype(np.uint8) |
| 219 | + kernel = np.ones((3, 3), dtype=np.uint8) |
| 220 | + for i in np.unique(res): |
| 221 | + if i == 0: |
| 222 | + continue |
| 223 | + if ids_to_show is not None: |
| 224 | + if i not in ids_to_show: |
| 225 | + continue |
| 226 | + mask = res == i |
| 227 | + contour = (mask * 255).astype(np.uint8) - \ |
| 228 | + cv2.erode((mask * 255).astype(np.uint8), kernel) |
| 229 | + contour = contour != 0 |
| 230 | + img[mask] = ( |
| 231 | + np.round((1 - opacity) * img[mask] + opacity * get_palette_color(i)) |
| 232 | + ) |
| 233 | + img[contour] = get_palette_color(i) |
| 234 | + if frame is not None: |
| 235 | + cv2.putText(img, str(frame), (10, 30), |
| 236 | + cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
| 237 | + if labels: |
| 238 | + # Print label to the center of the object |
| 239 | + y, x = np.where(mask) |
| 240 | + y, x = np.mean(y), np.mean(x) |
| 241 | + text = str(i) |
| 242 | + if parents is not None: |
| 243 | + if i in parents: |
| 244 | + if parents[i] != 0: |
| 245 | + text += f"({parents[i]})" |
| 246 | + cv2.putText(img, text, (int(x), int(y)), |
| 247 | + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
| 248 | + return img |
| 249 | + |
| 250 | + |
| 251 | +def parse_args(): |
| 252 | + """ Parses the arguments. """ |
| 253 | + parser = argparse.ArgumentParser(description='Validates CTC-Sequences.') |
| 254 | + parser.add_argument( |
| 255 | + '--img', type=str, required=True, |
| 256 | + help='The path to the images.' |
| 257 | + ) |
| 258 | + parser.add_argument( |
| 259 | + '--res', type=str, required=True, help='The path to the results.' |
| 260 | + ) |
| 261 | + parser.add_argument( |
| 262 | + '--viz', type=str, default=None, |
| 263 | + help='The path to save the visualizations.' |
| 264 | + ) |
| 265 | + parser.add_argument( |
| 266 | + '--video-name', type=str, default=None, |
| 267 | + help='The path to the video if a video should be created. Note that no ' |
| 268 | + 'visualization is available during video creation.' |
| 269 | + ) |
| 270 | + parser.add_argument( |
| 271 | + '--border-width', type=str, default=None, |
| 272 | + help='The width of the border. Either an integer or a string that ' |
| 273 | + 'describes the challenge name.' |
| 274 | + ) |
| 275 | + parser.add_argument( |
| 276 | + '--show-no-labels', action="store_false", |
| 277 | + help='Print no instance labels to the output.' |
| 278 | + ) |
| 279 | + parser.add_argument( |
| 280 | + '--show-no-parents', action="store_false", |
| 281 | + help='Print no parent labels to the output.' |
| 282 | + ) |
| 283 | + parser.add_argument( |
| 284 | + '--ids-to-show', type=int, nargs='+', default=None, |
| 285 | + help='The IDs of the instances to show. All others will be ignored.' |
| 286 | + ) |
| 287 | + parser.add_argument( |
| 288 | + '--start-frame', type=int, default=0, |
| 289 | + help='The frame to start the visualization.' |
| 290 | + ) |
| 291 | + parser.add_argument( |
| 292 | + '--framerate', type=int, default=10, |
| 293 | + help='The framerate of the video.' |
| 294 | + ) |
| 295 | + parser.add_argument( |
| 296 | + '--opacity', type=float, default=0.5, |
| 297 | + help='The opacity of the instance colors.' |
| 298 | + ) |
| 299 | + args = parser.parse_args() |
| 300 | + return args |
| 301 | + |
| 302 | + |
| 303 | +def main(): |
| 304 | + """ |
| 305 | + Main function that is called when the script is executed. |
| 306 | + """ |
| 307 | + args = parse_args() |
| 308 | + visualize( |
| 309 | + args.img, |
| 310 | + args.res, |
| 311 | + viz_dir=args.viz, |
| 312 | + video_name=args.video_name, |
| 313 | + border_width=args.border_width, |
| 314 | + show_labels=args.show_no_labels, |
| 315 | + show_parents=args.show_no_parents, |
| 316 | + ids_to_show=args.ids_to_show, |
| 317 | + start_frame=args.start_frame, |
| 318 | + framerate=args.framerate, |
| 319 | + opacity=args.opacity, |
| 320 | + ) |
| 321 | + |
| 322 | + |
| 323 | +if __name__ == "__main__": |
| 324 | + main() |
0 commit comments