From 146bbaaf461be7ed63094b49b0ae6596d1395fce Mon Sep 17 00:00:00 2001 From: svnv-svsv-jm Date: Thu, 22 May 2025 15:36:24 +0200 Subject: [PATCH 1/2] add cb --- src/lightning/pytorch/callbacks/__init__.py | 2 + .../pytorch/callbacks/differential_privacy.py | 325 ++++++++++++++++++ tests/tests_pytorch/callbacks/test_dp.py | 104 ++++++ 3 files changed, 431 insertions(+) create mode 100644 src/lightning/pytorch/callbacks/differential_privacy.py create mode 100644 tests/tests_pytorch/callbacks/test_dp.py diff --git a/src/lightning/pytorch/callbacks/__init__.py b/src/lightning/pytorch/callbacks/__init__.py index 9ee34f3866b27..a1eaa3aa0000a 100644 --- a/src/lightning/pytorch/callbacks/__init__.py +++ b/src/lightning/pytorch/callbacks/__init__.py @@ -15,6 +15,7 @@ from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.callbacks.checkpoint import Checkpoint from lightning.pytorch.callbacks.device_stats_monitor import DeviceStatsMonitor +from lightning.pytorch.callbacks.differential_privacy import DifferentialPrivacy from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.finetuning import BackboneFinetuning, BaseFinetuning from lightning.pytorch.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler @@ -58,4 +59,5 @@ "ThroughputMonitor", "Timer", "TQDMProgressBar", + "DifferentialPrivacy", ] diff --git a/src/lightning/pytorch/callbacks/differential_privacy.py b/src/lightning/pytorch/callbacks/differential_privacy.py new file mode 100644 index 0000000000000..50e4f0dfbe375 --- /dev/null +++ b/src/lightning/pytorch/callbacks/differential_privacy.py @@ -0,0 +1,325 @@ +__all__ = ["DifferentialPrivacy"] + +import typing as ty +from copy import deepcopy + +import torch +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +try: + from opacus import GradSampleModule + from opacus.accountants import RDPAccountant + from opacus.accountants.utils import get_noise_multiplier + from opacus.data_loader import DPDataLoader + from opacus.layers.dp_rnn import DPGRUCell + from opacus.optimizers import DPOptimizer as OpacusDPOptimizer +except ImportError as ex: + raise ImportError("Opacus is not installed. Please install it with `pip install opacus`.") from ex + +import pytorch_lightning as pl + + +def replace_grucell(module: torch.nn.Module) -> None: + """Replaces GRUCell modules with DP-counterparts.""" + for name, child in module.named_children(): + if isinstance(child, torch.nn.GRUCell) and not isinstance(child, DPGRUCell): + replacement = copy_gru(child) + setattr(module, name, replacement) + for name, child in module.named_children(): + replace_grucell(child) + + +def copy_gru(grucell: torch.nn.GRUCell) -> DPGRUCell: + """Creates a DP-GRUCell from a non-DP one.""" + input_size: int = grucell.input_size + hidden_size: int = grucell.hidden_size + bias: bool = grucell.bias + dpgrucell = DPGRUCell(input_size, hidden_size, bias) + for name, param in grucell.named_parameters(): + if "ih" in name: + _set_layer_param(dpgrucell, name, param, "ih") + elif "hh" in name: + _set_layer_param(dpgrucell, name, param, "hh") + else: + raise AttributeError(f"Unknown parameter {name}") + return dpgrucell + + +def _set_layer_param( + dpgrucell: DPGRUCell, + name: str, + param: torch.Tensor, + layer_name: str, +) -> None: + """Helper""" + layer = getattr(dpgrucell, layer_name) + if "weight" in name: + layer.weight = torch.nn.Parameter(deepcopy(param)) + elif "bias" in name: + layer.bias = torch.nn.Parameter(deepcopy(param)) + else: + raise AttributeError(f"Unknown parameter {name}") + setattr(dpgrucell, layer_name, layer) + + +def params( + optimizer: torch.optim.Optimizer, + accepted_names: list[str] = None, +) -> list[torch.nn.Parameter]: + """ + Return all parameters controlled by the optimizer + Args: + accepted_names (list[str]): + List of parameter group names you want to apply DP to. + This allows you to choose to apply DP only to specific parameter groups. + Of course, this will work only if the optimizer has named parameter groups. + If it doesn't, then this argument will be ignored and DP will be applied to all parameter groups. + Returns: + (list[torch.nn.Parameter]): Flat list of parameters from all `param_groups` + """ + # lower case + if accepted_names is not None: + accepted_names = [name.lower() for name in accepted_names] + # unwrap parameters from the param_groups into a flat list + ret = [] + for param_group in optimizer.param_groups: + if accepted_names is not None and "name" in param_group: + name: str = param_group["name"].lower() + if name.lower() in accepted_names: + ret += [p for p in param_group["params"] if p.requires_grad] + else: + ret += [p for p in param_group["params"] if p.requires_grad] + return ret + + +class DPOptimizer(OpacusDPOptimizer): + """Brainiac-2's DP-Optimizer""" + + def __init__( + self, + *args: ty.Any, + param_group_names: list[str] = None, + **kwargs: ty.Any, + ) -> None: + """Constructor.""" + self.param_group_names = param_group_names + super().__init__(*args, **kwargs) + + @property + def params(self) -> list[torch.nn.Parameter]: + """ + Returns a flat list of ``nn.Parameter`` managed by the optimizer + """ + return params(self, self.param_group_names) + + +class DifferentialPrivacy(pl.callbacks.EarlyStopping): + """Enables differential privacy using Opacus. + Converts optimizers to instances of the :class:`~opacus.optimizers.DPOptimizer` class. + This callback inherits from `EarlyStopping`, thus it is also able to stop the + training when enough privacy budget has been spent. + Please beware that Opacus does not support multi-optimizer training. + For more info, check the following links: + * https://opacus.ai/tutorials/ + * https://blog.openmined.org/differentially-private-deep-learning-using-opacus-in-20-lines-of-code/ + """ + + def __init__( + self, + budget: float = 1.0, + noise_multiplier: float = 1.0, + max_grad_norm: float = 1.0, + delta: float = None, + use_target_values: bool = False, + idx: ty.Sequence[int] = None, + log_spent_budget_as: str = "DP/spent-budget", + param_group_names: list[str] = None, + private_dataloader: bool = False, + default_alphas: ty.Sequence[ty.Union[float, int]] = None, + **gsm_kwargs: ty.Any, + ) -> None: + """Enables differential privacy using Opacus. + Converts optimizers to instances of the :class:`~opacus.optimizers.DPOptimizer` class. + This callback inherits from `EarlyStopping`, + thus it is also able to stop the training when enough privacy budget has been spent. + Args: + budget (float, optional): Defaults to 1.0. + Maximun privacy budget to spend. + noise_multiplier (float, optional): Defaults to 1.0. + Noise multiplier. + max_grad_norm (float, optional): Defaults to 1.0. + Max grad norm used for gradient clipping. + delta (float, optional): Defaults to None. + The target δ of the (ϵ,δ)-differential privacy guarantee. + Generally, it should be set to be less than the inverse of the size of the training dataset. + If `None`, this will be set to the inverse of the size of the training dataset `N`: `1/N`. + use_target_values (bool, optional): + Whether to call `privacy_engine.make_private_with_epsilon()` or `privacy_engine.make_private`. + If `True`, the value of `noise_multiplier` will be calibrated automatically so that the desired privacy + budget will be reached only at the end of the training. + idx (ty.Sequence[int]): + List of optimizer ID's to make private. Useful when a model may have more than one optimizer. + By default, all optimizers are made private. + log_spent_budget_as (str, optional): + How to log and expose the spent budget value. + Although this callback already allows you to stop the training when + enough privacy budget has been spent (see argument `stop_on_budget`), + this keyword argument can be used in combination with an `EarlyStopping` + callback, so that the latter may use this value to stop the training when enough budget has been spent. + param_group_names (list[str]): + List of parameter group names you want to apply DP to. This allows you + to choose to apply DP only to specific parameter groups. Of course, this + will work only if the optimizer has named parameter groups. If it + doesn't, then this argument will be ignored and DP will be applied to + all parameter groups. + private_dataloader (bool, optional): + Whether to make the dataloader private. Defaults to False. + **gsm_kwargs: + Input arguments for the :class:`~opacus.GradSampleModule` class. + """ + # inputs + self.budget = budget + self.delta = delta + self.noise_multiplier = noise_multiplier + self.max_grad_norm = max_grad_norm + self.use_target_values = use_target_values + self.log_spent_budget_as = log_spent_budget_as + self.param_group_names = param_group_names + self.private_dataloader = private_dataloader + self.gsm_kwargs = gsm_kwargs + if default_alphas is None: + self.default_alphas = RDPAccountant.DEFAULT_ALPHAS + list(range(64, 150)) + else: + self.default_alphas = default_alphas + # init early stopping callback + super().__init__( + monitor=self.log_spent_budget_as, + mode="max", + stopping_threshold=self.budget, + check_on_train_epoch_end=True, + # we do not want to stop if spent budget does not increase. this may even be desirable + min_delta=0, + patience=1000000, + ) + # attributes + self.epsilon: float = 0.0 + self.best_alpha: float = 0.0 + self.accountant = RDPAccountant() + self.idx = idx # optims to privatize + + def setup( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + stage: str = None, + ) -> None: + """Call the GradSampleModule() wrapper to add attributes to pl_module.""" + if stage == "fit": + replace_grucell(pl_module) + try: + pl_module = GradSampleModule(pl_module, **self.gsm_kwargs) + except ImportError as ex: + raise ImportError(f"{ex}. This may be due to a mismatch between Opacus and PyTorch version.") from ex + + def on_train_epoch_start( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + ) -> None: + """Called when the training epoch begins. Use this to make optimizers private.""" + # idx + if self.idx is None: + self.idx = range(len(trainer.optimizers)) + + # Replace current dataloaders with private counterparts + expected_batch_size = 1 + cl = trainer.fit_loop._combined_loader + if cl is not None: + dp_dls: list[DPDataLoader] = [] + for i, dl in enumerate(cl.flattened): + if isinstance(dl, DataLoader): + sample_rate: float = 1 / len(dl) + dataset_size: int = len(dl.dataset) # type: ignore + expected_batch_size = int(dataset_size * sample_rate) + if self.private_dataloader: + dp_dl = DPDataLoader.from_data_loader(dl, distributed=False) + dp_dls.append(dp_dl) + # it also allows you to easily replace the dataloaders + if self.private_dataloader: + cl.flattened = dp_dls + + # Delta + if self.delta is None: + self.delta = 1 / dataset_size + + # Make optimizers private + optimizers: list[Optimizer] = [] + dp_optimizer: ty.Union[Optimizer, DPOptimizer] + for i, optimizer in enumerate(trainer.optimizers): + if not isinstance(optimizer, DPOptimizer) and i in self.idx: + if self.use_target_values: + self.noise_multiplier = get_noise_multiplier( + target_epsilon=self.budget / 2, + target_delta=self.delta, + sample_rate=sample_rate, + epochs=trainer.max_epochs, + accountant="rdp", + ) + dp_optimizer = DPOptimizer( + optimizer=optimizer, + noise_multiplier=self.noise_multiplier, + max_grad_norm=self.max_grad_norm, + expected_batch_size=expected_batch_size, + param_group_names=self.param_group_names, + ) + dp_optimizer.attach_step_hook(self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate)) + else: + dp_optimizer = optimizer + optimizers.append(dp_optimizer) + # Replace optimizers + trainer.optimizers = optimizers + + def on_train_batch_end( # pylint: disable=unused-argument # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: ty.Any, + batch: ty.Any, + batch_idx: int, + *args: ty.Any, + ) -> None: + """Called after the batched has been digested. Use this to understand whether to stop or not.""" + self._log_and_stop_criterion(trainer, pl_module) + + def on_train_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + ) -> None: + """Run at the end of the training epoch.""" + + def get_privacy_spent(self) -> tuple[float, float]: + """Estimate spent budget.""" + # get privacy budget spent so far + epsilon, best_alpha = self.accountant.get_privacy_spent( + delta=self.delta, + alphas=self.default_alphas, + ) + return float(epsilon), float(best_alpha) + + def _log_and_stop_criterion( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + ) -> None: + """Logging privacy spent: (epsilon, delta) and stopping if necessary.""" + self.epsilon, self.best_alpha = self.get_privacy_spent() + pl_module.log( + self.log_spent_budget_as, + self.epsilon, + on_epoch=True, + prog_bar=True, + ) + if self.epsilon > self.budget: + trainer.should_stop = True diff --git a/tests/tests_pytorch/callbacks/test_dp.py b/tests/tests_pytorch/callbacks/test_dp.py new file mode 100644 index 0000000000000..2320eda87f188 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_dp.py @@ -0,0 +1,104 @@ +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics +from torch.utils.data import DataLoader, TensorDataset + +import lightning.pytorch as pl +from lightning.pytorch.callbacks.differential_privacy import DifferentialPrivacy + + +class MockDataModule(pl.LightningDataModule): + def __init__(self, batch_size=32, input_dim=10, num_classes=2): + super().__init__() + self.batch_size = batch_size + self.input_dim = input_dim + self.num_classes = num_classes + + def setup(self, stage=None): + # Generate random data + X = torch.randn(1000, self.input_dim) + y = torch.randint(0, self.num_classes, (1000,)) + dataset = TensorDataset(X, y) + self.train_data, self.val_data = torch.utils.data.random_split(dataset, [800, 200]) + + def train_dataloader(self): + return DataLoader(self.train_data, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.val_data, batch_size=self.batch_size) + + +class SimpleClassifier(pl.LightningModule): + def __init__(self, input_dim=10, num_classes=2, lr=1e-3): + super().__init__() + self.save_hyperparameters() + self.model = nn.Linear(input_dim, num_classes) + self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes) + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = self.accuracy(logits.softmax(dim=-1), y) + self.log("train_loss", loss, prog_bar=True) + self.log("train_acc", acc, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = self.accuracy(logits.softmax(dim=-1), y) + self.log("val_loss", loss, prog_bar=True) + self.log("val_acc", acc, prog_bar=True) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + + +def test_privacy_callback() -> None: + """Test on simple classifier. + + We test that: + * the privacy budget has been spent (`epsilon > 0`); + * spent budget is greater than max privacy budget; + * traininng did not stop because `max_steps` has been reached, but because the total budget has been spent. + """ + # choose dataset + datamodule = MockDataModule() + + # choose model: choose a model with more than one optim + model = SimpleClassifier() + + # init DP callback + dp_cb = DifferentialPrivacy(budget=0.232, private_dataloader=False) + + # define training + max_steps = 20 + trainer = pl.Trainer( + logger=False, + enable_checkpointing=False, + max_steps=max_steps, + callbacks=[dp_cb], + ) + trainer.fit(model=model, datamodule=datamodule) + + # tests + epsilon, best_alpha = dp_cb.get_privacy_spent() + print(f"Total spent budget {epsilon} with alpha: {best_alpha}") + assert epsilon > 0, f"No privacy budget has been spent: {epsilon}" + assert ( + epsilon >= dp_cb.budget + ), f"Spent budget is not greater than max privacy budget: epsilon = {epsilon} and budget = {dp_cb.budget}" + assert ( + trainer.global_step < max_steps + ), "Traininng stopped because max_steps has been reached, not because the total budget has been spent." + + +if __name__ == "__main__": + pytest.main([__file__, "-x", "-s"]) From 1f5acaab915ea6baedc889a1c6722179dc7209b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 May 2025 13:38:41 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/callbacks/differential_privacy.py | 33 ++++++++++++------- tests/tests_pytorch/callbacks/test_dp.py | 13 ++++---- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/lightning/pytorch/callbacks/differential_privacy.py b/src/lightning/pytorch/callbacks/differential_privacy.py index 50e4f0dfbe375..01c98f472194b 100644 --- a/src/lightning/pytorch/callbacks/differential_privacy.py +++ b/src/lightning/pytorch/callbacks/differential_privacy.py @@ -52,7 +52,7 @@ def _set_layer_param( param: torch.Tensor, layer_name: str, ) -> None: - """Helper""" + """Helper.""" layer = getattr(dpgrucell, layer_name) if "weight" in name: layer.weight = torch.nn.Parameter(deepcopy(param)) @@ -94,7 +94,7 @@ def params( class DPOptimizer(OpacusDPOptimizer): - """Brainiac-2's DP-Optimizer""" + """Brainiac-2's DP-Optimizer.""" def __init__( self, @@ -108,21 +108,20 @@ def __init__( @property def params(self) -> list[torch.nn.Parameter]: - """ - Returns a flat list of ``nn.Parameter`` managed by the optimizer - """ + """Returns a flat list of ``nn.Parameter`` managed by the optimizer.""" return params(self, self.param_group_names) class DifferentialPrivacy(pl.callbacks.EarlyStopping): - """Enables differential privacy using Opacus. - Converts optimizers to instances of the :class:`~opacus.optimizers.DPOptimizer` class. - This callback inherits from `EarlyStopping`, thus it is also able to stop the - training when enough privacy budget has been spent. - Please beware that Opacus does not support multi-optimizer training. + """Enables differential privacy using Opacus. Converts optimizers to instances of the + :class:`~opacus.optimizers.DPOptimizer` class. This callback inherits from `EarlyStopping`, thus it is also able to + stop the training when enough privacy budget has been spent. Please beware that Opacus does not support multi- + optimizer training. + For more info, check the following links: * https://opacus.ai/tutorials/ * https://blog.openmined.org/differentially-private-deep-learning-using-opacus-in-20-lines-of-code/ + """ def __init__( @@ -140,6 +139,7 @@ def __init__( **gsm_kwargs: ty.Any, ) -> None: """Enables differential privacy using Opacus. + Converts optimizers to instances of the :class:`~opacus.optimizers.DPOptimizer` class. This callback inherits from `EarlyStopping`, thus it is also able to stop the training when enough privacy budget has been spent. @@ -177,6 +177,7 @@ def __init__( Whether to make the dataloader private. Defaults to False. **gsm_kwargs: Input arguments for the :class:`~opacus.GradSampleModule` class. + """ # inputs self.budget = budget @@ -227,7 +228,11 @@ def on_train_epoch_start( trainer: pl.Trainer, pl_module: pl.LightningModule, ) -> None: - """Called when the training epoch begins. Use this to make optimizers private.""" + """Called when the training epoch begins. + + Use this to make optimizers private. + + """ # idx if self.idx is None: self.idx = range(len(trainer.optimizers)) @@ -289,7 +294,11 @@ def on_train_batch_end( # pylint: disable=unused-argument # type: ignore batch_idx: int, *args: ty.Any, ) -> None: - """Called after the batched has been digested. Use this to understand whether to stop or not.""" + """Called after the batched has been digested. + + Use this to understand whether to stop or not. + + """ self._log_and_stop_criterion(trainer, pl_module) def on_train_epoch_end( diff --git a/tests/tests_pytorch/callbacks/test_dp.py b/tests/tests_pytorch/callbacks/test_dp.py index 2320eda87f188..c022a0bb58657 100644 --- a/tests/tests_pytorch/callbacks/test_dp.py +++ b/tests/tests_pytorch/callbacks/test_dp.py @@ -68,6 +68,7 @@ def test_privacy_callback() -> None: * the privacy budget has been spent (`epsilon > 0`); * spent budget is greater than max privacy budget; * traininng did not stop because `max_steps` has been reached, but because the total budget has been spent. + """ # choose dataset datamodule = MockDataModule() @@ -92,12 +93,12 @@ def test_privacy_callback() -> None: epsilon, best_alpha = dp_cb.get_privacy_spent() print(f"Total spent budget {epsilon} with alpha: {best_alpha}") assert epsilon > 0, f"No privacy budget has been spent: {epsilon}" - assert ( - epsilon >= dp_cb.budget - ), f"Spent budget is not greater than max privacy budget: epsilon = {epsilon} and budget = {dp_cb.budget}" - assert ( - trainer.global_step < max_steps - ), "Traininng stopped because max_steps has been reached, not because the total budget has been spent." + assert epsilon >= dp_cb.budget, ( + f"Spent budget is not greater than max privacy budget: epsilon = {epsilon} and budget = {dp_cb.budget}" + ) + assert trainer.global_step < max_steps, ( + "Traininng stopped because max_steps has been reached, not because the total budget has been spent." + ) if __name__ == "__main__":