Skip to content

Commit 0a23ffe

Browse files
committed
simplify autobatchsize
1 parent eddac3a commit 0a23ffe

File tree

6 files changed

+57
-337
lines changed

6 files changed

+57
-337
lines changed

bergson/__main__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from simple_parsing import ArgumentParser, ConflictResolution
77

88
from .build import build
9-
from .cli.auto_batch_size import AutoBatchSize
109
from .config import IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
1110
from .query.query_index import query
1211
from .reduce import reduce
@@ -103,7 +102,7 @@ def execute(self):
103102
class Main:
104103
"""Routes to the subcommands."""
105104

106-
command: Union[Build, Query, Reduce, Score, AutoBatchSize]
105+
command: Union[Build, Query, Reduce, Score]
107106

108107
def execute(self):
109108
"""Run the script."""

bergson/cli/auto_batch_size.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

bergson/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class IndexConfig:
201201
"""Configuration for multi-node distributed preconditioner computation."""
202202

203203
max_tokens: int | None = None
204-
"""The maximum number of tokens to process. If None, all tokens will be processed. Only available for Dataset."""
204+
"""Max tokens to process. If None, all tokens processed. Dataset only."""
205205

206206
@property
207207
def partial_run_path(self) -> Path:

bergson/query/attributor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,19 @@ def __init__(
9191
norm = torch.cat(
9292
[self.grads[name] for name in self.ordered_modules], dim=1
9393
).norm(dim=1, keepdim=True)
94+
9495
for name in self.grads:
95-
self.grads[name] /= norm
96+
# Divide by norm (may create NaN/inf if norm is zero)
97+
normalized = self.grads[name] / norm
98+
# Convert NaN/inf to 0 and warn if any were found
99+
if not torch.isfinite(normalized).all():
100+
print(
101+
f"Warning: NaN/inf values detected after normalization in "
102+
f"{name}, converting to 0"
103+
)
104+
self.grads[name] = torch.nan_to_num(
105+
normalized, nan=0.0, posinf=0.0, neginf=0.0
106+
)
96107

97108
def search(
98109
self,

0 commit comments

Comments
 (0)