Skip to content

Commit 8ce0d7f

Browse files
committed
Add ruff fixes
1 parent cac38fe commit 8ce0d7f

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

probe_lens/probes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from tqdm import tqdm
21
import torch
32
import torch.nn as nn
3+
from tqdm.autonotebook import tqdm
44

55

66
class LinearProbe(nn.Module):
77
def __init__(self, input_dim, output_dim=1, device="cpu"):
8-
super(LinearProbe, self).__init__()
8+
super().__init__()
99
self.linear = nn.Linear(input_dim, output_dim, device=device)
1010

1111
def forward(self, x):

tests/test_probes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from probe_lens.probes import LinearProbe
21
import torch
32
import torch.nn as nn
43

4+
from probe_lens.probes import LinearProbe
5+
56

67
def test_linear_probe():
78
input_dim = 10

0 commit comments

Comments
 (0)