Skip to content

Commit 556a467

Browse files
committed
Added 3D image down-sampling in visualize3d() with configurable max_volume.
1 parent b531bd0 commit 556a467

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

mipcandy/data/visualization.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from importlib.util import find_spec
2+
from math import ceil
23
from multiprocessing import get_context
34
from os import PathLike
45
from typing import Literal
56

67
import numpy as np
78
import torch
89
from matplotlib import pyplot as plt
10+
from torch import nn
911

1012
from mipcandy.common import ColorizeLabel
1113
from mipcandy.data.geometric import ensure_num_dimensions
@@ -45,7 +47,7 @@ def _visualize3d_with_pyvista(image: np.ndarray, title: str | None, cmap: str,
4547
p.show()
4648

4749

48-
def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray",
50+
def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", max_volume: int = 1e8,
4951
backend: Literal["auto", "matplotlib", "pyvista"] = "auto", blocking: bool = False,
5052
screenshot_as: str | PathLike[str] | None = None) -> None:
5153
image = image.detach().float().cpu()
@@ -55,6 +57,12 @@ def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "g
5557
image = ensure_num_dimensions(image, 4)
5658
if image.ndim == 4 and image.shape[0] == 1:
5759
image = image.squeeze(0)
60+
d, h, w = image.shape
61+
total = d * h * w
62+
ratio = int(ceil((total / max_volume) ** (1 / 3))) if total > max_volume else 1
63+
if ratio > 1:
64+
image = ensure_num_dimensions(nn.functional.avg_pool3d(ensure_num_dimensions(image, 5), kernel_size=ratio,
65+
stride=ratio, ceil_mode=True), 3)
5866
image /= image.max()
5967
image = image.numpy()
6068
if backend == "auto":

0 commit comments

Comments
 (0)