Skip to content

Commit 22beb50

Browse files
cicichen01facebook-github-bot
authored andcommitted
Type Annotation for influence (#1247)
Summary: as titled. Differential Revision: D55035833
1 parent 949ec60 commit 22beb50

5 files changed

+14
-13
lines changed

captum/influence/_core/arnoldi_influence_function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _parameter_distill(
157157
k: Optional[int],
158158
hessian_reg: float,
159159
hessian_inverse_tol: float,
160-
):
160+
) -> Tuple[Tensor, List[Tuple[Tensor, ...]]]:
161161
"""
162162
This takes the output of `_parameter_arnoldi`, and extracts the top-k eigenvalues
163163
/ eigenvectors of the matrix that `_parameter_arnoldi` found the Krylov subspace

captum/influence/_core/influence_function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def _get_dataset_embeddings_intermediate_quantities_influence_function(
596596
batch_embeddings_fn: Callable,
597597
inputs_dataset: DataLoader,
598598
aggregate: bool,
599-
):
599+
) -> Tensor:
600600
"""
601601
given `batch_embeddings_fn`, which produces the embeddings for a given batch,
602602
returns either the embeddings for an entire dataset (if `aggregate` is false),

captum/influence/_core/similarity_influence.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020

21-
def euclidean_distance(test, train) -> Tensor:
21+
def euclidean_distance(test: Tensor, train: Tensor) -> Tensor:
2222
r"""
2323
Calculates the pairwise euclidean distance for batches of feature vectors.
2424
Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *).
@@ -31,7 +31,7 @@ def euclidean_distance(test, train) -> Tensor:
3131
return similarity
3232

3333

34-
def cosine_similarity(test, train, replace_nan=0) -> Tensor:
34+
def cosine_similarity(test: Tensor, train: Tensor, replace_nan: int = 0) -> Tensor:
3535
r"""
3636
Calculates the pairwise cosine similarity for batches of feature vectors.
3737
Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *).

captum/influence/_core/tracincp_fast_rand_proj.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def _basic_computation_tracincp_fast(
720720
targets: Tensor,
721721
loss_fn: Optional[Union[Module, Callable]] = None,
722722
reduction_type: Optional[str] = None,
723-
):
723+
) -> Tuple[Tensor, Tensor]:
724724
"""
725725
For instances of TracInCPFast and children classes, computation of influence scores
726726
or self influence scores repeatedly calls this function for different checkpoints
@@ -1363,7 +1363,7 @@ def _set_projections_tracincp_fast_rand_proj(
13631363
def _process_src_intermediate_quantities_tracincp_fast_rand_proj(
13641364
self,
13651365
src_intermediate_quantities: torch.Tensor,
1366-
):
1366+
) -> None:
13671367
"""
13681368
Assumes `self._get_intermediate_quantities_tracin_fast_rand_proj` returns
13691369
vector representations for each example, and that influence between a

captum/influence/_utils/common.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import (
55
Any,
66
Callable,
7+
Dict,
78
Iterable,
89
List,
910
NamedTuple,
@@ -613,7 +614,7 @@ def _influence_batch_intermediate_quantities_influence_function(
613614
influence_inst: "IntermediateQuantitiesInfluenceFunction",
614615
test_batch: Tuple[Any, ...],
615616
train_batch: Tuple[Any, ...],
616-
):
617+
) -> Tensor:
617618
"""
618619
computes influence of a test batch on a train batch, for implementations of
619620
`IntermediateQuantitiesInfluenceFunction`
@@ -628,7 +629,7 @@ def _influence_helper_intermediate_quantities_influence_function(
628629
influence_inst: "IntermediateQuantitiesInfluenceFunction",
629630
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
630631
show_progress: bool,
631-
):
632+
) -> Tensor:
632633
"""
633634
Helper function that computes influence scores for implementations of
634635
`NaiveInfluenceFunction` which implement the `compute_intermediate_quantities`
@@ -666,7 +667,7 @@ def _self_influence_helper_intermediate_quantities_influence_function(
666667
influence_inst: "IntermediateQuantitiesInfluenceFunction",
667668
inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]],
668669
show_progress: bool,
669-
):
670+
) -> Tensor:
670671
"""
671672
Helper function that computes self-influence scores for implementations of
672673
`NaiveInfluenceFunction` which implement the `compute_intermediate_quantities`
@@ -983,14 +984,14 @@ def _compute_batch_loss_influence_function_base(
983984
raise Exception
984985

985986

986-
def _set_attr(obj, names, val):
987+
def _set_attr(obj, names, val) -> None:
987988
if len(names) == 1:
988989
setattr(obj, names[0], val)
989990
else:
990991
_set_attr(getattr(obj, names[0]), names[1:], val)
991992

992993

993-
def _del_attr(obj, names):
994+
def _del_attr(obj, names) -> None:
994995
if len(names) == 1:
995996
delattr(obj, names[0])
996997
else:
@@ -1006,7 +1007,7 @@ def _model_make_functional(model, param_names, params):
10061007
return params
10071008

10081009

1009-
def _model_reinsert_params(model, param_names, params, register=False):
1010+
def _model_reinsert_params(model, param_names, params, register: bool = False) -> None:
10101011
for param_name, param in zip(param_names, params):
10111012
_set_attr(
10121013
model,
@@ -1024,7 +1025,7 @@ def _custom_functional_call(model, d, features):
10241025
return out
10251026

10261027

1027-
def _functional_call(model, d, features):
1028+
def _functional_call(model: Module, d: Dict[str, Tensor], features):
10281029
"""
10291030
Makes a call to `model.forward`, which is treated as a function of the parameters
10301031
in `d`, a dict from parameter name to parameter, instead of as a function of

0 commit comments

Comments
 (0)