Skip to content

Commit 38ff4d9

Browse files
committed
simplify autobatchsize
1 parent eddac3a commit 38ff4d9

File tree

8 files changed

+83
-393
lines changed

8 files changed

+83
-393
lines changed

bergson/__main__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from simple_parsing import ArgumentParser, ConflictResolution
77

88
from .build import build
9-
from .cli.auto_batch_size import AutoBatchSize
109
from .config import IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
1110
from .query.query_index import query
1211
from .reduce import reduce
@@ -103,7 +102,7 @@ def execute(self):
103102
class Main:
104103
"""Routes to the subcommands."""
105104

106-
command: Union[Build, Query, Reduce, Score, AutoBatchSize]
105+
command: Union[Build, Query, Reduce, Score]
107106

108107
def execute(self):
109108
"""Run the script."""

bergson/cli/auto_batch_size.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

bergson/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ class IndexConfig:
201201
"""Configuration for multi-node distributed preconditioner computation."""
202202

203203
max_tokens: int | None = None
204-
"""The maximum number of tokens to process. If None, all tokens will be processed. Only available for Dataset."""
204+
"""Max tokens to process. If None, all tokens processed. Dataset only.
205+
This experimental feature may be removed in the future."""
205206

206207
@property
207208
def partial_run_path(self) -> Path:

bergson/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def launch_distributed_run(
8282
newline = "\n"
8383
raise RuntimeError(
8484
f"{process_name} failed with {len(result.failures)} process "
85-
f"failure(s): {newline.join(result.failures)}"
85+
f"failure(s): {newline.join([str(f) for f in result.failures])}"
8686
)
8787
finally:
8888
if ctx is not None:

bergson/query/attributor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
# Load the gradients into memory
7979
mmap = load_gradients(index_path)
8080
assert mmap.dtype.names is not None
81-
# Copy gradients into device memory (handles bfloat16/V2 void types)
81+
# Copy gradients into device memory
8282
self.grads = {
8383
name: numpy_to_tensor(mmap[name]).to(device=device, dtype=dtype)
8484
for name in mmap.dtype.names
@@ -91,8 +91,19 @@ def __init__(
9191
norm = torch.cat(
9292
[self.grads[name] for name in self.ordered_modules], dim=1
9393
).norm(dim=1, keepdim=True)
94+
9495
for name in self.grads:
95-
self.grads[name] /= norm
96+
# Divide by norm (may create NaN/inf if norm is zero)
97+
normalized = self.grads[name] / norm
98+
# Convert NaN/inf to 0 and warn if any were found
99+
if not torch.isfinite(normalized).all():
100+
print(
101+
f"Warning: NaN/inf values detected after normalization in "
102+
f"{name}, converting to 0"
103+
)
104+
self.grads[name] = torch.nan_to_num(
105+
normalized, nan=0.0, posinf=0.0, neginf=0.0
106+
)
96107

97108
def search(
98109
self,

0 commit comments

Comments
 (0)