-
Notifications
You must be signed in to change notification settings - Fork 66
Open
Description
Using trace=False is not working anymore. Idk if it's deprecated but not removed yet or if it's a bug. Tested on 0.5.7 and 0.5.10.
import torch as th
import torch.nn as nn
from nnsight import NNsight
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, input_ids, attention_mask=None):
return {"logits": self.linear(input_ids.float())}
dummy_model = NNsight(DummyModel())
dummy_out = dummy_model.trace(th.randn(10), trace=False)
print(dummy_out) File "/mnt/nw/home/c.dumas/projects/nnterp/.venv/lib/python3.10/site-packages/nnsight/intervention/tracing/base.py", line 364, in parse
raise WithBlockNotFoundError(message)
nnsight.intervention.tracing.base.WithBlockNotFoundError: With block not found at line 23
We looked here:
return {"logits": self.linear(input_ids.float())}
dummy_model = NNsight(DummyModel())
dummy_out = dummy_model.trace(th.randn(10), trace=False) <--- HERE
print(dummy_out)Metadata
Metadata
Assignees
Labels
No labels