Skip to content

Commit 1f8e900

Browse files
NarineKfacebook-github-bot
authored andcommitted
Rename use_autograd_hacks to sample_wise_grads_per_batch (#879)
Summary: Pull Request resolved: #879 This diff rename use_autograd_hacks to sample_wise_grads_per_batch. Also updated the docs for `sample_wise_grads_per_batch`. Reviewed By: 99warriors Differential Revision: D34506676 fbshipit-source-id: e60c2d6c09915f31406208dd0946e996b1d5719a
1 parent b30d69b commit 1f8e900

File tree

5 files changed

+113
-75
lines changed

5 files changed

+113
-75
lines changed

captum/_utils/gradient.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ def _compute_jacobian_wrt_params(
772772
return tuple(grads)
773773

774774

775-
def _compute_jacobian_wrt_params_autograd_hacks(
775+
def _compute_jacobian_wrt_params_with_sample_wise_trick(
776776
model: Module,
777777
inputs: Tuple[Any, ...],
778778
labels: Optional[Tensor] = None,
@@ -781,8 +781,9 @@ def _compute_jacobian_wrt_params_autograd_hacks(
781781
) -> Tuple[Any, ...]:
782782
r"""
783783
Computes the Jacobian of a batch of test examples given a model, and optional
784-
loss function and target labels. This method uses autograd_hacks to fully vectorize
785-
the Jacobian calculation. Currently, only linear and conv2d layers are supported.
784+
loss function and target labels. This method uses sample-wise gradients per
785+
batch trick to fully vectorize the Jacobian calculation. Currently, only
786+
linear and conv2d layers are supported.
786787
787788
User must `add_hooks(model)` before calling this function.
788789

captum/influence/_core/tracincp.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from captum._utils.av import AV
1010
from captum._utils.gradient import (
1111
_compute_jacobian_wrt_params,
12-
_compute_jacobian_wrt_params_autograd_hacks,
12+
_compute_jacobian_wrt_params_with_sample_wise_trick,
1313
)
1414
from captum.influence._core.influence import DataInfluence
1515
from captum.influence._utils.common import (
@@ -347,7 +347,7 @@ def __init__(
347347
layers: Optional[List[str]] = None,
348348
loss_fn: Optional[Union[Module, Callable]] = None,
349349
batch_size: Union[int, None] = 1,
350-
use_autograd_hacks: bool = False,
350+
sample_wise_grads_per_batch: bool = False,
351351
) -> None:
352352
r"""
353353
Args:
@@ -396,9 +396,18 @@ def __init__(
396396
`influence_src_dataset` is a Dataset. If `influence_src_dataset`
397397
is a DataLoader, then `batch_size` is ignored as an argument.
398398
Default: 1
399-
use_autograd_hacks (bool, optional): Experimental mode that vectorize
400-
jacobian computation w.r.t parameters for a batch of inputs. Based
401-
on support in autograd_hacks.
399+
sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
400+
computations w.r.t. model parameters aggregates the results for a
401+
batch and does not allow to access sample-wise gradients w.r.t.
402+
model parameters. This forces us to iterate over each sample in
403+
the batch if we want sample-wise gradients which is computationally
404+
inefficient. We offer an implementation of batch-wise gradient
405+
computations w.r.t. to model parameters which is computationally
406+
more efficient. This implementation can be enabled by setting the
407+
`sample_wise_grad_per_batch` argument to `True`. Note that our
408+
current implementation enables batch-wise gradient computations
409+
only for a limited number of PyTorch nn.Modules: Conv2D and Linear.
410+
This list will be expanded in the near future.
402411
Default: False
403412
"""
404413

@@ -412,10 +421,10 @@ def __init__(
412421
batch_size,
413422
)
414423

415-
self.use_autograd_hacks = use_autograd_hacks
424+
self.sample_wise_grads_per_batch = sample_wise_grads_per_batch
416425

417426
if (
418-
self.use_autograd_hacks
427+
self.sample_wise_grads_per_batch
419428
and isinstance(loss_fn, Module) # TODO: allow loss_fn to be Callable
420429
and hasattr(loss_fn, "reduction")
421430
):
@@ -644,8 +653,8 @@ def _basic_computation_tracincp(
644653
targets (tensor or None): If computing influence scores on a loss function,
645654
these are the labels corresponding to the batch `inputs`.
646655
"""
647-
if self.use_autograd_hacks:
648-
return _compute_jacobian_wrt_params_autograd_hacks(
656+
if self.sample_wise_grads_per_batch:
657+
return _compute_jacobian_wrt_params_with_sample_wise_trick(
649658
self.model,
650659
inputs,
651660
targets,

tests/influence/_core/test_tracincp.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
_TestTracInRegression20DCheckIdx,
88
_TestTracInXORCheckIdx,
99
_TestTracInIdentityRegressionCheckIdx,
10-
_TestTracInRegression1DCheckAutogradHacks,
11-
_TestTracInRegression20DCheckAutogradHacks,
12-
_TestTracInXORCheckAutogradHacks,
13-
_TestTracInIdentityRegressionCheckAutogradHacks,
10+
_TestTracInRegression1DCheckSampleWiseTrick,
11+
_TestTracInRegression20DCheckSampleWiseTrick,
12+
_TestTracInXORCheckSampleWiseTrick,
13+
_TestTracInIdentityRegressionCheckSampleWiseTrick,
1414
_TestTracInRegression1DNumerical,
1515
_TestTracInGetKMostInfluential,
1616
_TestTracInSelfInfluence,
@@ -38,29 +38,29 @@ def setUp(self):
3838
tmpdir,
3939
batch_size=batch_size,
4040
loss_fn=loss_fn,
41-
use_autograd_hacks=False,
41+
sample_wise_grads_per_batch=False,
4242
)
4343
)
4444
super(TestTracInCP, self).setUp()
4545

4646

47-
class TestTracInCPCheckAutogradHacks(
48-
_TestTracInRegression1DCheckAutogradHacks,
49-
_TestTracInRegression20DCheckAutogradHacks,
50-
_TestTracInXORCheckAutogradHacks,
51-
_TestTracInIdentityRegressionCheckAutogradHacks,
47+
class TestTracInCPCheckSampleWiseTrick(
48+
_TestTracInRegression1DCheckSampleWiseTrick,
49+
_TestTracInRegression20DCheckSampleWiseTrick,
50+
_TestTracInXORCheckSampleWiseTrick,
51+
_TestTracInIdentityRegressionCheckSampleWiseTrick,
5252
BaseTest,
5353
):
5454
def setUp(self):
5555
self.tracin_constructor = (
56-
lambda net, dataset, tmpdir, batch_size, loss_fn, use_autograd_hacks: (
56+
lambda net, dataset, tmpdir, batch_size, loss_fn, sample_wise_trick: (
5757
TracInCP(
5858
net,
5959
dataset,
6060
tmpdir,
6161
batch_size=batch_size,
6262
loss_fn=loss_fn,
63-
use_autograd_hacks=use_autograd_hacks,
63+
sample_wise_grads_per_batch=sample_wise_trick,
6464
)
6565
)
6666
)
@@ -82,6 +82,6 @@ def setUp(self):
8282
tmpdir,
8383
batch_size=batch_size,
8484
loss_fn=loss_fn,
85-
use_autograd_hacks=True,
85+
sample_wise_grads_per_batch=True,
8686
)
8787
)

tests/influence/_utils/common.py

+45-29
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _test_tracin_regression(self, features: int, mode: int) -> None:
215215
for i in range(len(idx)):
216216
self.assertTrue(isSorted(idx[i]))
217217

218-
if mode == "check_autograd_hacks":
218+
if mode == "sample_wise_trick":
219219

220220
criterion = nn.MSELoss(reduction="none")
221221

@@ -228,39 +228,47 @@ def _test_tracin_regression(self, features: int, mode: int) -> None:
228228
False,
229229
)
230230

231-
# With autograd hacks
231+
# With sample-wise trick
232232
criterion = nn.MSELoss(reduction="sum")
233-
tracin_hack = self.tracin_constructor(
233+
tracin_sample_wise_trick = self.tracin_constructor(
234234
net, dataset, tmpdir, batch_size, criterion, True
235235
)
236236

237237
train_scores = tracin.influence(train_inputs, train_labels)
238-
train_scores_hack = tracin_hack.influence(train_inputs, train_labels)
239-
assertTensorAlmostEqual(self, train_scores, train_scores_hack)
238+
train_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
239+
train_inputs, train_labels
240+
)
241+
assertTensorAlmostEqual(
242+
self, train_scores, train_scores_sample_wise_trick
243+
)
240244

241245
test_scores = tracin.influence(test_inputs, test_labels)
242-
test_scores_hack = tracin_hack.influence(test_inputs, test_labels)
243-
assertTensorAlmostEqual(self, test_scores, test_scores_hack)
246+
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
247+
test_inputs, test_labels
248+
)
249+
assertTensorAlmostEqual(
250+
self, test_scores, test_scores_sample_wise_trick
251+
)
244252

245253

246254
class _TestTracInRegression1DCheckIdx(_TestTracInRegression):
247255
def test_tracin_regression_1D_check_idx(self):
248256
self._test_tracin_regression(1, "check_idx")
249257

250258

251-
class _TestTracInRegression1DCheckAutogradHacks(_TestTracInRegression):
252-
def test_tracin_regression_1D_check_autograd_hacks(self):
253-
self._test_tracin_regression(1, "check_autograd_hacks")
259+
class _TestTracInRegression1DCheckSampleWiseTrick(_TestTracInRegression):
260+
def test_tracin_regression_1D_check_sample_wise_trick(self):
261+
self._test_tracin_regression(1, "sample_wise_trick")
254262

255263

256264
class _TestTracInRegression20DCheckIdx(_TestTracInRegression):
257265
def test_tracin_regression_20D_check_idx(self):
258266
self._test_tracin_regression(20, "check_idx")
259267

260268

261-
class _TestTracInRegression20DCheckAutogradHacks(_TestTracInRegression):
262-
def test_tracin_regression_20D_check_autograd_hacks(self):
263-
self._test_tracin_regression(20, "check_autograd_hacks")
269+
class _TestTracInRegression20DCheckSampleWiseTrick(_TestTracInRegression):
270+
def test_tracin_regression_20D_check_sample_wise_trick(self):
271+
self._test_tracin_regression(20, "sample_wise_tricksample_wise_trick")
264272

265273

266274
class _TestTracInXOR:
@@ -434,7 +442,7 @@ def _test_tracin_xor(self, mode) -> None:
434442
influence_labels = dataset.labels[idx[i][0:5], 0]
435443
self.assertTrue(torch.all(testlabels[i, 0] == influence_labels))
436444

437-
if mode == "check_autograd_hacks":
445+
if mode == "sample_wise_trick":
438446

439447
criterion = nn.MSELoss(reduction="none")
440448

@@ -447,9 +455,9 @@ def _test_tracin_xor(self, mode) -> None:
447455
False,
448456
)
449457

450-
# With autograd hacks
458+
# With sample-wise trick
451459
criterion = nn.MSELoss(reduction="sum")
452-
tracin_hack = self.tracin_constructor(
460+
tracin_sample_wise_trick = self.tracin_constructor(
453461
net,
454462
dataset,
455463
tmpdir,
@@ -459,18 +467,22 @@ def _test_tracin_xor(self, mode) -> None:
459467
)
460468

461469
test_scores = tracin.influence(testset, testlabels)
462-
test_scores_hack = tracin_hack.influence(testset, testlabels)
463-
assertTensorAlmostEqual(self, test_scores, test_scores_hack)
470+
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
471+
testset, testlabels
472+
)
473+
assertTensorAlmostEqual(
474+
self, test_scores, test_scores_sample_wise_trick
475+
)
464476

465477

466478
class _TestTracInXORCheckIdx(_TestTracInXOR):
467479
def test_tracin_xor_check_idx(self):
468480
self._test_tracin_xor("check_idx")
469481

470482

471-
class _TestTracInXORCheckAutogradHacks(_TestTracInXOR):
472-
def test_tracin_xor_check_autograd_hacks(self):
473-
self._test_tracin_xor("check_autograd_hacks")
483+
class _TestTracInXORCheckSampleWiseTrick(_TestTracInXOR):
484+
def test_tracin_xor_check_sample_wise_trick(self):
485+
self._test_tracin_xor("sample_wise_trick")
474486

475487

476488
class _TestTracInIdentityRegression:
@@ -537,7 +549,7 @@ def _test_tracin_identity_regression(self, mode) -> None:
537549
for i in range(len(idx)):
538550
self.assertEqual(idx[i][0], i)
539551

540-
if mode == "check_autograd_hacks":
552+
if mode == "sample_wise_trick":
541553

542554
criterion = nn.MSELoss(reduction="none")
543555

@@ -550,9 +562,9 @@ def _test_tracin_identity_regression(self, mode) -> None:
550562
False,
551563
)
552564

553-
# With autograd hacks
565+
# With sample-wise trick
554566
criterion = nn.MSELoss(reduction="sum")
555-
tracin_hack = self.tracin_constructor(
567+
tracin_sample_wise_trick = self.tracin_constructor(
556568
net,
557569
dataset,
558570
tmpdir,
@@ -562,18 +574,22 @@ def _test_tracin_identity_regression(self, mode) -> None:
562574
)
563575

564576
train_scores = tracin.influence(train_inputs, train_labels)
565-
train_scores_hack = tracin_hack.influence(train_inputs, train_labels)
566-
assertTensorAlmostEqual(self, train_scores, train_scores_hack)
577+
train_scores_tracin_sample_wise_trick = (
578+
tracin_sample_wise_trick.influence(train_inputs, train_labels)
579+
)
580+
assertTensorAlmostEqual(
581+
self, train_scores, train_scores_tracin_sample_wise_trick
582+
)
567583

568584

569585
class _TestTracInIdentityRegressionCheckIdx(_TestTracInIdentityRegression):
570586
def test_tracin_identity_regression_check_idx(self):
571587
self._test_tracin_identity_regression("check_idx")
572588

573589

574-
class _TestTracInIdentityRegressionCheckAutogradHacks(_TestTracInIdentityRegression):
575-
def test_tracin_identity_regression_check_autograd_hacks(self):
576-
self._test_tracin_identity_regression("check_autograd_hacks")
590+
class _TestTracInIdentityRegressionCheckSampleWiseTrick(_TestTracInIdentityRegression):
591+
def test_tracin_identity_regression_check_sample_wise_trick(self):
592+
self._test_tracin_identity_regression("sample_wise_trick")
577593

578594

579595
class _TestTracInRandomProjectionRegression:

0 commit comments

Comments
 (0)