diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py index 6d53a765..dc99fd4c 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py @@ -239,7 +239,11 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): else: num_microbatches = self._num_microbatches microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [num_microbatches, -1]), axis=1) + tf.reshape( + loss, + [num_microbatches, + tf.shape(loss)[0] / num_microbatches]), + axis=1) if callable(var_list): var_list = var_list() @@ -250,7 +254,11 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): else: num_microbatches = self._num_microbatches microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [num_microbatches, -1]), axis=1) + tf.reshape( + loss, + [num_microbatches, + tf.shape(loss)[0] / num_microbatches]), + axis=1) var_list = tf.nest.flatten(var_list) @@ -294,7 +302,10 @@ def get_gradients(self, loss, params): # This code mostly follows the logic in the original DPOptimizerClass # in dp_optimizer.py, except that this returns only the gradients, # not the gradients and variables. - microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1]) + microbatch_losses = tf.reshape( + loss, + [self._num_microbatches, + tf.shape(loss)[0] / self._num_microbatches]) sample_params = ( self._dp_sum_query.derive_sample_params(self._global_state))