Skip to content

Commit 7e1bead

Browse files
committed
simplify autobatchsize
1 parent eddac3a commit 7e1bead

File tree

10 files changed

+99
-396
lines changed

10 files changed

+99
-396
lines changed

CLAUDE.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Always test your changes by running the appropriate script or CLI command.
1+
Always test your changes by running the appropriate script or CLI command. Never complete a task without testing your changes until the script or CLI command runs without issues for 3 minutes+ (at minimum). If you find an error unrelated to your task, at minimum quote the exact error back to me when you have completed your task and offer to investigate and fix it.
22

33
## Project Structure and Conventions
44

@@ -12,12 +12,22 @@ Use dataclasses for config, and use simple_parsing to parse the CLI configs data
1212

1313
Never save logs, scripts, and other random development into the root of a project. Create an appropriate directory such as runs/ or scripts/ and add it to the .gitignore.
1414

15+
torch.cuda.empty_cache() doesn't do what you hope it will do - don't use it.
16+
17+
Put imports at the top of the file unless you have a good reason to do otherwise.
18+
1519
# Development
1620

1721
You can call CLI commands without prefixing `python -m`, like `bergson build`.
1822

1923
Use `pre-commit run --all-files` if you forget to install pre-commit and it doesn't run in the hook.
2024

25+
Run bash commands in the dedicated tmux pane named "claude" if it is available.
26+
27+
Don't keep default run path values inside low level code - if a module calls another module, the higher level module should always pass through inject a base path.
28+
29+
Don't save data to a directory that is not in the gitignore - especially the data/ directory.
30+
2131
Don't remove large datasets from the HF cache without asking.
2232

2333
### Tests
@@ -26,4 +36,6 @@ Mark tests requiring GPUs with `@pytest.mark.skipif(not torch.cuda.is_available(
2636

2737
### Environment Setup
2838

29-
If you use need to use a venv, create and/or activate it with `python3 -m venv .venv && source .venv/bin/activate && pip install pytest`.
39+
If you use need to use a venv, create and/or activate it with `python3 -m venv .venv && source .venv/bin/activate`.
40+
41+
You can pull secrets from .env.

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ We view attribution as a counterfactual question: **_If we "unlearned" this trai
66
## Core features
77

88
- Gradient store for serial queries. We provide collection-time gradient compression for efficient storage, and integrate with FAISS for fast KNN search over large stores.
9-
- On-the-fly queries. Query gradients without compression or disk I/O overhead via a single pass over a dataset with a set of precomputed query gradients.
9+
- On-the-fly queries. Query gradients without disk I/O overhead via a single pass over a dataset with a set of precomputed query gradients.
1010
- Experiment with multiple query strategies based on [LESS](https://arxiv.org/pdf/2402.04333).
11+
- Ideal for compression-free gradients.
1112
- Train‑time gradient collection. Capture gradients produced during training with a ~17% performance overhead.
1213
- Scalable. We use [FSDP2](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html), BitsAndBytes, and other performance optimizations to support large models, datasets, and clusters.
1314
- Integrated with HuggingFace Transformers and Datasets. We also support on-disk datasets in a variety of formats.

bergson/__main__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from simple_parsing import ArgumentParser, ConflictResolution
77

88
from .build import build
9-
from .cli.auto_batch_size import AutoBatchSize
109
from .config import IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
1110
from .query.query_index import query
1211
from .reduce import reduce
@@ -103,7 +102,7 @@ def execute(self):
103102
class Main:
104103
"""Routes to the subcommands."""
105104

106-
command: Union[Build, Query, Reduce, Score, AutoBatchSize]
105+
command: Union[Build, Query, Reduce, Score]
107106

108107
def execute(self):
109108
"""Run the script."""

bergson/cli/auto_batch_size.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

bergson/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ class IndexConfig:
201201
"""Configuration for multi-node distributed preconditioner computation."""
202202

203203
max_tokens: int | None = None
204-
"""The maximum number of tokens to process. If None, all tokens will be processed. Only available for Dataset."""
204+
"""Max tokens to process. If None, all tokens processed. Dataset only.
205+
This experimental feature may be removed in the future."""
205206

206207
@property
207208
def partial_run_path(self) -> Path:

bergson/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def launch_distributed_run(
8282
newline = "\n"
8383
raise RuntimeError(
8484
f"{process_name} failed with {len(result.failures)} process "
85-
f"failure(s): {newline.join(result.failures)}"
85+
f"failure(s): {newline.join([str(f) for f in result.failures])}"
8686
)
8787
finally:
8888
if ctx is not None:

bergson/query/attributor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
# Load the gradients into memory
7979
mmap = load_gradients(index_path)
8080
assert mmap.dtype.names is not None
81-
# Copy gradients into device memory (handles bfloat16/V2 void types)
81+
# Copy gradients into device memory
8282
self.grads = {
8383
name: numpy_to_tensor(mmap[name]).to(device=device, dtype=dtype)
8484
for name in mmap.dtype.names
@@ -91,8 +91,19 @@ def __init__(
9191
norm = torch.cat(
9292
[self.grads[name] for name in self.ordered_modules], dim=1
9393
).norm(dim=1, keepdim=True)
94+
9495
for name in self.grads:
95-
self.grads[name] /= norm
96+
# Divide by norm (may create NaN/inf if norm is zero)
97+
normalized = self.grads[name] / norm
98+
# Convert NaN/inf to 0 and warn if any were found
99+
if not torch.isfinite(normalized).all():
100+
print(
101+
f"Warning: NaN/inf values detected after normalization in "
102+
f"{name}, converting to 0"
103+
)
104+
self.grads[name] = torch.nan_to_num(
105+
normalized, nan=0.0, posinf=0.0, neginf=0.0
106+
)
96107

97108
def search(
98109
self,

0 commit comments

Comments
 (0)