11from importlib .util import find_spec
2+ from math import ceil
23from multiprocessing import get_context
34from os import PathLike
45from typing import Literal
56
67import numpy as np
78import torch
89from matplotlib import pyplot as plt
10+ from torch import nn
911
1012from mipcandy .common import ColorizeLabel
1113from mipcandy .data .geometric import ensure_num_dimensions
@@ -45,7 +47,7 @@ def _visualize3d_with_pyvista(image: np.ndarray, title: str | None, cmap: str,
4547 p .show ()
4648
4749
48- def visualize3d (image : torch .Tensor , * , title : str | None = None , cmap : str = "gray" ,
50+ def visualize3d (image : torch .Tensor , * , title : str | None = None , cmap : str = "gray" , max_volume : int = 1e8 ,
4951 backend : Literal ["auto" , "matplotlib" , "pyvista" ] = "auto" , blocking : bool = False ,
5052 screenshot_as : str | PathLike [str ] | None = None ) -> None :
5153 image = image .detach ().float ().cpu ()
@@ -55,6 +57,12 @@ def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "g
5557 image = ensure_num_dimensions (image , 4 )
5658 if image .ndim == 4 and image .shape [0 ] == 1 :
5759 image = image .squeeze (0 )
60+ d , h , w = image .shape
61+ total = d * h * w
62+ ratio = int (ceil ((total / max_volume ) ** (1 / 3 ))) if total > max_volume else 1
63+ if ratio > 1 :
64+ image = ensure_num_dimensions (nn .functional .avg_pool3d (ensure_num_dimensions (image , 5 ), kernel_size = ratio ,
65+ stride = ratio , ceil_mode = True ), 3 )
5866 image /= image .max ()
5967 image = image .numpy ()
6068 if backend == "auto" :
0 commit comments