Skip to content

Commit 5e36e23

Browse files
authored
4741 Support MetaTensor data type in Concatitemd (#4745)
* add Metatenosr support in ConcatItemd Signed-off-by: KumoLiu <[email protected]>
1 parent 9c12cd8 commit 5e36e23

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

Diff for: monai/transforms/utility/dictionary.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
929929
class ConcatItemsd(MapTransform):
930930
"""
931931
Concatenate specified items from data dictionary together on the first dim to construct a big array.
932-
Expect all the items are numpy array or PyTorch Tensor.
932+
Expect all the items are numpy array or PyTorch Tensor or MetaTensor.
933+
Return the first input's meta information when items are MetaTensor.
933934
"""
934935

935936
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
@@ -951,7 +952,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
951952
"""
952953
Raises:
953954
TypeError: When items in ``data`` differ in type.
954-
TypeError: When the item type is not in ``Union[numpy.ndarray, torch.Tensor]``.
955+
TypeError: When the item type is not in ``Union[numpy.ndarray, torch.Tensor, MetaTensor]``.
955956
956957
"""
957958
d = dict(data)
@@ -969,10 +970,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
969970

970971
if data_type is np.ndarray:
971972
d[self.name] = np.concatenate(output, axis=self.dim)
972-
elif data_type is torch.Tensor:
973+
elif issubclass(data_type, torch.Tensor): # type: ignore
973974
d[self.name] = torch.cat(output, dim=self.dim) # type: ignore
974975
else:
975-
raise TypeError(f"Unsupported data type: {data_type}, available options are (numpy.ndarray, torch.Tensor).")
976+
raise TypeError(
977+
f"Unsupported data type: {data_type}, available options are (numpy.ndarray, torch.Tensor, MetaTensor)."
978+
)
976979
return d
977980

978981

Diff for: tests/test_concat_itemsd.py

+22
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import torch
1616

17+
from monai.data import MetaTensor
1718
from monai.transforms import ConcatItemsd
1819

1920

@@ -30,6 +31,20 @@ def test_tensor_values(self):
3031
torch.testing.assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
3132
torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
3233

34+
def test_metatensor_values(self):
35+
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0")
36+
input_data = {
37+
"img1": MetaTensor([[0, 1], [1, 2]], device=device),
38+
"img2": MetaTensor([[0, 1], [1, 2]], device=device),
39+
}
40+
result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data)
41+
self.assertTrue("cat_img" in result)
42+
self.assertTrue(isinstance(result["cat_img"], MetaTensor))
43+
self.assertEqual(result["img1"].meta, result["cat_img"].meta)
44+
result["cat_img"] += 1
45+
torch.testing.assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
46+
torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
47+
3348
def test_numpy_values(self):
3449
input_data = {"img1": np.array([[0, 1], [1, 2]]), "img2": np.array([[0, 1], [1, 2]])}
3550
result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data)
@@ -52,6 +67,13 @@ def test_single_tensor(self):
5267
torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]]))
5368
torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3]]))
5469

70+
def test_single_metatensor(self):
71+
input_data = {"img": MetaTensor([[0, 1], [1, 2]])}
72+
result = ConcatItemsd(keys="img", name="cat_img")(input_data)
73+
result["cat_img"] += 1
74+
torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]]))
75+
torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3]]))
76+
5577

5678
if __name__ == "__main__":
5779
unittest.main()

0 commit comments

Comments
 (0)