Skip to content

Commit 0d591b2

Browse files
committed
Override add_noise in DPOptimizerFastGradientClipping to use adjusted noise multiplier for adaptive clipping
1 parent 1f9283a commit 0d591b2

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

opacus/optimizers/adaclipoptimizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
_mark_as_processed,
2828
)
2929

30-
3130
logger = logging.getLogger(__name__)
3231

3332

opacus/optimizers/optimizer_fast_gradient_clipping.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import torch
2222
from torch.optim import Optimizer
2323

24-
from .optimizer import DPOptimizer
24+
from .optimizer import (
25+
DPOptimizer,
26+
_check_processed_flag,
27+
_generate_noise,
28+
_mark_as_processed,
29+
)
2530

2631

2732
logger = logging.getLogger(__name__)
@@ -188,3 +193,32 @@ def clip_and_accumulate(self):
188193
Redefines a parent class' function to not do anything
189194
"""
190195
pass
196+
197+
def add_noise(self):
198+
"""
199+
Adds noise to gradients. Stores noised result in ``p.grad``.
200+
201+
Uses ``_adjusted_noise_multiplier`` if set (for adaptive clipping),
202+
otherwise falls back to ``noise_multiplier``. This separation ensures
203+
correct privacy accounting when using adaptive clipping, where the
204+
original noise_multiplier is preserved for accounting while the
205+
adjusted value is used only for noise generation.
206+
"""
207+
# Use adjusted noise multiplier if set (for adaptive clipping),
208+
# otherwise use the standard noise_multiplier
209+
effective_noise_multiplier = getattr(
210+
self, "_adjusted_noise_multiplier", self.noise_multiplier
211+
)
212+
213+
for p in self.params:
214+
_check_processed_flag(p.summed_grad)
215+
216+
noise = _generate_noise(
217+
std=effective_noise_multiplier * self.max_grad_norm,
218+
reference=p.summed_grad,
219+
generator=self.generator,
220+
secure_mode=self.secure_mode,
221+
)
222+
p.grad = (p.summed_grad + noise).view_as(p)
223+
224+
_mark_as_processed(p.summed_grad)

opacus/utils/adaptive_clipping/adaptive_clipping_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,13 @@ def backward(self):
9595
per_sample_norms
9696
)
9797

98-
# update max grad norm and noise multiplier
98+
# update max grad norm and adjusted noise multiplier for noise generation
99+
# Note: We set _adjusted_noise_multiplier instead of noise_multiplier to preserve
100+
# the original noise_multiplier for correct privacy accounting (see Theorem 1 in
101+
# https://arxiv.org/pdf/1905.03871.pdf)
99102
self.module.max_grad_norm = new_max_grad_norm
100103
self.optimizer.max_grad_norm = new_max_grad_norm
101-
self.optimizer.noise_multiplier = new_noise_multiplier
104+
self.optimizer._adjusted_noise_multiplier = new_noise_multiplier
102105

103106
# get the loss rescaling coefficients using the updated max_grad_norm
104107
coeff = torch.where(

0 commit comments

Comments
 (0)