Skip to content

Commit 0eb4b3e

Browse files
pagarwlfacebook-github-bot
authored andcommitted
Adds fast gradient clipping support for the Embedding layer. (#694)
Summary: The algorithm used is described in the 'A Unified Fast Gradient Clipping Framework for DP-SGD' paper: https://proceedings.neurips.cc/paper_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf. ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue Previously, Ghost clipping was not supported in Opacus for embedding layer. With default DP-SGD implementation, the training OOMs out on large embedding layers over large physical batch size (useful for privacy). The regular DP-SGD needs O(Bnd), where B=physical batch size, n=vocab size, d=embedding dimension. To give an example on memory needed: we've seen embeddings with [vocab size=1000000, dim=5] (and higher) in real-world differential privacy applications. With a physical batch size of 16,000, memory needed: 16000 × 1000000 × 5 x 4 = 298.02 GB. With this change, we need significantly smaller memory: O(Br) where B is physical batch size, and r is number of unique indices in the embedding sequence. We could successfully run DP-SGD over above example, using < 8GiB. This is a good add to Opacus, enabling larger embedding layers over larger physical batch sizes to be trained with DP-SGD. ## How Has This Been Tested (if it applies) Unit tests. Runs over large embedding layer training over realworld DP application. ## Checklist - [x] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: #694 Reviewed By: EnayatUllah Differential Revision: D67258669 fbshipit-source-id: a8bbcb3471116edff4c8a0ac7a6c96f38edd1075
1 parent 0d186a4 commit 0eb4b3e

File tree

5 files changed

+547
-2
lines changed

5 files changed

+547
-2
lines changed

opacus/grad_sample/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .dp_multihead_attention import compute_sequence_bias_grad_sample # noqa
1818
from .dp_rnn import compute_rnn_linear_grad_sample # noqa
1919
from .embedding import compute_embedding_grad_sample # noqa
20+
from .embedding_norm_sample import compute_embedding_norm_sample # noqa
2021
from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample
2122
from .grad_sample_module_fast_gradient_clipping import ( # noqa
2223
GradSampleModuleFastGradientClipping,

opacus/grad_sample/embedding.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Dict
16+
from typing import Dict, List
1717

1818
import torch
1919
import torch.nn as nn
20+
from opacus.grad_sample import embedding_norm_sample
2021

21-
from .utils import register_grad_sampler
22+
from .utils import register_grad_sampler, register_norm_sampler
2223

2324

2425
@register_grad_sampler(nn.Embedding)
@@ -82,3 +83,24 @@ def compute_embeddingbag_gradsampler(layer, inputs, backprops):
8283
ret[layer.weight] = gsm
8384

8485
return ret
86+
87+
88+
@register_norm_sampler(nn.Embedding)
89+
def compute_embedding_norm_sample(
90+
layer: nn.Embedding,
91+
activations: List[torch.Tensor],
92+
backprops: torch.Tensor,
93+
) -> Dict[nn.Parameter, torch.Tensor]:
94+
"""Computes gradient norms for ``nn.Embedding`` layer.
95+
96+
Args:
97+
layer: Layer
98+
activations: Activations
99+
backprops: Backpropagations
100+
101+
Returns:
102+
A dictionary of parameter gradients
103+
"""
104+
return embedding_norm_sample.compute_embedding_norm_sample(
105+
layer, activations, backprops
106+
)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2024, The Opacus authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utility for computing gradient norm for the embedding layer.
17+
18+
Based on the algorithm from the paper:
19+
https://proceedings.neurips.cc/paper_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf.
20+
"""
21+
from typing import Dict, List
22+
23+
import torch
24+
from torch import nn
25+
26+
27+
def compute_embedding_norm_sample(
28+
layer: nn.Embedding,
29+
activations: List[torch.Tensor],
30+
backprops: torch.Tensor,
31+
) -> Dict[nn.Parameter, torch.Tensor]:
32+
"""Computes per sample gradient norms for ``nn.Embedding`` layer.
33+
34+
Args:
35+
layer: Layer
36+
activations: Activations
37+
backprops: Backpropagations
38+
39+
Returns:
40+
A dictionary of parameter gradients
41+
42+
NOTE: Here is an example input, and the expected intermediate values. This
43+
is proivided to help in understanding the algorithm:
44+
Inputs:
45+
layer: Embedding(3, 1) # (vocab_size, embedding_dim)
46+
activations: [tensor([[1, 1],
47+
[2, 0],
48+
[2, 0]])]
49+
backprops: tensor([[[0.2], [0.2]],
50+
[[0.3], [0.1]],
51+
[[0.3], [0.1]]])
52+
backprops.shape: torch.Size([3, 2, 1])
53+
54+
Intermediate values:
55+
input_ids: tensor([[1, 1],
56+
[2, 0],
57+
[2, 0]])
58+
input_ids.shape: torch.Size([3, 2])
59+
grad_values: tensor([[0.2000],
60+
[0.2000],
61+
[0.3000],
62+
[0.1000],
63+
[0.3000],
64+
[0.1000]])
65+
grad_values.shape: torch.Size([6, 1])
66+
nrows: 3
67+
ncols: 2
68+
row_indices: tensor([[0],
69+
[0],
70+
[1],
71+
[1],
72+
[2],
73+
[2]])
74+
flattened_indices: tensor([[1],
75+
[1],
76+
[2],
77+
[0],
78+
[2],
79+
[0]])
80+
paired_indices: tensor([[0, 1],
81+
[0, 1],
82+
[1, 2],
83+
[1, 0],
84+
[2, 2],
85+
[2, 0]])
86+
unique_paired_indices: tensor([[0, 1],
87+
[1, 0],
88+
[1, 2],
89+
[2, 0],
90+
[2, 2]])
91+
new_index_positions: tensor([0, 0, 2, 1, 4, 3])
92+
num_unique_paired_indices: 5
93+
summed_gradients: tensor([[0.4000],
94+
[0.1000],
95+
[0.3000],
96+
[0.1000],
97+
[0.3000]])
98+
sqr_gradient_sum: tensor([0.1600, 0.0100, 0.0900, 0.0100, 0.0900])
99+
unique_batch_ids: tensor([0, 1, 1, 2, 2])
100+
result: tensor([0.1600, 0.1000, 0.1000])
101+
result_sqrt: tensor([0.4000, 0.3162, 0.3162])
102+
"""
103+
device = activations[0].device
104+
input_ids = activations[0].to(device)
105+
grad_values = backprops.to(device)
106+
107+
# Reshape input_ids preserving the batch size as the first dimension
108+
input_ids = input_ids.reshape(input_ids.shape[0], -1)
109+
110+
# Reshape grad_values preserving the embedding dimension as the last dimension
111+
grad_values = grad_values.reshape(-1, grad_values.size(-1))
112+
113+
# Create 1D tensor of row indices
114+
nrows = input_ids.size(0)
115+
ncols = input_ids.size(1)
116+
row_indices = (
117+
torch.repeat_interleave(torch.arange(nrows).to(device), ncols)
118+
.unsqueeze(-1)
119+
.to(device)
120+
)
121+
122+
# Pair the input IDs with the row indices
123+
flattened_indices = input_ids.view(-1, 1)
124+
paired_indices = torch.cat([row_indices, flattened_indices], dim=1).to(device)
125+
126+
# Get unique paired indices and new index positions for aggregation
127+
unique_paired_indices, new_index_positions = torch.unique(
128+
paired_indices, dim=0, return_inverse=True, sorted=True
129+
)
130+
131+
# Sum gradients over new index positions and compute squared gradient norms
132+
num_unique_paired_indices = unique_paired_indices.size(0)
133+
summed_gradients = torch.zeros(
134+
num_unique_paired_indices, grad_values.size(-1), device=device
135+
)
136+
summed_gradients = summed_gradients.index_add(
137+
0, new_index_positions.to(device), grad_values
138+
)
139+
sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1)
140+
141+
# Scatter add the squared sums back to their respective rows
142+
result = torch.zeros(nrows, device=device)
143+
unique_batch_ids = unique_paired_indices[:, 0].to(device)
144+
result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum)
145+
146+
# Compute the square root for the final result (norm)
147+
result_sqrt = torch.sqrt(result)
148+
return {layer.weight: result_sqrt}

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17+
import unittest
1718

1819
import hypothesis.strategies as st
1920
import torch
@@ -67,6 +68,21 @@ def forward(self, x):
6768
return x
6869

6970

71+
class SampleEmbeddingModule(nn.Module):
72+
def __init__(self, vocab_size, embedding_dim):
73+
super(SampleEmbeddingModule, self).__init__()
74+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
75+
76+
# Manually set weights for the embedding layer for testing
77+
self.embedding.weight = nn.Parameter(
78+
torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32)
79+
)
80+
81+
def forward(self, x):
82+
x = self.embedding(x)
83+
return x
84+
85+
7086
class GradSampleModuleFastGradientClippingTest(GradSampleModuleTest):
7187
CLS = GradSampleModuleFastGradientClipping
7288

@@ -260,3 +276,151 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
260276
logging.info(f"Diff = {diff}")
261277
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
262278
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
279+
280+
281+
class GradSampleModuleFastGradientClippingEmbeddingLayerTest(unittest.TestCase):
282+
283+
def test_norm_calculation(self):
284+
"""
285+
Tests if norm calculation for embedding layer is the same between
286+
standard (Opacus) and fast gradient clipping"
287+
"""
288+
vocab_size = 3
289+
embedding_dim = 1
290+
291+
criterion = torch.nn.CrossEntropyLoss(reduction="none")
292+
noise_multiplier = 0.0
293+
input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)
294+
batch_size = 3
295+
max_grad_norm = 1.0
296+
sample_module = SampleEmbeddingModule(vocab_size, embedding_dim)
297+
model_normal = GradSampleModule(clone_module(sample_module))
298+
optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)
299+
optimizer_normal = DPOptimizer(
300+
optimizer_normal,
301+
noise_multiplier=noise_multiplier,
302+
max_grad_norm=max_grad_norm,
303+
expected_batch_size=batch_size,
304+
)
305+
306+
grad_sample_module = GradSampleModuleFastGradientClipping(
307+
clone_module(sample_module),
308+
max_grad_norm=max_grad_norm,
309+
use_ghost_clipping=True,
310+
)
311+
optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1)
312+
optimizer_gc = DPOptimizerFastGradientClipping(
313+
optimizer_gc,
314+
noise_multiplier=noise_multiplier,
315+
max_grad_norm=max_grad_norm,
316+
expected_batch_size=batch_size,
317+
)
318+
319+
optimizer_normal.zero_grad()
320+
output_normal = model_normal(input_data)
321+
target_data = torch.rand_like(output_normal)
322+
323+
loss_normal = torch.mean(criterion(output_normal, target_data), dim=0)
324+
loss_normal.backward()
325+
all_norms_normal = torch.stack(
326+
[
327+
torch.stack([g.norm() for g in param.grad_sample], dim=0)
328+
for param in model_normal.parameters()
329+
],
330+
dim=0,
331+
)
332+
flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal])
333+
334+
grad_sample_module.enable_hooks()
335+
output_gc = grad_sample_module(input_data)
336+
337+
first_loss_per_sample = criterion(output_gc, target_data)
338+
first_loss = torch.mean(first_loss_per_sample)
339+
first_loss.backward(retain_graph=True)
340+
341+
optimizer_gc.zero_grad()
342+
coeff = grad_sample_module.get_clipping_coef()
343+
second_loss_per_sample = coeff * first_loss_per_sample
344+
second_loss = torch.sum(second_loss_per_sample)
345+
grad_sample_module.disable_hooks()
346+
second_loss.backward()
347+
348+
all_norms_gc = [param._norm_sample for param in grad_sample_module.parameters()]
349+
flat_norms_gc = torch.cat([p.flatten() for p in all_norms_gc])
350+
351+
diff = flat_norms_normal - flat_norms_gc
352+
353+
logging.info(f"Diff = {diff}")
354+
msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
355+
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
356+
357+
def test_gradient_calculation(self):
358+
"""Tests if gradients for embedding layer are the same between standard
359+
(Opacus) and fast gradient clipping."""
360+
361+
noise_multiplier = 0.0
362+
vocab_size = 3
363+
embedding_dim = 1
364+
batch_size = 3
365+
input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)
366+
max_grad_norm = 1.0
367+
criterion = torch.nn.CrossEntropyLoss()
368+
369+
sample_module = SampleEmbeddingModule(vocab_size, embedding_dim)
370+
model_normal = GradSampleModule(clone_module(sample_module))
371+
grad_sample_module = GradSampleModuleFastGradientClipping(
372+
clone_module(sample_module),
373+
max_grad_norm=max_grad_norm,
374+
use_ghost_clipping=True,
375+
)
376+
377+
optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)
378+
optimizer_normal = DPOptimizer(
379+
optimizer_normal,
380+
noise_multiplier=noise_multiplier,
381+
max_grad_norm=max_grad_norm,
382+
expected_batch_size=batch_size,
383+
)
384+
385+
optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1)
386+
optimizer_gc = DPOptimizerFastGradientClipping(
387+
optimizer_gc,
388+
noise_multiplier=noise_multiplier,
389+
max_grad_norm=max_grad_norm,
390+
expected_batch_size=batch_size,
391+
)
392+
393+
criterion_gc = DPLossFastGradientClipping(
394+
grad_sample_module, optimizer_gc, criterion
395+
)
396+
397+
optimizer_normal.zero_grad()
398+
output_normal = model_normal(input_data)
399+
target_data = torch.tensor([[[0.1], [0.1]], [[0.2], [0.3]], [[0.2], [0.3]]])
400+
loss_normal = torch.mean(criterion(output_normal, target_data), dim=0)
401+
loss_normal.backward()
402+
optimizer_normal.step()
403+
404+
all_grads_normal = [param.summed_grad for param in model_normal.parameters()]
405+
flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal])
406+
407+
optimizer_gc.zero_grad()
408+
grad_sample_module.enable_hooks()
409+
output_gc = grad_sample_module(input_data)
410+
411+
loss_gc = criterion_gc(output_gc, target_data)
412+
loss_gc.backward()
413+
optimizer_gc.step()
414+
415+
all_grads_gc = [param.grad for param in grad_sample_module.parameters()]
416+
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])
417+
diff = torch.tensor(
418+
[
419+
(g_gc - g_normal).norm()
420+
for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal)
421+
]
422+
)
423+
424+
logging.info(f"Diff = {diff}")
425+
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
426+
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg

0 commit comments

Comments
 (0)