Skip to content

Commit 4f44dd0

Browse files
committed
Created MC-WBDN Model
1 parent 8cdc813 commit 4f44dd0

26 files changed

+3601
-1784
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ validation/
1111
test/
1212
old_data/
1313
experiments
14+
remove_experiment.py
15+
resample.py
1416

backend/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_callbacks(config: Dict[str, Any], val_data: ImgSequence, model: Model) -
4343
checkpoint = ModelCheckpoint(f"checkpoints/{model.name}/{model.name}", save_best_only=False, monitor='val_loss', mode='min', save_weights_only=True)
4444
prediction_logger = PredictionCallback(val_data, model)
4545
learning_rate_scheduler = LearningRateScheduler(lr_scheduler)
46-
early_stopping = EarlyStopping(monitor="val_loss", min_delta=0.0001, patience=15, verbose=1, mode="min")
46+
early_stopping = EarlyStopping(monitor="val_loss", min_delta=0.0001, patience=10, verbose=1, mode="min")
4747
return [tensorboard, csv, checkpoint, prediction_logger, learning_rate_scheduler, early_stopping] if get_create_logs(config) else [learning_rate_scheduler, early_stopping]
4848

4949

backend/config.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,28 @@ def get_num_experiments(config: Dict[str, Any]) -> int:
154154
return config["experiments"]
155155

156156

157-
def get_water_threshold(config: Dict[str, Any]) -> float:
157+
def get_random_subsample(config: Dict[str, Any]) -> bool:
158158
"""
159-
Get the water threshold that patches must meet to avoid being discarded
159+
Get the setting for whether or not the data pipeline should sub-sample 512x512 patches
160160
:param config: A dictionary storing the project configuration; typically loaded from an external file
161-
:returns: The water threshold as a percentage that must be met by a patch to avoid being discarded
161+
:returns: Whether or not to randomly sub-sample patches
162162
"""
163-
return config["hyperparameters"]["water_threshold"]
163+
return config["hyperparameters"]["random_subsample"]
164164

165165

166-
def get_random_subsample(config: Dict[str, Any]) -> bool:
166+
def get_water_threshold(config: Dict[str, Any]) -> int:
167167
"""
168-
Get the setting for whether or not the data pipeline should sub-sample 512x512 patches
168+
Get the water threshold for waterbody transfer
169169
:param config: A dictionary storing the project configuration; typically loaded from an external file
170-
:returns: Whether or not to randomly sub-sample patches
170+
:returns: The content threshold (percent) to be applied to waterbody transfer
171171
"""
172-
return config["hyperparameters"]["random_subsample"]
172+
return config["hyperparameters"]["water_threshold"]
173+
174+
175+
def get_mixed_precision(config: Dict[str, Any]) -> bool:
176+
"""
177+
Return True if we want to use mixed precision to speed up trainig/inference at the cost of accuracy
178+
:param config: A dictionary storing the project configuration; typically loaded from an external file
179+
:returns: A boolean indicating whether or not we want to use mixed precision
180+
"""
181+
return config["use_mixed_precision"]

backend/data_loader.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
class DataLoader:
99
"""A class to save and load images from disk"""
10-
def __init__(self, timestamp: int = 1, overlapping_patches: bool = False, random_subsample: bool = False):
10+
def __init__(self, timestamp: int = 1, overlapping_patches: bool = False, random_subsample: bool = False, upscale_swir: bool = True):
1111
self.timestamp = timestamp
1212
self.folders = {1: "2018.04", 2: "2018.12", 3: "2019.02"}
1313
self.overlapping_patches = overlapping_patches
1414
self.random_subsample = random_subsample
15+
self.upscale_swir = upscale_swir
1516

1617
def get_rgb_features(self, tile_number: int, coords: Tuple[int, int] = (0, 0), preprocess_img: bool = True, tile_dir: str = "tiles") -> np.ndarray:
1718
"""
@@ -38,8 +39,10 @@ def get_swir_features(self, tile_number: int, coords: Tuple[int, int] = None, pr
3839
:return: The SWIR features of the matching patch,
3940
"""
4041
tile = self.read_image(f"data/{self.folders.get(self.timestamp, 1)}/{tile_dir}/swir/swir.{tile_number}.tif", preprocess_img=preprocess_img)
41-
tile = np.resize(cv2.resize(tile, (1024, 1024), interpolation = cv2.INTER_AREA), (1024, 1024, 1))
42-
return self.subsample_tile(tile, coords=coords) if coords is not None else tile
42+
tile = np.resize(cv2.resize(tile, (1024, 1024), interpolation = cv2.INTER_AREA), (1024, 1024, 1)) if self.upscale_swir else tile
43+
if coords is not None:
44+
return self.subsample_tile(tile, coords=coords) if self.upscale_swir else self.subsample_swir_tile(tile, coords=coords)
45+
return tile
4346

4447
def get_mask(self, tile_number: int, coords: Tuple[int, int] = None, preprocess_img: bool = True, tile_dir: str = "tiles") -> np.ndarray:
4548
"""
@@ -80,6 +83,14 @@ def subsample_tile(self, tile: np.ndarray, coords: Tuple[int, int] = (0, 0)) ->
8083
Take a 512X512 sub-patch from a 1024X12024 tile.
8184
"""
8285
return tile[coords[1]:coords[1]+512, coords[0]:coords[0]+512, :]
86+
87+
def subsample_swir_tile(self, tile: np.ndarray, coords: Tuple[int, int] = (0, 0)) -> np.ndarray:
88+
"""
89+
Take a 512X512 sub-patch from a 1024X12024 tile.
90+
"""
91+
y_coord = coords[1] // 2
92+
x_coord = coords[0] // 2
93+
return tile[y_coord:y_coord+256, x_coord:x_coord+256, :]
8394

8495
def get_patch_coords(self, patch_index: int = 0) -> Tuple[int, int]:
8596
"""Get the coordinates for a patch inside a tile from a given patch_index. If random_subsample is True, the coords will be selected randomly."""

backend/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def MIOU():
1111
m = tf.keras.metrics.MeanIoU(num_classes=2)
1212
def MIoU(y_true, y_pred):
1313
m.reset_states()
14-
y_true, y_pred = flatten(y_true), flatten(tf.where(y_pred >= 0.5, 1.1, 0.0))
14+
y_true, y_pred = y_true, tf.where(y_pred >= 0.5, 1.1, 0.0)
1515
_ = m.update_state(y_true, y_pred)
1616
return m.result()
1717
return MIoU

backend/pipeline.py

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gc
22
import os
33
import math
4+
import statistics
45
import json
56
import random
67
import shutil
@@ -14,15 +15,15 @@
1415
from tensorflow.keras.backend import clear_session
1516
from backend.utils import adjust_rgb
1617
from backend.metrics import MIOU
17-
from backend.config import get_timestamp, get_waterbody_transfer, get_random_subsample
18+
from backend.config import get_timestamp, get_waterbody_transfer, get_random_subsample, get_fusion_head, get_water_threshold
1819
from models.utils import evaluate_model
1920
from backend.data_loader import DataLoader
2021

2122

2223
class ImgSequence(KerasSequence):
23-
def __init__(self, timestamp: int, tiles: List[int], batch_size: int = 32, bands: Sequence[str] = None, is_train: bool = False, random_subsample: bool = False):
24+
def __init__(self, timestamp: int, tiles: List[int], batch_size: int = 32, bands: Sequence[str] = None, is_train: bool = False, random_subsample: bool = False, upscale_swir: bool = True):
2425
# Initialize Member Variables
25-
self.data_loader = DataLoader(timestamp, overlapping_patches=is_train, random_subsample=(random_subsample and is_train))
26+
self.data_loader = DataLoader(timestamp, overlapping_patches=is_train, random_subsample=(random_subsample and is_train), upscale_swir=upscale_swir)
2627
self.batch_size = batch_size
2728
self.bands = ["RGB"] if bands is None else bands
2829
self.indices = []
@@ -39,7 +40,7 @@ def __init__(self, timestamp: int, tiles: List[int], batch_size: int = 32, bands
3940
def __len__(self) -> int:
4041
return math.ceil(len(self.indices) / self.batch_size)
4142

42-
def __getitem__(self, idx):
43+
def __getitem__(self, idx, normalize_data=True):
4344
# Create Batch
4445
feature_batches = {"RGB": [], "NIR": [], "SWIR": [], "mask": []}
4546
batch = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
@@ -54,7 +55,10 @@ def __getitem__(self, idx):
5455

5556
# Add Features To Batch
5657
for key, val in features.items():
57-
feature_batches[key].append(DataLoader.normalize_channels(val.astype("float32")) if key != "mask" else val)
58+
if normalize_data:
59+
feature_batches[key].append(DataLoader.normalize_channels(val.astype("float32")) if key != "mask" else val)
60+
else:
61+
feature_batches[key].append(val)
5862

5963
# Return Batch
6064
return [np.array(feature_batches[band]).astype("float32") for band in ("RGB", "NIR", "SWIR") if len(feature_batches[band]) > 0], np.array(feature_batches["mask"]).astype("float32")
@@ -90,52 +94,63 @@ def predict_batch(self, model: Model, directory: str):
9094
os.mkdir(model_directory)
9195

9296
# Iterate Over All Patches In Batch
93-
MIoUs, MIoU = [], MIOU()
94-
for patch_index in self.indices:
97+
MIoUs, MIoU, i = [], MIOU(), 0
98+
for batch in range(len(self)):
99+
100+
# Get Batch
101+
features, masks = self.__getitem__(batch, normalize_data=False)
102+
normalized_features, _ = self.__getitem__(batch)
103+
rgb_features = features[0] if "RGB" in self.bands else None
104+
nir_features = features[1 if "RGB" in self.bands else 0] if "NIR" in self.bands else None
105+
swir_features = features[2] if "SWIR" in self.bands else None
95106

96-
# Load Features And Mask
97-
features = self._get_features(patch_index)
98-
mask = features["mask"]
99-
100107
# Get Prediction
101-
prediction = model.predict([np.array([DataLoader.normalize_channels(features[band].astype("float32"))]) for band in self.bands])
102-
MIoUs.append([patch_index, MIoU(mask.astype("float32"), prediction).numpy()])
103-
104-
# Plot Features
105-
i = 0
106-
_, axs = plt.subplots(1, len(self.bands) + 2)
107-
for band in self.bands:
108-
axs[i].imshow(adjust_rgb(features[band], gamma=0.5) if band == "RGB" else features[band])
109-
axs[i].set_title(band, fontsize=6)
110-
axs[i].axis("off")
108+
predictions = model.predict(normalized_features)
109+
110+
# Iterate Over Each Prediction In The Batch
111+
for p in range(predictions.shape[0]):
112+
113+
mask = masks[p, ...]
114+
prediction = predictions[p, ...]
115+
MIoUs.append([self.indices[i], MIoU(mask, prediction).numpy()])
116+
117+
# Plot Features
118+
col = 0
119+
_, axs = plt.subplots(1, len(self.bands) + 2)
120+
for band, feature in zip(["RGB", "NIR", "SWIR"], [rgb_features, nir_features, swir_features]):
121+
if feature is not None:
122+
axs[col].imshow(adjust_rgb(feature[p, ...], gamma=0.5) if feature.shape[-1] == 3 else feature[p, ...])
123+
axs[col].set_title(band, fontsize=6)
124+
axs[col].axis("off")
125+
col += 1
126+
127+
# Plot Ground Truth
128+
axs[col].imshow(mask)
129+
axs[col].set_title("Ground Truth", fontsize=6)
130+
axs[col].axis("off")
131+
col += 1
132+
133+
# Plot Prediction
134+
axs[col].imshow(np.where(prediction < 0.5, 0, 1))
135+
axs[col].set_title(f"Prediction ({MIoUs[-1][1]:.3f})", fontsize=6)
136+
axs[col].axis("off")
137+
col += 1
138+
139+
# Save Prediction To Disk
140+
plt.tight_layout()
141+
plt.savefig(f"{model_directory}/prediction.{self.indices[i]}.png", dpi=300, bbox_inches='tight')
142+
plt.cla()
143+
plt.close()
144+
145+
# Housekeeping
146+
gc.collect()
147+
clear_session()
111148
i += 1
112-
113-
# Plot Ground Truth
114-
axs[i].imshow(mask)
115-
axs[i].set_title("Ground Truth", fontsize=6)
116-
axs[i].axis("off")
117-
i += 1
118-
119-
# Plot Prediction
120-
axs[i].imshow(np.where(prediction < 0.5, 0, 1)[0])
121-
axs[i].set_title(f"Prediction ({MIoUs[-1][1]:.3f})", fontsize=6)
122-
axs[i].axis("off")
123-
i += 1
124-
125-
# Save Prediction To Disk
126-
plt.tight_layout()
127-
plt.savefig(f"{model_directory}/prediction.{patch_index}.png", dpi=300, bbox_inches='tight')
128-
plt.cla()
129-
plt.close()
130-
131-
# Housekeeping
132-
gc.collect()
133-
clear_session()
134149

135150
# Save MIoU For Each Patch
136-
summary = np.array(MIoUs)
137-
df = pandas.DataFrame(summary[:, 1:], columns=["MIoU"], index=summary[:, 0].astype("int32"))
138-
df.to_csv(f"{model_directory}/Evaluation.csv", index_label="patch")
151+
# summary = np.array(MIoUs)
152+
# df = pandas.DataFrame(summary[:, 1:], columns=["MIoU"], index=summary[:, 0].astype("int32"))
153+
# df.to_csv(f"{model_directory}/Evaluation.csv", index_label="patch")
139154

140155
# Evaluate Final Performance
141156
results = evaluate_model(model, self)
@@ -171,10 +186,14 @@ def _get_features(self, patch: int, subsample: bool = True) -> Dict[str, np.ndar
171186

172187

173188
class WaterbodyTransferImgSequence(ImgSequence):
189+
def __init__(self, timestamp: int, tiles: List[int], batch_size: int = 32, bands: Sequence[str] = None, is_train: bool = False, random_subsample: bool = False, upscale_swir: bool = True, water_threshold: int = 5):
190+
super().__init__(timestamp, tiles, batch_size, bands, is_train, random_subsample, upscale_swir)
191+
self.water_threshold = water_threshold
192+
174193
"""A data pipeline that returns tiles with transplanted waterbodies"""
175194
def _get_features(self, patch: int, subsample: bool = True) -> Dict[str, np.ndarray]:
176195
tile_index = patch // 100
177-
return self.data_loader.get_features(patch, self.bands, tile_dir="tiles" if tile_index <= 400 else "transplanted_tiles")
196+
return self.data_loader.get_features(patch, self.bands, tile_dir="tiles" if tile_index <= 400 else f"transplanted_tiles_{self.water_threshold}")
178197

179198

180199
def load_dataset(config) -> Tuple[ImgSequence, ImgSequence, ImgSequence]:
@@ -188,15 +207,20 @@ def load_dataset(config) -> Tuple[ImgSequence, ImgSequence, ImgSequence]:
188207
batch_size = config["hyperparameters"]["batch_size"]
189208

190209
# Read Batches From JSON File
191-
batch_filename = "batches/transplanted.json" if get_waterbody_transfer(config) else "batches/tiles.json"
210+
water_threshold = get_water_threshold(config)
211+
batch_filename = f"batches/transplanted_tiles_{water_threshold}.json" if get_waterbody_transfer(config) else "batches/tiles.json"
192212
with open(batch_filename) as f:
193213
batch_file = json.loads(f.read())
194214

195215
# Choose Type Of Data Pipeline Based On Project Config
196216
Constructor = WaterbodyTransferImgSequence if get_waterbody_transfer(config) else ImgSequence
197217

198218
# Create Train, Validation, And Test Data
199-
train_data = Constructor(get_timestamp(config), batch_file["train"], batch_size=batch_size, bands=bands, is_train=True, random_subsample=get_random_subsample(config))
200-
val_data = ImgSequence(get_timestamp(config), batch_file["validation"], batch_size=batch_size, bands=bands, is_train=False)
201-
test_data = ImgSequence(get_timestamp(config), batch_file["test"], batch_size=batch_size, bands=bands, is_train=False)
219+
upscale_swir = get_fusion_head(config) != "paper"
220+
if get_waterbody_transfer(config):
221+
train_data = WaterbodyTransferImgSequence(get_timestamp(config), batch_file["train"], batch_size=batch_size, bands=bands, is_train=True, random_subsample=get_random_subsample(config), upscale_swir=upscale_swir, water_threshold=water_threshold)
222+
else:
223+
train_data = ImgSequence(get_timestamp(config), batch_file["train"], batch_size=batch_size, bands=bands, is_train=True, random_subsample=get_random_subsample(config), upscale_swir=upscale_swir)
224+
val_data = ImgSequence(get_timestamp(config), batch_file["validation"], batch_size=12, bands=bands, is_train=False, upscale_swir=upscale_swir)
225+
test_data = ImgSequence(get_timestamp(config), batch_file["test"], batch_size=12, bands=bands, is_train=False, upscale_swir=upscale_swir)
202226
return train_data, val_data, test_data

0 commit comments

Comments
 (0)