Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8ef905b
Adds a transform to generate heatmap from landmarks
eclipse0922 Sep 18, 2025
226bf90
Adds heatmap generation demo and tests
eclipse0922 Sep 19, 2025
3097baf
Enables batched input for heatmap generation
eclipse0922 Sep 21, 2025
08a715a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2025
25ceb7f
Refactors GenerateHeatmap transforms for clarity
eclipse0922 Sep 21, 2025
9e33e7c
rename parameter
eclipse0922 Sep 21, 2025
62831e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2025
15ec97a
fix formatting
eclipse0922 Sep 21, 2025
5bc7993
fix flake8
eclipse0922 Sep 21, 2025
0e907bb
fix meta tensor problem
eclipse0922 Sep 21, 2025
54a81a5
better unit tests
eclipse0922 Sep 21, 2025
4b367ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2025
eafe59a
fix test error
eclipse0922 Sep 21, 2025
9f10dcf
Merge branch 'dev' into generate_heatmap_transforms
eclipse0922 Sep 23, 2025
aaf2833
address the code review.
eclipse0922 Sep 24, 2025
1bf0850
fixes
eclipse0922 Sep 25, 2025
2c7b4d0
Improves GenerateHeatmap transform and documentation
eclipse0922 Sep 26, 2025
1b5888b
Fixes heatmap normalization and shape checking
eclipse0922 Sep 26, 2025
fd4be38
Adds `GenerateHeatmap` transform
eclipse0922 Sep 26, 2025
fc28c71
fix nitpick comments
eclipse0922 Sep 26, 2025
4d9c5ad
Remove batch, use GaussianFilter
eclipse0922 Nov 4, 2025
a6575cc
fix the comment by ai review
eclipse0922 Nov 4, 2025
9bbd608
fix dtype
eclipse0922 Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@
AsDiscrete,
DistanceTransformEDT,
FillHoles,
GenerateHeatmap,
Invert,
KeepLargestConnectedComponent,
LabelFilter,
Expand All @@ -319,6 +320,9 @@
FillHolesD,
FillHolesd,
FillHolesDict,
GenerateHeatmapd,
GenerateHeatmapD,
GenerateHeatmapDict,
InvertD,
Invertd,
InvertDict,
Expand Down
158 changes: 157 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
remove_small_objects,
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
from monai.utils import (
TransformBackends,
convert_data_type,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
)
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
Expand All @@ -54,6 +61,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"GenerateHeatmap",
"DistanceTransformEDT",
]

Expand Down Expand Up @@ -742,6 +750,154 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
return self.post_convert(out_pt, img)


class GenerateHeatmap(Transform):
"""
Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.

Notes:
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
- Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
- Output layout uses channel-first convention with one channel per landmark.
- Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3).
- Output heatmap shape: (N, Y, X) for 2D or (N, Z, Y, X) for 3D.
- Each channel index corresponds to one landmark.

Args:
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).

Raises:
ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.

"""

backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

def __init__(
self,
sigma: Sequence[float] | float = 5.0,
spatial_shape: Sequence[int] | None = None,
truncated: float = 4.0,
normalize: bool = True,
dtype: np.dtype | torch.dtype | type = np.float32,
) -> None:
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
if any(s <= 0 for s in sigma):
raise ValueError("Argument `sigma` values must be positive.")
self._sigma = tuple(float(s) for s in sigma)
else:
if float(sigma) <= 0:
raise ValueError("Argument `sigma` must be positive.")
self._sigma = (float(sigma),)
if truncated <= 0:
raise ValueError("Argument `truncated` must be positive.")
self.truncated = float(truncated)
self.normalize = normalize
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
# Validate that dtype is floating-point for meaningful Gaussian values
if not self.torch_dtype.is_floating_point:
raise ValueError(f"Argument `dtype` must be a floating-point type, got {self.torch_dtype}")
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)

def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
"""
Args:
points: landmark coordinates as ndarray/Tensor with shape (N, D),
ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number
of landmarks and D is the spatial dimensionality.
spatial_shape: spatial size as a sequence. If None, uses the value provided at construction.

Returns:
Heatmaps with shape (N, *spatial), one channel per landmark.

Raises:
ValueError: if points shape/dimension or spatial_shape is invalid.
"""
original_points = points
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)

if points_t.ndim != 2:
raise ValueError(
f"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape {points_t.shape}."
)

if points_t.shape[-1] not in (2, 3):
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")

device = points_t.device
num_points, spatial_dims = points_t.shape

target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
sigma = self._resolve_sigma(spatial_dims)

# Create sparse image with impulses at landmark locations
heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device)
bounds_t = torch.as_tensor(target_shape, device=device, dtype=points_t.dtype)

for idx, center in enumerate(points_t):
if not torch.isfinite(center).all():
continue
if not ((center >= 0).all() and (center < bounds_t).all()):
continue
# Round to nearest integer for impulse placement, then clamp to valid index range
center_int = center.round().long()
# Clamp indices to [0, size-1] to avoid out-of-bounds (e.g., 9.7 rounds to 10 in size-10 array)
bounds_max = (bounds_t - 1).long()
center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max)
# Place impulse (use maximum in case of overlapping landmarks)
current_val = heatmap[idx][tuple(center_int)]
heatmap[idx][tuple(center_int)] = torch.maximum(
current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device)
)

# Apply Gaussian blur using GaussianFilter
# Reshape to (num_points, 1, *spatial) for per-channel filtering
heatmap_input = heatmap.unsqueeze(1) # Add channel dimension

gaussian_filter = GaussianFilter(
spatial_dims=spatial_dims, sigma=sigma, truncated=self.truncated, approx="erf", requires_grad=False
).to(device=device, dtype=self.torch_dtype)

heatmap_blurred = gaussian_filter(heatmap_input)
heatmap = heatmap_blurred.squeeze(1) # Remove channel dimension

# Normalize per channel if requested
if self.normalize:
for idx in range(num_points):
peak = heatmap[idx].amax()
if peak > 0:
heatmap[idx].div_(peak)

target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
return converted

def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
shape = call_shape if call_shape is not None else self.spatial_shape
if shape is None:
raise ValueError("Argument `spatial_shape` must be provided either at construction time or call time.")
shape_tuple = ensure_tuple(shape)
if len(shape_tuple) != spatial_dims:
if len(shape_tuple) == 1:
shape_tuple = shape_tuple * spatial_dims # type: ignore
else:
raise ValueError(
"Argument `spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast)."
)
return tuple(int(s) for s in shape_tuple)

def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
if len(self._sigma) == spatial_dims:
return self._sigma
if len(self._sigma) == 1:
return self._sigma * spatial_dims
raise ValueError("Argument `sigma` sequence length must equal the number of spatial dimensions.")


class ProbNMS(Transform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
Expand Down
Loading
Loading