diff --git a/.circleci/config.yml b/.circleci/config.yml index 2af1e34f..c095373e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -162,6 +162,8 @@ commands: pip install tensorboard python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <> python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)" + python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <> --grad_sample_mode no_op + python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)" when: always - store_test_results: path: runs/cifar10/test-reports diff --git a/examples/cifar10.py b/examples/cifar10.py index c4471f7b..6a85eb6d 100644 --- a/examples/cifar10.py +++ b/examples/cifar10.py @@ -138,6 +138,26 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device): losses = [] top1_acc = [] + if args.grad_sample_mode == "no_op": + from functorch import grad_and_value, make_functional, vmap + + # Functorch prepare + fmodel, _fparams = make_functional(model) + + def compute_loss_stateless_model(params, sample, target): + batch = sample.unsqueeze(0) + targets = target.unsqueeze(0) + + predictions = fmodel(params, batch) + loss = criterion(predictions, targets) + return loss + + ft_compute_grad = grad_and_value(compute_loss_stateless_model) + ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0)) + # Using model.parameters() instead of fparams + # as fparams seems to not point to the dynamically updated parameters + params = list(model.parameters()) + for i, (images, target) in enumerate(tqdm(train_loader)): images = images.to(device) @@ -145,18 +165,28 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device): # compute output output = model(images) - loss = criterion(output, target) - preds = np.argmax(output.detach().cpu().numpy(), axis=1) - labels = target.detach().cpu().numpy() - # measure accuracy and record loss - acc1 = accuracy(preds, labels) + if args.grad_sample_mode == "no_op": + per_sample_grads, per_sample_losses = ft_compute_sample_grad( + params, images, target + ) + per_sample_grads = [g.detach() for g in per_sample_grads] + loss = torch.mean(per_sample_losses) + for (p, g) in zip(params, per_sample_grads): + p.grad_sample = g + else: + loss = criterion(output, target) + preds = np.argmax(output.detach().cpu().numpy(), axis=1) + labels = target.detach().cpu().numpy() - losses.append(loss.item()) - top1_acc.append(acc1) + # measure accuracy and record loss + acc1 = accuracy(preds, labels) + top1_acc.append(acc1) + + # compute gradient and do SGD step + loss.backward() - # compute gradient and do SGD step - loss.backward() + losses.append(loss.item()) # make sure we take a step after processing the last mini-batch in the # epoch to ensure we start the next epoch with a clean state @@ -331,6 +361,7 @@ def main(): noise_multiplier=args.sigma, max_grad_norm=max_grad_norm, clipping=clipping, + grad_sample_mode=args.grad_sample_mode, ) # Store some logs @@ -388,6 +419,7 @@ def main(): def parse_args(): parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") + parser.add_argument("--grad_sample_mode", type=str, default="hooks") parser.add_argument( "-j", "--workers", diff --git a/opacus/grad_sample/gsm_no_op.py b/opacus/grad_sample/gsm_no_op.py new file mode 100644 index 00000000..3fd3df8f --- /dev/null +++ b/opacus/grad_sample/gsm_no_op.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from opacus.grad_sample.gsm_base import AbstractGradSampleModule + + +class GradSampleModuleNoOp(AbstractGradSampleModule): + """ + NoOp GradSampleModule. + Only wraps the module. The main goal of this class is to provide the same API for all methods. + See README.md for more details + """ + + def __init__( + self, + m: nn.Module, + *, + batch_first=True, + loss_reduction="mean", + ): + if not batch_first: + raise NotImplementedError + + super().__init__( + m, + batch_first=batch_first, + loss_reduction=loss_reduction, + ) + + def forward(self, x: torch.Tensor, *args, **kwargs): + return self._module.forward(x, *args, **kwargs) diff --git a/opacus/grad_sample/utils.py b/opacus/grad_sample/utils.py index 1f245025..8b5e8ff5 100644 --- a/opacus/grad_sample/utils.py +++ b/opacus/grad_sample/utils.py @@ -20,6 +20,7 @@ from .grad_sample_module import GradSampleModule from .gsm_base import AbstractGradSampleModule from .gsm_exp_weights import GradSampleModuleExpandedWeights +from .gsm_no_op import GradSampleModuleNoOp def register_grad_sampler( @@ -69,6 +70,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]: return GradSampleModule elif grad_sample_mode == "ew": return GradSampleModuleExpandedWeights + elif grad_sample_mode == "no_op": + return GradSampleModuleNoOp else: raise ValueError( f"Unexpected grad_sample_mode: {grad_sample_mode}. " diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index b516db56..46a414d9 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -395,7 +395,7 @@ def clip_and_accumulate(self): """ per_param_norms = [ - g.norm(2, dim=tuple(range(1, g.ndim))) for g in self.grad_samples + g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples ] per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(