11import gc
22import os
33import math
4+ import statistics
45import json
56import random
67import shutil
1415from tensorflow .keras .backend import clear_session
1516from backend .utils import adjust_rgb
1617from backend .metrics import MIOU
17- from backend .config import get_timestamp , get_waterbody_transfer , get_random_subsample
18+ from backend .config import get_timestamp , get_waterbody_transfer , get_random_subsample , get_fusion_head , get_water_threshold
1819from models .utils import evaluate_model
1920from backend .data_loader import DataLoader
2021
2122
2223class ImgSequence (KerasSequence ):
23- def __init__ (self , timestamp : int , tiles : List [int ], batch_size : int = 32 , bands : Sequence [str ] = None , is_train : bool = False , random_subsample : bool = False ):
24+ def __init__ (self , timestamp : int , tiles : List [int ], batch_size : int = 32 , bands : Sequence [str ] = None , is_train : bool = False , random_subsample : bool = False , upscale_swir : bool = True ):
2425 # Initialize Member Variables
25- self .data_loader = DataLoader (timestamp , overlapping_patches = is_train , random_subsample = (random_subsample and is_train ))
26+ self .data_loader = DataLoader (timestamp , overlapping_patches = is_train , random_subsample = (random_subsample and is_train ), upscale_swir = upscale_swir )
2627 self .batch_size = batch_size
2728 self .bands = ["RGB" ] if bands is None else bands
2829 self .indices = []
@@ -39,7 +40,7 @@ def __init__(self, timestamp: int, tiles: List[int], batch_size: int = 32, bands
3940 def __len__ (self ) -> int :
4041 return math .ceil (len (self .indices ) / self .batch_size )
4142
42- def __getitem__ (self , idx ):
43+ def __getitem__ (self , idx , normalize_data = True ):
4344 # Create Batch
4445 feature_batches = {"RGB" : [], "NIR" : [], "SWIR" : [], "mask" : []}
4546 batch = self .indices [idx * self .batch_size :(idx + 1 )* self .batch_size ]
@@ -54,7 +55,10 @@ def __getitem__(self, idx):
5455
5556 # Add Features To Batch
5657 for key , val in features .items ():
57- feature_batches [key ].append (DataLoader .normalize_channels (val .astype ("float32" )) if key != "mask" else val )
58+ if normalize_data :
59+ feature_batches [key ].append (DataLoader .normalize_channels (val .astype ("float32" )) if key != "mask" else val )
60+ else :
61+ feature_batches [key ].append (val )
5862
5963 # Return Batch
6064 return [np .array (feature_batches [band ]).astype ("float32" ) for band in ("RGB" , "NIR" , "SWIR" ) if len (feature_batches [band ]) > 0 ], np .array (feature_batches ["mask" ]).astype ("float32" )
@@ -90,52 +94,63 @@ def predict_batch(self, model: Model, directory: str):
9094 os .mkdir (model_directory )
9195
9296 # Iterate Over All Patches In Batch
93- MIoUs , MIoU = [], MIOU ()
94- for patch_index in self .indices :
97+ MIoUs , MIoU , i = [], MIOU (), 0
98+ for batch in range (len (self )):
99+
100+ # Get Batch
101+ features , masks = self .__getitem__ (batch , normalize_data = False )
102+ normalized_features , _ = self .__getitem__ (batch )
103+ rgb_features = features [0 ] if "RGB" in self .bands else None
104+ nir_features = features [1 if "RGB" in self .bands else 0 ] if "NIR" in self .bands else None
105+ swir_features = features [2 ] if "SWIR" in self .bands else None
95106
96- # Load Features And Mask
97- features = self ._get_features (patch_index )
98- mask = features ["mask" ]
99-
100107 # Get Prediction
101- prediction = model .predict ([np .array ([DataLoader .normalize_channels (features [band ].astype ("float32" ))]) for band in self .bands ])
102- MIoUs .append ([patch_index , MIoU (mask .astype ("float32" ), prediction ).numpy ()])
103-
104- # Plot Features
105- i = 0
106- _ , axs = plt .subplots (1 , len (self .bands ) + 2 )
107- for band in self .bands :
108- axs [i ].imshow (adjust_rgb (features [band ], gamma = 0.5 ) if band == "RGB" else features [band ])
109- axs [i ].set_title (band , fontsize = 6 )
110- axs [i ].axis ("off" )
108+ predictions = model .predict (normalized_features )
109+
110+ # Iterate Over Each Prediction In The Batch
111+ for p in range (predictions .shape [0 ]):
112+
113+ mask = masks [p , ...]
114+ prediction = predictions [p , ...]
115+ MIoUs .append ([self .indices [i ], MIoU (mask , prediction ).numpy ()])
116+
117+ # Plot Features
118+ col = 0
119+ _ , axs = plt .subplots (1 , len (self .bands ) + 2 )
120+ for band , feature in zip (["RGB" , "NIR" , "SWIR" ], [rgb_features , nir_features , swir_features ]):
121+ if feature is not None :
122+ axs [col ].imshow (adjust_rgb (feature [p , ...], gamma = 0.5 ) if feature .shape [- 1 ] == 3 else feature [p , ...])
123+ axs [col ].set_title (band , fontsize = 6 )
124+ axs [col ].axis ("off" )
125+ col += 1
126+
127+ # Plot Ground Truth
128+ axs [col ].imshow (mask )
129+ axs [col ].set_title ("Ground Truth" , fontsize = 6 )
130+ axs [col ].axis ("off" )
131+ col += 1
132+
133+ # Plot Prediction
134+ axs [col ].imshow (np .where (prediction < 0.5 , 0 , 1 ))
135+ axs [col ].set_title (f"Prediction ({ MIoUs [- 1 ][1 ]:.3f} )" , fontsize = 6 )
136+ axs [col ].axis ("off" )
137+ col += 1
138+
139+ # Save Prediction To Disk
140+ plt .tight_layout ()
141+ plt .savefig (f"{ model_directory } /prediction.{ self .indices [i ]} .png" , dpi = 300 , bbox_inches = 'tight' )
142+ plt .cla ()
143+ plt .close ()
144+
145+ # Housekeeping
146+ gc .collect ()
147+ clear_session ()
111148 i += 1
112-
113- # Plot Ground Truth
114- axs [i ].imshow (mask )
115- axs [i ].set_title ("Ground Truth" , fontsize = 6 )
116- axs [i ].axis ("off" )
117- i += 1
118-
119- # Plot Prediction
120- axs [i ].imshow (np .where (prediction < 0.5 , 0 , 1 )[0 ])
121- axs [i ].set_title (f"Prediction ({ MIoUs [- 1 ][1 ]:.3f} )" , fontsize = 6 )
122- axs [i ].axis ("off" )
123- i += 1
124-
125- # Save Prediction To Disk
126- plt .tight_layout ()
127- plt .savefig (f"{ model_directory } /prediction.{ patch_index } .png" , dpi = 300 , bbox_inches = 'tight' )
128- plt .cla ()
129- plt .close ()
130-
131- # Housekeeping
132- gc .collect ()
133- clear_session ()
134149
135150 # Save MIoU For Each Patch
136- summary = np .array (MIoUs )
137- df = pandas .DataFrame (summary [:, 1 :], columns = ["MIoU" ], index = summary [:, 0 ].astype ("int32" ))
138- df .to_csv (f"{ model_directory } /Evaluation.csv" , index_label = "patch" )
151+ # summary = np.array(MIoUs)
152+ # df = pandas.DataFrame(summary[:, 1:], columns=["MIoU"], index=summary[:, 0].astype("int32"))
153+ # df.to_csv(f"{model_directory}/Evaluation.csv", index_label="patch")
139154
140155 # Evaluate Final Performance
141156 results = evaluate_model (model , self )
@@ -171,10 +186,14 @@ def _get_features(self, patch: int, subsample: bool = True) -> Dict[str, np.ndar
171186
172187
173188class WaterbodyTransferImgSequence (ImgSequence ):
189+ def __init__ (self , timestamp : int , tiles : List [int ], batch_size : int = 32 , bands : Sequence [str ] = None , is_train : bool = False , random_subsample : bool = False , upscale_swir : bool = True , water_threshold : int = 5 ):
190+ super ().__init__ (timestamp , tiles , batch_size , bands , is_train , random_subsample , upscale_swir )
191+ self .water_threshold = water_threshold
192+
174193 """A data pipeline that returns tiles with transplanted waterbodies"""
175194 def _get_features (self , patch : int , subsample : bool = True ) -> Dict [str , np .ndarray ]:
176195 tile_index = patch // 100
177- return self .data_loader .get_features (patch , self .bands , tile_dir = "tiles" if tile_index <= 400 else "transplanted_tiles " )
196+ return self .data_loader .get_features (patch , self .bands , tile_dir = "tiles" if tile_index <= 400 else f"transplanted_tiles_ { self . water_threshold } " )
178197
179198
180199def load_dataset (config ) -> Tuple [ImgSequence , ImgSequence , ImgSequence ]:
@@ -188,15 +207,20 @@ def load_dataset(config) -> Tuple[ImgSequence, ImgSequence, ImgSequence]:
188207 batch_size = config ["hyperparameters" ]["batch_size" ]
189208
190209 # Read Batches From JSON File
191- batch_filename = "batches/transplanted.json" if get_waterbody_transfer (config ) else "batches/tiles.json"
210+ water_threshold = get_water_threshold (config )
211+ batch_filename = f"batches/transplanted_tiles_{ water_threshold } .json" if get_waterbody_transfer (config ) else "batches/tiles.json"
192212 with open (batch_filename ) as f :
193213 batch_file = json .loads (f .read ())
194214
195215 # Choose Type Of Data Pipeline Based On Project Config
196216 Constructor = WaterbodyTransferImgSequence if get_waterbody_transfer (config ) else ImgSequence
197217
198218 # Create Train, Validation, And Test Data
199- train_data = Constructor (get_timestamp (config ), batch_file ["train" ], batch_size = batch_size , bands = bands , is_train = True , random_subsample = get_random_subsample (config ))
200- val_data = ImgSequence (get_timestamp (config ), batch_file ["validation" ], batch_size = batch_size , bands = bands , is_train = False )
201- test_data = ImgSequence (get_timestamp (config ), batch_file ["test" ], batch_size = batch_size , bands = bands , is_train = False )
219+ upscale_swir = get_fusion_head (config ) != "paper"
220+ if get_waterbody_transfer (config ):
221+ train_data = WaterbodyTransferImgSequence (get_timestamp (config ), batch_file ["train" ], batch_size = batch_size , bands = bands , is_train = True , random_subsample = get_random_subsample (config ), upscale_swir = upscale_swir , water_threshold = water_threshold )
222+ else :
223+ train_data = ImgSequence (get_timestamp (config ), batch_file ["train" ], batch_size = batch_size , bands = bands , is_train = True , random_subsample = get_random_subsample (config ), upscale_swir = upscale_swir )
224+ val_data = ImgSequence (get_timestamp (config ), batch_file ["validation" ], batch_size = 12 , bands = bands , is_train = False , upscale_swir = upscale_swir )
225+ test_data = ImgSequence (get_timestamp (config ), batch_file ["test" ], batch_size = 12 , bands = bands , is_train = False , upscale_swir = upscale_swir )
202226 return train_data , val_data , test_data
0 commit comments