Skip to content

Commit cf95e94

Browse files
authored
Merge pull request #2058 from AdeelH/fix-xr
Fix handing of some edge cases when reading chips from `XarraySource`
2 parents 2359be1 + dd2562c commit cf95e94

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

rastervision_core/rastervision/core/data/raster_source/xarray_source.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ def __init__(self,
7979
self.full_extent = Box(0, 0, height, width)
8080
if bbox is None:
8181
bbox = self.full_extent
82+
else:
83+
if bbox not in self.full_extent:
84+
new_bbox = bbox.intersection(self.full_extent)
85+
log.warning(f'Clipping ({bbox}) to the DataArray\'s '
86+
f'full extent ({self.full_extent}). '
87+
f'New bbox={new_bbox}')
88+
bbox = new_bbox
8289

8390
super().__init__(
8491
channel_order,
@@ -133,20 +140,23 @@ def _get_chip(self,
133140
out_shape: Optional[Tuple[int, ...]] = None) -> np.ndarray:
134141
window = window.to_global_coords(self.bbox)
135142

136-
yslice, xsclice = window.to_slices()
143+
window_within_bbox = window.intersection(self.bbox)
144+
145+
yslice, xslice = window_within_bbox.to_slices()
137146
if self.temporal:
138147
chip = self.data_array.isel(
139-
x=xsclice, y=yslice, band=bands, time=time).to_numpy()
148+
x=xslice, y=yslice, band=bands, time=time).to_numpy()
140149
else:
141150
chip = self.data_array.isel(
142-
x=xsclice, y=yslice, band=bands).to_numpy()
151+
x=xslice, y=yslice, band=bands).to_numpy()
143152

144-
*batch_dims, h, w, c = chip.shape
145-
if window.size != (h, w):
146-
window_actual = window.intersection(self.full_extent)
147-
yslice, xsclice = window_actual.to_local_coords(window).to_slices()
153+
if window != window_within_bbox:
154+
*batch_dims, h, w, c = chip.shape
155+
# coords of window_within_bbox within window
156+
yslice, xslice = window_within_bbox.to_local_coords(
157+
window).to_slices()
148158
tmp = np.zeros((*batch_dims, *window.size, c))
149-
tmp[..., yslice, xsclice, :] = chip
159+
tmp[..., yslice, xslice, :] = chip
150160
chip = tmp
151161

152162
chip = fill_overflow(self.bbox, window, chip)

tests/core/data/raster_source/test_xarray_source.py

+29
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,35 @@ def test_get_raw_chip(self):
101101
chip_expected = np.array([[[0, 1, 2, 3]]], dtype=arr.dtype)
102102
np.testing.assert_array_equal(chip, chip_expected)
103103

104+
def test_get_raw_chip_overflowing_window(self):
105+
arr = np.arange(100).reshape(10, 10, 1)
106+
da = DataArray(arr, dims=['y', 'x', 'band'])
107+
rs = XarraySource(da, IdentityCRSTransformer(), bbox=Box(2, 2, 7, 7))
108+
109+
chip = rs.get_raw_chip(Box(3, 3, 7, 7))
110+
chip_expected = np.zeros((4, 4, 1))
111+
chip_expected[:2, :2] = arr[5:7, 5:7]
112+
np.testing.assert_array_equal(chip, chip_expected)
113+
114+
chip = rs.get_raw_chip(Box(-2, -2, 2, 2))
115+
chip_expected = np.zeros((4, 4, 1))
116+
chip_expected[2:, 2:] = arr[2:4, 2:4]
117+
np.testing.assert_array_equal(chip, chip_expected)
118+
119+
chip = rs.get_raw_chip(Box(-5, -5, 0, 0))
120+
chip_expected = np.zeros((5, 5, 1))
121+
np.testing.assert_array_equal(chip, chip_expected)
122+
123+
chip = rs.get_raw_chip(Box(6, 6, 9, 9))
124+
chip_expected = np.zeros((3, 3, 1))
125+
np.testing.assert_array_equal(chip, chip_expected)
126+
127+
def test_get_bbox_overflows_full_extent(self):
128+
arr = np.empty((5, 5, 1))
129+
da = DataArray(arr, dims=['y', 'x', 'band'])
130+
rs = XarraySource(da, IdentityCRSTransformer(), bbox=Box(2, 2, 5, 7))
131+
self.assertEqual(rs.bbox, Box(2, 2, 5, 5))
132+
104133
def test_get_chip(self):
105134
arr = np.ones((5, 5, 4), dtype=np.uint8)
106135
arr *= np.arange(4, dtype=np.uint8)

0 commit comments

Comments
 (0)