Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ class Query:
def execute(self):
"""Query the gradient dataset."""

if (
os.path.exists(self.index_cfg.run_path)
and self.index_cfg.save_index
and not self.index_cfg.run_path == "Sdfs"
):
if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index:
raise ValueError(
"Index path already exists and save_index is True - "
"running this query will overwrite the existing gradients. "
Expand Down
34 changes: 10 additions & 24 deletions bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,8 @@ class AttentionConfig:
class QueryConfig:
"""Config for querying an index on the fly."""

run_path: str = ""
"""Path to the query dataset. If empty, a new dataset will be built
using the query_data config and an index config."""

query_data: DataConfig = field(default_factory=DataConfig)
"""Data to use for the query."""
query_path: str = ""
"""Path to the query dataset."""

query_method: Literal["mean", "nearest"] = "mean"
"""Method to use for computing the query."""
Expand All @@ -83,25 +79,15 @@ class QueryConfig:
"""Whether to write the query dataset gradient processor
to disk."""

apply_query_preconditioner: Literal["none", "existing", "precompute"] = "none"
"""Whether to apply (or compute and apply) a preconditioner
computed over the query dataset."""

apply_index_preconditioner: Literal["none", "existing", "precompute"] = "none"
"""Whether to apply (or compute and apply) a preconditioner
computed over the index dataset."""

query_preconditioner_path: str | None = None
"""Path to a precomputed preconditioner. This does not affect
the ability to compute a new preconditioner during gradient collection.
The precomputed preconditioner is applied to the query dataset
gradients."""
"""Path to a precomputed preconditioner. The precomputed
preconditioner is applied to the query dataset gradients."""

index_preconditioner_path: str | None = None
"""Path to a precomputed preconditioner. This does not affect
the ability to compute a new preconditioner during gradient collection.
The precomputed preconditioner is applied to the query dataset
gradients."""
"""Path to a precomputed preconditioner. The precomputed
preconditioner is applied to the query dataset gradients.
This does not affect the ability to compute a new
preconditioner during gradient collection."""

mixing_coefficient: float = 0.5
"""Coefficient to weight the application of the query preconditioner
Expand All @@ -123,10 +109,10 @@ class IndexConfig:
"""Config for building the index and running the model/dataset pipeline."""

run_path: str = field(positional=True)
"""Name of the run. Used to create a directory for the index."""
"""Name of the run. Used to create a directory for run artifacts."""

save_index: bool = True
"""Whether to write the gradients to disk."""
"""Whether to write the gradient index to disk."""

save_processor: bool = True
"""Whether to write the gradient processor to disk."""
Expand Down
58 changes: 18 additions & 40 deletions bergson/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
import socket
from copy import deepcopy
from datetime import timedelta
from typing import cast

Expand All @@ -20,15 +19,14 @@
PreTrainedModel,
)

from .build import build_gradient_dataset, estimate_advantage
from .build import estimate_advantage
from .collection import collect_gradients
from .data import (
IndexConfig,
QueryConfig,
allocate_batches,
load_data_string,
load_gradient_dataset,
load_gradients,
tokenize,
)
from .gradients import GradientProcessor
Expand All @@ -42,55 +40,35 @@ def get_query_data(index_cfg: IndexConfig, query_cfg: QueryConfig):
may be mixed as described in https://arxiv.org/html/2410.17413v1#S3.
"""
# Collect the query gradients if they don't exist
if not os.path.exists(query_cfg.run_path):
# Create a copy of the index configuration that uses the query dataset
cfg = deepcopy(index_cfg)
cfg.data = deepcopy(query_cfg.query_data)
cfg.run_path = query_cfg.run_path

if query_cfg.save_processor or query_cfg.apply_query_preconditioner != "none":
cfg.save_processor = True
cfg.save_index = True
cfg.streaming = False

print("Building query dataset...")
build_gradient_dataset(cfg)

# Collect the index preconditioner if it doesn't exist
if (
query_cfg.apply_index_preconditioner != "none"
and query_cfg.index_preconditioner_path is None
):
print(
"Building index dataset gradient processor. Warning: "
"this will take approximately as long as the query itself."
if not os.path.exists(query_cfg.query_path):
raise FileNotFoundError(
f"Query dataset not found at {query_cfg.query_path}. "
"Please build a query dataset index first."
)
build_gradient_dataset(index_cfg)

# Load the query dataset
with open(os.path.join(query_cfg.run_path, "info.json"), "r") as f:
with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f:
target_modules = json.load(f)["dtype"]["names"]

query_ds = load_gradient_dataset(query_cfg.run_path, concatenate_gradients=False)
query_ds = load_gradient_dataset(query_cfg.query_path, concatenate_gradients=False)
query_ds = query_ds.with_format("torch", columns=target_modules)

use_q = query_cfg.apply_query_preconditioner != "none"
use_i = query_cfg.apply_index_preconditioner != "none"
use_q = query_cfg.query_preconditioner_path is not None
use_i = query_cfg.index_preconditioner_path is not None

if use_q or use_i:
q, i = {}, {}
if use_q:
assert query_cfg.query_preconditioner_path is not None
q = GradientProcessor.load(
query_cfg.query_preconditioner_path or query_cfg.run_path,
query_cfg.query_preconditioner_path,
map_location="cuda",
).preconditioners
if use_i:
i_path = (
query_cfg.index_preconditioner_path
or index_cfg.processor_path
or index_cfg.run_path
)
i = GradientProcessor.load(i_path, map_location="cuda").preconditioners
assert query_cfg.index_preconditioner_path is not None
i = GradientProcessor.load(
query_cfg.index_preconditioner_path, map_location="cuda"
).preconditioners

mixed_preconditioner = (
{
Expand Down Expand Up @@ -148,7 +126,7 @@ def sum_(*cols):
(acc[module] / len(query_ds)).to(device=device, dtype=dtype)
for module in query_cfg.modules
],
dim=1,
dim=0,
)

@torch.inference_mode()
Expand Down Expand Up @@ -320,8 +298,8 @@ def worker(
else:
attention_cfgs = {}

if not query_cfg.modules:
query_cfg.modules = load_gradients(query_cfg.run_path).dtype.names
with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f:
query_cfg.modules = json.load(f)["dtype"]["names"]

query_ds = query_ds.with_format("torch", columns=query_cfg.modules)

Expand Down