@@ -444,7 +444,7 @@ def _check_loss_fn(
444
444
influence_instance : Union ["TracInCPBase" , "InfluenceFunctionBase" ],
445
445
loss_fn : Optional [Union [Module , Callable ]],
446
446
loss_fn_name : str ,
447
- sample_wise_grads_per_batch : Optional [ bool ] = None ,
447
+ sample_wise_grads_per_batch : bool = True ,
448
448
) -> str :
449
449
"""
450
450
This checks whether `loss_fn` satisfies the requirements assumed of all
@@ -469,16 +469,13 @@ def _check_loss_fn(
469
469
# attribute.
470
470
if hasattr (loss_fn , "reduction" ):
471
471
reduction = loss_fn .reduction # type: ignore
472
- if sample_wise_grads_per_batch is None :
472
+ if sample_wise_grads_per_batch :
473
473
assert reduction in [
474
474
"sum" ,
475
475
"mean" ,
476
- ], 'reduction for `loss_fn` must be "sum" or "mean"'
477
- reduction_type = str (reduction )
478
- elif sample_wise_grads_per_batch :
479
- assert reduction in ["sum" , "mean" ], (
476
+ ], (
480
477
'reduction for `loss_fn` must be "sum" or "mean" when '
481
- "`sample_wise_grads_per_batch` is True"
478
+ "`sample_wise_grads_per_batch` is True (i.e. the default value) "
482
479
)
483
480
reduction_type = str (reduction )
484
481
else :
@@ -490,18 +487,7 @@ def _check_loss_fn(
490
487
# if we are unable to access the reduction used by `loss_fn`, we warn
491
488
# the user about the assumptions we are making regarding the reduction
492
489
# used by `loss_fn`
493
- if sample_wise_grads_per_batch is None :
494
- warnings .warn (
495
- f'Since `{ loss_fn_name } ` has no "reduction" attribute, the '
496
- f'implementation assumes that `{ loss_fn_name } ` is a "reduction" loss '
497
- "function that reduces the per-example losses by taking their *sum*. "
498
- f"If `{ loss_fn_name } ` instead reduces the per-example losses by "
499
- f"taking their mean, please set the reduction attribute of "
500
- f'`{ loss_fn_name } ` to "mean", i.e. '
501
- f'`{ loss_fn_name } .reduction = "mean"`.'
502
- )
503
- reduction_type = "sum"
504
- elif sample_wise_grads_per_batch :
490
+ if sample_wise_grads_per_batch :
505
491
warnings .warn (
506
492
f"Since `{ loss_fn_name } `` has no 'reduction' attribute, and "
507
493
"`sample_wise_grads_per_batch` is True, the implementation assumes "
0 commit comments