33from tensorflow .keras .layers import GlobalAveragePooling2D
44from tensorflow .keras import Model
55from tensorflow .keras .regularizers import l2
6- from tensorflow .keras .models import load_model
76from tensorflow .keras .utils import get_file
7+ from keras import layers
88
99
1010URL = 'https://github.com/oarriaga/altamira-data/releases/download/v0.6/'
@@ -84,6 +84,112 @@ def build_xception(
8484 return model
8585
8686
87+ def build_minixception (input_shape , num_classes , l2_reg = 0.01 ):
88+ """Function for instantiating an Mini-Xception model.
89+
90+ # Arguments
91+ input_shape: List corresponding to the input shape
92+ of the model.
93+ num_classes: Integer.
94+ l2_reg. Float. L2 regularization used
95+ in the convolutional kernels.
96+
97+ # Returns
98+ Tensorflow-Keras model.
99+ """
100+
101+ regularization = l2 (l2_reg )
102+
103+ # base
104+ img_input = Input (input_shape )
105+ x = Conv2D (5 , (3 , 3 ), strides = (1 , 1 ), kernel_regularizer = regularization ,
106+ use_bias = False )(img_input )
107+ x = BatchNormalization ()(x )
108+ x = Activation ('relu' )(x )
109+ x = Conv2D (8 , (3 , 3 ), strides = (1 , 1 ), kernel_regularizer = regularization ,
110+ use_bias = False )(x )
111+ x = BatchNormalization ()(x )
112+ x = Activation ('relu' )(x )
113+
114+ # module 1
115+ residual = Conv2D (16 , (1 , 1 ), strides = (2 , 2 ),
116+ padding = 'same' , use_bias = False )(x )
117+ residual = BatchNormalization ()(residual )
118+
119+ x = SeparableConv2D (16 , (3 , 3 ), padding = 'same' ,
120+ depthwise_regularizer = regularization ,
121+ use_bias = False )(x )
122+ x = BatchNormalization ()(x )
123+ x = Activation ('relu' )(x )
124+ x = SeparableConv2D (16 , (3 , 3 ), padding = 'same' ,
125+ depthwise_regularizer = regularization ,
126+ use_bias = False )(x )
127+ x = BatchNormalization ()(x )
128+
129+ x = MaxPooling2D ((3 , 3 ), strides = (2 , 2 ), padding = 'same' )(x )
130+ x = layers .add ([x , residual ])
131+
132+ # module 2
133+ residual = Conv2D (32 , (1 , 1 ), strides = (2 , 2 ),
134+ padding = 'same' , use_bias = False )(x )
135+ residual = BatchNormalization ()(residual )
136+
137+ x = SeparableConv2D (32 , (3 , 3 ), padding = 'same' ,
138+ depthwise_regularizer = regularization ,
139+ use_bias = False )(x )
140+ x = BatchNormalization ()(x )
141+ x = Activation ('relu' )(x )
142+ x = SeparableConv2D (32 , (3 , 3 ), padding = 'same' ,
143+ depthwise_regularizer = regularization ,
144+ use_bias = False )(x )
145+ x = BatchNormalization ()(x )
146+
147+ x = MaxPooling2D ((3 , 3 ), strides = (2 , 2 ), padding = 'same' )(x )
148+ x = layers .add ([x , residual ])
149+
150+ # module 3
151+ residual = Conv2D (64 , (1 , 1 ), strides = (2 , 2 ),
152+ padding = 'same' , use_bias = False )(x )
153+ residual = BatchNormalization ()(residual )
154+
155+ x = SeparableConv2D (64 , (3 , 3 ), padding = 'same' ,
156+ depthwise_regularizer = regularization ,
157+ use_bias = False )(x )
158+ x = BatchNormalization ()(x )
159+ x = Activation ('relu' )(x )
160+ x = SeparableConv2D (64 , (3 , 3 ), padding = 'same' ,
161+ depthwise_regularizer = regularization ,
162+ use_bias = False )(x )
163+ x = BatchNormalization ()(x )
164+
165+ x = MaxPooling2D ((3 , 3 ), strides = (2 , 2 ), padding = 'same' )(x )
166+ x = layers .add ([x , residual ])
167+
168+ # module 4
169+ residual = Conv2D (128 , (1 , 1 ), strides = (1 , 1 ),
170+ padding = 'same' , use_bias = False )(x )
171+ residual = BatchNormalization ()(residual )
172+
173+ x = SeparableConv2D (128 , (3 , 3 ), padding = 'same' ,
174+ depthwise_regularizer = regularization ,
175+ use_bias = False )(x )
176+ x = BatchNormalization ()(x )
177+ x = Activation ('relu' )(x )
178+ x = SeparableConv2D (128 , (3 , 3 ), padding = 'same' ,
179+ depthwise_regularizer = regularization ,
180+ use_bias = False )(x )
181+ x = BatchNormalization ()(x )
182+
183+ x = layers .add ([x , residual ])
184+
185+ x = Conv2D (num_classes , (3 , 3 ), padding = 'same' )(x )
186+ x = GlobalAveragePooling2D ()(x )
187+ output = Activation ('softmax' , name = 'predictions' )(x )
188+
189+ model = Model (img_input , output )
190+ return model
191+
192+
87193def MiniXception (input_shape , num_classes , weights = None ):
88194 """Build MiniXception (see references).
89195
@@ -101,9 +207,10 @@ def MiniXception(input_shape, num_classes, weights=None):
101207 Gender Classification](https://arxiv.org/abs/1710.07557)
102208 """
103209 if weights == 'FER' :
104- filename = 'fer2013_mini_XCEPTION.119-0.65. hdf5'
210+ filename = 'fer2013_mini_XCEPTION.hdf5'
105211 path = get_file (filename , URL + filename , cache_subdir = 'paz/models' )
106- model = load_model (path )
212+ model = build_minixception (input_shape , num_classes )
213+ model .load_weights (path )
107214 else :
108215 stem_kernels = [32 , 64 ]
109216 block_data = [128 , 128 , 256 , 256 , 512 , 512 , 1024 ]
0 commit comments