Skip to content

Commit 4fc2761

Browse files
yucufacebook-github-bot
authored andcommitted
Create TARGETS for captum/_utils (#1250)
Summary: Create separate TARGETS files for different part of Captum project. Start with a relatively simple one: captum/_utils. Mostly TARGETS file change, with more exception to correct import, split test helper function to separate file etc. Differential Revision: D55091069
1 parent fabac35 commit 4fc2761

File tree

5 files changed

+61
-75
lines changed

5 files changed

+61
-75
lines changed

captum/_utils/av.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import captum._utils.common as common
1010
import torch
11-
from captum.attr import LayerActivation
11+
from captum.attr._core.layer.layer_activation import LayerActivation
1212
from torch import Tensor
1313
from torch.nn import Module
1414
from torch.utils.data import DataLoader, Dataset

captum/_utils/models/__init__.py

-20
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,5 @@
1-
from captum._utils.models.linear_model import (
2-
LinearModel,
3-
SGDLasso,
4-
SGDLinearModel,
5-
SGDLinearRegression,
6-
SGDRidge,
7-
SkLearnLasso,
8-
SkLearnLinearModel,
9-
SkLearnLinearRegression,
10-
SkLearnRidge,
11-
)
121
from captum._utils.models.model import Model
132

143
__all__ = [
154
"Model",
16-
"LinearModel",
17-
"SGDLinearModel",
18-
"SGDLasso",
19-
"SGDRidge",
20-
"SGDLinearRegression",
21-
"SkLearnLinearModel",
22-
"SkLearnLasso",
23-
"SkLearnRidge",
24-
"SkLearnLinearRegression",
255
]

tests/utils/evaluate_linear_model.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
from typing import cast, Dict
3+
4+
import torch
5+
from torch import Tensor
6+
7+
8+
def evaluate(test_data, classifier) -> Dict[str, Tensor]:
9+
classifier.eval()
10+
11+
l1_loss = 0.0
12+
l2_loss = 0.0
13+
n = 0
14+
l2_losses = []
15+
with torch.no_grad():
16+
for data in test_data:
17+
if len(data) == 2:
18+
x, y = data
19+
w = None
20+
else:
21+
x, y, w = data
22+
23+
out = classifier(x)
24+
25+
y = y.view(x.shape[0], -1)
26+
assert y.shape == out.shape
27+
28+
if w is None:
29+
l1_loss += (out - y).abs().sum(0).to(dtype=torch.float64)
30+
l2_loss += ((out - y) ** 2).sum(0).to(dtype=torch.float64)
31+
l2_losses.append(((out - y) ** 2).to(dtype=torch.float64))
32+
else:
33+
l1_loss += (
34+
(w.view(-1, 1) * (out - y)).abs().sum(0).to(dtype=torch.float64)
35+
)
36+
l2_loss += (
37+
(w.view(-1, 1) * ((out - y) ** 2)).sum(0).to(dtype=torch.float64)
38+
)
39+
l2_losses.append(
40+
(w.view(-1, 1) * ((out - y) ** 2)).to(dtype=torch.float64)
41+
)
42+
43+
n += x.shape[0]
44+
45+
l2_losses = torch.cat(l2_losses, dim=0)
46+
assert n > 0
47+
48+
# just to double check
49+
assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all()
50+
51+
classifier.train()
52+
return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)}

tests/utils/models/linear_models/_test_linear_classifier.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import sklearn.datasets as datasets
88
import torch
9-
from tests.utils.test_linear_model import _evaluate
9+
from tests.utils.evaluate_linear_model import evaluate
1010
from torch.utils.data import DataLoader, TensorDataset
1111

1212

@@ -80,11 +80,11 @@ def compare_to_sk_learn(
8080
alpha=alpha,
8181
)
8282

83-
sklearn_stats.update(_evaluate(val_loader, sklearn_classifier))
84-
pytorch_stats.update(_evaluate(val_loader, pytorch_classifier))
83+
sklearn_stats.update(evaluate(val_loader, sklearn_classifier))
84+
pytorch_stats.update(evaluate(val_loader, pytorch_classifier))
8585

86-
train_stats_pytorch = _evaluate(train_loader, pytorch_classifier)
87-
train_stats_sklearn = _evaluate(train_loader, sklearn_classifier)
86+
train_stats_pytorch = evaluate(train_loader, pytorch_classifier)
87+
train_stats_sklearn = evaluate(train_loader, sklearn_classifier)
8888

8989
o_pytorch = {"l2": train_stats_pytorch["l2"]}
9090
o_sklearn = {"l2": train_stats_sklearn["l2"]}

tests/utils/test_linear_model.py

+3-49
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
from typing import cast, Dict, Optional, Union
3+
from typing import Optional, Union
44

55
import torch
66
from captum._utils.models.linear_model.model import (
@@ -10,56 +10,10 @@
1010
)
1111
from tests.helpers import BaseTest
1212
from tests.helpers.basic import assertTensorAlmostEqual
13+
from tests.utils.evaluate_linear_model import evaluate
1314
from torch import Tensor
1415

1516

16-
def _evaluate(test_data, classifier) -> Dict[str, Tensor]:
17-
classifier.eval()
18-
19-
l1_loss = 0.0
20-
l2_loss = 0.0
21-
n = 0
22-
l2_losses = []
23-
with torch.no_grad():
24-
for data in test_data:
25-
if len(data) == 2:
26-
x, y = data
27-
w = None
28-
else:
29-
x, y, w = data
30-
31-
out = classifier(x)
32-
33-
y = y.view(x.shape[0], -1)
34-
assert y.shape == out.shape
35-
36-
if w is None:
37-
l1_loss += (out - y).abs().sum(0).to(dtype=torch.float64)
38-
l2_loss += ((out - y) ** 2).sum(0).to(dtype=torch.float64)
39-
l2_losses.append(((out - y) ** 2).to(dtype=torch.float64))
40-
else:
41-
l1_loss += (
42-
(w.view(-1, 1) * (out - y)).abs().sum(0).to(dtype=torch.float64)
43-
)
44-
l2_loss += (
45-
(w.view(-1, 1) * ((out - y) ** 2)).sum(0).to(dtype=torch.float64)
46-
)
47-
l2_losses.append(
48-
(w.view(-1, 1) * ((out - y) ** 2)).to(dtype=torch.float64)
49-
)
50-
51-
n += x.shape[0]
52-
53-
l2_losses = torch.cat(l2_losses, dim=0)
54-
assert n > 0
55-
56-
# just to double check
57-
assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all()
58-
59-
classifier.train()
60-
return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)}
61-
62-
6317
class TestLinearModel(BaseTest):
6418
MAX_POINTS: int = 3
6519

@@ -100,7 +54,7 @@ def train_and_compare(
10054

10155
self.assertTrue(model.bias() is not None if bias else model.bias() is None)
10256

103-
l2_loss = _evaluate(train_loader, model)["l2"]
57+
l2_loss = evaluate(train_loader, model)["l2"]
10458

10559
if objective == "lasso":
10660
reg = model.representation().norm(p=1).view_as(l2_loss)

0 commit comments

Comments
 (0)