14
14
import numpy as np
15
15
import torch
16
16
17
+ from monai .data import MetaTensor
17
18
from monai .transforms import ConcatItemsd
18
19
19
20
@@ -30,6 +31,20 @@ def test_tensor_values(self):
30
31
torch .testing .assert_allclose (result ["img1" ], torch .tensor ([[0 , 1 ], [1 , 2 ]], device = device ))
31
32
torch .testing .assert_allclose (result ["cat_img" ], torch .tensor ([[1 , 2 ], [2 , 3 ], [1 , 2 ], [2 , 3 ]], device = device ))
32
33
34
+ def test_metatensor_values (self ):
35
+ device = torch .device ("cuda:0" ) if torch .cuda .is_available () else torch .device ("cpu:0" )
36
+ input_data = {
37
+ "img1" : MetaTensor ([[0 , 1 ], [1 , 2 ]], device = device ),
38
+ "img2" : MetaTensor ([[0 , 1 ], [1 , 2 ]], device = device ),
39
+ }
40
+ result = ConcatItemsd (keys = ["img1" , "img2" ], name = "cat_img" )(input_data )
41
+ self .assertTrue ("cat_img" in result )
42
+ self .assertTrue (isinstance (result ["cat_img" ], MetaTensor ))
43
+ self .assertEqual (result ["img1" ].meta , result ["cat_img" ].meta )
44
+ result ["cat_img" ] += 1
45
+ torch .testing .assert_allclose (result ["img1" ], torch .tensor ([[0 , 1 ], [1 , 2 ]], device = device ))
46
+ torch .testing .assert_allclose (result ["cat_img" ], torch .tensor ([[1 , 2 ], [2 , 3 ], [1 , 2 ], [2 , 3 ]], device = device ))
47
+
33
48
def test_numpy_values (self ):
34
49
input_data = {"img1" : np .array ([[0 , 1 ], [1 , 2 ]]), "img2" : np .array ([[0 , 1 ], [1 , 2 ]])}
35
50
result = ConcatItemsd (keys = ["img1" , "img2" ], name = "cat_img" )(input_data )
@@ -52,6 +67,13 @@ def test_single_tensor(self):
52
67
torch .testing .assert_allclose (result ["img" ], torch .tensor ([[0 , 1 ], [1 , 2 ]]))
53
68
torch .testing .assert_allclose (result ["cat_img" ], torch .tensor ([[1 , 2 ], [2 , 3 ]]))
54
69
70
+ def test_single_metatensor (self ):
71
+ input_data = {"img" : MetaTensor ([[0 , 1 ], [1 , 2 ]])}
72
+ result = ConcatItemsd (keys = "img" , name = "cat_img" )(input_data )
73
+ result ["cat_img" ] += 1
74
+ torch .testing .assert_allclose (result ["img" ], torch .tensor ([[0 , 1 ], [1 , 2 ]]))
75
+ torch .testing .assert_allclose (result ["cat_img" ], torch .tensor ([[1 , 2 ], [2 , 3 ]]))
76
+
55
77
56
78
if __name__ == "__main__" :
57
79
unittest .main ()
0 commit comments