Skip to content

Commit db7ac33

Browse files
cicichen01facebook-github-bot
authored andcommitted
Use Python properties to cleanly apply rules to checkpoints (#1249)
Summary: As titled. Isolate the rules to individual method for better structure and complete verification at run time. Reviewed By: vivekmig Differential Revision: D55038334
1 parent f601cdd commit db7ac33

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

captum/influence/_core/tracincp.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,9 @@ def __init__(
140140
Default: None
141141
"""
142142

143-
self.model = model
143+
self.model: Module = model
144144

145-
if isinstance(checkpoints, str):
146-
self.checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*")))
147-
elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str):
148-
self.checkpoints = AV.sort_files(checkpoints)
149-
else:
150-
self.checkpoints = list(checkpoints) # cast to avoid mypy error
151-
if isinstance(self.checkpoints, List):
152-
assert len(self.checkpoints) > 0, "No checkpoints saved!"
145+
self.checkpoints = checkpoints # type: ignore
153146

154147
self.checkpoints_load_func = checkpoints_load_func
155148
self.loss_fn = loss_fn
@@ -181,6 +174,24 @@ def __init__(
181174
"percentage completion of the computation, nor any time estimates."
182175
)
183176

177+
@property
178+
def checkpoints(self) -> List[str]:
179+
return self._checkpoints
180+
181+
@checkpoints.setter
182+
def checkpoints(self, checkpoints: Union[str, List[str], Iterator]) -> None:
183+
if isinstance(checkpoints, str):
184+
self._checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*")))
185+
elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str):
186+
self._checkpoints = AV.sort_files(checkpoints)
187+
else:
188+
self._checkpoints = list(checkpoints) # cast to avoid mypy error
189+
190+
if len(self._checkpoints) <= 0:
191+
raise ValueError(
192+
f"Invalid checkpoints provided for TracIn class: {checkpoints}!"
193+
)
194+
184195
@abstractmethod
185196
def self_influence(
186197
self,

captum/influence/_core/tracincp_fast_rand_proj.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class TracInCPFast(TracInCPBase):
8282
def __init__(
8383
self,
8484
model: Module,
85-
final_fc_layer: Module,
85+
final_fc_layer: Union[Module, str],
8686
train_dataset: Union[Dataset, DataLoader],
8787
checkpoints: Union[str, List[str], Iterator],
8888
checkpoints_load_func: Callable = _load_flexible_state_dict,
@@ -183,7 +183,7 @@ def __init__(
183183
self.vectorize = vectorize
184184

185185
# TODO: restore prior state
186-
self.final_fc_layer = final_fc_layer
186+
self.final_fc_layer = final_fc_layer # type: ignore
187187
for param in self.final_fc_layer.parameters():
188188
param.requires_grad = True
189189

tests/influence/_core/test_tracin_validation.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TestTracinValidator(BaseTest):
3636
)
3737
def test_tracin_require_inputs_dataset(
3838
self,
39-
reduction,
39+
reduction: str,
4040
tracin_constructor: Callable,
4141
) -> None:
4242
"""
@@ -64,6 +64,10 @@ def test_tracin_require_inputs_dataset(
6464
tracin.influence(None, k=None)
6565

6666
def test_tracincp_fast_rand_proj_inputs(self) -> None:
67+
"""
68+
This test verifies that TracInCPFast should be initialized
69+
with a valid `final_fc_layer`.
70+
"""
6771
with tempfile.TemporaryDirectory() as tmpdir:
6872
(
6973
net,
@@ -83,3 +87,34 @@ def test_tracincp_fast_rand_proj_inputs(self) -> None:
8387
loss_fn=nn.MSELoss(),
8488
batch_size=1,
8589
)
90+
91+
@parameterized.expand(
92+
param_list,
93+
name_func=build_test_name_func(),
94+
)
95+
def test_tracincp_input_checkpoints(
96+
self, reduction: str, tracin_constructor: Callable
97+
) -> None:
98+
"""
99+
This test verifies that tracinCP and tracinCPFast
100+
class should be initialized with valid `checkpoints`.
101+
"""
102+
with tempfile.TemporaryDirectory() as invalid_tmpdir:
103+
with tempfile.TemporaryDirectory() as tmpdir:
104+
(
105+
net,
106+
train_dataset,
107+
test_samples,
108+
test_labels,
109+
) = get_random_model_and_data(tmpdir, unpack_inputs=False)
110+
111+
with self.assertRaisesRegex(
112+
ValueError, "Invalid checkpoints provided for TracIn class: "
113+
):
114+
tracin_constructor(
115+
net,
116+
train_dataset,
117+
invalid_tmpdir,
118+
loss_fn=nn.MSELoss(),
119+
batch_size=1,
120+
)

0 commit comments

Comments
 (0)