Skip to content

Add CategoricalCrossEntropy and gaussian noise sampling to Bolton Implementation #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
109 changes: 106 additions & 3 deletions tensorflow_privacy/privacy/bolt_on/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from tensorflow.python.keras import losses
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.platform import tf_logging as logging

from absl import logging

class StrongConvexMixin: # pylint: disable=old-style-class
"""Strong Convex Mixin base class.
Expand Down Expand Up @@ -237,7 +236,7 @@ def __init__(self,
dtype: tf datatype to use for tensor conversions.
"""
if label_smoothing != 0:
logging.warning("The impact of label smoothing on privacy is unknown. "
logging.warning("WARNING: The impact of label smoothing on privacy is unknown. "
"Use label smoothing at your own risk as it may not "
"guarantee privacy.")

Expand Down Expand Up @@ -302,3 +301,107 @@ def kernel_regularizer(self):
set to half the 0.5 * reg_lambda.
"""
return L1L2(l2=self.reg_lambda/2)

class StrongConvexCategoricalCrossentropy(
losses.CategoricalCrossentropy,
StrongConvexMixin
):
"""Strongly Convex CategoricalCrossentropy with softmax layer loss using l2 weight regularization."""

def __init__(self,
reg_lambda,
c_arg,
radius_constant,
from_logits=True,
label_smoothing=0,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
dtype=tf.float32):
"""StrongConvexCategoricalCrossentropy class.

Args:
reg_lambda: Weight regularization constant
c_arg: Penalty parameter C of the loss term
radius_constant: constant defining the length of the radius
from_logits: True if the input are unscaled logits. False if they are
already scaled.
label_smoothing: amount of smoothing to perform on labels
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). Note, the
impact of this parameter's effect on privacy is not known and thus the
default should be used.
reduction: reduction type to use. See super class
dtype: tf datatype to use for tensor conversions.
"""

if label_smoothing != 0:
import sys
logging.warning("WARNING: The impact of label smoothing on privacy is unknown."
"Use label smoothing at your own risk as it may not "
"guarantee privacy.")

if reg_lambda <= 0:
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
if c_arg <= 0:
raise ValueError("c: {0}, should be >= 0".format(c_arg))
if radius_constant <= 0:
raise ValueError("radius_constant: {0}, should be >= 0".format(
radius_constant
))
self.dtype = dtype
self.C = c_arg # pylint: disable=invalid-name
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexCategoricalCrossentropy, self).__init__(
reduction=reduction,
name="strongconvexcategoricalcrossentropy",
from_logits=from_logits,
label_smoothing=label_smoothing,
)
self.radius_constant = radius_constant


def call(self, y_true, y_pred):
"""Computes loss.

Args:
y_true: Ground truth values.
y_pred: The predicted values.

Returns:
Loss values per sample.
"""
loss = super(StrongConvexCategoricalCrossentropy, self).call(y_true, y_pred)
loss = loss * self.C
return loss


def radius(self):
"""See super class."""
return self.radius_constant / self.reg_lambda


def gamma(self):
"""See super class."""
return self.reg_lambda


def beta(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda


def lipchitz_constant(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda * self.radius()


def kernel_regularizer(self):
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).

L2 regularization is required for this loss function to be strongly convex.

Returns:
The L2 regularizer layer for this loss function, with regularizer constant
set to half the 0.5 * reg_lambda.
"""
return L1L2(l2=self.reg_lambda / 2)
163 changes: 158 additions & 5 deletions tensorflow_privacy/privacy/bolt_on/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.regularizers import L1L2
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexCategoricalCrossentropy
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexHuber
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin


@contextmanager
def captured_output():
"""Capture std_out and std_err within context."""
Expand All @@ -40,7 +40,9 @@ def captured_output():
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err

# must close new_out and new_err in order to re-initialize
new_out.close()
new_err.close()

class StrongConvexMixinTests(keras_parameterized.TestCase):
"""Tests for the StrongConvexMixin."""
Expand Down Expand Up @@ -151,7 +153,7 @@ def test_bad_init_params(self, reg_lambda, C, radius_constant):
'y_true': [1],
'result': 10000,
},
{'testcase_name': 'positivee gradient positive logits',
{'testcase_name': 'positive gradient positive logits',
'logits': [10000],
'y_true': [0],
'result': 10000,
Expand Down Expand Up @@ -242,11 +244,162 @@ def test_prints(self, init_args, fn, args, print_res):
loss = StrongConvexBinaryCrossentropy(*init_args)
if fn is not None:
getattr(loss, fn, lambda *arguments: print('error'))(*args)
self.assertRegexMatch(err.getvalue().strip(), [print_res])
self.assertRegexMatch(err.getvalue().strip(), [print_res])

class CategoricalCrossesntropyTests(keras_parameterized.TestCase):
"""tests for CategoricalCrossesntropy StrongConvex loss."""

@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'C': 1,
'radius_constant': 1
}, # pylint: disable=invalid-name
])
def test_init_params(self, reg_lambda, C, radius_constant):
"""Test initialization for given arguments.
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
loss = StrongConvexCategoricalCrossentropy(reg_lambda, C, radius_constant)
self.assertIsInstance(loss, StrongConvexCategoricalCrossentropy)

@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'C': -1,
'radius_constant': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'C': 1,
'radius_constant': -1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'C': 1,
'radius_constant': 1
}, # pylint: disable=invalid-name
])
def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError.
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
with self.assertRaises(ValueError):
StrongConvexCategoricalCrossentropy(reg_lambda, C, radius_constant)

@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
# [] for compatibility with tensorflow loss calculation
{'testcase_name': 'both positive',
'logits': [[10000, 0]],
'y_true': [[1, 0]],
'result': 0,
},
{'testcase_name': 'negative gradient positive logits',
'logits': [[-10000, 0]],
'y_true': [[1, 0]],
'result': 10000,
},
{'testcase_name': 'positive gradient negative logits',
'logits': [[10000, 0]],
'y_true': [[0, 1]],
'result': 10000,
},
{'testcase_name': 'both negative',
'logits': [[-10000, 0]],
'y_true': [[0, 1]],
'result': 0
},
])
def test_calculation(self, logits, y_true, result):
"""Test the call method to ensure it returns the correct value.
Args:
logits: unscaled output of model
y_true: label
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexCategoricalCrossentropy(0.00001, 1, 1)
loss = loss(y_true, logits)
self.assertEqual(loss.numpy(), result)

@parameterized.named_parameters([
{'testcase_name': 'beta',
'init_args': [1, 1, 1],
'fn': 'beta',
'args': [1],
'result': tf.constant(2, dtype=tf.float32)
},
{'testcase_name': 'gamma',
'fn': 'gamma',
'init_args': [1, 1, 1],
'args': [],
'result': tf.constant(1, dtype=tf.float32),
},
{'testcase_name': 'lipchitz constant',
'fn': 'lipchitz_constant',
'init_args': [1, 1, 1],
'args': [1],
'result': tf.constant(2, dtype=tf.float32),
},
{'testcase_name': 'kernel regularizer',
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1],
'args': [],
'result': L1L2(l2=0.5),
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of CategoricalCrossentropy loss returns the correct result.
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexCategoricalCrossentropy(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
expected = expected.numpy()
result = result.numpy()
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
expected = expected.l2
result = result.l2
self.assertEqual(expected, result)

@parameterized.named_parameters([
{'testcase_name': 'label_smoothing',
'init_args': [1, 1, 1, True, 0.1],
'fn': None,
'args': None,
'print_res': 'The impact of label smoothing on privacy is unknown.'
},
])
def test_prints(self, init_args, fn, args, print_res):
"""Test logger warning from StrongConvexCategoricalCrossentropy.
Args:
init_args: arguments to init the object with.
fn: function to test
args: arguments to above function
print_res: print result that should have been printed.
"""
with captured_output() as (out, err): # pylint: disable=unused-variable
loss = StrongConvexCategoricalCrossentropy(*init_args)
if fn is not None:
getattr(loss, fn, lambda *arguments: print('error'))(*args)
self.assertRegexMatch(err.getvalue().strip(), [print_res])

class HuberTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss."""
"""tests for CategoricalCrossesntropy StrongConvex loss."""

@parameterized.named_parameters([
{'testcase_name': 'normal',
Expand Down
23 changes: 14 additions & 9 deletions tensorflow_privacy/privacy/bolt_on/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def fit(self,
class_weight=None,
n_samples=None,
epsilon=2,
delta=0.1,
noise_distribution='laplace',
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
Expand All @@ -133,8 +134,8 @@ def fit(self,
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
n_samples: the number of individual samples in x.
epsilon: privacy parameter, which trades off between utility an privacy.
See the bolt-on paper for more description.
epsilon, delta: privacy parameter, which trades off utility and privacy.
See the bolt-on paper for more description.
noise_distribution: the distribution to pull noise from.
steps_per_epoch:
**kwargs: kwargs to keras Model.fit. See super.
Expand Down Expand Up @@ -169,6 +170,7 @@ def fit(self,
'this in using n_samples.')
with self.optimizer(noise_distribution,
epsilon,
delta,
self.layers,
class_weight_,
data_size,
Expand All @@ -186,6 +188,7 @@ def fit_generator(self,
class_weight=None,
noise_distribution='laplace',
epsilon=2,
delta=0.1,
n_samples=None,
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
Expand All @@ -199,8 +202,8 @@ def fit_generator(self,
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
noise_distribution: the distribution to get noise from.
epsilon: privacy parameter, which trades off utility and privacy. See
BoltOn paper for more description.
epsilon, delta: privacy parameter, which trades off utility and privacy.
See BoltOn paper for more description.
n_samples: number of individual samples in x
steps_per_epoch: Number of steps per training epoch, see super.
**kwargs: **kwargs
Expand All @@ -217,16 +220,18 @@ def fit_generator(self,
elif hasattr(generator, '__len__'):
data_size = len(generator)
else:
raise ValueError('The number of samples could not be determined. '
'Please make sure that if you are using a generator'
'to call this method directly with n_samples kwarg '
'passed.')
batch_size = self._validate_or_infer_batch_size(None, steps_per_epoch,
raise ValueError("The number of samples could not be determined. "
"Please make sure that if you are using a generator"
"to call this method directly with n_samples kwarg "
"passed.")
batch_size = self._validate_or_infer_batch_size(None,
steps_per_epoch,
generator)
if batch_size is None:
batch_size = 32
with self.optimizer(noise_distribution,
epsilon,
delta,
self.layers,
class_weight,
data_size,
Expand Down
Loading