Skip to content

Commit cc90d74

Browse files
authored
Remove usage of deprecated AsChannelFirst, AddChannel , and SplitChannel (#6281)
Fixes #6280 ### Description This PR removes the usage of `AsChannelFirst`, `AddChannel`, and `SplitChannel` (which are deprecated since v0.8) and add removed version to v1.3. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Behrooz <[email protected]>
1 parent 98b6a15 commit cc90d74

22 files changed

+234
-173
lines changed

Diff for: monai/transforms/transform.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,7 @@ def __call__(self, data: Any):
275275
276276
#. string data without shape, `LoadImage` transform expects file paths,
277277
#. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
278-
except for example: `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and
279-
`AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels),
278+
except for example: `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...])
280279
281280
- the channel dimension is often not omitted even if number of channels is one.
282281
@@ -441,8 +440,7 @@ def __call__(self, data):
441440
442441
#. string data without shape, `LoadImaged` transform expects file paths,
443442
#. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
444-
except for example: `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and
445-
`AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels)
443+
except for example: `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...])
446444
447445
- the channel dimension is often not omitted even if number of channels is one.
448446

Diff for: monai/transforms/utility/array.py

+49-58
Original file line numberDiff line numberDiff line change
@@ -142,38 +142,6 @@ def __call__(self, data: Any) -> Any:
142142
return data
143143

144144

145-
@deprecated(since="0.8", msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.")
146-
class AsChannelFirst(Transform):
147-
"""
148-
Change the channel dimension of the image to the first dimension.
149-
150-
Most of the image transformations in ``monai.transforms``
151-
assume the input image is in the channel-first format, which has the shape
152-
(num_channels, spatial_dim_1[, spatial_dim_2, ...]).
153-
154-
This transform could be used to convert, for example, a channel-last image array in shape
155-
(spatial_dim_1[, spatial_dim_2, ...], num_channels) into the channel-first format,
156-
so that the multidimensional image array can be correctly interpreted by the other transforms.
157-
158-
Args:
159-
channel_dim: which dimension of input image is the channel, default is the last dimension.
160-
"""
161-
162-
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
163-
164-
def __init__(self, channel_dim: int = -1) -> None:
165-
if not (isinstance(channel_dim, int) and channel_dim >= -1):
166-
raise ValueError(f"invalid channel dimension ({channel_dim}).")
167-
self.channel_dim = channel_dim
168-
169-
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
170-
"""
171-
Apply the transform to `img`.
172-
"""
173-
out: NdarrayOrTensor = convert_to_tensor(moveaxis(img, self.channel_dim, 0), track_meta=get_track_meta())
174-
return out
175-
176-
177145
class AsChannelLast(Transform):
178146
"""
179147
Change the channel dimension of the image to the last dimension.
@@ -204,31 +172,6 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
204172
return out
205173

206174

207-
@deprecated(since="0.8", msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.")
208-
class AddChannel(Transform):
209-
"""
210-
Adds a 1-length channel dimension to the input image.
211-
212-
Most of the image transformations in ``monai.transforms``
213-
assumes the input image is in the channel-first format, which has the shape
214-
(num_channels, spatial_dim_1[, spatial_dim_2, ...]).
215-
216-
This transform could be used, for example, to convert a (spatial_dim_1[, spatial_dim_2, ...])
217-
spatial image into the channel-first format so that the
218-
multidimensional image array can be correctly interpreted by the other
219-
transforms.
220-
"""
221-
222-
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
223-
224-
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
225-
"""
226-
Apply the transform to `img`.
227-
"""
228-
out: NdarrayOrTensor = convert_to_tensor(img[None], track_meta=get_track_meta())
229-
return out
230-
231-
232175
class EnsureChannelFirst(Transform):
233176
"""
234177
Adjust or add the channel dimension of input data to ensure `channel_first` shape.
@@ -291,6 +234,54 @@ def __call__(self, img: torch.Tensor, meta_dict: Mapping | None = None) -> torch
291234
return convert_to_tensor(result, track_meta=get_track_meta()) # type: ignore
292235

293236

237+
@deprecated(
238+
since="0.8",
239+
removed="1.3",
240+
msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.",
241+
)
242+
class AsChannelFirst(EnsureChannelFirst):
243+
"""
244+
Change the channel dimension of the image to the first dimension.
245+
Most of the image transformations in ``monai.transforms``
246+
assume the input image is in the channel-first format, which has the shape
247+
(num_channels, spatial_dim_1[, spatial_dim_2, ...]).
248+
This transform could be used to convert, for example, a channel-last image array in shape
249+
(spatial_dim_1[, spatial_dim_2, ...], num_channels) into the channel-first format,
250+
so that the multidimensional image array can be correctly interpreted by the other transforms.
251+
Args:
252+
channel_dim: which dimension of input image is the channel, default is the last dimension.
253+
"""
254+
255+
def __init__(self, channel_dim: int = -1) -> None:
256+
super().__init__(channel_dim=channel_dim)
257+
258+
259+
@deprecated(
260+
since="0.8",
261+
removed="1.3",
262+
msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead"
263+
" with `channel_dim='no_channel'`.",
264+
)
265+
class AddChannel(EnsureChannelFirst):
266+
"""
267+
Adds a 1-length channel dimension to the input image.
268+
269+
Most of the image transformations in ``monai.transforms``
270+
assumes the input image is in the channel-first format, which has the shape
271+
(num_channels, spatial_dim_1[, spatial_dim_2, ...]).
272+
273+
This transform could be used, for example, to convert a (spatial_dim_1[, spatial_dim_2, ...])
274+
spatial image into the channel-first format so that the
275+
multidimensional image array can be correctly interpreted by the other
276+
transforms.
277+
"""
278+
279+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
280+
281+
def __init__(self) -> None:
282+
super().__init__(channel_dim="no_channel")
283+
284+
294285
class RepeatChannel(Transform):
295286
"""
296287
Repeat channel data to construct expected input shape for models.
@@ -391,7 +382,7 @@ def __call__(self, img: torch.Tensor) -> list[torch.Tensor]:
391382
return outputs
392383

393384

394-
@deprecated(since="0.8", msg_suffix="please use `SplitDim` instead.")
385+
@deprecated(since="0.8", removed="1.3", msg_suffix="please use `SplitDim` instead.")
395386
class SplitChannel(SplitDim):
396387
"""
397388
Split Numpy array or PyTorch Tensor data according to the channel dim.

Diff for: monai/transforms/utility/dictionary.py

+45-52
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@
3333
from monai.transforms.traits import MultiSampleTrait, RandomizableTrait
3434
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
3535
from monai.transforms.utility.array import (
36-
AddChannel,
3736
AddCoordinateChannels,
3837
AddExtremePointsChannel,
39-
AsChannelFirst,
4038
AsChannelLast,
4139
CastToType,
4240
ClassesToIndices,
@@ -221,31 +219,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
221219
return d
222220

223221

224-
class AsChannelFirstd(MapTransform):
225-
"""
226-
Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`.
227-
"""
228-
229-
backend = AsChannelFirst.backend
230-
231-
def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None:
232-
"""
233-
Args:
234-
keys: keys of the corresponding items to be transformed.
235-
See also: :py:class:`monai.transforms.compose.MapTransform`
236-
channel_dim: which dimension of input image is the channel, default is the last dimension.
237-
allow_missing_keys: don't raise exception if key is missing.
238-
"""
239-
super().__init__(keys, allow_missing_keys)
240-
self.converter = AsChannelFirst(channel_dim=channel_dim)
241-
242-
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
243-
d = dict(data)
244-
for key in self.key_iterator(d):
245-
d[key] = self.converter(d[key])
246-
return d
247-
248-
249222
class AsChannelLastd(MapTransform):
250223
"""
251224
Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`.
@@ -271,30 +244,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
271244
return d
272245

273246

274-
class AddChanneld(MapTransform):
275-
"""
276-
Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`.
277-
"""
278-
279-
backend = AddChannel.backend
280-
281-
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
282-
"""
283-
Args:
284-
keys: keys of the corresponding items to be transformed.
285-
See also: :py:class:`monai.transforms.compose.MapTransform`
286-
allow_missing_keys: don't raise exception if key is missing.
287-
"""
288-
super().__init__(keys, allow_missing_keys)
289-
self.adder = AddChannel()
290-
291-
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
292-
d = dict(data)
293-
for key in self.key_iterator(d):
294-
d[key] = self.adder(d[key])
295-
return d
296-
297-
298247
class EnsureChannelFirstd(MapTransform):
299248
"""
300249
Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`.
@@ -336,6 +285,50 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
336285
return d
337286

338287

288+
@deprecated(
289+
since="0.8",
290+
removed="1.3",
291+
msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirstd instead.",
292+
)
293+
class AsChannelFirstd(EnsureChannelFirstd):
294+
"""
295+
Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`.
296+
"""
297+
298+
def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None:
299+
"""
300+
Args:
301+
keys: keys of the corresponding items to be transformed.
302+
See also: :py:class:`monai.transforms.compose.MapTransform`
303+
channel_dim: which dimension of input image is the channel, default is the last dimension.
304+
allow_missing_keys: don't raise exception if key is missing.
305+
"""
306+
super().__init__(keys=keys, channel_dim=channel_dim, allow_missing_keys=allow_missing_keys)
307+
308+
309+
@deprecated(
310+
since="0.8",
311+
removed="1.3",
312+
msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirstd instead"
313+
" with `channel_dim='no_channel'`.",
314+
)
315+
class AddChanneld(EnsureChannelFirstd):
316+
"""
317+
Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`.
318+
"""
319+
320+
backend = EnsureChannelFirstd.backend
321+
322+
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
323+
"""
324+
Args:
325+
keys: keys of the corresponding items to be transformed.
326+
See also: :py:class:`monai.transforms.compose.MapTransform`
327+
allow_missing_keys: don't raise exception if key is missing.
328+
"""
329+
super().__init__(keys, allow_missing_keys, channel_dim="no_channel")
330+
331+
339332
class RepeatChanneld(MapTransform):
340333
"""
341334
Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`.
@@ -452,7 +445,7 @@ def __call__(
452445
return d
453446

454447

455-
@deprecated(since="0.8", msg_suffix="please use `SplitDimd` instead.")
448+
@deprecated(since="0.8", removed="1.3", msg_suffix="please use `SplitDimd` instead.")
456449
class SplitChanneld(SplitDimd):
457450
"""
458451
Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`.

Diff for: tests/test_arraydataset.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@
2222
from torch.utils.data import DataLoader
2323

2424
from monai.data import ArrayDataset
25-
from monai.transforms import AddChannel, Compose, LoadImage, RandAdjustContrast, RandGaussianNoise, Spacing
25+
from monai.transforms import Compose, EnsureChannelFirst, LoadImage, RandAdjustContrast, RandGaussianNoise, Spacing
2626

2727
TEST_CASE_1 = [
28-
Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]),
29-
Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]),
28+
Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim="no_channel"), RandGaussianNoise(prob=1.0)]),
29+
Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim="no_channel"), RandGaussianNoise(prob=1.0)]),
3030
(0, 1),
3131
(1, 128, 128, 128),
3232
]
3333

3434
TEST_CASE_2 = [
35-
Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]),
36-
Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]),
35+
Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim="no_channel"), RandAdjustContrast(prob=1.0)]),
36+
Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim="no_channel"), RandAdjustContrast(prob=1.0)]),
3737
(0, 1),
3838
(1, 128, 128, 128),
3939
]
@@ -50,13 +50,30 @@ def __call__(self, input_):
5050

5151

5252
TEST_CASE_3 = [
53-
TestCompose([LoadImage(image_only=True), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
54-
TestCompose([LoadImage(image_only=True), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
53+
TestCompose(
54+
[
55+
LoadImage(image_only=True),
56+
EnsureChannelFirst(channel_dim="no_channel"),
57+
Spacing(pixdim=(2, 2, 4)),
58+
RandAdjustContrast(prob=1.0),
59+
]
60+
),
61+
TestCompose(
62+
[
63+
LoadImage(image_only=True),
64+
EnsureChannelFirst(channel_dim="no_channel"),
65+
Spacing(pixdim=(2, 2, 4)),
66+
RandAdjustContrast(prob=1.0),
67+
]
68+
),
5569
(0, 2),
5670
(1, 64, 64, 33),
5771
]
5872

59-
TEST_CASE_4 = [Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)]
73+
TEST_CASE_4 = [
74+
Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim="no_channel"), RandGaussianNoise(prob=1.0)]),
75+
(1, 128, 128, 128),
76+
]
6077

6178

6279
class TestArrayDataset(unittest.TestCase):

Diff for: tests/test_compose.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ def __call__(self, data):
194194
c.randomize()
195195

196196
def test_err_msg(self):
197-
transforms = mt.Compose([abs, mt.AddChannel(), round])
198-
with self.assertRaisesRegex(Exception, "AddChannel"):
197+
transforms = mt.Compose([abs, mt.EnsureChannelFirst(), round])
198+
with self.assertRaisesRegex(Exception, "EnsureChannelFirst"):
199199
transforms(42.1)
200200

201201
def test_data_loader(self):
@@ -244,7 +244,7 @@ def test_data_loader_2(self):
244244
set_determinism(None)
245245

246246
def test_flatten_and_len(self):
247-
x = mt.AddChannel()
247+
x = mt.EnsureChannelFirst(channel_dim="no_channel")
248248
t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])])
249249

250250
t2 = t1.flatten()

Diff for: tests/test_cross_validation.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from monai.apps import CrossValidation, DecathlonDataset
1818
from monai.data import MetaTensor
19-
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd
19+
from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd
2020
from tests.utils import skip_if_downloading_fails, skip_if_quick
2121

2222

@@ -25,7 +25,11 @@ class TestCrossValidation(unittest.TestCase):
2525
def test_values(self):
2626
testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
2727
train_transform = Compose(
28-
[LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), ScaleIntensityd(keys="image")]
28+
[
29+
LoadImaged(keys=["image", "label"]),
30+
EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
31+
ScaleIntensityd(keys="image"),
32+
]
2933
)
3034
val_transform = LoadImaged(keys=["image", "label"])
3135

0 commit comments

Comments
 (0)