Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Always test your changes by running the appropriate script or CLI command. Never complete a task without testing your changes until the script or CLI command runs without issues for 3 minutes+ (at minimum). If you find an error unrelated to your task, at minimum quote the exact error back to me when you have completed your task and offer to investigate and fix it.
Always test your changes by running the appropriate script or CLI command. Never complete a task without testing your changes until the script or CLI command runs without issues. If it's a long-running script let it run for at least a few iterations of the main loop. If you find an error unrelated to your task, at minimum quote the exact error back to me when you have completed your task and offer to investigate and fix it.

## Project Structure and Conventions

Expand All @@ -18,17 +18,39 @@ Put imports at the top of the file unless you have a good reason to do otherwise

# Development

Never use try/except blocks - fail fast, fail explicitly.

Never use "fallbacks".

Do not write lines longer than 88 characters.

Don't use ALL CAPS unless it's proper English (e.g. an acronym).

Don't keep default run path values inside low level code - if a module calls another module, the higher level module should always pass through inject a base path.

Don't save data to a directory that is not in the .gitignore - especially the data/ directory.

Don't remove large datasets from the HF cache without asking.

You can call CLI commands without prefixing `python -m`, like `bergson build`.

Use `pre-commit run --all-files` if you forget to install pre-commit and it doesn't run in the hook.

Run bash commands in the dedicated tmux pane named "claude" if it is available.

Don't keep default run path values inside low level code - if a module calls another module, the higher level module should always pass through inject a base path.
Don't betray lineage. An example of betraying lineage is duplicating a file, making changes in the duplicate, then calling it "foo_fixed" rather than "foo". Instead, commit the file and modify it directly. Another example is adding a RoundButton to a module containing a Button but not updating the original Button to be called RectangleButton. This betrays that the rectangular button was written first.

Don't save data to a directory that is not in the gitignore - especially the data/ directory.
If you think some data files (e.g. CSVs) have been invalidated but you're not 100% sure, you can add them to a .gitignore'd archive directory along with an equivalentally named markdown file explaining the context.

Don't remove large datasets from the HF cache without asking.
File names always use snake case - in_memory, not inmemory.

When writing files to disk python scripts should choose their own filenames but be provided with their file paths.

### Documentation

Do not mark documentation for code that has been removed as deprecated - simply remove the documentation.

No context leakage: do not write code or comments that link features to the specific experiment for which the feature was developed, unless it's only useful for that particular experiment. Be as generic as is correctly possible and not more.

### Tests

Expand Down
23 changes: 23 additions & 0 deletions bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
import shutil
from dataclasses import asdict
from datetime import timedelta
from pathlib import Path

import torch
import torch.distributed as dist
from datasets import Dataset, IterableDataset
from tqdm.auto import tqdm

from bergson.collection import collect_gradients
from bergson.collector.gradient_collectors import GradientCollector
from bergson.config import IndexConfig
from bergson.data import allocate_batches
from bergson.distributed import launch_distributed_run
from bergson.utils.auto_batch_size import (
determine_batch_size,
)
from bergson.utils.utils import assert_type, setup_reproducibility
from bergson.utils.worker_utils import (
create_processor,
Expand Down Expand Up @@ -63,6 +68,24 @@ def build_worker(
model, target_modules = setup_model_and_peft(cfg)
processor = create_processor(model, ds, cfg, target_modules)

# Auto batch size determination if enabled
if cfg.autobatchsize:
cfg.token_batch_size = determine_batch_size(
root=Path(".cache"),
cfg=cfg,
model=model,
collector=GradientCollector(
model=model.base_model,
cfg=cfg,
processor=processor,
target_modules=target_modules,
data=ds,
scorer=None,
reduce_cfg=None,
),
starting_batch_size=cfg.token_batch_size,
)

attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}

kwargs = {
Expand Down
3 changes: 3 additions & 0 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class IndexConfig:
token_batch_size: int = 2048
"""Batch size in tokens for building the index."""

autobatchsize: bool = False
"""Whether to automatically determine the optimal batch size."""

processor_path: str = ""
"""Path to a precomputed processor."""

Expand Down
Loading
Loading