4
4
from typing import (
5
5
Any ,
6
6
Callable ,
7
+ Dict ,
7
8
Iterable ,
8
9
List ,
9
10
NamedTuple ,
@@ -613,7 +614,7 @@ def _influence_batch_intermediate_quantities_influence_function(
613
614
influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
614
615
test_batch : Tuple [Any , ...],
615
616
train_batch : Tuple [Any , ...],
616
- ):
617
+ ) -> Tensor :
617
618
"""
618
619
computes influence of a test batch on a train batch, for implementations of
619
620
`IntermediateQuantitiesInfluenceFunction`
@@ -628,7 +629,7 @@ def _influence_helper_intermediate_quantities_influence_function(
628
629
influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
629
630
inputs_dataset : Union [Tuple [Any , ...], DataLoader ],
630
631
show_progress : bool ,
631
- ):
632
+ ) -> Tensor :
632
633
"""
633
634
Helper function that computes influence scores for implementations of
634
635
`NaiveInfluenceFunction` which implement the `compute_intermediate_quantities`
@@ -666,7 +667,7 @@ def _self_influence_helper_intermediate_quantities_influence_function(
666
667
influence_inst : "IntermediateQuantitiesInfluenceFunction" ,
667
668
inputs_dataset : Optional [Union [Tuple [Any , ...], DataLoader ]],
668
669
show_progress : bool ,
669
- ):
670
+ ) -> Tensor :
670
671
"""
671
672
Helper function that computes self-influence scores for implementations of
672
673
`NaiveInfluenceFunction` which implement the `compute_intermediate_quantities`
@@ -983,14 +984,14 @@ def _compute_batch_loss_influence_function_base(
983
984
raise Exception
984
985
985
986
986
- def _set_attr (obj , names , val ):
987
+ def _set_attr (obj , names , val ) -> None :
987
988
if len (names ) == 1 :
988
989
setattr (obj , names [0 ], val )
989
990
else :
990
991
_set_attr (getattr (obj , names [0 ]), names [1 :], val )
991
992
992
993
993
- def _del_attr (obj , names ):
994
+ def _del_attr (obj , names ) -> None :
994
995
if len (names ) == 1 :
995
996
delattr (obj , names [0 ])
996
997
else :
@@ -1006,7 +1007,7 @@ def _model_make_functional(model, param_names, params):
1006
1007
return params
1007
1008
1008
1009
1009
- def _model_reinsert_params (model , param_names , params , register = False ):
1010
+ def _model_reinsert_params (model , param_names , params , register : bool = False ) -> None :
1010
1011
for param_name , param in zip (param_names , params ):
1011
1012
_set_attr (
1012
1013
model ,
@@ -1024,7 +1025,7 @@ def _custom_functional_call(model, d, features):
1024
1025
return out
1025
1026
1026
1027
1027
- def _functional_call (model , d , features ):
1028
+ def _functional_call (model : Module , d : Dict [ str , Tensor ] , features ):
1028
1029
"""
1029
1030
Makes a call to `model.forward`, which is treated as a function of the parameters
1030
1031
in `d`, a dict from parameter name to parameter, instead of as a function of
0 commit comments