Skip to content

Commit 1763e0a

Browse files
committed
wip
1 parent cdebc3f commit 1763e0a

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

delphi/__main__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import orjson
1010
import torch
1111
from datasets import Dataset, load_dataset
12-
from simple_parsing import ArgumentParser, field, list_field
12+
from simple_parsing import ArgumentParser
1313
from sparsify.data import chunk_and_tokenize
1414
from torch import Tensor
1515
from torchtyping import TensorType
@@ -26,14 +26,13 @@
2626
from delphi.config import CacheConfig, ExperimentConfig, LatentConfig, RunConfig
2727
from delphi.explainers import ContrastiveExplainer, DefaultExplainer
2828
from delphi.latents import LatentCache, LatentDataset
29-
from delphi.latents.loader import LatentLoader
3029
from delphi.latents.constructors import default_constructor
3130
from delphi.latents.samplers import sample
3231
from delphi.log.result_analysis import log_results
3332
from delphi.pipeline import Pipe, Pipeline, process_wrapper
3433
from delphi.scorers import DetectionScorer, FuzzingScorer
35-
from delphi.sparse_coders import load_sparse_coders
3634
from delphi.semantic_index.index import build_or_load_index, load_index
35+
from delphi.sparse_coders import load_sparse_coders
3736
from delphi.utils import assert_type
3837

3938

@@ -165,8 +164,13 @@ def explainer_postprocess(result):
165164
tokenizer=dataset.tokenizer,
166165
threshold=0.3,
167166
verbose=run_cfg.verbose,
167+
)
168+
169+
explainer_pipe = Pipe(
170+
process_wrapper(
171+
explainer,
172+
postprocess=explainer_postprocess,
168173
),
169-
postprocess=explainer_postprocess,
170174
)
171175

172176
# Builds the record from result returned by the pipeline

delphi/explainers/contrastive_explainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
self,
1616
client,
1717
tokenizer,
18-
index: faiss.IndexFlatL2,
18+
index: faiss.Index,
1919
verbose: bool = False,
2020
activations: bool = False,
2121
cot: bool = False,
@@ -35,6 +35,7 @@ def __init__(
3535
self.generation_kwargs = generation_kwargs
3636

3737
async def __call__(self, record):
38+
breakpoint()
3839
messages = self._build_prompt(record.train)
3940

4041
response = await self.client.generate(

delphi/tests/e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ async def test():
4242
num_gpus=torch.cuda.device_count(),
4343
filter_bos=True,
4444
verbose=True,
45+
semantic_index=True,
4546
)
4647

4748
start_time = time.time()

0 commit comments

Comments
 (0)