Skip to content

Commit 97ce002

Browse files
committed
fixing multifiles dataloader with resampling
1 parent 2ad7951 commit 97ce002

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

makani/utils/dataloaders/data_loader_dummy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,6 @@ def __init__(self,
114114
torch.deg2rad(torch.tensor(self.lat_lon[1][self.crop_anchor[1] : self.crop_anchor[1] + self.crop_shape[1]])).to(torch.float32),
115115
)
116116

117-
# rescaled image shape
118-
self.img_shape_resampled = (math.ceil(self.img_shape[0] / self.subsampling_factor),
119-
math.ceil(self.img_shape[1] / self.subsampling_factor))
120-
121117
def _get_files_stats(self):
122118

123119
if self.img_shape is None:
@@ -198,6 +194,9 @@ def _get_files_stats(self):
198194
self.img_local_offset_x = self.read_anchor[0]
199195
self.img_local_offset_y = self.read_anchor[1]
200196

197+
# resampling stuff
198+
self.img_shape_resampled = (math.ceil(self.img_shape[0] / self.subsampling_factor),
199+
math.ceil(self.img_shape[1] / self.subsampling_factor))
201200
self.img_local_shape_x_resampled = self.return_shape[0]
202201
self.img_local_shape_y_resampled = self.return_shape[1]
203202
self.img_shape_x_resampled = self.img_shape_resampled[0]

makani/utils/dataloaders/data_loader_multifiles.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ def __init__(self,
153153
self.lat_lon_local[0][::self.subsampling_factor],
154154
self.lat_lon_local[1][::self.subsampling_factor],
155155
)
156-
self.img_shape_resampled = (math.ceil(self.img_shape[0] / self.subsampling_factor),
157-
math.ceil(self.img_shape[1] / self.subsampling_factor))
158156

159157
# grid types
160158
self.grid_converter = GridConverter(
@@ -332,6 +330,9 @@ def _get_files_stats(self, enable_logging):
332330
self.img_local_offset_x = self.read_anchor[0]
333331
self.img_local_offset_y = self.read_anchor[1]
334332

333+
# resampling stuff
334+
self.img_shape_resampled = (math.ceil(self.img_shape[0] / self.subsampling_factor),
335+
math.ceil(self.img_shape[1] / self.subsampling_factor))
335336
self.img_local_shape_x_resampled = self.return_shape[0]
336337
self.img_local_shape_y_resampled = self.return_shape[1]
337338
self.img_shape_x_resampled = self.img_shape_resampled[0]

0 commit comments

Comments
 (0)