Skip to content

Commit

Permalink
Adds fast gradient clipping support for the Embedding layer. (#694)
Browse files Browse the repository at this point in the history
Summary:
The algorithm used is described in the 'A Unified Fast Gradient Clipping Framework for DP-SGD' paper: https://proceedings.neurips.cc/paper_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf.

## Types of changes

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue

Previously, Ghost clipping was not supported in Opacus for embedding layer. With default DP-SGD implementation, the training OOMs out on large embedding layers over large physical batch size (useful for privacy). The regular DP-SGD needs O(Bnd), where B=physical batch size, n=vocab size, d=embedding dimension. To give an example on memory needed: we've seen embeddings with [vocab size=1000000, dim=5] (and higher) in real-world differential privacy applications. With a physical batch size of 16,000, memory needed: 16000 × 1000000 × 5 x 4 = 298.02 GB.

With this change, we need significantly smaller memory: O(Br) where B is physical batch size, and r is number of unique  indices in the embedding sequence. We could successfully run DP-SGD over above example, using < 8GiB.

This is a good add to Opacus, enabling larger embedding layers over larger physical batch sizes to be trained with DP-SGD.

## How Has This Been Tested (if it applies)

Unit tests. Runs over large embedding layer training over realworld DP application.

## Checklist

- [x] The documentation is up-to-date with the changes I made.
- [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: #694

Reviewed By: EnayatUllah

Differential Revision: D67258669

fbshipit-source-id: a8bbcb3471116edff4c8a0ac7a6c96f38edd1075
  • Loading branch information
pagarwl authored and facebook-github-bot committed Feb 12, 2025
1 parent 0d186a4 commit 0eb4b3e
Show file tree
Hide file tree
Showing 5 changed files with 547 additions and 2 deletions.
1 change: 1 addition & 0 deletions opacus/grad_sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .dp_multihead_attention import compute_sequence_bias_grad_sample # noqa
from .dp_rnn import compute_rnn_linear_grad_sample # noqa
from .embedding import compute_embedding_grad_sample # noqa
from .embedding_norm_sample import compute_embedding_norm_sample # noqa
from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample
from .grad_sample_module_fast_gradient_clipping import ( # noqa
GradSampleModuleFastGradientClipping,
Expand Down
26 changes: 24 additions & 2 deletions opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import Dict, List

import torch
import torch.nn as nn
from opacus.grad_sample import embedding_norm_sample

from .utils import register_grad_sampler
from .utils import register_grad_sampler, register_norm_sampler


@register_grad_sampler(nn.Embedding)
Expand Down Expand Up @@ -82,3 +83,24 @@ def compute_embeddingbag_gradsampler(layer, inputs, backprops):
ret[layer.weight] = gsm

return ret


@register_norm_sampler(nn.Embedding)
def compute_embedding_norm_sample(
layer: nn.Embedding,
activations: List[torch.Tensor],
backprops: torch.Tensor,
) -> Dict[nn.Parameter, torch.Tensor]:
"""Computes gradient norms for ``nn.Embedding`` layer.
Args:
layer: Layer
activations: Activations
backprops: Backpropagations
Returns:
A dictionary of parameter gradients
"""
return embedding_norm_sample.compute_embedding_norm_sample(
layer, activations, backprops
)
148 changes: 148 additions & 0 deletions opacus/grad_sample/embedding_norm_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Copyright 2024, The Opacus authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility for computing gradient norm for the embedding layer.
Based on the algorithm from the paper:
https://proceedings.neurips.cc/paper_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf.
"""
from typing import Dict, List

import torch
from torch import nn


def compute_embedding_norm_sample(
layer: nn.Embedding,
activations: List[torch.Tensor],
backprops: torch.Tensor,
) -> Dict[nn.Parameter, torch.Tensor]:
"""Computes per sample gradient norms for ``nn.Embedding`` layer.
Args:
layer: Layer
activations: Activations
backprops: Backpropagations
Returns:
A dictionary of parameter gradients
NOTE: Here is an example input, and the expected intermediate values. This
is proivided to help in understanding the algorithm:
Inputs:
layer: Embedding(3, 1) # (vocab_size, embedding_dim)
activations: [tensor([[1, 1],
[2, 0],
[2, 0]])]
backprops: tensor([[[0.2], [0.2]],
[[0.3], [0.1]],
[[0.3], [0.1]]])
backprops.shape: torch.Size([3, 2, 1])
Intermediate values:
input_ids: tensor([[1, 1],
[2, 0],
[2, 0]])
input_ids.shape: torch.Size([3, 2])
grad_values: tensor([[0.2000],
[0.2000],
[0.3000],
[0.1000],
[0.3000],
[0.1000]])
grad_values.shape: torch.Size([6, 1])
nrows: 3
ncols: 2
row_indices: tensor([[0],
[0],
[1],
[1],
[2],
[2]])
flattened_indices: tensor([[1],
[1],
[2],
[0],
[2],
[0]])
paired_indices: tensor([[0, 1],
[0, 1],
[1, 2],
[1, 0],
[2, 2],
[2, 0]])
unique_paired_indices: tensor([[0, 1],
[1, 0],
[1, 2],
[2, 0],
[2, 2]])
new_index_positions: tensor([0, 0, 2, 1, 4, 3])
num_unique_paired_indices: 5
summed_gradients: tensor([[0.4000],
[0.1000],
[0.3000],
[0.1000],
[0.3000]])
sqr_gradient_sum: tensor([0.1600, 0.0100, 0.0900, 0.0100, 0.0900])
unique_batch_ids: tensor([0, 1, 1, 2, 2])
result: tensor([0.1600, 0.1000, 0.1000])
result_sqrt: tensor([0.4000, 0.3162, 0.3162])
"""
device = activations[0].device
input_ids = activations[0].to(device)
grad_values = backprops.to(device)

# Reshape input_ids preserving the batch size as the first dimension
input_ids = input_ids.reshape(input_ids.shape[0], -1)

# Reshape grad_values preserving the embedding dimension as the last dimension
grad_values = grad_values.reshape(-1, grad_values.size(-1))

# Create 1D tensor of row indices
nrows = input_ids.size(0)
ncols = input_ids.size(1)
row_indices = (
torch.repeat_interleave(torch.arange(nrows).to(device), ncols)
.unsqueeze(-1)
.to(device)
)

# Pair the input IDs with the row indices
flattened_indices = input_ids.view(-1, 1)
paired_indices = torch.cat([row_indices, flattened_indices], dim=1).to(device)

# Get unique paired indices and new index positions for aggregation
unique_paired_indices, new_index_positions = torch.unique(
paired_indices, dim=0, return_inverse=True, sorted=True
)

# Sum gradients over new index positions and compute squared gradient norms
num_unique_paired_indices = unique_paired_indices.size(0)
summed_gradients = torch.zeros(
num_unique_paired_indices, grad_values.size(-1), device=device
)
summed_gradients = summed_gradients.index_add(
0, new_index_positions.to(device), grad_values
)
sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1)

# Scatter add the squared sums back to their respective rows
result = torch.zeros(nrows, device=device)
unique_batch_ids = unique_paired_indices[:, 0].to(device)
result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum)

# Compute the square root for the final result (norm)
result_sqrt = torch.sqrt(result)
return {layer.weight: result_sqrt}
164 changes: 164 additions & 0 deletions opacus/tests/grad_sample_module_fast_gradient_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import unittest

import hypothesis.strategies as st
import torch
Expand Down Expand Up @@ -67,6 +68,21 @@ def forward(self, x):
return x


class SampleEmbeddingModule(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(SampleEmbeddingModule, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)

# Manually set weights for the embedding layer for testing
self.embedding.weight = nn.Parameter(
torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32)
)

def forward(self, x):
x = self.embedding(x)
return x


class GradSampleModuleFastGradientClippingTest(GradSampleModuleTest):
CLS = GradSampleModuleFastGradientClipping

Expand Down Expand Up @@ -260,3 +276,151 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
logging.info(f"Diff = {diff}")
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg


class GradSampleModuleFastGradientClippingEmbeddingLayerTest(unittest.TestCase):

def test_norm_calculation(self):
"""
Tests if norm calculation for embedding layer is the same between
standard (Opacus) and fast gradient clipping"
"""
vocab_size = 3
embedding_dim = 1

criterion = torch.nn.CrossEntropyLoss(reduction="none")
noise_multiplier = 0.0
input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)
batch_size = 3
max_grad_norm = 1.0
sample_module = SampleEmbeddingModule(vocab_size, embedding_dim)
model_normal = GradSampleModule(clone_module(sample_module))
optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)
optimizer_normal = DPOptimizer(
optimizer_normal,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

grad_sample_module = GradSampleModuleFastGradientClipping(
clone_module(sample_module),
max_grad_norm=max_grad_norm,
use_ghost_clipping=True,
)
optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1)
optimizer_gc = DPOptimizerFastGradientClipping(
optimizer_gc,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

optimizer_normal.zero_grad()
output_normal = model_normal(input_data)
target_data = torch.rand_like(output_normal)

loss_normal = torch.mean(criterion(output_normal, target_data), dim=0)
loss_normal.backward()
all_norms_normal = torch.stack(
[
torch.stack([g.norm() for g in param.grad_sample], dim=0)
for param in model_normal.parameters()
],
dim=0,
)
flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal])

grad_sample_module.enable_hooks()
output_gc = grad_sample_module(input_data)

first_loss_per_sample = criterion(output_gc, target_data)
first_loss = torch.mean(first_loss_per_sample)
first_loss.backward(retain_graph=True)

optimizer_gc.zero_grad()
coeff = grad_sample_module.get_clipping_coef()
second_loss_per_sample = coeff * first_loss_per_sample
second_loss = torch.sum(second_loss_per_sample)
grad_sample_module.disable_hooks()
second_loss.backward()

all_norms_gc = [param._norm_sample for param in grad_sample_module.parameters()]
flat_norms_gc = torch.cat([p.flatten() for p in all_norms_gc])

diff = flat_norms_normal - flat_norms_gc

logging.info(f"Diff = {diff}")
msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg

def test_gradient_calculation(self):
"""Tests if gradients for embedding layer are the same between standard
(Opacus) and fast gradient clipping."""

noise_multiplier = 0.0
vocab_size = 3
embedding_dim = 1
batch_size = 3
input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)
max_grad_norm = 1.0
criterion = torch.nn.CrossEntropyLoss()

sample_module = SampleEmbeddingModule(vocab_size, embedding_dim)
model_normal = GradSampleModule(clone_module(sample_module))
grad_sample_module = GradSampleModuleFastGradientClipping(
clone_module(sample_module),
max_grad_norm=max_grad_norm,
use_ghost_clipping=True,
)

optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)
optimizer_normal = DPOptimizer(
optimizer_normal,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1)
optimizer_gc = DPOptimizerFastGradientClipping(
optimizer_gc,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

criterion_gc = DPLossFastGradientClipping(
grad_sample_module, optimizer_gc, criterion
)

optimizer_normal.zero_grad()
output_normal = model_normal(input_data)
target_data = torch.tensor([[[0.1], [0.1]], [[0.2], [0.3]], [[0.2], [0.3]]])
loss_normal = torch.mean(criterion(output_normal, target_data), dim=0)
loss_normal.backward()
optimizer_normal.step()

all_grads_normal = [param.summed_grad for param in model_normal.parameters()]
flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal])

optimizer_gc.zero_grad()
grad_sample_module.enable_hooks()
output_gc = grad_sample_module(input_data)

loss_gc = criterion_gc(output_gc, target_data)
loss_gc.backward()
optimizer_gc.step()

all_grads_gc = [param.grad for param in grad_sample_module.parameters()]
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])
diff = torch.tensor(
[
(g_gc - g_normal).norm()
for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal)
]
)

logging.info(f"Diff = {diff}")
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
Loading

0 comments on commit 0eb4b3e

Please sign in to comment.