Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bergson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .faiss_index import FaissConfig
from .gradcheck import FiniteDiff
from .gradients import GradientCollector, GradientProcessor
from .score_writer import MemmapScoreWriter

__all__ = [
"collect_gradients",
Expand All @@ -18,4 +19,5 @@
"IndexConfig",
"DataConfig",
"AttentionConfig",
"MemmapScoreWriter",
]
4 changes: 3 additions & 1 deletion bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ class Query:

def execute(self):
"""Query the gradient dataset."""
assert self.query_cfg.scores_path
assert self.query_cfg.query_path

if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index:
raise ValueError(
"Index path already exists and save_index is True - "
"running this query will overwrite the existing gradients. "
"If you meant to query the existing gradients, use "
"If you meant to query the existing gradients use "
"Attributor instead."
)

Expand Down
4 changes: 2 additions & 2 deletions bergson/attributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def scores(self) -> Tensor:
class Attributor:
def __init__(
self,
index_path: str,
index_path: Path,
device: str = "cpu",
dtype: torch.dtype = torch.float32,
unit_norm: bool = False,
Expand All @@ -59,7 +59,7 @@ def __init__(
f"faiss_{faiss_cfg.index_factory.replace(',', '_')}"
f"{'_cosine' if unit_norm else ''}"
)
faiss_path = Path(index_path) / faiss_index_name
faiss_path = index_path / faiss_index_name

if not (faiss_path / "config.json").exists():
FaissIndex.create_index(
Expand Down
21 changes: 17 additions & 4 deletions bergson/build.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import os
import shutil
import socket
from dataclasses import is_dataclass
from datetime import timedelta
from pathlib import Path
from typing import cast

import pandas as pd
Expand Down Expand Up @@ -133,7 +136,7 @@ def worker(
print(f"Loading processor from '{cfg.processor_path}'")

processor = GradientProcessor.load(
cfg.processor_path,
Path(cfg.processor_path),
map_location=f"cuda:{rank}",
)
else:
Expand Down Expand Up @@ -170,6 +173,8 @@ def worker(
save_index=cfg.save_index,
save_processor=cfg.save_processor,
drop_columns=cfg.drop_columns,
token_batch_size=cfg.token_batch_size,
module_wise=cfg.module_wise,
)
else:
# Convert each shard to a Dataset then map over its gradients
Expand All @@ -185,7 +190,7 @@ def flush():
model,
ds_shard,
processor,
os.path.join(cfg.partial_run_path, f"shard-{shard_id:05d}"),
cfg.partial_run_path / f"shard-{shard_id:05d}",
batches=batches,
kl_divergence=cfg.loss_fn == "kl",
loss_reduction=cfg.loss_reduction,
Expand All @@ -196,6 +201,8 @@ def flush():
# Save a processor state checkpoint after each shard
save_processor=cfg.save_processor,
drop_columns=cfg.drop_columns,
token_batch_size=cfg.token_batch_size,
module_wise=cfg.module_wise,
)
buf.clear()
shard_id += 1
Expand Down Expand Up @@ -254,8 +261,14 @@ def build_gradient_dataset(cfg: IndexConfig):
)

# Write index config to json
os.makedirs(cfg.partial_run_path, exist_ok=True)
with open(os.path.join(cfg.partial_run_path, "index_config.json"), "w") as f:
json.dump(cfg, f)
index_cfg_dict = cfg.__dict__
for key in index_cfg_dict:
if is_dataclass(index_cfg_dict[key]):
index_cfg_dict[key] = index_cfg_dict[key].__dict__

json.dump(index_cfg_dict, f)

world_size = torch.cuda.device_count()
if world_size <= 1:
Expand Down Expand Up @@ -288,6 +301,6 @@ def build_gradient_dataset(cfg: IndexConfig):
ctx.wait()

try:
os.rename(cfg.partial_run_path, cfg.run_path)
shutil.move(cfg.partial_run_path, cfg.run_path)
except Exception:
pass
41 changes: 21 additions & 20 deletions bergson/collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from typing import Callable, Literal
from pathlib import Path
from typing import Literal

import numpy as np
import torch
Expand All @@ -12,13 +13,14 @@
from .data import create_index, pad_and_tensor
from .gradients import AttentionConfig, GradientCollector, GradientProcessor
from .peft import set_peft_enabled
from .score_writer import ScoreWriter


def collect_gradients(
model: PreTrainedModel,
data: Dataset,
processor: GradientProcessor,
path: str,
path: Path,
*,
batches: list[list[int]] | None = None,
kl_divergence: bool | None = None,
Expand All @@ -29,7 +31,9 @@ def collect_gradients(
save_index: bool = True,
save_processor: bool = True,
drop_columns: bool = False,
query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor] | None = None,
score_writer: ScoreWriter | None = None,
token_batch_size: int | None = None,
module_wise: bool = False,
):
"""
Compute projected gradients using a subset of the dataset.
Expand All @@ -53,14 +57,17 @@ def collect_gradients(
lo = torch.finfo(dtype).min
hi = torch.finfo(dtype).max

def callback(name: str, g: torch.Tensor):
def callback(name: str, g: torch.Tensor, indices: list[int]):
g = g.flatten(1).clamp_(lo, hi)
if save_index:
# Asynchronously move the gradient to CPU and convert to the final dtype
mod_grads[name] = g.to(device="cpu", dtype=dtype, non_blocking=True)
else:
mod_grads[name] = g.to(dtype=dtype)

if score_writer and module_wise:
score_writer(indices, mod_grads, name=name)

# Compute the outer product of the flattened gradient
if not skip_preconditioners:
g = g.float()
Expand Down Expand Up @@ -94,12 +101,6 @@ def callback(name: str, g: torch.Tensor):
dtype=dtype,
fill_value=0.0,
)
per_doc_scores = torch.full(
(len(data),),
device=model.device,
dtype=dtype,
fill_value=0.0,
)

for indices in tqdm(batches, disable=rank != 0, desc="Building index"):
batch = data[indices]
Expand All @@ -118,6 +119,8 @@ def callback(name: str, g: torch.Tensor):
set_peft_enabled(model, True)

with collector:
collector.indices = indices

ft_lps = torch.log_softmax(model(x).logits[:, :-1], dim=-1)

# Compute average KL across all unmasked tokens
Expand All @@ -129,6 +132,8 @@ def callback(name: str, g: torch.Tensor):
losses.mean().backward()
else:
with collector:
collector.indices = indices

logits = model(x).logits[:, :-1]

losses = F.cross_entropy(
Expand Down Expand Up @@ -156,9 +161,11 @@ def callback(name: str, g: torch.Tensor):
for module_name in mod_grads.keys():
grad_buffer[module_name][indices] = mod_grads[module_name].numpy()

if query_callback is not None:
scores = query_callback(mod_grads)
per_doc_scores[indices] = scores.detach().type_as(per_doc_scores)
if score_writer is not None:
if module_wise:
score_writer.finalize_module_wise(indices)
else:
score_writer(indices, mod_grads)

mod_grads.clear()
per_doc_losses[indices] = losses.detach().type_as(per_doc_losses)
Expand All @@ -178,13 +185,7 @@ def callback(name: str, g: torch.Tensor):
feature=Value("float16" if dtype == torch.float16 else "float32"),
new_fingerprint="loss",
)
data = data.add_column(
"scores",
per_doc_scores.cpu().numpy(),
feature=Value("float16" if dtype == torch.float16 else "float32"),
new_fingerprint="scores",
)
data.save_to_disk(path + "/data.hf")
data.save_to_disk(path / "data.hf")

if save_processor:
processor.save(path)
Expand Down
Loading