Skip to content

Commit 72071ad

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Use SimpleDistributedPerLayerClipping optimizer in hooks mode (#750)
Summary: We use SimpleDistributedPerLayerOptimizer instead of DistributedPerLayerOptimizer. The latter causes an issue when switching to `register_full_backward_hook`. The issue arises because DistributedPerLayerOptimizer uses per-parameter hooks on top of the per-module hooks. During the backward pass, the per-parameter hooks fire before the per-module hooks. Per-sample gradients are computed when the per-module hooks fire, and an error occurs when the per-parameter hooks try to access the per-sample gradients before they are computed. Forcing the order in which hooks are called is not possible with PyTorch. Differential Revision: D72420168
1 parent 6c2cde9 commit 72071ad

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

opacus/optimizers/__init__.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
# limitations under the License.
1414

1515
from .adaclipoptimizer import AdaClipDPOptimizer
16-
from .ddp_perlayeroptimizer import (
17-
DistributedPerLayerOptimizer,
18-
SimpleDistributedPerLayerOptimizer,
19-
)
16+
from .ddp_perlayeroptimizer import SimpleDistributedPerLayerOptimizer
2017
from .ddpoptimizer import DistributedDPOptimizer
2118
from .ddpoptimizer_fast_gradient_clipping import (
2219
DistributedDPOptimizerFastGradientClipping,
@@ -28,7 +25,6 @@
2825

2926
__all__ = [
3027
"AdaClipDPOptimizer",
31-
"DistributedPerLayerOptimizer",
3228
"DistributedDPOptimizer",
3329
"DPOptimizer",
3430
"DPOptimizerFastGradientClipping",
@@ -55,9 +51,7 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str
5551
elif clipping == "per_layer" and distributed is False:
5652
return DPPerLayerOptimizer
5753
elif clipping == "per_layer" and distributed is True:
58-
if grad_sample_mode == "hooks":
59-
return DistributedPerLayerOptimizer
60-
elif grad_sample_mode == "ew":
54+
if grad_sample_mode == "hooks" or grad_sample_mode == "ew":
6155
return SimpleDistributedPerLayerOptimizer
6256
else:
6357
raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}")

opacus/tests/multigpu_gradcheck.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
from opacus import PrivacyEngine
2727
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
2828
from opacus.grad_sample import GradSampleModuleFastGradientClipping
29-
from opacus.optimizers.ddp_perlayeroptimizer import (
30-
DistributedPerLayerOptimizer,
31-
SimpleDistributedPerLayerOptimizer,
32-
)
29+
from opacus.optimizers.ddp_perlayeroptimizer import SimpleDistributedPerLayerOptimizer
3330
from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer
3431
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
3532
DistributedDPOptimizerFastGradientClipping,
@@ -165,10 +162,7 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
165162
grad_sample_mode=grad_sample_mode,
166163
)
167164
if clipping == "per_layer":
168-
assert isinstance(
169-
optimizer,
170-
(DistributedPerLayerOptimizer, SimpleDistributedPerLayerOptimizer),
171-
)
165+
assert isinstance(optimizer, SimpleDistributedPerLayerOptimizer)
172166
else:
173167
assert isinstance(optimizer, DistributedDPOptimizer)
174168

0 commit comments

Comments
 (0)