Skip to content

Commit 474e3f1

Browse files
committed
fix mask
1 parent 43adc77 commit 474e3f1

3 files changed

Lines changed: 26 additions & 26 deletions

File tree

tests/test_io_rasterio.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_point_valid():
259259
pt = src.point(lon, lat, expression="b1*2;b1-100")
260260
assert len(pt.data) == 2
261261
assert len(pt.mask) == 1
262-
assert pt.mask[0] == 255
262+
assert pt._mask[0]
263263
assert pt.band_names == ["b1*2", "b1-100"]
264264

265265
with pytest.warns(ExpressionMixingWarning):
@@ -284,7 +284,7 @@ def test_point_valid():
284284

285285
pt = src.point(-59.53, 74.03, indexes=(1, 1, 1))
286286
assert len(pt.data) == 3
287-
assert pt.mask[0] == 0
287+
assert not pt._mask[0]
288288
assert pt.band_names == ["b1", "b1", "b1"]
289289

290290

@@ -554,7 +554,7 @@ def test_imageData_output():
554554
assert img.count == 1
555555
assert img.data_as_image().shape == (256, 256, 1)
556556

557-
assert numpy.array_equal(~img.array.mask[0] * 255, img.mask)
557+
assert numpy.array_equal(~img.array.mask[0], img._mask)
558558

559559
assert img.crs == WEB_MERCATOR_TMS.crs
560560
assert img.bounds == WEB_MERCATOR_TMS.xy_bounds(43, 24, 7)

tests/test_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_point_data():
235235
assert pt.count == 3
236236
assert pt.data.shape == (3,)
237237
assert pt.mask.shape == (1,)
238-
assert pt.mask.tolist() == [255]
238+
assert pt._mask.tolist() == [True]
239239
assert pt.band_names == ["b1", "b2", "b3"]
240240

241241
with pytest.raises(ValueError):
@@ -257,7 +257,7 @@ def test_point_data():
257257
pts = PointData.create_from_list([pt1, pt2])
258258
assert pts.data.tolist() == [1, 2, 3]
259259
assert pts.band_names == ["b1", "b2", "b1+b2"]
260-
assert pts.mask.tolist() == [255]
260+
assert pts._mask.tolist() == [True]
261261

262262
pts = PointData.create_from_list(
263263
[
@@ -266,7 +266,7 @@ def test_point_data():
266266
]
267267
)
268268
assert pts.array.mask.tolist() == [False, True]
269-
assert pts.mask.tolist() == [0]
269+
assert pts._mask.tolist() == [False]
270270

271271
pts = PointData.create_from_list(
272272
[
@@ -275,7 +275,7 @@ def test_point_data():
275275
]
276276
)
277277
assert pts.array.mask.tolist() == [False, False]
278-
assert pts.mask.tolist() == [255]
278+
assert pts._mask.tolist() == [True]
279279

280280
pts = PointData.create_from_list(
281281
[
@@ -284,7 +284,7 @@ def test_point_data():
284284
]
285285
)
286286
assert pts.array.mask.tolist() == [True, True]
287-
assert pts.mask.tolist() == [0]
287+
assert pts._mask.tolist() == [False]
288288

289289
with pytest.raises(InvalidPointDataError):
290290
PointData.create_from_list([])

tests/test_reader.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -425,19 +425,19 @@ def test_read_nodata():
425425
with rasterio.open(COG) as src_dst:
426426
arr, mask = reader.part(src_dst, bounds, nodata=1)
427427

428-
masknodata = (arr[0] != 1).astype(numpy.uint8) * 255
428+
masknodata = (arr[0] != 1).astype(numpy.uint16) * 65535
429429
numpy.testing.assert_array_equal(mask, masknodata)
430430

431431
with rasterio.open(COG) as src_dst:
432432
arr, mask = reader.read(src_dst, nodata=1)
433433

434-
masknodata = (arr[0] != 1).astype(numpy.uint8) * 255
434+
masknodata = (arr[0] != 1).astype(numpy.uint16) * 65535
435435
numpy.testing.assert_array_equal(mask, masknodata)
436436

437437
with rasterio.open(COG) as src_dst:
438438
arr, mask = reader.read(src_dst, dst_crs="epsg:3857", nodata=1)
439439

440-
masknodata = (arr[0] != 1).astype(numpy.uint8) * 255
440+
masknodata = (arr[0] != 1).astype(numpy.uint16) * 65535
441441
numpy.testing.assert_array_equal(mask, masknodata)
442442

443443

@@ -529,7 +529,7 @@ def test_point():
529529
nodata=1,
530530
)
531531
assert pt.data == numpy.array([2800])
532-
assert pt.mask == numpy.array([255])
532+
assert pt.mask == numpy.array([[65535]])
533533
assert pt.band_names == ["b1"]
534534

535535
# resampling_method is useless with interpolate=False
@@ -553,7 +553,7 @@ def test_point():
553553
interpolate=True,
554554
)
555555
assert pt.data == numpy.array([2800])
556-
assert pt.mask == numpy.array([255])
556+
assert pt.mask == numpy.array([[65535]])
557557
assert pt.band_names == ["b1"]
558558
assert pt.pixel_location
559559
assert isinstance(pt.pixel_location[0], float)
@@ -569,7 +569,7 @@ def test_point():
569569
interpolate=True,
570570
)
571571
assert pt.data == numpy.array([2819])
572-
assert pt.mask == numpy.array([255])
572+
assert pt.mask == numpy.array([[65535]])
573573
assert pt.band_names == ["b1"]
574574

575575
# Interpolate=True + resampling=average
@@ -583,7 +583,7 @@ def test_point():
583583
interpolate=True,
584584
)
585585
assert pt.data == numpy.array([2904])
586-
assert pt.mask == numpy.array([255])
586+
assert pt.mask == numpy.array([[65535]])
587587
assert pt.band_names == ["b1"]
588588

589589
# Interpolate=True + resampling=Cubic
@@ -597,13 +597,13 @@ def test_point():
597597
interpolate=True,
598598
)
599599
assert pt.data == numpy.array([2812])
600-
assert pt.mask == numpy.array([255])
600+
assert pt.mask == numpy.array([[65535]])
601601
assert pt.band_names == ["b1"]
602602

603603
with rasterio.open(COG_SCALE) as src_dst:
604604
pt = reader.point(src_dst, [310000, 4100000], coord_crs=src_dst.crs, indexes=1)
605605
assert pt.data == numpy.array([8917])
606-
assert pt.mask == numpy.array([255])
606+
assert pt.mask == numpy.array([[32767]])
607607
assert pt.band_names == ["b1"]
608608

609609
pt = reader.point(src_dst, [310000, 4100000], coord_crs=src_dst.crs)
@@ -847,23 +847,23 @@ def test_part_no_VRT():
847847
img = reader.part(src_dst, bounds, bounds_crs="epsg:4326")
848848
assert img.height == 1453
849849
assert img.width == 1613
850-
assert img.mask[0, 0] == 255
850+
assert img.mask[0, 0] == 65535
851851
assert img.mask[-1, -1] == 0 # boundless
852852
assert img.bounds == bounds_dst_crs
853853

854854
# Use bbox in Image CRS
855855
img_crs = reader.part(src_dst, bounds_dst_crs)
856856
assert img.height == 1453
857857
assert img.width == 1613
858-
assert img_crs.mask[0, 0] == 255
858+
assert img_crs.mask[0, 0] == 65535
859859
assert img_crs.mask[-1, -1] == 0 # boundless
860860
assert img.bounds == bounds_dst_crs
861861

862862
# MaxSize
863863
img = reader.part(src_dst, bounds, bounds_crs="epsg:4326", max_size=1024)
864864
assert img.height < 1024
865865
assert img.width == 1024
866-
assert img.mask[0, 0] == 255
866+
assert img.mask[0, 0] == 65535
867867
assert img.mask[-1, -1] == 0 # boundless
868868
assert img.bounds == bounds_dst_crs
869869

@@ -877,15 +877,15 @@ def test_part_no_VRT():
877877
)
878878
assert img.height == 100
879879
assert img.width == 100
880-
assert img.mask[0, 0] == 255
880+
assert img.mask[0, 0] == 65535
881881
assert img.mask[-1, -1] == 0 # boundless
882882
assert img.bounds == bounds_dst_crs
883883

884884
# Buffer
885885
img = reader.part(src_dst, bounds, bounds_crs="epsg:4326", buffer=1)
886886
assert img.height == 1455
887887
assert img.width == 1615
888-
assert img.mask[0, 0] == 255
888+
assert img.mask[0, 0] == 65535
889889
assert img.mask[-1, -1] == 0 # boundless
890890
assert not img.bounds == bounds_dst_crs
891891

@@ -894,7 +894,7 @@ def test_part_no_VRT():
894894
img_pad = reader.part(src_dst, bounds, bounds_crs="epsg:4326", padding=1)
895895
assert img_pad.height == 1453
896896
assert img_pad.width == 1613
897-
assert img_pad.mask[0, 0] == 255
897+
assert img_pad.mask[0, 0] == 65535
898898
assert img_pad.mask[-1, -1] == 0 # boundless
899899
assert img_pad.bounds == bounds_dst_crs
900900
# Padding should not have any influence when not doing any rescaling/reprojection
@@ -1082,16 +1082,16 @@ def test_nodata_orverride():
10821082
assert prev.mask[0, 0] == 0
10831083

10841084
prev = reader.read(src_dst, max_size=100, nodata=2720)
1085-
assert prev.mask[0, 0] == 255
1085+
assert prev.mask[0, 0] == 65535
10861086
assert not numpy.all(prev.mask)
10871087

10881088

10891089
def test_tile_read_nodata_float():
10901090
"""Should work as expected when using NaN as nodata value."""
10911091
with rasterio.open(COG_NODATA_FLOAT_NAN) as src_dst:
10921092
prev = reader.read(src_dst, max_size=100)
1093-
assert prev.mask[0, 0] == 0
1094-
assert not numpy.all(prev.mask)
1093+
assert prev.mask[0, 0] == -3.4028235e38
1094+
assert not numpy.all(prev._mask)
10951095

10961096

10971097
def test_inverted_latitude_point():

0 commit comments

Comments
 (0)