-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathunet.py
More file actions
137 lines (122 loc) · 4.9 KB
/
unet.py
File metadata and controls
137 lines (122 loc) · 4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.models import Model, load_model, save_model
from tensorflow.keras.layers import (
Input,
Activation,
BatchNormalization,
Dropout,
Lambda,
Conv2D,
Conv2DTranspose,
MaxPooling2D,
concatenate,
)
from tensorflow.keras import backend as K
def unet(input_size=(256, 256, 3)):
"""
This function creates and returns a U-Net model. U-Net is a type of convolutional neural network
designed for fast and precise segmentation of images. It consists of a contracting (downsampling)
path and an expansive (upsampling) path, which gives it a U-shaped architecture.
Parameters:
-----------
input_size : tuple of int
The size of the input images. It is a 3-tuple for (height, width, channels).
Default is (256, 256, 3).
Returns:
--------
model : keras.models.Model
The constructed U-Net model.
"""
inputs = Input(input_size)
# First DownConvolution / Encoder Leg will begin, so start with Conv2D
conv1 = Conv2D(filters=64, kernel_size=(3, 3), padding="same")(inputs)
bn1 = Activation("relu")(conv1)
conv1 = Conv2D(filters=64, kernel_size=(3, 3), padding="same")(bn1)
bn1 = BatchNormalization(axis=3)(conv1)
bn1 = Activation("relu")(bn1)
pool1 = MaxPooling2D(pool_size=(2, 2))(bn1)
conv2 = Conv2D(filters=128, kernel_size=(3, 3), padding="same")(pool1)
bn2 = Activation("relu")(conv2)
conv2 = Conv2D(filters=128, kernel_size=(3, 3), padding="same")(bn2)
bn2 = BatchNormalization(axis=3)(conv2)
bn2 = Activation("relu")(bn2)
pool2 = MaxPooling2D(pool_size=(2, 2))(bn2)
conv3 = Conv2D(filters=256, kernel_size=(3, 3), padding="same")(pool2)
bn3 = Activation("relu")(conv3)
conv3 = Conv2D(filters=256, kernel_size=(3, 3), padding="same")(bn3)
bn3 = BatchNormalization(axis=3)(conv3)
bn3 = Activation("relu")(bn3)
pool3 = MaxPooling2D(pool_size=(2, 2))(bn3)
conv4 = Conv2D(filters=512, kernel_size=(3, 3), padding="same")(pool3)
bn4 = Activation("relu")(conv4)
conv4 = Conv2D(filters=512, kernel_size=(3, 3), padding="same")(bn4)
bn4 = BatchNormalization(axis=3)(conv4)
bn4 = Activation("relu")(bn4)
pool4 = MaxPooling2D(pool_size=(2, 2))(bn4)
conv5 = Conv2D(filters=1024, kernel_size=(3, 3), padding="same")(pool4)
bn5 = Activation("relu")(conv5)
conv5 = Conv2D(filters=1024, kernel_size=(3, 3), padding="same")(bn5)
bn5 = BatchNormalization(axis=3)(conv5)
bn5 = Activation("relu")(bn5)
""" Now UpConvolution / Decoder Leg will begin, so start with Conv2DTranspose
The gray arrows (in the above image) indicate the skip connections that concatenate the encoder feature map with the decoder, which helps the backward flow of gradients for improved training. """
up6 = concatenate(
[
Conv2DTranspose(512, kernel_size=(2, 2), strides=(2, 2), padding="same")(
bn5
),
conv4,
],
axis=3,
)
""" After every concatenation we again apply two consecutive regular convolutions so that the model can learn to assemble a more precise output """
conv6 = Conv2D(filters=512, kernel_size=(3, 3), padding="same")(up6)
bn6 = Activation("relu")(conv6)
conv6 = Conv2D(filters=512, kernel_size=(3, 3), padding="same")(bn6)
bn6 = BatchNormalization(axis=3)(conv6)
bn6 = Activation("relu")(bn6)
up7 = concatenate(
[
Conv2DTranspose(256, kernel_size=(2, 2), strides=(2, 2), padding="same")(
bn6
),
conv3,
],
axis=3,
)
conv7 = Conv2D(filters=256, kernel_size=(3, 3), padding="same")(up7)
bn7 = Activation("relu")(conv7)
conv7 = Conv2D(filters=256, kernel_size=(3, 3), padding="same")(bn7)
bn7 = BatchNormalization(axis=3)(conv7)
bn7 = Activation("relu")(bn7)
up8 = concatenate(
[
Conv2DTranspose(128, kernel_size=(2, 2), strides=(2, 2), padding="same")(
bn7
),
conv2,
],
axis=3,
)
conv8 = Conv2D(filters=128, kernel_size=(3, 3), padding="same")(up8)
bn8 = Activation("relu")(conv8)
conv8 = Conv2D(filters=128, kernel_size=(3, 3), padding="same")(bn8)
bn8 = BatchNormalization(axis=3)(conv8)
bn8 = Activation("relu")(bn8)
up9 = concatenate(
[
Conv2DTranspose(64, kernel_size=(2, 2), strides=(2, 2), padding="same")(
bn8
),
conv1,
],
axis=3,
)
conv9 = Conv2D(filters=64, kernel_size=(3, 3), padding="same")(up9)
bn9 = Activation("relu")(conv9)
conv9 = Conv2D(filters=64, kernel_size=(3, 3), padding="same")(bn9)
bn9 = BatchNormalization(axis=3)(conv9)
bn9 = Activation("relu")(bn9)
conv10 = Conv2D(filters=1, kernel_size=(1, 1), activation="sigmoid")(bn9)
return Model(inputs=[inputs], outputs=[conv10])