|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 | import logging
|
| 17 | +import unittest |
17 | 18 |
|
18 | 19 | import hypothesis.strategies as st
|
19 | 20 | import torch
|
@@ -67,6 +68,21 @@ def forward(self, x):
|
67 | 68 | return x
|
68 | 69 |
|
69 | 70 |
|
| 71 | +class SampleEmbeddingModule(nn.Module): |
| 72 | + def __init__(self, vocab_size, embedding_dim): |
| 73 | + super(SampleEmbeddingModule, self).__init__() |
| 74 | + self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| 75 | + |
| 76 | + # Manually set weights for the embedding layer for testing |
| 77 | + self.embedding.weight = nn.Parameter( |
| 78 | + torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32) |
| 79 | + ) |
| 80 | + |
| 81 | + def forward(self, x): |
| 82 | + x = self.embedding(x) |
| 83 | + return x |
| 84 | + |
| 85 | + |
70 | 86 | class GradSampleModuleFastGradientClippingTest(GradSampleModuleTest):
|
71 | 87 | CLS = GradSampleModuleFastGradientClipping
|
72 | 88 |
|
@@ -260,3 +276,151 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
|
260 | 276 | logging.info(f"Diff = {diff}")
|
261 | 277 | msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
|
262 | 278 | assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
|
| 279 | + |
| 280 | + |
| 281 | +class GradSampleModuleFastGradientClippingEmbeddingLayerTest(unittest.TestCase): |
| 282 | + |
| 283 | + def test_norm_calculation(self): |
| 284 | + """ |
| 285 | + Tests if norm calculation for embedding layer is the same between |
| 286 | + standard (Opacus) and fast gradient clipping" |
| 287 | + """ |
| 288 | + vocab_size = 3 |
| 289 | + embedding_dim = 1 |
| 290 | + |
| 291 | + criterion = torch.nn.CrossEntropyLoss(reduction="none") |
| 292 | + noise_multiplier = 0.0 |
| 293 | + input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long) |
| 294 | + batch_size = 3 |
| 295 | + max_grad_norm = 1.0 |
| 296 | + sample_module = SampleEmbeddingModule(vocab_size, embedding_dim) |
| 297 | + model_normal = GradSampleModule(clone_module(sample_module)) |
| 298 | + optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1) |
| 299 | + optimizer_normal = DPOptimizer( |
| 300 | + optimizer_normal, |
| 301 | + noise_multiplier=noise_multiplier, |
| 302 | + max_grad_norm=max_grad_norm, |
| 303 | + expected_batch_size=batch_size, |
| 304 | + ) |
| 305 | + |
| 306 | + grad_sample_module = GradSampleModuleFastGradientClipping( |
| 307 | + clone_module(sample_module), |
| 308 | + max_grad_norm=max_grad_norm, |
| 309 | + use_ghost_clipping=True, |
| 310 | + ) |
| 311 | + optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1) |
| 312 | + optimizer_gc = DPOptimizerFastGradientClipping( |
| 313 | + optimizer_gc, |
| 314 | + noise_multiplier=noise_multiplier, |
| 315 | + max_grad_norm=max_grad_norm, |
| 316 | + expected_batch_size=batch_size, |
| 317 | + ) |
| 318 | + |
| 319 | + optimizer_normal.zero_grad() |
| 320 | + output_normal = model_normal(input_data) |
| 321 | + target_data = torch.rand_like(output_normal) |
| 322 | + |
| 323 | + loss_normal = torch.mean(criterion(output_normal, target_data), dim=0) |
| 324 | + loss_normal.backward() |
| 325 | + all_norms_normal = torch.stack( |
| 326 | + [ |
| 327 | + torch.stack([g.norm() for g in param.grad_sample], dim=0) |
| 328 | + for param in model_normal.parameters() |
| 329 | + ], |
| 330 | + dim=0, |
| 331 | + ) |
| 332 | + flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal]) |
| 333 | + |
| 334 | + grad_sample_module.enable_hooks() |
| 335 | + output_gc = grad_sample_module(input_data) |
| 336 | + |
| 337 | + first_loss_per_sample = criterion(output_gc, target_data) |
| 338 | + first_loss = torch.mean(first_loss_per_sample) |
| 339 | + first_loss.backward(retain_graph=True) |
| 340 | + |
| 341 | + optimizer_gc.zero_grad() |
| 342 | + coeff = grad_sample_module.get_clipping_coef() |
| 343 | + second_loss_per_sample = coeff * first_loss_per_sample |
| 344 | + second_loss = torch.sum(second_loss_per_sample) |
| 345 | + grad_sample_module.disable_hooks() |
| 346 | + second_loss.backward() |
| 347 | + |
| 348 | + all_norms_gc = [param._norm_sample for param in grad_sample_module.parameters()] |
| 349 | + flat_norms_gc = torch.cat([p.flatten() for p in all_norms_gc]) |
| 350 | + |
| 351 | + diff = flat_norms_normal - flat_norms_gc |
| 352 | + |
| 353 | + logging.info(f"Diff = {diff}") |
| 354 | + msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different" |
| 355 | + assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg |
| 356 | + |
| 357 | + def test_gradient_calculation(self): |
| 358 | + """Tests if gradients for embedding layer are the same between standard |
| 359 | + (Opacus) and fast gradient clipping.""" |
| 360 | + |
| 361 | + noise_multiplier = 0.0 |
| 362 | + vocab_size = 3 |
| 363 | + embedding_dim = 1 |
| 364 | + batch_size = 3 |
| 365 | + input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long) |
| 366 | + max_grad_norm = 1.0 |
| 367 | + criterion = torch.nn.CrossEntropyLoss() |
| 368 | + |
| 369 | + sample_module = SampleEmbeddingModule(vocab_size, embedding_dim) |
| 370 | + model_normal = GradSampleModule(clone_module(sample_module)) |
| 371 | + grad_sample_module = GradSampleModuleFastGradientClipping( |
| 372 | + clone_module(sample_module), |
| 373 | + max_grad_norm=max_grad_norm, |
| 374 | + use_ghost_clipping=True, |
| 375 | + ) |
| 376 | + |
| 377 | + optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1) |
| 378 | + optimizer_normal = DPOptimizer( |
| 379 | + optimizer_normal, |
| 380 | + noise_multiplier=noise_multiplier, |
| 381 | + max_grad_norm=max_grad_norm, |
| 382 | + expected_batch_size=batch_size, |
| 383 | + ) |
| 384 | + |
| 385 | + optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1) |
| 386 | + optimizer_gc = DPOptimizerFastGradientClipping( |
| 387 | + optimizer_gc, |
| 388 | + noise_multiplier=noise_multiplier, |
| 389 | + max_grad_norm=max_grad_norm, |
| 390 | + expected_batch_size=batch_size, |
| 391 | + ) |
| 392 | + |
| 393 | + criterion_gc = DPLossFastGradientClipping( |
| 394 | + grad_sample_module, optimizer_gc, criterion |
| 395 | + ) |
| 396 | + |
| 397 | + optimizer_normal.zero_grad() |
| 398 | + output_normal = model_normal(input_data) |
| 399 | + target_data = torch.tensor([[[0.1], [0.1]], [[0.2], [0.3]], [[0.2], [0.3]]]) |
| 400 | + loss_normal = torch.mean(criterion(output_normal, target_data), dim=0) |
| 401 | + loss_normal.backward() |
| 402 | + optimizer_normal.step() |
| 403 | + |
| 404 | + all_grads_normal = [param.summed_grad for param in model_normal.parameters()] |
| 405 | + flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal]) |
| 406 | + |
| 407 | + optimizer_gc.zero_grad() |
| 408 | + grad_sample_module.enable_hooks() |
| 409 | + output_gc = grad_sample_module(input_data) |
| 410 | + |
| 411 | + loss_gc = criterion_gc(output_gc, target_data) |
| 412 | + loss_gc.backward() |
| 413 | + optimizer_gc.step() |
| 414 | + |
| 415 | + all_grads_gc = [param.grad for param in grad_sample_module.parameters()] |
| 416 | + flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc]) |
| 417 | + diff = torch.tensor( |
| 418 | + [ |
| 419 | + (g_gc - g_normal).norm() |
| 420 | + for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal) |
| 421 | + ] |
| 422 | + ) |
| 423 | + |
| 424 | + logging.info(f"Diff = {diff}") |
| 425 | + msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different" |
| 426 | + assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg |
0 commit comments