Skip to content

Commit 10aecde

Browse files
committed
top_k visualization should always register forward hooks
1 parent 4aa668a commit 10aecde

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/sdialog/interpretability/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,9 @@ def __init__(self, cache_key, layer_key, agent, response_hook, inspector=None):
467467
self.agent = agent
468468
self.response_hook = response_hook
469469
self.inspector = inspector
470-
self.register(agent.base_model, self.inspector.inspect_input)
470+
# We always need a forward hook instead of a pre-hook if we want to capture the lm_head predictions.
471+
# So, no pre hook here is mandatory.
472+
self.register(agent.base_model, is_pre_hook=False)
471473

472474
# Initialize the logits cache for this response
473475
_ = self.agent._hook_response_logit[len(self.response_hook.responses)] = []

0 commit comments

Comments
 (0)