Skip to content

Commit 1a99017

Browse files
committed
Fix typing linear
1 parent d0e6690 commit 1a99017

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/utils/test_linear_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
from typing import Dict, Optional, Union
3+
from typing import cast, Dict, Optional, Union
44

55
import torch
66
from captum._utils.models.linear_model.model import (
@@ -15,8 +15,8 @@
1515
def _evaluate(test_data, classifier) -> Dict[str, Tensor]:
1616
classifier.eval()
1717

18-
l1_loss = torch.tensor(0.0)
19-
l2_loss = torch.tensor(0.0)
18+
l1_loss = 0.0
19+
l2_loss = 0.0
2020
n = 0
2121
l2_losses = []
2222
with torch.no_grad():
@@ -56,7 +56,7 @@ def _evaluate(test_data, classifier) -> Dict[str, Tensor]:
5656
assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all()
5757

5858
classifier.train()
59-
return {"l1": l1_loss / n, "l2": l2_loss / n}
59+
return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)}
6060

6161

6262
class TestLinearModel(BaseTest):

0 commit comments

Comments
 (0)