Skip to content

Commit 219ed0d

Browse files
authored
Fixed inconsistency for shapes of non-square images (#1202)
Signed-off-by: Charlelie Laurent <[email protected]>
1 parent afe3966 commit 219ed0d

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

examples/weather/corrdiff/conf/base/dataset/cwb.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19]
2323
# Indices of output channels
2424
out_channels: [0, 1, 2, 3]
2525
# Shape of the image
26-
img_shape_x: 448
27-
img_shape_y: 448
26+
img_shape_x: 448 # domain width
27+
img_shape_y: 448 # domain height
2828
# Add grid coordinates to the image
2929
add_grid: true
3030
# Factor to downscale the image

examples/weather/corrdiff/datasets/cwb.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,9 @@ def __getitem__(self, idx):
399399
# crop and downsamples
400400
# rolling
401401
if self.train and self.roll:
402-
y_roll = random.randint(0, self.img_shape_y)
402+
x_roll = random.randint(0, self.img_shape_x)
403403
else:
404-
y_roll = 0
404+
x_roll = 0
405405

406406
# channels
407407
input = input[self.in_channels, :, :]
@@ -411,7 +411,7 @@ def __getitem__(self, idx):
411411
target = self._create_lowres_(target, factor=self.ds_factor)
412412

413413
reshape_args = (
414-
y_roll,
414+
x_roll,
415415
self.train,
416416
self.n_history,
417417
self.in_channels,
@@ -468,7 +468,7 @@ def time(self):
468468

469469
def image_shape(self):
470470
"""Get the shape of the image (same for input and output)."""
471-
return (self.img_shape_x, self.img_shape_y)
471+
return (self.img_shape_y, self.img_shape_x)
472472

473473
def normalize_input(self, x):
474474
"""Convert input from physical units to normalized data."""

examples/weather/corrdiff/datasets/img_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
def reshape_fields(
2323
img,
2424
inp_or_tar,
25-
y_roll,
25+
x_roll,
2626
train,
2727
n_history,
2828
in_channels,
@@ -39,7 +39,7 @@ def reshape_fields(
3939
):
4040
"""
4141
Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of
42-
size ((n_channels*(n_history+1), img_shape_x, img_shape_y)
42+
size ((n_channels*(n_history+1), img_shape_y, img_shape_x)
4343
"""
4444

4545
if len(np.shape(img)) == 3:
@@ -59,7 +59,7 @@ def reshape_fields(
5959
means = np.load(global_means_path)[:, channels]
6060
stds = np.load(global_stds_path)[:, channels]
6161

62-
img = img[:, :, :img_shape_x, :img_shape_y]
62+
img = img[:, :, :img_shape_y, :img_shape_x]
6363

6464
if normalize and train:
6565
if normalization == "minmax":
@@ -70,11 +70,11 @@ def reshape_fields(
7070
img /= stds
7171

7272
if roll:
73-
img = np.roll(img, y_roll, axis=-1)
73+
img = np.roll(img, x_roll, axis=-1)
7474

7575
if inp_or_tar == "inp":
76-
img = np.reshape(img, (n_channels * (n_history + 1), img_shape_x, img_shape_y))
76+
img = np.reshape(img, (n_channels * (n_history + 1), img_shape_y, img_shape_x))
7777
elif inp_or_tar == "tar":
78-
img = np.reshape(img, (n_channels, img_shape_x, img_shape_y))
78+
img = np.reshape(img, (n_channels, img_shape_y, img_shape_x))
7979

8080
return torch.as_tensor(img)

0 commit comments

Comments
 (0)