Skip to content

Commit 3723045

Browse files
pre-commit-ci[bot]luciaquirke
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6ebe5bf commit 3723045

File tree

7 files changed

+296
-120
lines changed

7 files changed

+296
-120
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ The experiments discussed in [the blog post](https://blog.eleuther.ai/autointerp
230230

231231
## Development
232232

233-
Run unit tests:
233+
Run unit tests:
234234

235235
```pytest .```
236236

delphi/__main__.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from delphi.log.result_analysis import log_results
2626
from delphi.pipeline import Pipe, Pipeline, process_wrapper
2727
from delphi.scorers import DetectionScorer, FuzzingScorer
28-
from delphi.semantic_index.index import build_or_load_index, load_index
2928
from delphi.sparse_coders import load_hook_to_sparse_encode, load_sparse_coders
3029
from delphi.utils import assert_type
3130

@@ -97,7 +96,6 @@ def create_neighbours(
9796

9897
async def process_cache(
9998
run_cfg: RunConfig,
100-
base_path: Path,
10199
latents_path: Path,
102100
explanations_path: Path,
103101
scores_path: Path,
@@ -133,9 +131,6 @@ async def process_cache(
133131
tokenizer=tokenizer,
134132
)
135133

136-
if run_cfg.semantic_index:
137-
index = load_index(base_path, run_cfg.cache_cfg)
138-
139134
if run_cfg.explainer_provider == "offline":
140135
client = Offline(
141136
run_cfg.explainer_model,
@@ -165,16 +160,19 @@ async def process_cache(
165160
f"Explainer provider {run_cfg.explainer_provider} not supported"
166161
)
167162

163+
from delphi.explainers.explainer import ExplainerResult
164+
168165
def explainer_postprocess(result):
169166
with open(explanations_path / f"{result.record.latent}.txt", "wb") as f:
170167
f.write(orjson.dumps(result.explanation))
168+
169+
if not isinstance(result, ExplainerResult):
170+
breakpoint()
171171
return result
172172

173-
if run_cfg.semantic_index:
173+
if run_cfg.constructor_cfg.non_activating_source == "FAISS":
174174
explainer = ContrastiveExplainer(
175175
client,
176-
tokenizer=dataset.tokenizer,
177-
index=index,
178176
threshold=0.3,
179177
verbose=run_cfg.verbose,
180178
)
@@ -189,6 +187,9 @@ def explainer_postprocess(result):
189187

190188
# Builds the record from result returned by the pipeline
191189
def scorer_preprocess(result):
190+
if isinstance(result, list):
191+
result = result[0]
192+
192193
record = result.record
193194
record.explanation = result.explanation
194195
record.extra_examples = record.not_active
@@ -259,8 +260,8 @@ def populate_cache(
259260
)
260261
data = data.shuffle(run_cfg.seed)
261262

262-
if run_cfg.semantic_index:
263-
build_or_load_index(data, base_path, run_cfg.cache_cfg)
263+
# if run_cfg.constructor_cfg.non_activating_source == "FAISS":
264+
# build_or_load_index(data, base_path, run_cfg.cache_cfg)
264265

265266
tokens_ds = chunk_and_tokenize(
266267
data, # type: ignore
@@ -368,9 +369,6 @@ async def run(
368369
transcode,
369370
)
370371

371-
if run_cfg.semantic_index:
372-
load_index(base_path, run_cfg.cache_cfg)
373-
374372
del model, hookpoint_to_sparse_encode
375373
if run_cfg.constructor_cfg.non_activating_source == "neighbours":
376374
non_redundant_hookpoints = assert_type(
@@ -398,7 +396,6 @@ async def run(
398396
if non_redundant_hookpoints:
399397
await process_cache(
400398
run_cfg,
401-
base_path,
402399
latents_path,
403400
explanations_path,
404401
scores_path,

delphi/config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class SamplerConfig(Serializable):
2525

2626
@dataclass
2727
class ConstructorConfig(Serializable):
28+
faiss_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
29+
"""Embedding model to use for FAISS index."""
30+
2831
example_ctx_len: int = 32
2932
"""Length of each sampled example sequence. Longer sequences
3033
reduce detection scoring performance in weak models.
@@ -41,11 +44,12 @@ class ConstructorConfig(Serializable):
4144
n_non_activating: int = 50
4245
"""Number of non-activating examples to be constructed."""
4346

44-
non_activating_source: Literal["random", "neighbours"] = "random"
47+
non_activating_source: Literal["random", "neighbours", "FAISS"] = "FAISS"
4548
"""Source of non-activating examples. Random uses non-activating contexts
4649
sampled from any non activating window. Neighbours uses actvating contexts
47-
from pre-computed latent neighbours. They are still non-activating but
48-
have a higher chance of being similar to the activating examples."""
50+
from pre-computed latent neighbours. FAISS uses semantic similarity search
51+
to find hard negatives that are semantically similar to activating examples
52+
but don't activate the latent."""
4953

5054
neighbours_type: Literal[
5155
"co-occurrence", "decoder_similarity", "encoder_similarity"
Lines changed: 114 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,145 @@
11
import asyncio
2-
import re
2+
from dataclasses import dataclass
33

4-
import faiss
4+
import torch
55

6-
from delphi.explainers.default.prompt_builder import build_single_token_prompt
6+
from delphi.explainers.default.prompts import SYSTEM_CONTRASTIVE
77
from delphi.explainers.explainer import Explainer, ExplainerResult
8-
from delphi.logger import logger
8+
from delphi.latents.latents import ActivatingExample, LatentRecord, NonActivatingExample
99

1010

11+
@dataclass
1112
class ContrastiveExplainer(Explainer):
12-
name = "contrastive"
13-
14-
def __init__(
15-
self,
16-
client,
17-
tokenizer,
18-
index: faiss.Index,
19-
verbose: bool = False,
20-
activations: bool = False,
21-
cot: bool = False,
22-
threshold: float = 0.6,
23-
temperature: float = 0.0,
24-
**generation_kwargs,
25-
):
26-
self.client = client
27-
self.tokenizer = tokenizer
28-
self.index = index
29-
self.verbose = verbose
30-
31-
self.activations = activations
32-
self.cot = cot
33-
self.threshold = threshold
34-
self.temperature = temperature
35-
self.generation_kwargs = generation_kwargs
36-
37-
async def __call__(self, record):
38-
breakpoint()
39-
messages = self._build_prompt(record.train)
40-
13+
activations: bool = True
14+
"""Whether to show activations to the explainer."""
15+
max_examples: int = 15
16+
"""Maximum number of activating examples to use."""
17+
max_non_activating: int = 5
18+
"""Maximum number of non-activating examples to use."""
19+
20+
async def __call__(self, record: LatentRecord) -> ExplainerResult:
21+
"""
22+
Override the base __call__ method to use both train and not_active examples.
23+
24+
Args:
25+
record: The latent record containing both activating and
26+
non-activating examples.
27+
28+
Returns:
29+
ExplainerResult: The explainer result containing the explanation.
30+
"""
31+
# Sample from both activating and non-activating examples
32+
activating_examples = record.train[: self.max_examples]
33+
34+
non_activating_examples = []
35+
if len(record.not_active) > 0:
36+
non_activating_examples = record.not_active[: self.max_non_activating]
37+
38+
# Ensure non-activating examples have normalized activations for consistency
39+
for example in non_activating_examples:
40+
if example.normalized_activations is None:
41+
# Use zeros for non-activating examples
42+
example.normalized_activations = torch.zeros_like(
43+
example.activations
44+
)
45+
46+
# Combine examples for the prompt
47+
combined_examples = activating_examples + non_activating_examples
48+
49+
# Build the prompt with both types of examples
50+
messages = self._build_prompt(combined_examples)
51+
print("message", messages[-1]["content"])
52+
53+
# Generate the explanation
4154
response = await self.client.generate(
4255
messages, temperature=self.temperature, **self.generation_kwargs
4356
)
4457

4558
try:
4659
explanation = self.parse_explanation(response.text)
4760
if self.verbose:
48-
return (
49-
messages[-1]["content"],
50-
response,
51-
ExplainerResult(record=record, explanation=explanation),
52-
)
61+
from ..logger import logger
62+
63+
logger.info(f"Explanation: {explanation}")
64+
logger.info(f"Messages: {messages[-1]['content']}")
65+
logger.info(f"Response: {response}")
5366

5467
return ExplainerResult(record=record, explanation=explanation)
5568
except Exception as e:
69+
from ..logger import logger
70+
5671
logger.error(f"Explanation parsing failed: {e}")
5772
return ExplainerResult(
5873
record=record, explanation="Explanation could not be parsed."
5974
)
6075

61-
def parse_explanation(self, text: str) -> str:
62-
try:
63-
match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL)
64-
return (
65-
match.group(1).strip() if match else "Explanation could not be parsed."
66-
)
67-
except Exception as e:
68-
logger.error(f"Explanation parsing regex failed: {e}")
69-
raise
70-
71-
def _highlight(self, index, example):
72-
# result = f"Example {index}: "
73-
result = ""
74-
threshold = example.max_activation * self.threshold
75-
if self.tokenizer is not None:
76-
str_toks = self.tokenizer.batch_decode(example.tokens)
77-
example.str_toks = str_toks
78-
else:
79-
str_toks = example.tokens
80-
example.str_toks = str_toks
81-
activations = example.activations
82-
83-
def check(i):
84-
return activations[i] > threshold
85-
86-
i = 0
87-
while i < len(str_toks):
88-
if check(i):
89-
# result += "<<"
90-
91-
while i < len(str_toks) and check(i):
92-
result += str_toks[i]
93-
i += 1
94-
# result += ">>"
95-
else:
96-
# result += str_toks[i]
97-
i += 1
98-
99-
return "".join(result)
100-
101-
def _join_activations(self, example):
102-
activations = []
103-
104-
for i, activation in enumerate(example.activations):
105-
if activation > example.max_activation * self.threshold:
106-
activations.append(
107-
(example.str_toks[i], int(example.normalized_activations[i]))
108-
)
109-
110-
acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations)
76+
def _build_prompt(
77+
self, examples: list[ActivatingExample | NonActivatingExample]
78+
) -> list[dict]:
79+
"""
80+
Build a prompt with both activating and non-activating examples clearly labeled.
11181
112-
return "Activations: " + acts
82+
Args:
83+
examples: List containing both activating and non-activating examples.
11384
114-
def _build_prompt(self, examples):
85+
Returns:
86+
A list of message dictionaries for the prompt.
87+
"""
11588
highlighted_examples = []
11689

117-
for i, example in enumerate(examples):
118-
highlighted_examples.append(self._highlight(i + 1, example))
90+
# First, separate activating and non-activating examples
91+
activating_examples = [
92+
ex for ex in examples if isinstance(ex, ActivatingExample)
93+
]
94+
non_activating_examples = [
95+
ex for ex in examples if not isinstance(ex, ActivatingExample)
96+
]
97+
98+
# Process activating examples
99+
if activating_examples:
100+
highlighted_examples.append("EXAMPLES:")
101+
for i, example in enumerate(activating_examples, 1):
102+
str_toks = example.str_tokens
103+
activations = example.activations.tolist()
104+
highlighted_examples.append(
105+
f"Example {i}: {self._highlight(str_toks, activations)}"
106+
)
119107

120-
if self.activations:
121-
highlighted_examples.append(self._join_activations(example))
108+
if self.activations and example.normalized_activations is not None:
109+
normalized_activations = example.normalized_activations.tolist()
110+
highlighted_examples.append(
111+
self._join_activations(
112+
str_toks, activations, normalized_activations
113+
)
114+
)
115+
116+
# Process non-activating examples
117+
if non_activating_examples:
118+
highlighted_examples.append("\nCOUNTEREXAMPLES:")
119+
for i, example in enumerate(non_activating_examples, 1):
120+
str_toks = example.str_tokens
121+
activations = example.activations.tolist()
122+
# Note: For non-activating examples, the _highlight method won't
123+
# highlight anything since activation values will be below threshold
124+
highlighted_examples.append(
125+
f"Example {i}: {self._highlight(str_toks, activations)}"
126+
)
122127

123-
return build_single_token_prompt(
124-
examples=highlighted_examples,
125-
)
128+
# Join all sections into a single string
129+
highlighted_examples_str = "\n".join(highlighted_examples)
130+
131+
# Create messages array with the system prompt
132+
return [
133+
{
134+
"role": "system",
135+
"content": SYSTEM_CONTRASTIVE.format(prompt=""),
136+
},
137+
{
138+
"role": "user",
139+
"content": f"WORDS: {highlighted_examples_str}",
140+
},
141+
]
126142

127143
def call_sync(self, record):
144+
"""Synchronous wrapper for the asynchronous __call__ method."""
128145
return asyncio.run(self.__call__(record))

delphi/explainers/default/prompts.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@
3737
{prompt}
3838
"""
3939

40+
SYSTEM_CONTRASTIVE = """You are a meticulous AI researcher conducting an important investigation into patterns found in language. Your task is to analyze text and provide an explanation that thoroughly encapsulates possible patterns found in it.
41+
Guidelines:
42+
43+
You will be given a list of text examples on which special words are selected and between delimiters like <<this>>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <<just like this>>. How important each token is for the behavior is listed after each example in parentheses.
44+
45+
- Try to produce a concise final description. Simply describe the text latents that are common in the examples, and what patterns you found.
46+
- Counterexamples where no special words are present are also provided to help you understand the patterns' edge cases.
47+
- If the examples are uninformative, you don't need to mention them. Don't focus on giving examples of important tokens, but try to summarize the patterns found in the examples.
48+
- Do not mention the marker tokens (<< >>) in your explanation.
49+
- Do not make lists of possible explanations. Keep your explanations short and concise.
50+
- The last line of your response must be the formatted explanation, using [EXPLANATION]:
51+
52+
{prompt}
53+
"""
54+
4055

4156
COT = """
4257
To better find the explanation for the language patterns go through the following stages:
@@ -228,3 +243,7 @@ def system(cot=False):
228243

229244
def system_single_token():
230245
return [{"role": "system", "content": SYSTEM_SINGLE_TOKEN}]
246+
247+
248+
def system_contrastive():
249+
return [{"role": "system", "content": SYSTEM_CONTRASTIVE}]

0 commit comments

Comments
 (0)