Skip to content

Commit 4f8cbca

Browse files
authored
Merge pull request #79 from EleutherAI/remove_latent_loader
Fix notebooks directly
2 parents 8eb4a55 + 3366306 commit 4f8cbca

File tree

4 files changed

+56
-175
lines changed

4 files changed

+56
-175
lines changed

delphi/latents/loader.py

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -322,123 +322,3 @@ async def _aprocess_latent(self, buffer_output: BufferOutput) -> LatentRecord:
322322
if self.transform is not None:
323323
self.transform(record)
324324
return record
325-
326-
327-
class LatentLoader:
328-
"""
329-
Loader class for processing latent records from a LatentDataset.
330-
"""
331-
332-
def __init__(
333-
self,
334-
latent_dataset: "LatentDataset",
335-
constructor: Optional[Callable] = None,
336-
sampler: Optional[Callable] = None,
337-
transform: Optional[Callable] = None,
338-
):
339-
"""
340-
Initialize a LatentLoader.
341-
342-
Args:
343-
latent_dataset (LatentDataset): The dataset to load latents from.
344-
constructor (Optional[Callable]): Function to construct latent records.
345-
sampler (Optional[Callable]): Function to sample from latent records.
346-
transform (Optional[Callable]): Function to transform latent records.
347-
"""
348-
self.latent_dataset = latent_dataset
349-
self.constructor = constructor
350-
self.sampler = sampler
351-
self.transform = transform
352-
353-
async def __aiter__(self):
354-
"""
355-
Asynchronous iterator for processing latent records.
356-
357-
Yields:
358-
LatentRecord: Processed latent records.
359-
"""
360-
for buffer in self.latent_dataset.buffers:
361-
async for record in self._aprocess_buffer(buffer):
362-
yield record
363-
364-
async def _aprocess_buffer(self, buffer):
365-
"""
366-
Asynchronously process a buffer.
367-
368-
Args:
369-
buffer (TensorBuffer): Buffer to process.
370-
371-
Yields:
372-
Optional[LatentRecord]: Processed latent record or None.
373-
"""
374-
for data in buffer:
375-
if data is not None:
376-
record = await self._aprocess_latent(data)
377-
if record is not None:
378-
yield record
379-
await asyncio.sleep(0) # Allow other coroutines to run
380-
381-
async def _aprocess_latent(self, buffer_output):
382-
"""
383-
Asynchronously process a single latent.
384-
385-
Args:
386-
buffer_output (BufferOutput): Latent data to process.
387-
388-
Returns:
389-
Optional[LatentRecord]: Processed latent record or None.
390-
"""
391-
record = LatentRecord(buffer_output.latent)
392-
if self.constructor is not None:
393-
self.constructor(record=record, buffer_output=buffer_output)
394-
if self.sampler is not None:
395-
self.sampler(record)
396-
if self.transform is not None:
397-
self.transform(record)
398-
return record
399-
400-
def __iter__(self):
401-
"""
402-
Synchronous iterator for processing latent records.
403-
404-
Yields:
405-
LatentRecord: Processed latent records.
406-
"""
407-
for buffer in self.latent_dataset.buffers:
408-
for record in self._process_buffer(buffer):
409-
yield record
410-
411-
def _process_buffer(self, buffer):
412-
"""
413-
Process a buffer synchronously.
414-
415-
Args:
416-
buffer (TensorBuffer): Buffer to process.
417-
418-
Yields:
419-
Optional[LatentRecord]: Processed latent record or None.
420-
"""
421-
for data in buffer:
422-
if data is not None:
423-
record = self._process_latent(data)
424-
if record is not None:
425-
yield record
426-
427-
def _process_latent(self, buffer_output):
428-
"""
429-
Process a single latent synchronously.
430-
431-
Args:
432-
buffer_output (BufferOutput): Latent data to process.
433-
434-
Returns:
435-
Optional[LatentRecord]: Processed latent record or None.
436-
"""
437-
record = LatentRecord(buffer_output.latent)
438-
if self.constructor is not None:
439-
self.constructor(record=record, buffer_output=buffer_output)
440-
if self.sampler is not None:
441-
self.sampler(record)
442-
if self.transform is not None:
443-
self.transform(record)
444-
return record

examples/generate_explanations.ipynb

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
"metadata": {},
2121
"outputs": [],
2222
"source": [
23-
"import asyncio\n",
2423
"import os\n",
2524
"from functools import partial\n",
2625
"\n",
@@ -30,7 +29,7 @@
3029
"from delphi.clients import OpenRouter\n",
3130
"from delphi.config import ExperimentConfig, LatentConfig\n",
3231
"from delphi.explainers import DefaultExplainer\n",
33-
"from delphi.latents import LatentDataset, LatentLoader\n",
32+
"from delphi.latents import LatentDataset\n",
3433
"from delphi.latents.constructors import default_constructor\n",
3534
"from delphi.latents.samplers import sample\n",
3635
"from delphi.pipeline import Pipeline, process_wrapper\n",
@@ -61,12 +60,7 @@
6160
"module = \".model.layers.10\" # The layer to explain\n",
6261
"latent_dict = {module: torch.arange(0,5)} # The what latents to explain\n",
6362
"\n",
64-
"dataset = LatentDataset(\n",
65-
" raw_dir=\"latents\", # The folder where the cache is stored\n",
66-
" cfg=latent_cfg,\n",
67-
" modules=[module],\n",
68-
" latents=latent_dict,\n",
69-
")\n"
63+
"\n"
7064
]
7165
},
7266
{
@@ -116,8 +110,14 @@
116110
" max_examples=latent_cfg.max_examples\n",
117111
" )\n",
118112
"sampler=partial(sample,cfg=experiment_cfg)\n",
119-
"loader = LatentLoader(dataset, constructor=constructor, sampler=sampler)\n",
120-
" "
113+
"dataset = LatentDataset(\n",
114+
" raw_dir=\"latents\", # The folder where the cache is stored\n",
115+
" cfg=latent_cfg,\n",
116+
" modules=[module],\n",
117+
" latents=latent_dict,\n",
118+
" constructor=constructor,\n",
119+
" sampler=sampler\n",
120+
") "
121121
]
122122
},
123123
{
@@ -210,7 +210,7 @@
210210
],
211211
"source": [
212212
"pipeline = Pipeline(\n",
213-
" loader,\n",
213+
" dataset,\n",
214214
" explainer_pipe,\n",
215215
")\n",
216216
"number_of_parallel_latents = 10\n",

examples/latent_contexts.ipynb

Lines changed: 34 additions & 33 deletions
Large diffs are not rendered by default.

examples/score_explanations.ipynb

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@
2424
"import os \n",
2525
"import torch\n",
2626
"import orjson\n",
27-
"import asyncio\n",
2827
"from delphi.clients import OpenRouter\n",
2928
"from delphi.config import ExperimentConfig, LatentConfig\n",
3029
"from delphi.explainers import explanation_loader\n",
3130
"from delphi.latents import (\n",
32-
" LatentDataset,\n",
33-
" LatentLoader\n",
31+
" LatentDataset\n",
3432
")\n",
3533
"from delphi.latents.constructors import default_constructor\n",
3634
"from delphi.latents.samplers import sample\n",
@@ -65,12 +63,7 @@
6563
"module = \".model.layers.10\" # The layer to score\n",
6664
"latent_dict = {module: torch.arange(0,3)} # The what latents to score\n",
6765
"\n",
68-
"dataset = LatentDataset(\n",
69-
" raw_dir=\"latents\", # The folder where the cache is stored\n",
70-
" cfg=latent_cfg,\n",
71-
" modules=[module],\n",
72-
" latents=latent_dict,\n",
73-
")\n"
66+
"\n"
7467
]
7568
},
7669
{
@@ -120,7 +113,14 @@
120113
" max_examples=latent_cfg.max_examples\n",
121114
" )\n",
122115
"sampler=partial(sample,cfg=experiment_cfg)\n",
123-
"loader = LatentLoader(dataset, constructor=constructor, sampler=sampler)\n",
116+
"dataset = LatentDataset(\n",
117+
" raw_dir=\"latents\", # The folder where the cache is stored\n",
118+
" cfg=latent_cfg,\n",
119+
" modules=[module],\n",
120+
" latents=latent_dict,\n",
121+
" constructor=constructor,\n",
122+
" sampler=sampler\n",
123+
")\n",
124124
" "
125125
]
126126
},
@@ -217,7 +217,7 @@
217217
],
218218
"source": [
219219
"pipeline = Pipeline(\n",
220-
" loader,\n",
220+
" dataset,\n",
221221
" explainer_pipe,\n",
222222
" scorer_pipe,\n",
223223
")\n",

0 commit comments

Comments
 (0)