Skip to content

Commit ba45e7b

Browse files
committed
Samplers: fix type hints
1 parent 9786db4 commit ba45e7b

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

tests/samplers/test_batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]:
3030

3131

3232
class CustomGeoDataset(GeoDataset):
33-
def __init__(self, geometry: Sequence[Geometry], res: float = 10) -> None:
33+
def __init__(
34+
self, geometry: Sequence[Geometry], res: tuple[float, float] = (10, 10)
35+
) -> None:
3436
intervals = [(MINT, MAXT)] * len(geometry)
3537
index = pd.IntervalIndex.from_tuples(intervals, closed='both', name='datetime')
3638
crs = CRS.from_epsg(3005)
@@ -121,7 +123,7 @@ def test_roi(self, dataset: CustomGeoDataset) -> None:
121123

122124
def test_small_area(self) -> None:
123125
geometry = [shapely.box(0, 0, 10, 10), shapely.box(20, 20, 21, 21)]
124-
ds = CustomGeoDataset(geometry, res=1)
126+
ds = CustomGeoDataset(geometry, res=(1, 1))
125127
sampler = RandomBatchGeoSampler(ds, 2, 2, 10)
126128
for _ in sampler:
127129
continue

tests/samplers/test_single.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def __iter__(self) -> Iterator[BoundingBox]:
3737

3838

3939
class CustomGeoDataset(GeoDataset):
40-
def __init__(self, geometry: Sequence[Geometry], res: float = 10) -> None:
40+
def __init__(
41+
self, geometry: Sequence[Geometry], res: tuple[float, float] = (10, 10)
42+
) -> None:
4143
intervals = [(MINT, MAXT)] * len(geometry)
4244
index = pd.IntervalIndex.from_tuples(intervals, closed='both', name='datetime')
4345
crs = CRS.from_epsg(3005)
@@ -114,7 +116,7 @@ def test_roi(self, dataset: CustomGeoDataset) -> None:
114116

115117
def test_small_area(self) -> None:
116118
geometry = [shapely.box(0, 0, 10, 10), shapely.box(20, 20, 21, 21)]
117-
ds = CustomGeoDataset(geometry, res=1)
119+
ds = CustomGeoDataset(geometry, res=(1, 1))
118120
sampler = RandomGeoSampler(ds, 2, 10)
119121
for _ in sampler:
120122
continue

tests/samplers/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the MIT License.
33

44
import math
5+
from datetime import datetime
56

67
import pytest
78

@@ -11,6 +12,9 @@
1112

1213
MAYBE_TUPLE = float | tuple[float, float]
1314

15+
MINT = datetime(2025, 4, 24)
16+
MAXT = datetime(2025, 4, 25)
17+
1418

1519
@pytest.mark.parametrize(
1620
'size,stride,expected',
@@ -34,7 +38,7 @@
3438
def test_tile_to_chips(
3539
size: MAYBE_TUPLE, stride: MAYBE_TUPLE | None, expected: MAYBE_TUPLE
3640
) -> None:
37-
bounds = BoundingBox(0, 10, 20, 30, 40, 50)
41+
bounds = BoundingBox(0, 10, 20, 30, MINT, MAXT)
3842
size = _to_tuple(size)
3943
if stride is not None:
4044
stride = _to_tuple(stride)

0 commit comments

Comments
 (0)