Skip to content

Commit b76e4d5

Browse files
cicichen01facebook-github-bot
authored andcommitted
Simplify the _check_loss_fn() logic (#1243)
Summary: The _check_loss_fn() has exact same logic when sample_wise_grads_per_batch is None and True cases. Thus simplify the logic. Differential Revision: D54883319
1 parent 5eb5498 commit b76e4d5

File tree

1 file changed

+5
-19
lines changed

1 file changed

+5
-19
lines changed

captum/influence/_utils/common.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def _check_loss_fn(
444444
influence_instance: Union["TracInCPBase", "InfluenceFunctionBase"],
445445
loss_fn: Optional[Union[Module, Callable]],
446446
loss_fn_name: str,
447-
sample_wise_grads_per_batch: Optional[bool] = None,
447+
sample_wise_grads_per_batch: bool = True,
448448
) -> str:
449449
"""
450450
This checks whether `loss_fn` satisfies the requirements assumed of all
@@ -469,16 +469,13 @@ def _check_loss_fn(
469469
# attribute.
470470
if hasattr(loss_fn, "reduction"):
471471
reduction = loss_fn.reduction # type: ignore
472-
if sample_wise_grads_per_batch is None:
472+
if sample_wise_grads_per_batch:
473473
assert reduction in [
474474
"sum",
475475
"mean",
476-
], 'reduction for `loss_fn` must be "sum" or "mean"'
477-
reduction_type = str(reduction)
478-
elif sample_wise_grads_per_batch:
479-
assert reduction in ["sum", "mean"], (
476+
], (
480477
'reduction for `loss_fn` must be "sum" or "mean" when '
481-
"`sample_wise_grads_per_batch` is True"
478+
"`sample_wise_grads_per_batch` is True (i.e. the default value) "
482479
)
483480
reduction_type = str(reduction)
484481
else:
@@ -490,18 +487,7 @@ def _check_loss_fn(
490487
# if we are unable to access the reduction used by `loss_fn`, we warn
491488
# the user about the assumptions we are making regarding the reduction
492489
# used by `loss_fn`
493-
if sample_wise_grads_per_batch is None:
494-
warnings.warn(
495-
f'Since `{loss_fn_name}` has no "reduction" attribute, the '
496-
f'implementation assumes that `{loss_fn_name}` is a "reduction" loss '
497-
"function that reduces the per-example losses by taking their *sum*. "
498-
f"If `{loss_fn_name}` instead reduces the per-example losses by "
499-
f"taking their mean, please set the reduction attribute of "
500-
f'`{loss_fn_name}` to "mean", i.e. '
501-
f'`{loss_fn_name}.reduction = "mean"`.'
502-
)
503-
reduction_type = "sum"
504-
elif sample_wise_grads_per_batch:
490+
if sample_wise_grads_per_batch:
505491
warnings.warn(
506492
f"Since `{loss_fn_name}`` has no 'reduction' attribute, and "
507493
"`sample_wise_grads_per_batch` is True, the implementation assumes "

0 commit comments

Comments
 (0)