-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
124 lines (101 loc) · 3.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import numpy as np
import cv2
from glob import glob
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Recall, Precision
from model import build_unet
from metrics import dice_loss, dice_coef, iou
""" Global parameters """
H = 512
W = 512
def create_dir(path):
""" Create a directory. """
if not os.path.exists(path):
os.makedirs(path)
def load_data(path, split=0.1):
images = sorted(glob(os.path.join(path, "CXR_png", "*.png")))
masks1 = sorted(glob(os.path.join(path, "ManualMask", "leftMask", "*.png")))
masks2 = sorted(glob(os.path.join(path, "ManualMask", "rightMask", "*.png")))
split_size = int(len(images) * split)
train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
train_y1, valid_y1 = train_test_split(masks1, test_size=split_size, random_state=42)
train_y2, valid_y2 = train_test_split(masks2, test_size=split_size, random_state=42)
train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)
train_y1, test_y1 = train_test_split(train_y1, test_size=split_size, random_state=42)
train_y2, test_y2 = train_test_split(train_y2, test_size=split_size, random_state=42)
return (train_x, train_y1, train_y2), (valid_x, valid_y1, valid_y2), (test_x, test_y1, test_y2)
def read_image(path):
x = cv2.imread(path, cv2.IMREAD_COLOR)
x = cv2.resize(x, (W, H))
x = x/255.0
x = x.astype(np.float32)
return x
def read_mask(path1, path2):
x1 = cv2.imread(path1, cv2.IMREAD_GRAYSCALE)
x2 = cv2.imread(path2, cv2.IMREAD_GRAYSCALE)
x = x1 + x2
x = cv2.resize(x, (W, H))
x = x/np.max(x)
x = x > 0.5
x = x.astype(np.float32)
x = np.expand_dims(x, axis=-1)
return x
def tf_parse(x, y1, y2):
def _parse(x, y1, y2):
x = x.decode()
y1 = y1.decode()
y2 = y2.decode()
x = read_image(x)
y = read_mask(y1, y2)
return x, y
x, y = tf.numpy_function(_parse, [x, y1, y2], [tf.float32, tf.float32])
x.set_shape([H, W, 3])
y.set_shape([H, W, 1])
return x, y
def tf_dataset(X, Y1, Y2, batch=8):
dataset = tf.data.Dataset.from_tensor_slices((X, Y1, Y2))
dataset = dataset.shuffle(buffer_size=200)
dataset = dataset.map(tf_parse)
dataset = dataset.batch(batch)
dataset = dataset.prefetch(4)
return dataset
if __name__ == "__main__":
""" Seeding """
np.random.seed(42)
tf.random.set_seed(42)
""" Directory for storing files """
create_dir("files")
""" Hyperparameters """
batch_size = 2
lr = 1e-5
num_epochs = 10
model_path = os.path.join("files", "model.h5")
csv_path = os.path.join("files", "data.csv")
""" Dataset """
dataset_path = "C:/Users/Ashwath/Desktop/Capstone/Unet lung segmentation/MontgomerySet"
(train_x, train_y1, train_y2), (valid_x, valid_y1, valid_y2), (test_x, test_y1, test_y2) = load_data(dataset_path)
print(f"Train: {len(train_x)} - {len(train_y1)} - {len(train_y2)}")
print(f"Valid: {len(valid_x)} - {len(valid_y1)} - {len(valid_y2)}")
print(f"Test: {len(test_x)} - {len(test_y1)} - {len(test_y2)}")
train_dataset = tf_dataset(train_x, train_y1, train_y2, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y1, valid_y2, batch=batch_size)
""" Model """
model = build_unet((H, W, 3))
metrics = [dice_coef, iou, Recall(), Precision()]
model.compile(loss=dice_loss, optimizer=Adam(lr), metrics=metrics)
callbacks = [
ModelCheckpoint(model_path, verbose=1, save_best_only=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7, verbose=1),
CSVLogger(csv_path)
]
model.fit(
train_dataset,
epochs=num_epochs,
validation_data=valid_dataset,
callbacks=callbacks
)