Skip to content

Commit 24da7c7

Browse files
Register GlobalMutualInformationLoss bin_centers as buffer (#8869)
Fixes #8866. ### Description This PR fixes a device placement bug in `GlobalMutualInformationLoss.__init__` where `bin_centers` was assigned as a plain Python attribute instead of being registered as a buffer. - Registered `bin_centers` as a non-persistent buffer in `image_dissimilarity.py`, ensuring that `GlobalMutualInformationLoss` now follows normal `nn.Module` buffer semantics for `.to(device)` / `dtype` moves and avoids silent gradient tracking. - Added a regression test in `test_global_mutual_information_loss.py` to verify that `bin_centers` is properly exposed through `named_buffers()`, does not require gradients, and successfully changes dtype when the module is moved. *Verification:* Passed `python -m pytest tests/losses/image_dissimilarity/test_global_mutual_information_loss.py -q -k gaussian_bin_centers_registered_buffer` and `-k ill_opts`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 837dc32 commit 24da7c7

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

monai/losses/image_dissimilarity.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,11 @@ def __init__(
233233
self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"])
234234
self.num_bins = num_bins
235235
self.kernel_type = kernel_type
236+
self.bin_centers: torch.Tensor | None
237+
self.register_buffer("bin_centers", None, persistent=False)
236238
if self.kernel_type == "gaussian":
237239
self.preterm = 1 / (2 * sigma**2)
238-
self.bin_centers = bin_centers[None, None, ...]
240+
self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False)
239241
self.smooth_nr = float(smooth_nr)
240242
self.smooth_dr = float(smooth_dr)
241243

@@ -314,6 +316,8 @@ def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, to
314316
"""
315317
img = torch.clamp(img, 0, 1)
316318
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
319+
if self.bin_centers is None:
320+
raise ValueError("bin_centers must be defined for gaussian parzen windowing.")
317321
weight = torch.exp(
318322
-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2
319323
) # (batch, num_sample, num_bin)

tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,25 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
116116

117117

118118
class TestGlobalMutualInformationLossIll(unittest.TestCase):
119+
def test_gaussian_bin_centers_registered_buffer(self):
120+
loss = GlobalMutualInformationLoss(kernel_type="gaussian", num_bins=16)
121+
122+
self.assertIn("bin_centers", dict(loss.named_buffers()))
123+
self.assertIsNotNone(loss.bin_centers)
124+
self.assertFalse(loss.bin_centers.requires_grad)
125+
126+
loss = loss.to(dtype=torch.float64)
127+
self.assertEqual(loss.bin_centers.dtype, torch.float64)
128+
129+
if torch.cuda.is_available():
130+
loss = loss.to(device="cuda:0")
131+
self.assertEqual(loss.bin_centers.device, torch.device("cuda:0"))
132+
133+
def test_b_spline_bin_centers_exists_as_none(self):
134+
loss = GlobalMutualInformationLoss(kernel_type="b-spline")
135+
136+
self.assertIsNone(loss.bin_centers)
137+
119138
@parameterized.expand(
120139
[
121140
(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims

0 commit comments

Comments
 (0)