Skip to content

Commit 9104a53

Browse files
committed
fix: migrating from compat to keras regularizer
1 parent 87c01eb commit 9104a53

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

tutorials/mnist_lr_tutorial.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
4040
from tensorflow_privacy.privacy.optimizers import dp_optimizer
4141

42-
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
43-
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
44-
else:
45-
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
46-
4742
FLAGS = flags.FLAGS
4843

4944
flags.DEFINE_boolean(
@@ -66,10 +61,10 @@ def lr_model_fn(features, labels, mode, nclasses, dim):
6661
logits = tf.layers.dense(
6762
inputs=input_layer,
6863
units=nclasses,
69-
kernel_regularizer=tf.contrib.layers.l2_regularizer(
70-
scale=FLAGS.regularizer),
71-
bias_regularizer=tf.contrib.layers.l2_regularizer(
72-
scale=FLAGS.regularizer))
64+
kernel_regularizer=tf.keras.regularizers.l2(
65+
l=FLAGS.regularizer),
66+
bias_regularizer=tf.keras.regularizers.l2(
67+
l=FLAGS.regularizer))
7368

7469
# Calculate loss as a vector (to support microbatches in DP-SGD).
7570
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
@@ -91,7 +86,7 @@ def lr_model_fn(features, labels, mode, nclasses, dim):
9186
learning_rate=FLAGS.learning_rate)
9287
opt_loss = vector_loss
9388
else:
94-
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
89+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
9590
opt_loss = scalar_loss
9691
global_step = tf.train.get_global_step()
9792
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)

0 commit comments

Comments
 (0)