Skip to content

Commit f37a0a0

Browse files
committed
add dropout to transfer learning architecture plugin
1 parent 559b844 commit f37a0a0

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/ensemble-transfer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def create_model(model_settings, model_parameters, io=sys.stdout):
178178
ninput_tics = model.input.shape[1]
179179
model.trainable = trainable[i]
180180
truncated_model = tf.keras.Model(inputs=model.inputs,
181-
outputs=[model.output[0][1:splice_layers[i]],
181+
outputs=[model.output[0][:splice_layers[i]],
182182
model.output[0][splice_layers[i]]])
183183
models.append(truncated_model)
184184
audio_tic_rates.append(audio_tic_rate)
@@ -195,11 +195,10 @@ def create_model(model_settings, model_parameters, io=sys.stdout):
195195
x = x(inputs)
196196
x = m(x)
197197
hidden_layers.extend(x[0])
198-
lowerlegs.append(tf.keras.Model(inputs=inputs, outputs=x[1]))
198+
lowerlegs.append(x[1])
199199

200200
upperlegs = []
201-
for leg in lowerlegs:
202-
x = leg.output
201+
for x in lowerlegs:
203202
for (t,f,m) in conv_layers:
204203
x = ReLU()(x)
205204
x = Conv2D(m, (t,f))(x)
@@ -208,6 +207,9 @@ def create_model(model_settings, model_parameters, io=sys.stdout):
208207

209208
x = Concatenate()([Reshape((1,-1))(x) for x in upperlegs])
210209

210+
if dropout>0:
211+
x = Dropout(dropout)(x)
212+
211213
for idense, nunits in enumerate(dense_layers+[model_settings['nlabels']]):
212214
if idense>0:
213215
x = ReLU()(x)

0 commit comments

Comments
 (0)