Skip to content

Commit 2c72071

Browse files
committed
Improve readme
1 parent 3d04814 commit 2c72071

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ We initialize a linear probe with the same number of outputs as classes in the e
4343
```python
4444
from probe_lens.probes import LinearProbe
4545
X, y = next(iter(dataloader))
46-
probe = LinearProbe(X.shape[1], y.shape[1], class_names=spelling_task.get_classes())
46+
probe = LinearProbe(X.shape[1], y.shape[1], class_names=spelling_task.get_classes(), device=DEVICE)
4747
```
4848

4949
#### Train probe
@@ -58,6 +58,7 @@ We use the `visualize_performance` method to visualize the performance of the pr
5858
```python
5959
plot = probe.visualize_performance(dataloader)
6060
```
61+
![Confusion Matrix](confusion_matrix.png)
6162

6263

6364
## Roadmap
@@ -81,3 +82,8 @@ plot = probe.visualize_performance(dataloader)
8182
- [ ] Add more visualization experiments
8283
- [ ] ... ?
8384

85+
### Documentation
86+
- [ ] Add docstrings
87+
- [ ] Add tutorials
88+
- [ ] Reproduce experiments from major papers (SAE-Spelling, etc.)
89+

confusion_matrix.png

30.2 KB
Loading

probe_lens/probes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,13 @@ def visualize_performance(
3030

3131
accuracy = accuracy_score(gts.cpu(), preds.cpu())
3232
f2_score = fbeta_score(gts.cpu(), preds.cpu(), beta=2, average="weighted")
33-
3433
cm = confusion_matrix(gts.cpu(), preds.cpu())
35-
plt.figure(figsize=(10, 7))
3634
_class_names = (
3735
self.class_names
3836
if self.class_names
3937
else [str(i) for i in range(cm.shape[0])]
4038
)
39+
plt.figure(figsize=(10, 7))
4140
sns.heatmap(
4241
cm,
4342
annot=True,

0 commit comments

Comments
 (0)