-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathchild_net.py
executable file
·75 lines (62 loc) · 2.57 KB
/
child_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import tensorflow as tf
import keras
from keras import models, layers, datasets, utils, backend, optimizers, initializers
from keras.layers import Dense, Dropout, Activation, Flatten, Input
from keras.layers import Conv2D, MaxPooling2D
# Constructing Child Networks
class ChildNetwork:
def __init__(self, x_train, y_train, x_test, y_test, sess, batch_size=200, epoch=10):
self.x_train = x_train
self.y_train = y_train
self.x_test = x_test
self.y_test = y_test
self.opt = keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
self.batch_size = batch_size
self.num_classes = 10
self.epochs = epoch
self.data_augmentation = True
self.model = self.build_model()
self.sess = sess
# Let's train the model using RMSprop
self.model.compile(loss='categorical_crossentropy',
optimizer=self.opt,
metrics=['accuracy'])
def reinitialize(self, x, y):
self.x_train = x
self.y_train = y
var_to_init = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='child')
init_new_vars_op = tf.variables_initializer(var_to_init)
self.sess.run(init_new_vars_op)
#self.model.reset_states()
def build_model(self):
# https://keras.io/examples/cifar10_cnn/
with tf.variable_scope("child",reuse=tf.AUTO_REUSE):
model = models.Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=self.x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(self.num_classes))
model.add(Activation('softmax'))
return model
def train(self):
self.model.fit(self.x_train, self.y_train,
batch_size=self.batch_size,
epochs=self.epochs,
verbose=0)
def evaluate(self):
return self.model.evaluate(self.x_test, self.y_test, verbose=0)