|
10 | 10 | from math import floor, isclose
|
11 | 11 | from typing import cast
|
12 | 12 |
|
| 13 | +import shapely |
13 | 14 | from rtree.index import Index, Property
|
| 15 | +from shapely import LineString |
14 | 16 | from torch import Generator, default_generator, randint, randperm
|
15 | 17 |
|
16 | 18 | from ..datasets import GeoDataset
|
@@ -110,43 +112,63 @@ def random_bbox_splitting(
|
110 | 112 | if any(n <= 0 for n in fractions):
|
111 | 113 | raise ValueError('All items in input fractions must be greater than 0.')
|
112 | 114 |
|
113 |
| - new_indexes = [ |
114 |
| - Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions |
115 |
| - ] |
| 115 | + new_datasets = [deepcopy(dataset) for _ in fractions] |
116 | 116 |
|
117 |
| - for i, hit in enumerate( |
118 |
| - dataset.index.intersection(dataset.index.bounds, objects=True) |
119 |
| - ): |
120 |
| - box = BoundingBox(*hit.bounds) |
121 |
| - fraction_left = 1.0 |
| 117 | + for i in range(len(dataset)): |
| 118 | + geometry_remaining = dataset.index.geometry.iloc[i] |
| 119 | + fraction_remaining = 1.0 |
122 | 120 |
|
123 | 121 | # Randomly choose the split direction
|
124 | 122 | horizontal, flip = randint(0, 2, (2,), generator=generator)
|
125 | 123 | for j, fraction in enumerate(fractions):
|
126 |
| - if fraction_left == fraction: |
| 124 | + if isclose(fraction_remaining, fraction): |
127 | 125 | # For the last fraction, no need to split again
|
128 |
| - new_box = box |
129 |
| - elif flip: |
130 |
| - # new_box corresponds to fraction, box is the remainder that we might |
131 |
| - # split again in the next iteration. Each split is done according to |
132 |
| - # fraction wrt what's left |
133 |
| - box, new_box = box.split( |
134 |
| - (fraction_left - fraction) / fraction_left, horizontal |
135 |
| - ) |
| 126 | + new_geometry = geometry_remaining |
136 | 127 | else:
|
137 |
| - # Same as above, but without flipping |
138 |
| - new_box, box = box.split(fraction / fraction_left, horizontal) |
139 |
| - |
140 |
| - new_indexes[j].insert(i, tuple(new_box), hit.object) |
141 |
| - fraction_left -= fraction |
| 128 | + # Create a new_geometry from geometry_remaining |
| 129 | + minx, miny, maxx, maxy = geometry_remaining.bounds |
| 130 | + |
| 131 | + if flip: |
| 132 | + frac = fraction_remaining - fraction |
| 133 | + else: |
| 134 | + frac = fraction |
| 135 | + |
| 136 | + if horizontal: |
| 137 | + splity = miny + (maxy - miny) * frac / fraction_remaining |
| 138 | + line = LineString([(minx, splity), (maxx, splity)]) |
| 139 | + else: |
| 140 | + splitx = minx + (maxx - minx) * frac / fraction_remaining |
| 141 | + line = LineString([(splitx, miny), (splitx, maxy)]) |
| 142 | + |
| 143 | + geom1, geom2 = shapely.ops.split(geometry_remaining, line).geoms |
| 144 | + if horizontal: |
| 145 | + if flip: |
| 146 | + if geom1.centroid.y < splity: |
| 147 | + geometry_remaining, new_geometry = geom1, geom2 |
| 148 | + else: |
| 149 | + new_geometry, geometry_remaining = geom1, geom2 |
| 150 | + else: |
| 151 | + if geom1.centroid.y < splity: |
| 152 | + new_geometry, geometry_remaining = geom1, geom2 |
| 153 | + else: |
| 154 | + geometry_remaining, new_geometry = geom1, geom2 |
| 155 | + else: |
| 156 | + if flip: |
| 157 | + if geom1.centroid.x < splitx: |
| 158 | + geometry_remaining, new_geometry = geom1, geom2 |
| 159 | + else: |
| 160 | + new_geometry, geometry_remaining = geom1, geom2 |
| 161 | + else: |
| 162 | + if geom1.centroid.x < splitx: |
| 163 | + new_geometry, geometry_remaining = geom1, geom2 |
| 164 | + else: |
| 165 | + geometry_remaining, new_geometry = geom1, geom2 |
| 166 | + |
| 167 | + new_datasets[j].index.iloc[i].geometry = new_geometry |
| 168 | + |
| 169 | + fraction_remaining -= fraction |
142 | 170 | horizontal = not horizontal
|
143 | 171 |
|
144 |
| - new_datasets = [] |
145 |
| - for index in new_indexes: |
146 |
| - ds = deepcopy(dataset) |
147 |
| - ds.index = index |
148 |
| - new_datasets.append(ds) |
149 |
| - |
150 | 172 | return new_datasets
|
151 | 173 |
|
152 | 174 |
|
|
0 commit comments