Skip to content

Commit fb3d9f1

Browse files
authored
Width and height swapped when creating test image in data/synthetic/create_test_image_2d (#5627)
Signed-off-by: OeslleLucena <[email protected]> Fixes #5373. ### Description Change the `synthetic.py` script to have the (height, width, depth) format. Edit all the tests that uses `create_test_image_3d` to match the new format. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: OeslleLucena <[email protected]>
1 parent e50fa88 commit fb3d9f1

11 files changed

+53
-53
lines changed

monai/data/synthetic.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020

2121
def create_test_image_2d(
22-
width: int,
2322
height: int,
23+
width: int,
2424
num_objs: int = 12,
2525
rad_max: int = 30,
2626
rad_min: int = 5,
@@ -37,8 +37,8 @@ def create_test_image_2d(
3737
an image without channel dimension, otherwise create an image with channel dimension as first dim or last dim.
3838
3939
Args:
40-
width: width of the image. The value should be larger than `2 * rad_max`.
4140
height: height of the image. The value should be larger than `2 * rad_max`.
41+
width: width of the image. The value should be larger than `2 * rad_max`.
4242
num_objs: number of circles to generate. Defaults to `12`.
4343
rad_max: maximum circle radius. Defaults to `30`.
4444
rad_min: minimum circle radius. Defaults to `5`.
@@ -50,25 +50,25 @@ def create_test_image_2d(
5050
random_state: the random generator to use. Defaults to `np.random`.
5151
5252
Returns:
53-
Randomised Numpy array with shape (`width`, `height`)
53+
Randomised Numpy array with shape (`height`, `width`)
5454
"""
5555

5656
if rad_max <= rad_min:
5757
raise ValueError("`rad_min` should be less than `rad_max`.")
5858
if rad_min < 1:
5959
raise ValueError("`rad_min` should be no less than 1.")
60-
min_size = min(width, height)
60+
min_size = min(height, width)
6161
if min_size <= 2 * rad_max:
6262
raise ValueError("the minimal size of the image should be larger than `2 * rad_max`.")
6363

64-
image = np.zeros((width, height))
64+
image = np.zeros((height, width))
6565
rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore
6666

6767
for _ in range(num_objs):
68-
x = rs.randint(rad_max, width - rad_max)
69-
y = rs.randint(rad_max, height - rad_max)
68+
x = rs.randint(rad_max, height - rad_max)
69+
y = rs.randint(rad_max, width - rad_max)
7070
rad = rs.randint(rad_min, rad_max)
71-
spy, spx = np.ogrid[-x : width - x, -y : height - y]
71+
spy, spx = np.ogrid[-x : height - x, -y : width - y]
7272
circle = (spx * spx + spy * spy) <= rad * rad
7373

7474
if num_seg_classes > 1:
@@ -124,7 +124,7 @@ def create_test_image_3d(
124124
random_state: the random generator to use. Defaults to `np.random`.
125125
126126
Returns:
127-
Randomised Numpy array with shape (`width`, `height`, `depth`)
127+
Randomised Numpy array with shape (`height`, `width`, `depth`)
128128
129129
See also:
130130
:py:meth:`~create_test_image_2d`
@@ -134,19 +134,19 @@ def create_test_image_3d(
134134
raise ValueError("`rad_min` should be less than `rad_max`.")
135135
if rad_min < 1:
136136
raise ValueError("`rad_min` should be no less than 1.")
137-
min_size = min(width, height, depth)
137+
min_size = min(height, width, depth)
138138
if min_size <= 2 * rad_max:
139139
raise ValueError("the minimal size of the image should be larger than `2 * rad_max`.")
140140

141-
image = np.zeros((width, height, depth))
141+
image = np.zeros((height, width, depth))
142142
rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore
143143

144144
for _ in range(num_objs):
145-
x = rs.randint(rad_max, width - rad_max)
146-
y = rs.randint(rad_max, height - rad_max)
145+
x = rs.randint(rad_max, height - rad_max)
146+
y = rs.randint(rad_max, width - rad_max)
147147
z = rs.randint(rad_max, depth - rad_max)
148148
rad = rs.randint(rad_min, rad_max)
149-
spy, spx, spz = np.ogrid[-x : width - x, -y : height - y, -z : depth - z]
149+
spy, spx, spz = np.ogrid[-x : height - x, -y : width - y, -z : depth - z]
150150
circle = (spx * spx + spy * spy + spz * spz) <= rad * rad
151151

152152
if num_seg_classes > 1:

tests/test_adn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
[{"norm": "INSTANCE", "norm_dim": 3, "dropout_dim": 1, "dropout": 0.8, "ordering": "AND"}],
4747
[
4848
{
49-
"norm": ("layer", {"normalized_shape": (64, 80)}),
49+
"norm": ("layer", {"normalized_shape": (48, 80)}),
5050
"norm_dim": 3,
5151
"dropout_dim": 1,
5252
"dropout": 0.8,
@@ -76,7 +76,7 @@ def test_adn_3d(self, args):
7676
adn = ADN(**args)
7777
print(adn)
7878
out = adn(self.imt)
79-
expected_shape = (1, self.input_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2])
79+
expected_shape = (1, self.input_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])
8080
self.assertEqual(out.shape, expected_shape)
8181

8282

tests/test_convolutions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,19 @@ class TestConvolution3D(TorchImageTestCase3D):
7070
def test_conv1(self):
7171
conv = Convolution(3, self.input_channels, self.output_channels, dropout=0.1, adn_ordering="DAN")
7272
out = conv(self.imt)
73-
expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2])
73+
expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])
7474
self.assertEqual(out.shape, expected_shape)
7575

7676
def test_conv1_no_acti(self):
7777
conv = Convolution(3, self.input_channels, self.output_channels, act=None, adn_ordering="AND")
7878
out = conv(self.imt)
79-
expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2])
79+
expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])
8080
self.assertEqual(out.shape, expected_shape)
8181

8282
def test_conv_only1(self):
8383
conv = Convolution(3, self.input_channels, self.output_channels, conv_only=True)
8484
out = conv(self.imt)
85-
expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2])
85+
expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])
8686
self.assertEqual(out.shape, expected_shape)
8787

8888
def test_stride1(self):
@@ -92,34 +92,34 @@ def test_stride1(self):
9292
expected_shape = (
9393
1,
9494
self.output_channels,
95-
self.im_shape[1] // 2,
9695
self.im_shape[0] // 2,
96+
self.im_shape[1] // 2,
9797
self.im_shape[2] // 2,
9898
)
9999
self.assertEqual(out.shape, expected_shape)
100100

101101
def test_dilation1(self):
102102
conv = Convolution(3, self.input_channels, self.output_channels, dilation=3)
103103
out = conv(self.imt)
104-
expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2])
104+
expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])
105105
self.assertEqual(out.shape, expected_shape)
106106

107107
def test_dropout1(self):
108108
conv = Convolution(3, self.input_channels, self.output_channels, dropout=0.15)
109109
out = conv(self.imt)
110-
expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2])
110+
expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])
111111
self.assertEqual(out.shape, expected_shape)
112112

113113
def test_transpose1(self):
114114
conv = Convolution(3, self.input_channels, self.output_channels, is_transposed=True)
115115
out = conv(self.imt)
116-
expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2])
116+
expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])
117117
self.assertEqual(out.shape, expected_shape)
118118

119119
def test_transpose2(self):
120120
conv = Convolution(3, self.input_channels, self.output_channels, strides=2, is_transposed=True)
121121
out = conv(self.imt)
122-
expected_shape = (1, self.output_channels, self.im_shape[1] * 2, self.im_shape[0] * 2, self.im_shape[2] * 2)
122+
expected_shape = (1, self.output_channels, self.im_shape[0] * 2, self.im_shape[1] * 2, self.im_shape[2] * 2)
123123
self.assertEqual(out.shape, expected_shape)
124124

125125

tests/test_denseblock.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def test_block_conv(self):
4242
expected_shape = (
4343
1,
4444
self.output_channels + self.input_channels * 2,
45-
self.im_shape[1],
4645
self.im_shape[0],
46+
self.im_shape[1],
4747
self.im_shape[2],
4848
)
4949
self.assertEqual(out.shape, expected_shape)
@@ -87,15 +87,15 @@ def test_block1(self):
8787
channels = [2, 4]
8888
conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=channels)
8989
out = conv(self.imt)
90-
expected_shape = (1, self.input_channels + sum(channels), self.im_shape[1], self.im_shape[0], self.im_shape[2])
90+
expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1], self.im_shape[2])
9191
self.assertEqual(out.shape, expected_shape)
9292

9393
def test_block2(self):
9494
channels = [2, 4]
9595
dilations = [1, 2]
9696
conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=channels, dilations=dilations)
9797
out = conv(self.imt)
98-
expected_shape = (1, self.input_channels + sum(channels), self.im_shape[1], self.im_shape[0], self.im_shape[2])
98+
expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1], self.im_shape[2])
9999
self.assertEqual(out.shape, expected_shape)
100100

101101

tests/test_integration_sliding_window.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class TestIntegrationSlidingWindow(DistTestCase):
6767
def setUp(self):
6868
set_determinism(seed=0)
6969

70-
im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)
70+
im, seg = create_test_image_3d(28, 25, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)
7171
self.img_name = make_nifti_image(im)
7272
self.seg_name = make_nifti_image(seg)
7373
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")

tests/test_invert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_invert(self):
7474
self.assertTupleEqual(orig.shape[1:], (100, 100, 100))
7575
# check the nearest interpolation mode
7676
assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
77-
self.assertTupleEqual(i.shape[1:], (100, 101, 107))
77+
self.assertTupleEqual(i.shape[1:], (101, 100, 107))
7878
# check labels match
7979
reverted = i.detach().cpu().numpy().astype(np.int32)
8080
original = LoadImage(image_only=True)(data[-1])

tests/test_invertd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,21 +104,21 @@ def test_invert(self):
104104
# check the nearest interpolation mode
105105
i = item["image_inverted"]
106106
assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
107-
self.assertTupleEqual(i.shape[1:], (100, 101, 107))
107+
self.assertTupleEqual(i.shape[1:], (101, 100, 107))
108108
i = item["label_inverted"]
109109
assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
110-
self.assertTupleEqual(i.shape[1:], (100, 101, 107))
110+
self.assertTupleEqual(i.shape[1:], (101, 100, 107))
111111

112112
# check the case that different items use different interpolation mode to invert transforms
113113
d = item["image_inverted1"]
114114
# if the interpolation mode is nearest, accumulated diff should be smaller than 1
115115
self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0)
116-
self.assertTupleEqual(d.shape, (1, 100, 101, 107))
116+
self.assertTupleEqual(d.shape, (1, 101, 100, 107))
117117

118118
d = item["label_inverted1"]
119119
# if the interpolation mode is not nearest, accumulated diff should be greater than 10000
120120
self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0)
121-
self.assertTupleEqual(d.shape, (1, 100, 101, 107))
121+
self.assertTupleEqual(d.shape, (1, 101, 100, 107))
122122

123123
# check labels match
124124
reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32)

tests/test_rand_rotate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
TEST_CASES_3D: List[Tuple] = []
3838
for p in TEST_NDARRAYS_ALL:
3939
TEST_CASES_3D.append(
40-
(p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109))
40+
(p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 81, 110, 112))
4141
)
4242
TEST_CASES_3D.append(
4343
(
@@ -49,7 +49,7 @@
4949
"nearest",
5050
"border",
5151
True,
52-
(1, 89, 105, 104),
52+
(1, 97, 100, 97),
5353
)
5454
)
5555
TEST_CASES_3D.append(
@@ -62,10 +62,10 @@
6262
"nearest",
6363
"zeros",
6464
True,
65-
(1, 48, 64, 80),
65+
(1, 64, 48, 80),
6666
)
6767
)
68-
TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)))
68+
TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 64, 61, 87)))
6969

7070

7171
class TestRandRotate2D(NumpyImageTestCase2D):

tests/test_rand_rotated.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
TEST_CASES_3D: List[Tuple] = []
3232
for p in TEST_NDARRAYS_ALL:
3333
TEST_CASES_3D.append(
34-
(p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109))
34+
(p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 81, 110, 112))
3535
)
3636
TEST_CASES_3D.append(
3737
(
@@ -43,7 +43,7 @@
4343
GridSampleMode.NEAREST,
4444
GridSamplePadMode.BORDER,
4545
False,
46-
(1, 87, 104, 109),
46+
(1, 81, 110, 112),
4747
)
4848
)
4949
TEST_CASES_3D.append(
@@ -56,7 +56,7 @@
5656
"nearest",
5757
"border",
5858
True,
59-
(1, 89, 105, 104),
59+
(1, 97, 100, 97),
6060
)
6161
)
6262
TEST_CASES_3D.append(
@@ -69,7 +69,7 @@
6969
GridSampleMode.NEAREST,
7070
GridSamplePadMode.BORDER,
7171
True,
72-
(1, 89, 105, 104),
72+
(1, 97, 100, 97),
7373
)
7474
)
7575
TEST_CASES_3D.append(
@@ -82,7 +82,7 @@
8282
"nearest",
8383
"zeros",
8484
True,
85-
(1, 48, 64, 80),
85+
(1, 64, 48, 80),
8686
)
8787
)
8888
TEST_CASES_3D.append(
@@ -95,12 +95,12 @@
9595
GridSampleMode.NEAREST,
9696
GridSamplePadMode.ZEROS,
9797
True,
98-
(1, 48, 64, 80),
98+
(1, 64, 48, 80),
9999
)
100100
)
101-
TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)))
101+
TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 64, 61, 87)))
102102
TEST_CASES_3D.append(
103-
(p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90))
103+
(p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 64, 61, 87))
104104
)
105105

106106

tests/test_rand_weighted_crop.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def get_data(ndim):
112112
dict(spatial_size=(10, -1, -1), num_samples=3),
113113
p(im),
114114
q(weight),
115-
(1, 10, 64, 80),
116-
[[14, 32, 40], [41, 32, 40], [20, 32, 40]],
115+
(1, 10, 48, 80),
116+
[[14, 24, 40], [41, 24, 40], [20, 24, 40]],
117117
]
118118
)
119119
im = SEGN_3D
@@ -126,8 +126,8 @@ def get_data(ndim):
126126
dict(spatial_size=(10000, 400, 80), num_samples=3),
127127
p(im),
128128
q(weight),
129-
(1, 48, 64, 80),
130-
[[24, 32, 40], [24, 32, 40], [24, 32, 40]],
129+
(1, 64, 48, 80),
130+
[[32, 24, 40], [32, 24, 40], [32, 24, 40]],
131131
]
132132
)
133133
im = IMT_3D
@@ -138,11 +138,11 @@ def get_data(ndim):
138138
TESTS.append(
139139
[
140140
"bad w 3d",
141-
dict(spatial_size=(48, 64, 80), num_samples=3),
141+
dict(spatial_size=(64, 48, 80), num_samples=3),
142142
p(im),
143143
q(weight),
144-
(1, 48, 64, 80),
145-
[[24, 32, 40], [24, 32, 40], [24, 32, 40]],
144+
(1, 64, 48, 80),
145+
[[32, 24, 40], [32, 24, 40], [32, 24, 40]],
146146
]
147147
)
148148

tests/test_synthetic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
[2, {"width": 64, "height": 64, "rad_max": 10, "rad_min": 4}, 0.1479004, 0.739502, (64, 64), 5],
2222
[
2323
2,
24-
{"width": 32, "height": 28, "num_objs": 3, "rad_max": 5, "rad_min": 1, "noise_max": 0.2},
24+
{"width": 28, "height": 32, "num_objs": 3, "rad_max": 5, "rad_min": 1, "noise_max": 0.2},
2525
0.1709315,
2626
0.4040179,
2727
(32, 28),

0 commit comments

Comments
 (0)