1
1
#!/usr/bin/env python3
2
2
3
3
import glob
4
+ import warnings
4
5
from abc import abstractmethod
5
6
from os .path import join
6
7
from typing import Any , Callable , Iterator , List , Optional , Union , Tuple , NamedTuple
@@ -385,7 +386,23 @@ def __init__(
385
386
be computed for all layers. Otherwise, they will only be computed
386
387
for the layers specified in `layers`.
387
388
Default: None
388
- loss_fn (Callable, optional): The loss function applied to model.
389
+ loss_fn (Callable, optional): The loss function applied to model. There
390
+ are two options for the return type of `loss_fn`. First, `loss_fn`
391
+ can be a "per-example" loss function - returns a 1D Tensor of
392
+ losses for each example in a batch. `nn.BCELoss(reduction="none")`
393
+ would be an "per-example" loss function. Second, `loss_fn` can be
394
+ a "reduction" loss function that reduces the per-example losses,
395
+ in a batch, and returns a single scalar Tensor. For this option,
396
+ the reduction must be the *sum* or the *mean* of the per-example
397
+ losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable.
398
+ Note for the first option, the `sample_wise_grads_per_batch`
399
+ argument must be False, and for the second option,
400
+ `sample_wise_grads_per_batch` must be True. Also note that for
401
+ the second option, if `loss_fn` has no "reduction" attribute,
402
+ the implementation assumes that the reduction is the *sum* of the
403
+ per-example losses. If this is not the case, i.e. the reduction
404
+ is the *mean*, please set the "reduction" attribute of `loss_fn`
405
+ to "mean", i.e. `loss_fn.reduction = "mean"`.
389
406
Default: None
390
407
batch_size (int or None, optional): Batch size of the DataLoader created to
391
408
iterate through `influence_src_dataset`, if it is a Dataset.
@@ -404,10 +421,16 @@ def __init__(
404
421
inefficient. We offer an implementation of batch-wise gradient
405
422
computations w.r.t. to model parameters which is computationally
406
423
more efficient. This implementation can be enabled by setting the
407
- `sample_wise_grad_per_batch` argument to `True`. Note that our
424
+ `sample_wise_grad_per_batch` argument to `True`, and should be
425
+ enabled if and only if the `loss_fn` argument is a "reduction" loss
426
+ function. For example, `nn.BCELoss(reduction="sum")` would be a
427
+ valid `loss_fn` if this implementation is enabled (see
428
+ documentation for `loss_fn` for more details). Note that our
408
429
current implementation enables batch-wise gradient computations
409
430
only for a limited number of PyTorch nn.Modules: Conv2D and Linear.
410
- This list will be expanded in the near future.
431
+ This list will be expanded in the near future. Therefore, please
432
+ do not enable this implementation if gradients will be computed
433
+ for other kinds of layers.
411
434
Default: False
412
435
"""
413
436
@@ -423,14 +446,47 @@ def __init__(
423
446
424
447
self .sample_wise_grads_per_batch = sample_wise_grads_per_batch
425
448
426
- if (
427
- self .sample_wise_grads_per_batch
428
- and isinstance (loss_fn , Module ) # TODO: allow loss_fn to be Callable
429
- and hasattr (loss_fn , "reduction" )
430
- ):
431
- self .reduction_type = str (loss_fn .reduction )
449
+ # If we are able to access the reduction used by `loss_fn`, we check whether
450
+ # the reduction is compatible with `sample_wise_grads_per_batch`
451
+ if isinstance (loss_fn , Module ) and hasattr (
452
+ loss_fn , "reduction"
453
+ ): # TODO: allow loss_fn to be Callable
454
+ if self .sample_wise_grads_per_batch :
455
+ assert loss_fn .reduction in ["sum" , "mean" ], (
456
+ 'reduction for `loss_fn` must be "sum" or "mean" when '
457
+ "`sample_wise_grads_per_batch` is True"
458
+ )
459
+ self .reduction_type = str (loss_fn .reduction )
460
+ else :
461
+ assert loss_fn .reduction == "none" , (
462
+ 'reduction for `loss_fn` must be "none" when '
463
+ "`sample_wise_grads_per_batch` is False"
464
+ )
432
465
else :
433
- self .reduction_type = "sum"
466
+ # if we are unable to access the reduction used by `loss_fn`, we warn
467
+ # the user about the assumptions we are making regarding the reduction
468
+ # used by `loss_fn`
469
+ if self .sample_wise_grads_per_batch :
470
+ warnings .warn (
471
+ 'Since `loss_fn` has no "reduction" attribute, and '
472
+ "`sample_wise_grads_per_batch` is True, the implementation assumes "
473
+ 'that `loss_fn` is a "reduction" loss function that reduces the '
474
+ "per-example losses by taking their *sum*. If `loss_fn` "
475
+ "instead reduces the per-example losses by taking their mean, "
476
+ 'please set the reduction attribute of `loss_fn` to "mean", i.e. '
477
+ '`loss_fn.reduction = "mean"`. Note that if '
478
+ "`sample_wise_grads_per_batch` is True, the implementation "
479
+ "assumes the reduction is either a sum or mean reduction."
480
+ )
481
+ self .reduction_type = "sum"
482
+ else :
483
+ warnings .warn (
484
+ 'Since `loss_fn` has no "reduction" attribute, and '
485
+ "`sample_wise_grads_per_batch` is False, the implementation "
486
+ 'assumes that `loss_fn` is a "per-example" loss function (see '
487
+ "documentation for `loss_fn` for details). Please ensure that "
488
+ "this is the case."
489
+ )
434
490
435
491
r"""
436
492
TODO: Either restore model state after done (would have to place functionality
0 commit comments