2222class ImgSequence (KerasSequence ):
2323 def __init__ (self , timestamp : int , tiles : List [int ], batch_size : int = 32 , bands : Sequence [str ] = None , is_train : bool = False , random_subsample : bool = False ):
2424 # Initialize Member Variables
25- self .data_loader = DataLoader (timestamp , overlapping_patches = is_train , random_subsample = random_subsample )
25+ self .data_loader = DataLoader (timestamp , overlapping_patches = is_train , random_subsample = ( random_subsample and is_train ) )
2626 self .batch_size = batch_size
2727 self .bands = ["RGB" ] if bands is None else bands
2828 self .indices = []
@@ -170,101 +170,11 @@ def _get_features(self, patch: int, subsample: bool = True) -> Dict[str, np.ndar
170170 return self .data_loader .get_features (patch , self .bands , subsample = subsample )
171171
172172
173- class SubSampleImgSequence (ImgSequence ):
174- """A class to demonstrate the waterbody transfer method."""
175-
176- def __getitem__ (self , idx ):
177- # Create Batch
178- feature_batches = {"RGB" : [], "NIR" : [], "SWIR" : [], "mask" : []}
179- batch = self .indices [idx * self .batch_size :(idx + 1 )* self .batch_size ]
180- for b in batch :
181-
182- # Get Mask And Features For Patch b
183- features = self ._get_features (b )
184-
185- # Randomly Sample 256 x 256 Sub-Patch
186- if self .is_train :
187- features = self .generate_composite (features )
188-
189- # Add Features To Batch
190- for key , val in features .items ():
191- feature_batches [key ].append (DataLoader .normalize_channels (val .astype ("float32" )) if key != "mask" else val )
192-
193- # Return Batch
194- return [np .array (feature_batches [band ]).astype ("float32" ) for band in ("RGB" , "NIR" , "SWIR" ) if len (feature_batches [band ]) > 0 ], np .array (feature_batches ["mask" ]).astype ("float32" )
195-
196- def subsample_patch (self , patch : Dict [str , np .ndarray ], sample_random : bool = False ) -> Tuple [Dict [str , np .ndarray ]]:
197- top_left , top_right , bottom_left , bottom_right = dict (), dict (), dict (), dict ()
198- xs = [0 , 256 , 0 , 256 ] if not sample_random else [random .randint (0 , 256 ) for _ in range (4 )]
199- ys = [0 , 0 , 256 , 256 ] if not sample_random else [random .randint (0 , 256 ) for _ in range (4 )]
200- for band in patch .keys ():
201- top_left [band ] = patch [band ][xs [0 ]:xs [0 ]+ 256 , ys [0 ]:ys [0 ]+ 256 , :]
202- top_right [band ] = patch [band ][xs [1 ]:xs [1 ]+ 256 , ys [1 ]:ys [1 ]+ 256 , :]
203- bottom_left [band ] = patch [band ][xs [2 ]:xs [2 ]+ 256 , ys [2 ]:ys [2 ]+ 256 , :]
204- bottom_right [band ] = patch [band ][xs [3 ]:xs [3 ]+ 256 , ys [3 ]:ys [3 ]+ 256 , :]
205- return top_left , top_right , bottom_left , bottom_right
206-
207- def generate_composite (self , patch : Dict [str , np .ndarray ]) -> Dict [str , np .ndarray ]:
208- # Get Sub-Patches
209- top_left , top_right , bottom_left , bottom_right = self .subsample_patch (patch )
210-
211- # Apply Rotations To Quarters
212- self ._rotate_patch (top_left )
213- self ._rotate_patch (top_right )
214- self ._rotate_patch (bottom_left )
215- self ._rotate_patch (bottom_right )
216-
217- # Apply Flips To Quarters
218- self ._flip_patch (top_left )
219- self ._flip_patch (top_right )
220- self ._flip_patch (bottom_left )
221- self ._flip_patch (bottom_right )
222-
223- # Generate Composite
224- return self ._combine_quarters (top_left , top_right , bottom_left , bottom_right )
225-
226- def _combine_quarters (self , top_left : Dict [str , np .ndarray ], top_right : Dict [str , np .ndarray ], bottom_left : Dict [str , np .ndarray ], bottom_right : Dict [str , np .ndarray ]) -> Dict [str , np .ndarray ]:
227- # Generate Random Ordering Of Quarters For Composite Image
228- quarter_indices = [0 , 1 , 2 , 3 ]
229- random .shuffle (quarter_indices )
230-
231- # Assemble Composite Image
232- composite = dict ()
233- for band in top_left .keys ():
234- if band == "RGB" :
235- red_quarters = [np .reshape (quarter , (256 , 256 )) for quarter in [top_left [band ][..., 0 ], top_right [band ][..., 0 ], bottom_left [band ][..., 0 ], bottom_right [band ][..., 0 ]]]
236- green_quarters = [np .reshape (quarter , (256 , 256 )) for quarter in [top_left [band ][..., 1 ], top_right [band ][..., 1 ], bottom_left [band ][..., 1 ], bottom_right [band ][..., 1 ]]]
237- blue_quarters = [np .reshape (quarter , (256 , 256 )) for quarter in [top_left [band ][..., 2 ], top_right [band ][..., 2 ], bottom_left [band ][..., 2 ], bottom_right [band ][..., 2 ]]]
238-
239- red_composite = np .reshape (np .array (np .bmat ([[red_quarters [0 ], red_quarters [1 ]], [red_quarters [2 ], red_quarters [3 ]]])), (512 , 512 , 1 ))
240- green_composite = np .reshape (np .array (np .bmat ([[green_quarters [0 ], green_quarters [1 ]], [green_quarters [2 ], green_quarters [3 ]]])), (512 , 512 , 1 ))
241- blue_composite = np .reshape (np .array (np .bmat ([[blue_quarters [0 ], blue_quarters [1 ]], [blue_quarters [2 ], blue_quarters [3 ]]])), (512 , 512 , 1 ))
242-
243- composite [band ] = np .concatenate ((red_composite , green_composite , blue_composite ), axis = - 1 )
244- else :
245- quarters = [np .reshape (quarter , (256 , 256 )) for quarter in [top_left [band ], top_right [band ], bottom_left [band ], bottom_right [band ]]]
246- composite [band ] = np .reshape (np .array (np .bmat ([[quarters [0 ], quarters [1 ]], [quarters [2 ], quarters [3 ]]])), (512 , 512 , 1 ))
247-
248- return composite
249-
250- class TransferImgSequence (ImgSequence ):
251- """A class to demonstrate the waterbody transfer method."""
252- def __init__ (self , timestamp : int , tiles : List [int ], batch_size : int = 32 , bands : Sequence [str ] = None , is_train : bool = False , random_subsample : bool = False ):
253- # Initialize Member Variables
254- self .data_loader = DataLoader (timestamp , overlapping_patches = is_train , random_subsample = random_subsample )
255- self .batch_size = batch_size
256- self .bands = ["RGB" ] if bands is None else bands
257- self .indices = tiles
258- self .is_train = is_train
259-
260- # If We Want To Apply Waterbody Transferrence, Locate All Patches With At Least 10% Water
261- self .transfer_patches = []
262- if self .is_train :
263- for tile_index in self .indices :
264- mask = self .data_loader .get_mask (tile_index )
265- if 5.0 < self ._water_content (mask ):
266- print (tile_index , mask .shape , self ._water_content (mask ))
267- self .transfer_patches .append (tile_index )
173+ class WaterbodyTransferImgSequence (ImgSequence ):
174+ """A data pipeline that returns tiles with transplanted waterbodies"""
175+ def _get_features (self , patch : int , subsample : bool = True ) -> Dict [str , np .ndarray ]:
176+ tile_index = patch // 100
177+ return self .data_loader .get_features (patch , self .bands , tile_dir = "tiles" if tile_index <= 400 else "transplanted_tiles" )
268178
269179
270180def load_dataset (config ) -> Tuple [ImgSequence , ImgSequence , ImgSequence ]:
@@ -278,18 +188,15 @@ def load_dataset(config) -> Tuple[ImgSequence, ImgSequence, ImgSequence]:
278188 batch_size = config ["hyperparameters" ]["batch_size" ]
279189
280190 # Read Batches From JSON File
281- with open (f"batches/tiles.json" ) as f :
191+ batch_filename = "batches/transplanted.json" if get_waterbody_transfer (config ) else "batches/tiles.json"
192+ with open (batch_filename ) as f :
282193 batch_file = json .loads (f .read ())
283194
284195 # Choose Type Of Data Pipeline Based On Project Config
285- Constructor = ImgSequence
286- if get_waterbody_transfer (config ):
287- Constructor = TransferImgSequence
288- # elif get_random_subsample(config):
289- # Constructor = SubSampleImgSequence
196+ Constructor = WaterbodyTransferImgSequence if get_waterbody_transfer (config ) else ImgSequence
290197
291198 # Create Train, Validation, And Test Data
292199 train_data = Constructor (get_timestamp (config ), batch_file ["train" ], batch_size = batch_size , bands = bands , is_train = True , random_subsample = get_random_subsample (config ))
293- val_data = Constructor (get_timestamp (config ), batch_file ["validation" ], batch_size = batch_size , bands = bands , is_train = False )
294- test_data = Constructor (get_timestamp (config ), batch_file ["test" ], batch_size = batch_size , bands = bands , is_train = False )
200+ val_data = ImgSequence (get_timestamp (config ), batch_file ["validation" ], batch_size = batch_size , bands = bands , is_train = False )
201+ test_data = ImgSequence (get_timestamp (config ), batch_file ["test" ], batch_size = batch_size , bands = bands , is_train = False )
295202 return train_data , val_data , test_data
0 commit comments