Skip to content

Commit 3865a98

Browse files
committed
simplify logic
1 parent 60c0862 commit 3865a98

1 file changed

Lines changed: 10 additions & 17 deletions

File tree

src/sdialog/interpretability/__init__.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,9 @@ def _hook(self, module, input, output):
357357
# Check if min_token should steer all system prompt tokens (non-integer or string "-1")
358358
steer_all_system_prompt = not isinstance(min_token, (int, np.integer))
359359

360-
# Check if max_token should steer all generated tokens (string "-1" or non-integer)
360+
# Check if max_token should steer all generated tokens (any string or integer -1)
361361
steer_all_generated = (
362-
max_token == "-1"
363-
or max_token == -1
364-
or not isinstance(max_token, (int, np.integer))
362+
isinstance(max_token, str) or max_token == -1
365363
)
366364

367365
if output_tensor.shape[1] > 1:
@@ -556,7 +554,7 @@ class Inspector:
556554
:param steering_interval: (min_token, max_token) steering window (optional). Defaults to (0, -1),
557555
where -1 means no upper bound.
558556
:type steering_interval: Optional[Tuple[int, int]]
559-
:param top_k: Number of top predictions to store for each token. If None, logits are not captured.
557+
:param top_k: Number of top token predictions to store for each token. If None, logits are not captured.
560558
If -1, all tokens in the vocabulary are returned with their logits. Defaults to None.
561559
:type top_k: Optional[int]
562560
:param lm_head_layer: Name of the language model head layer (e.g., "lm_head"). Defaults to "lm_head".
@@ -1054,27 +1052,22 @@ def __getitem__(self, key):
10541052
# Get the activation tensor for this cache key
10551053
rep_tensor = rep_tensor_dict[key]
10561054

1057-
# need to know if token is system prompt or generated to properly index activation tensor
1055+
# Calculate activation index: system prompt tokens come first, then generated tokens
10581056
if self.is_system_prompt:
1059-
# For system prompt tokens, handle negative indices properly
1060-
# Negative index means "from the end of system prompt tokens" for proper recursion
1057+
# For system prompt, normalize negative index relative to system prompt length
10611058
if self.token_index < 0:
1062-
# Convert negative index to positive: -1 -> last system prompt token
10631059
activation_index = self.response.length_system_prompt + self.token_index
10641060
else:
10651061
activation_index = self.token_index
10661062
else:
1067-
# For generated tokens, handle negative indices properly
1068-
# Negative index means "from the end of generated tokens", not from the end of all tokens
1063+
# For generated tokens: positive indices need offset
10691064
input_response = self.agent._hooked_responses[self.response_index]['input'][0]
10701065
if self.token_index < 0:
1071-
# Convert negative index to positive relative to generated tokens
1072-
# Then add the system prompt offset
1073-
num_generated_tokens = len(self.response.tokens)
1074-
positive_index = num_generated_tokens + self.token_index
1075-
activation_index = positive_index + input_response.length_system_prompt
1066+
# Negative index: Python indexing from end of tensor (last generated token is at the end)
1067+
activation_index = self.token_index
10761068
else:
1077-
activation_index = self.token_index + input_response.length_system_prompt
1069+
# Positive index: add system prompt offset
1070+
activation_index = input_response.length_system_prompt + self.token_index
10781071

10791072
if hasattr(rep_tensor, '__getitem__'):
10801073
return rep_tensor[activation_index]

0 commit comments

Comments
 (0)