-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathintro.py
More file actions
172 lines (146 loc) · 5.29 KB
/
Copy pathintro.py
File metadata and controls
172 lines (146 loc) · 5.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import argparse
import json
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def init_flags():
global FLAGS
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", default="data")
parser.add_argument("--rundir", default=".")
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=0.5)
parser.add_argument("--prepare", dest='just_data', action="store_true")
parser.add_argument("--test", action="store_true")
FLAGS, _ = parser.parse_known_args()
def init_data():
global mnist
mnist = input_data.read_data_sets(FLAGS.datadir, one_hot=True)
def init_train():
init_model()
init_train_op()
init_eval_op()
init_summaries()
init_collections()
init_session()
def init_model():
global x, y, W, b
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
def init_train_op():
global y_, loss, train_op
y_ = tf.placeholder(tf.float32, [None, 10])
loss = tf.reduce_mean(
-tf.reduce_sum(
y_ * tf.log(y),
reduction_indices=[1]))
train_op = tf.train.GradientDescentOptimizer(FLAGS.lr).minimize(loss)
def init_eval_op():
global accuracy
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def init_summaries():
init_inputs_summary()
init_variable_summaries(W, "weights")
init_variable_summaries(b, "biases")
init_op_summaries()
init_summary_writers()
def init_inputs_summary():
tf.summary.image("inputs", tf.reshape(x, [-1, 28, 28, 1]), 10)
def init_variable_summaries(var, name):
with tf.name_scope(name):
mean = tf.reduce_mean(var)
tf.summary.scalar("mean", mean)
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar("stddev", stddev)
tf.summary.scalar("max", tf.reduce_max(var))
tf.summary.scalar("min", tf.reduce_min(var))
tf.summary.histogram(name, var)
def init_op_summaries():
tf.summary.scalar("loss", loss)
tf.summary.scalar("acc", accuracy)
def init_summary_writers():
global summaries, train_writer, validate_writer
summaries = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(
FLAGS.rundir + "/train",
tf.get_default_graph())
validate_writer = tf.summary.FileWriter(
FLAGS.rundir + "/validate")
def init_collections():
tf.add_to_collection("inputs", json.dumps({"image": x.name}))
tf.add_to_collection("outputs", json.dumps({"prediction": y.name}))
tf.add_to_collection("x", x.name)
tf.add_to_collection("y_", y_.name)
tf.add_to_collection("acc", accuracy.name)
def init_session():
global sess
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def train():
steps = (mnist.train.num_examples // FLAGS.batch_size) * FLAGS.epochs
for step in range(steps):
images, labels = mnist.train.next_batch(FLAGS.batch_size)
batch = {x: images, y_: labels}
sess.run(train_op, batch)
maybe_log_accuracy(step, batch)
maybe_save_model(step)
save_model()
def maybe_log_accuracy(step, last_training_batch):
if step % 100 == 0:
evaluate(step, last_training_batch, train_writer, "training")
validate_data = {
x: mnist.validation.images,
y_: mnist.validation.labels
}
evaluate(step, validate_data, validate_writer, "validate")
def evaluate(step, data, writer, name):
accuracy_val, summary = sess.run([accuracy, summaries], data)
writer.add_summary(summary, step)
writer.flush()
print("Step %i: %s=%f" % (step, name, accuracy_val))
def maybe_save_model(step):
epoch_step = mnist.train.num_examples / FLAGS.batch_size
if step != 0 and step % epoch_step == 0:
save_model()
def save_model():
print("Saving trained model")
tf.gfile.MakeDirs(FLAGS.rundir + "/model")
tf.train.Saver().save(sess, FLAGS.rundir + "/model/export")
def init_test():
init_session()
init_exported_collections()
init_test_writer()
def init_exported_collections():
global x, y_, accuracy
saver = tf.train.import_meta_graph(FLAGS.rundir + "/model/export.meta")
saver.restore(sess, FLAGS.rundir + "/model/export")
x = tensor_by_collection_name("x")
y_ = tensor_by_collection_name("y_")
accuracy = tensor_by_collection_name("acc")
def tensor_by_collection_name(name):
tensor_name = tf.get_collection(name)[0].decode("UTF-8")
return sess.graph.get_tensor_by_name(tensor_name)
def init_test_writer():
global summaries, writer
summaries = tf.summary.merge_all()
writer = tf.summary.FileWriter(FLAGS.rundir)
def test():
data = {x: mnist.test.images, y_: mnist.test.labels}
test_accuracy, summary = sess.run([accuracy, summaries], data)
writer.add_summary(summary)
writer.flush()
print("Test accuracy=%f" % test_accuracy)
if __name__ == "__main__":
init_flags()
init_data()
if FLAGS.just_data:
pass
elif FLAGS.test:
init_test()
test()
else:
init_train()
train()