Skip to content

Commit 05b234f

Browse files
committed
Update TFRecord example.
1 parent 90f76da commit 05b234f

File tree

1 file changed

+27
-28
lines changed

1 file changed

+27
-28
lines changed

examples/mnist_tfrecord.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,40 +33,33 @@
3333
tensors, save the model weights, and then evaluate the
3434
model using the numpy based Keras API.
3535
36-
Gets to 99.1% test accuracy after 78 epochs
37-
(there is still a lot of margin for parameter tuning).
36+
Gets to ~99.1% validation accuracy after 5 epochs
37+
(high variance from run to run: 98.9-99.3).
3838
'''
39-
import os
40-
import copy
41-
import time
42-
4339
import numpy as np
4440

4541
import tensorflow as tf
42+
import keras
4643
from keras import backend as K
47-
from keras.models import Model
4844
from keras import layers
49-
from keras import objectives
50-
from keras.utils import np_utils
51-
from keras import objectives
5245

5346
from tensorflow.contrib.learn.python.learn.datasets import mnist
5447

5548
if K.backend() != 'tensorflow':
5649
raise RuntimeError('This example can only run with the '
57-
'TensorFlow backend for the time being, '
50+
'TensorFlow backend, '
5851
'because it requires TFRecords, which '
5952
'are not supported on other platforms.')
6053

6154

6255
def cnn_layers(x_train_input):
6356
x = layers.Conv2D(32, (3, 3),
6457
activation='relu', padding='valid')(x_train_input)
58+
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
6559
x = layers.Conv2D(64, (3, 3), activation='relu')(x)
6660
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
67-
x = layers.Dropout(0.25)(x)
6861
x = layers.Flatten()(x)
69-
x = layers.Dense(128, activation='relu')(x)
62+
x = layers.Dense(512, activation='relu')(x)
7063
x = layers.Dropout(0.5)(x)
7164
x_train_out = layers.Dense(classes,
7265
activation='softmax',
@@ -78,7 +71,7 @@ def cnn_layers(x_train_input):
7871
batch_size = 128
7972
batch_shape = (batch_size, 28, 28, 1)
8073
steps_per_epoch = 469
81-
epochs = 78
74+
epochs = 5
8275
classes = 10
8376

8477
# The capacity variable controls the maximum queue size
@@ -120,39 +113,45 @@ def cnn_layers(x_train_input):
120113

121114
x_train_input = layers.Input(tensor=x_train_batch, batch_shape=x_batch_shape)
122115
x_train_out = cnn_layers(x_train_input)
123-
train_model = Model(inputs=x_train_input, outputs=x_train_out)
124-
125-
cce = objectives.categorical_crossentropy(y_train_batch, x_train_out)
126-
train_model.add_loss(cce)
127-
128-
# Do not pass the loss directly to model.compile()
129-
# because it is not yet supported for Input Tensors.
130-
train_model.compile(optimizer='rmsprop',
131-
loss=None,
132-
metrics=['accuracy'])
116+
train_model = keras.models.Model(inputs=x_train_input, outputs=x_train_out)
117+
118+
# Pass the target tensor `y_train_batch` to `compile`
119+
# via the `target_tensors` keyword argument:
120+
train_model.compile(optimizer=keras.optimizers.RMSprop(lr=2e-3, decay=1e-5),
121+
loss='categorical_crossentropy',
122+
metrics=['accuracy'],
123+
target_tensors=[y_train_batch])
133124
train_model.summary()
134125

126+
# Fit the model using data from the TFRecord data tensors.
135127
coord = tf.train.Coordinator()
136128
threads = tf.train.start_queue_runners(sess, coord)
129+
137130
train_model.fit(epochs=epochs,
138131
steps_per_epoch=steps_per_epoch)
139132

133+
# Save the model weights.
140134
train_model.save_weights('saved_wt.h5')
141135

136+
# Clean up the TF session.
142137
coord.request_stop()
143138
coord.join(threads)
144139
K.clear_session()
145140

146141
# Second Session to test loading trained model without tensors
147142
x_test = np.reshape(data.validation.images, (data.validation.images.shape[0], 28, 28, 1))
148143
y_test = data.validation.labels
149-
x_test_inp = layers.Input(batch_shape=(None,) + (x_test.shape[1:]))
144+
x_test_inp = layers.Input(shape=(x_test.shape[1:]))
150145
test_out = cnn_layers(x_test_inp)
151-
test_model = Model(inputs=x_test_inp, outputs=test_out)
146+
test_model = keras.models.Model(inputs=x_test_inp, outputs=test_out)
152147

153148
test_model.load_weights('saved_wt.h5')
154-
test_model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
149+
test_model.compile(optimizer='rmsprop',
150+
loss='categorical_crossentropy',
151+
metrics=['accuracy'])
155152
test_model.summary()
156153

157-
loss, acc = test_model.evaluate(x_test, np_utils.to_categorical(y_test), classes)
154+
loss, acc = test_model.evaluate(x_test,
155+
keras.utils.to_categorical(y_test),
156+
classes)
158157
print('\nTest accuracy: {0}'.format(acc))

0 commit comments

Comments
 (0)