Skip to content

Commit 8cc99fe

Browse files
99warriorsfacebook-github-bot
authored andcommitted
require reduction = "sum" (#880)
Summary: Pull Request resolved: #880 - require reduction of loss_fn for TracInCP to be "sum" if sample_wise_grads_per_batch==True, by changing documentation and assertions - require reduction of loss_fn for TracInCPFast and TracInCPFastRandProj to be "sum", always, by changing documentation and assertions Although reduction can technically be "mean" for TracInCP, I thought it would be more simple on the user if we just told them it has to be "sum". Otherwise, we allow "sum" and "mean" for TracInCP, but only "sum" for TracInCPFast and TracInCPFastRandProj, which could be confusing. In the future, we can allow reduction to be "mean" for TracInCPFast and TracInCPFastRandProj. Then, we can allow both reductions for all classes. This diff replaces D34354958, which changed deleted files. Reviewed By: NarineK Differential Revision: D34533617 fbshipit-source-id: 4beeb941e4019addd81adc871e2f63b9e41dd1a3
1 parent 1f8e900 commit 8cc99fe

File tree

2 files changed

+84
-16
lines changed

2 files changed

+84
-16
lines changed

captum/influence/_core/tracincp.py

+66-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import glob
4+
import warnings
45
from abc import abstractmethod
56
from os.path import join
67
from typing import Any, Callable, Iterator, List, Optional, Union, Tuple, NamedTuple
@@ -385,7 +386,23 @@ def __init__(
385386
be computed for all layers. Otherwise, they will only be computed
386387
for the layers specified in `layers`.
387388
Default: None
388-
loss_fn (Callable, optional): The loss function applied to model.
389+
loss_fn (Callable, optional): The loss function applied to model. There
390+
are two options for the return type of `loss_fn`. First, `loss_fn`
391+
can be a "per-example" loss function - returns a 1D Tensor of
392+
losses for each example in a batch. `nn.BCELoss(reduction="none")`
393+
would be an "per-example" loss function. Second, `loss_fn` can be
394+
a "reduction" loss function that reduces the per-example losses,
395+
in a batch, and returns a single scalar Tensor. For this option,
396+
the reduction must be the *sum* or the *mean* of the per-example
397+
losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable.
398+
Note for the first option, the `sample_wise_grads_per_batch`
399+
argument must be False, and for the second option,
400+
`sample_wise_grads_per_batch` must be True. Also note that for
401+
the second option, if `loss_fn` has no "reduction" attribute,
402+
the implementation assumes that the reduction is the *sum* of the
403+
per-example losses. If this is not the case, i.e. the reduction
404+
is the *mean*, please set the "reduction" attribute of `loss_fn`
405+
to "mean", i.e. `loss_fn.reduction = "mean"`.
389406
Default: None
390407
batch_size (int or None, optional): Batch size of the DataLoader created to
391408
iterate through `influence_src_dataset`, if it is a Dataset.
@@ -404,10 +421,16 @@ def __init__(
404421
inefficient. We offer an implementation of batch-wise gradient
405422
computations w.r.t. to model parameters which is computationally
406423
more efficient. This implementation can be enabled by setting the
407-
`sample_wise_grad_per_batch` argument to `True`. Note that our
424+
`sample_wise_grad_per_batch` argument to `True`, and should be
425+
enabled if and only if the `loss_fn` argument is a "reduction" loss
426+
function. For example, `nn.BCELoss(reduction="sum")` would be a
427+
valid `loss_fn` if this implementation is enabled (see
428+
documentation for `loss_fn` for more details). Note that our
408429
current implementation enables batch-wise gradient computations
409430
only for a limited number of PyTorch nn.Modules: Conv2D and Linear.
410-
This list will be expanded in the near future.
431+
This list will be expanded in the near future. Therefore, please
432+
do not enable this implementation if gradients will be computed
433+
for other kinds of layers.
411434
Default: False
412435
"""
413436

@@ -423,14 +446,47 @@ def __init__(
423446

424447
self.sample_wise_grads_per_batch = sample_wise_grads_per_batch
425448

426-
if (
427-
self.sample_wise_grads_per_batch
428-
and isinstance(loss_fn, Module) # TODO: allow loss_fn to be Callable
429-
and hasattr(loss_fn, "reduction")
430-
):
431-
self.reduction_type = str(loss_fn.reduction)
449+
# If we are able to access the reduction used by `loss_fn`, we check whether
450+
# the reduction is compatible with `sample_wise_grads_per_batch`
451+
if isinstance(loss_fn, Module) and hasattr(
452+
loss_fn, "reduction"
453+
): # TODO: allow loss_fn to be Callable
454+
if self.sample_wise_grads_per_batch:
455+
assert loss_fn.reduction in ["sum", "mean"], (
456+
'reduction for `loss_fn` must be "sum" or "mean" when '
457+
"`sample_wise_grads_per_batch` is True"
458+
)
459+
self.reduction_type = str(loss_fn.reduction)
460+
else:
461+
assert loss_fn.reduction == "none", (
462+
'reduction for `loss_fn` must be "none" when '
463+
"`sample_wise_grads_per_batch` is False"
464+
)
432465
else:
433-
self.reduction_type = "sum"
466+
# if we are unable to access the reduction used by `loss_fn`, we warn
467+
# the user about the assumptions we are making regarding the reduction
468+
# used by `loss_fn`
469+
if self.sample_wise_grads_per_batch:
470+
warnings.warn(
471+
'Since `loss_fn` has no "reduction" attribute, and '
472+
"`sample_wise_grads_per_batch` is True, the implementation assumes "
473+
'that `loss_fn` is a "reduction" loss function that reduces the '
474+
"per-example losses by taking their *sum*. If `loss_fn` "
475+
"instead reduces the per-example losses by taking their mean, "
476+
'please set the reduction attribute of `loss_fn` to "mean", i.e. '
477+
'`loss_fn.reduction = "mean"`. Note that if '
478+
"`sample_wise_grads_per_batch` is True, the implementation "
479+
"assumes the reduction is either a sum or mean reduction."
480+
)
481+
self.reduction_type = "sum"
482+
else:
483+
warnings.warn(
484+
'Since `loss_fn` has no "reduction" attribute, and '
485+
"`sample_wise_grads_per_batch` is False, the implementation "
486+
'assumes that `loss_fn` is a "per-example" loss function (see '
487+
"documentation for `loss_fn` for details). Please ensure that "
488+
"this is the case."
489+
)
434490

435491
r"""
436492
TODO: Either restore model state after done (would have to place functionality

captum/influence/_core/tracincp_fast_rand_proj.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,13 @@ def __init__(
114114
learning rate if it is saved. By default uses a utility to load a
115115
model saved as a state dict.
116116
Default: _load_flexible_state_dict
117-
loss_fn (Callable or Module): The loss function applied to model. This must
118-
be specified.
117+
loss_fn (Callable, optional): The loss function applied to model. `loss_fn`
118+
must be a "reduction" loss function that reduces the per-example
119+
losses in a batch, and returns a single scalar Tensor. Furthermore,
120+
the reduction must be the *sum* of the per-example losses. For
121+
instance, `nn.BCELoss(reduction="sum")` is acceptable, but
122+
`nn.BCELoss(reduction="mean")` is *not* acceptable.
123+
Default: None
119124
batch_size (int or None, optional): Batch size of the DataLoader created to
120125
iterate through `influence_src_dataset`, if it is a Dataset.
121126
`batch_size` should be chosen as large as possible so that certain
@@ -150,10 +155,12 @@ def __init__(
150155
param.requires_grad = True
151156

152157
assert loss_fn is not None, "loss function must not be none"
158+
# If we are able to access the reduction used by `loss_fn`, we check whether
159+
# the reduction is "sum", as required.
153160
# TODO: allow loss_fn to be Callable
154161
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
155-
msg = "loss_fn.reduction should be `sum` or `mean`."
156-
assert loss_fn.reduction != "none", msg
162+
msg = "`loss_fn.reduction` must be `sum`."
163+
assert loss_fn.reduction == "sum", msg
157164

158165
def _influence_batch_tracincp_fast(
159166
self,
@@ -442,8 +449,13 @@ def __init__(
442449
learning rate if it is saved. By default uses a utility to load a
443450
model saved as a state dict.
444451
Default: _load_flexible_state_dict
445-
loss_fn (Callable or Module): The loss function applied to model. This must
446-
be specified.
452+
loss_fn (Callable, optional): The loss function applied to model. `loss_fn`
453+
must be a "reduction" loss function that reduces the per-example
454+
losses in a batch, and returns a single scalar Tensor. Furthermore,
455+
the reduction must be the *sum* of the per-example losses. For
456+
instance, `nn.BCELoss(reduction="sum")` is acceptable, but
457+
`nn.BCELoss(reduction="mean")` is *not* acceptable.
458+
Default: None
447459
batch_size (int or None, optional): Batch size of the DataLoader created to
448460
iterate through `influence_src_dataset`, if it is a Dataset.
449461
`batch_size` should be chosen as large as possible so that certain

0 commit comments

Comments
 (0)