Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,5 @@ runs
*.pth

*zarr/*

monai-dev/
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All notable changes to MONAI are documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
* Added `RandNonCentralChiNoise` and `RandNonCentralChiNoised` for generalized Rician noise simulation in MRI.

## [1.5.1] - 2025-09-22

Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
RandHistogramShift,
RandIntensityRemap,
RandKSpaceSpikeNoise,
RandNonCentralChiNoise,
RandRicianNoise,
RandScaleIntensity,
RandScaleIntensityFixedMean,
Expand Down Expand Up @@ -202,6 +203,9 @@
RandRicianNoised,
RandRicianNoiseD,
RandRicianNoiseDict,
RandNonCentralChiNoised,
RandNonCentralChiNoiseD,
RandNonCentralChiNoiseDict,
RandScaleIntensityd,
RandScaleIntensityD,
RandScaleIntensityDict,
Expand Down
105 changes: 105 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

__all__ = [
"RandGaussianNoise",
"RandNonCentralChiNoise",
"RandRicianNoise",
"ShiftIntensity",
"RandShiftIntensity",
Expand Down Expand Up @@ -140,6 +141,110 @@ def __call__(self, img: NdarrayOrTensor, mean: float | None = None, randomize: b
return img + noise


class RandNonCentralChiNoise(RandomizableTransform):
"""
Add non-central chi noise to an image.
This distribution is the square root of the sum of squares of k independent
Gaussian random variables, where one of the variables has a non-zero mean
(the signal).
This is a generalization of Rician noise. `degrees_of_freedom=2` is Rician noise.
See: https://en.wikipedia.org/wiki/Noncentral_chi_distribution and https://archive.ismrm.org/2024/3123_NZkvJdQat.html

Args:
prob: Probability to add noise.
mean: Mean or "centre" of the Gaussian noise distributions.
std: Standard deviation (spread) of the Gaussian noise distributions.
degrees_of_freedom: Number of Gaussian distributions (degrees of freedom).
`degrees_of_freedom=2` is Rician noise.
channel_wise: If True, treats each channel of the image separately.
relative: If True, the spread of the sampled Gaussian distributions will
be std times the standard deviation of the image or channel's intensity
histogram.
sample_std: If True, sample the spread of the Gaussian distributions
uniformly from 0 to std.
dtype: output data type, if None, same as input image. defaults to float32.

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
prob: float = 0.1,
mean: Sequence[float] | float = 0.0,
std: Sequence[float] | float = 1.0,
degrees_of_freedom: int = 64, # 64 default because typical modern brain MRI is 32 quadrature coils
channel_wise: bool = False,
relative: bool = False,
sample_std: bool = True,
dtype: DtypeLike = np.float32,
) -> None:
RandomizableTransform.__init__(self, prob)
self.prob = prob
self.mean = mean
self.std = std
if not isinstance(degrees_of_freedom, int) or degrees_of_freedom < 1:
raise ValueError("degrees_of_freedom must be an integer >= 1.")
self.degrees_of_freedom = degrees_of_freedom
self.channel_wise = channel_wise
self.relative = relative
self.sample_std = sample_std
self.dtype = dtype

def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float, k: int):
dtype_np = get_equivalent_dtype(img.dtype, np.ndarray)
im_shape = img.shape
_std = self.R.uniform(0, std) if self.sample_std else std

# Create a stack of k noise arrays
noise_shape = (k, *im_shape)
all_noises_np = self.R.normal(mean, _std, size=noise_shape).astype(dtype_np, copy=False)

if isinstance(img, torch.Tensor):
all_noises = torch.tensor(all_noises_np, device=img.device)
all_noises[0] = all_noises[0] + img
sum_sq = torch.sum(all_noises**2, dim=0)
return torch.sqrt(sum_sq)


all_noises_np[0] = all_noises_np[0] + img
sum_sq = np.sum(all_noises_np**2, axis=0)
return np.sqrt(sum_sq)

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if randomize:
super().randomize(None)

if not self._do_transform:
return img

if self.channel_wise:
_mean = ensure_tuple_rep(self.mean, len(img))
_std = ensure_tuple_rep(self.std, len(img))
for i, d in enumerate(img):
img[i] = self._add_noise(
d,
mean=_mean[i],
std=_std[i] * d.std() if self.relative else _std[i],
k=self.degrees_of_freedom,
)
else:
if not isinstance(self.mean, (int, float)):
raise RuntimeError(f"If channel_wise is False, mean must be a float or int, got {type(self.mean)}.")
if not isinstance(self.std, (int, float)):
raise RuntimeError(f"If channel_wise is False, std must be a float or int, got {type(self.std)}.")
std = self.std * img.std().item() if self.relative else self.std
if not isinstance(std, (int, float)):
raise RuntimeError(f"std must be a float or int number, got {type(std)}.")
img = self._add_noise(img, mean=self.mean, std=std, k=self.degrees_of_freedom)
return img



class RandRicianNoise(RandomizableTransform):
"""
Add Rician noise to image.
Expand Down
77 changes: 77 additions & 0 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
RandGibbsNoise,
RandHistogramShift,
RandKSpaceSpikeNoise,
RandNonCentralChiNoise,
RandRicianNoise,
RandScaleIntensity,
RandScaleIntensityFixedMean,
Expand All @@ -69,6 +70,7 @@
__all__ = [
"RandGaussianNoised",
"RandRicianNoised",
"RandNonCentralChiNoised",
"ShiftIntensityd",
"RandShiftIntensityd",
"ScaleIntensityd",
Expand Down Expand Up @@ -236,6 +238,80 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class RandNonCentralChiNoised(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandNonCentralChiNoise`.
Add non-central chi noise to image. This transform assumes all the expected fields have same shape, if want to add
different noise for every field, please use this transform separately.
This is a generalization of Rician noise. `degrees_of_freedom=2` is Rician noise.

Args:
keys: Keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
prob: Probability to add non-central chi noise to the dictionary.
mean: Mean or "centre" of the Gaussian distributions sampled to make up
the noise.
std: Standard deviation (spread) of the Gaussian distributions sampled
to make up the noise.
degrees_of_freedom: Number of Gaussian distributions (degrees of freedom).
`degrees_of_freedom=2` is Rician noise.
channel_wise: If True, treats each channel of the image separately.
relative: If True, the spread of the sampled Gaussian distributions will
be std times the standard deviation of the image or channel's intensity
histogram.
sample_std: If True, sample the spread of the Gaussian distributions
uniformly from 0 to std.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: Don't raise exception if key is missing.
"""

backend = RandNonCentralChiNoise.backend

def __init__(
self,
keys: KeysCollection,
prob: float = 0.1,
mean: Sequence[float] | float = 0.0,
std: Sequence[float] | float = 1.0,
degrees_of_freedom: int = 64,
channel_wise: bool = False,
relative: bool = False,
sample_std: bool = True,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self.rand_non_central_chi_noise = RandNonCentralChiNoise(
prob=1.0,
mean=mean,
std=std,
degrees_of_freedom=degrees_of_freedom,
channel_wise=channel_wise,
relative=relative,
sample_std=sample_std,
dtype=dtype,
)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> RandNonCentralChiNoised:
super().set_random_state(seed, state)
self.rand_non_central_chi_noise.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize(None)
if not self._do_transform:
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

for key in self.key_iterator(d):
d[key] = self.rand_non_central_chi_noise(d[key], randomize=True)
return d

class RandRicianNoised(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandRicianNoise`.
Expand Down Expand Up @@ -1953,6 +2029,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N

RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised
RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised
RandNonCentralChiNoiseD = RandNonCentralChiNoiseDict = RandNonCentralChiNoised
ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd
RandShiftIntensityD = RandShiftIntensityDict = RandShiftIntensityd
StdShiftIntensityD = StdShiftIntensityDict = StdShiftIntensityd
Expand Down
83 changes: 83 additions & 0 deletions tests/transforms/test_rand_noncentralchi_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import RandNonCentralChiNoise
from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D

TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(("test_zero_mean", p, 0, 0.1))
TESTS.append(("test_non_zero_mean", p, 1, 0.5))


class TestRandNonCentralChiNoise(NumpyImageTestCase2D):
@parameterized.expand(TESTS)
def test_correct_results(self, _, in_type, mean, std):
seed = 0
degrees_of_freedom = 64 # 64 is common due to 32 channel head coil
noise_fn = RandNonCentralChiNoise(prob=1.0, mean=mean, std=std, degrees_of_freedom=degrees_of_freedom)
noise_fn.set_random_state(seed)
im = in_type(self.imt)
noised = noise_fn(im)
if isinstance(im, torch.Tensor):
self.assertEqual(im.dtype, noised.dtype)
np.random.seed(seed)
np.random.random()
_std = np.random.uniform(0, std)

noise_shape = (degrees_of_freedom, *self.imt.shape)
all_noises = np.random.normal(mean, _std, size=noise_shape).astype(np.float32)
all_noises[0] += self.imt
sum_sq = np.sum(all_noises**2, axis=0)
expected = np.sqrt(sum_sq)

if isinstance(noised, torch.Tensor):
noised = noised.cpu()
np.testing.assert_allclose(expected, noised, atol=1e-5)

@parameterized.expand(TESTS)
def test_correct_results_dof2(self, _, in_type, mean, std):
"""
Test with k=2 (the Rician case)
"""
seed = 0
degrees_of_freedom = 2
noise_fn = RandNonCentralChiNoise(prob=1.0, mean=mean, std=std, degrees_of_freedom=degrees_of_freedom)
noise_fn.set_random_state(seed)
im = in_type(self.imt)
noised = noise_fn(im)
if isinstance(im, torch.Tensor):
self.assertEqual(im.dtype, noised.dtype)

np.random.seed(seed)
np.random.random() # for prob
_std = np.random.uniform(0, std) # for sample_std
noise_shape = (degrees_of_freedom, *self.imt.shape)
all_noises = np.random.normal(mean, _std, size=noise_shape).astype(np.float32)
all_noises[0] += self.imt
sum_sq = np.sum(all_noises**2, axis=0)
expected = np.sqrt(sum_sq)

if isinstance(noised, torch.Tensor):
noised = noised.cpu()
np.testing.assert_allclose(expected, noised, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
unittest.main()
Loading
Loading