Skip to content

Commit f0cbcbb

Browse files
committed
refactor: simplify amp. by iter. RDP explaination
1 parent 7a83f2f commit f0cbcbb

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

tutorials/mnist_lr_tutorial.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def lr_model_fn(features, labels, mode, nclasses, dim):
8686
learning_rate=FLAGS.learning_rate)
8787
opt_loss = vector_loss
8888
else:
89-
optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
89+
optimizer = tf.train.GradientDescentOptimizer(
90+
learning_rate=FLAGS.learning_rate)
9091
opt_loss = scalar_loss
9192
global_step = tf.train.get_global_step()
9293
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
@@ -164,14 +165,17 @@ def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier):
164165
np.linspace(20, 100, num=81)])
165166
delta = 1e-5
166167
for p in (.5, .9, .99):
167-
steps = math.ceil(steps_per_epoch * p) # Steps in the last epoch.
168-
coef = 2 * (noise_multiplier * batch_size)**-2 * (
169-
# Accounting for privacy loss
170-
(epochs - 1) / steps_per_epoch + # ... from all-but-last epochs
171-
1 / (steps_per_epoch - steps + 1)) # ... due to the last epoch
168+
steps = math.ceil(steps_per_epoch * p) # Steps in the last epoch
169+
# compute rdp coeff for a single differing batch
170+
coeff = 2 * (noise_multiplier * batch_size)**-2
171+
# amplification by iteration from all-but-last-epochs
172+
amp_part1 = (epochs - 1) / steps_per_epoch
173+
# min amplification by iteration for at least p items due to last epoch
174+
amp_part2 = 1 / (steps_per_epoch - steps + 1)
175+
# compute rdp of output model
176+
rdp = [coeff * order * (amp_part1 + amp_part2) for order in orders]
172177
# Using RDP accountant to compute eps. Doing computation analytically is
173178
# an option.
174-
rdp = [order * coef for order in orders]
175179
eps, _, _ = get_privacy_spent(orders, rdp, target_delta=delta)
176180
print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format(
177181
p * 100, eps, delta))

0 commit comments

Comments
 (0)