Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Don't save data to a directory that is not in the gitignore - especially the dat

Don't remove large datasets from the HF cache without asking.

When a bug is reported, start by writing a test that reproduces the bug. Then fix the bug and prove it with a passing test.

### Tests

Mark tests requiring GPUs with `@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")`.
Expand Down
10 changes: 5 additions & 5 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
PreprocessConfig,
QueryConfig,
ScoreConfig,
TrackstarConfig,
TrackStarConfig,
)
from .double_backward import DoubleBackwardConfig, double_backward
from .hessians.hessian_approximations import approximate_hessians
Expand Down Expand Up @@ -120,7 +120,7 @@ def execute(self):


@dataclass
class Trackstar:
class TrackStar:
"""Run preconditioners, build, and score as a single pipeline."""

index_cfg: IndexConfig
Expand All @@ -129,13 +129,13 @@ class Trackstar:

preprocess_cfg: PreprocessConfig

trackstar_cfg: TrackstarConfig
trackstar_cfg: TrackStarConfig

def execute(self):
if self.index_cfg.normalizer != "adafactor":
print(
"Warning: not using Adafactor normalizer. Pass --normalizer adafactor "
"to match the Trackstar paper."
"to match the TrackStar paper."
)

trackstar(
Expand All @@ -160,7 +160,7 @@ class Main:
"""Routes to the subcommands."""

command: Union[
Build, Query, Preconditioners, Reduce, Score, Hessian, Trackstar, Magic
Build, Query, Preconditioners, Reduce, Score, Hessian, TrackStar, Magic
]

def execute(self):
Expand Down
131 changes: 79 additions & 52 deletions bergson/collector/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
from bergson.utils.peft import set_peft_enabled
from bergson.utils.utils import assert_type

# Ensure bf16 matmuls use fp32 accumulation across tiles. Without this, cross-tile
# reductions may use fp16 accumulators, losing precision when summing outer products
# across long sequences.
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False


@dataclass
class HookCollectorBase(ContextDecorator, ABC):
Expand Down Expand Up @@ -753,37 +758,49 @@ def fwd_bwd_factory(cfg: IndexConfig) -> Callable:
Returns a tensor of shape [batch_size] with one loss value per sample.
"""

_AUTOCAST_DTYPES = {"bf16": torch.bfloat16, "fp16": torch.float16}
autocast_dtype = _AUTOCAST_DTYPES.get(cfg.precision)

def fwd_bwd(model, x: Tensor, y: Tensor, batch: dict):
logits = model(x).logits[:, :-1]
masks = y[:, 1:] != -100
denoms = (
masks.sum(dim=1, dtype=model.dtype) if cfg.loss_reduction == "mean" else 1.0
device_type = "cuda" if x.is_cuda else "cpu"
amp = (
torch.autocast(device_type, dtype=autocast_dtype)
if autocast_dtype
else nullcontext()
)
with amp:
logits = model(x).logits[:, :-1]
masks = y[:, 1:] != -100
denoms = (
masks.sum(dim=1, dtype=model.dtype)
if cfg.loss_reduction == "mean"
else 1.0
)

if cfg.loss_fn == "kl":
with torch.inference_mode():
set_peft_enabled(model, False)
ref_lps = torch.log_softmax(model(x).logits[:, :-1], dim=-1)
set_peft_enabled(model, True)
if cfg.loss_fn == "kl":
with torch.inference_mode():
set_peft_enabled(model, False)
ref_lps = torch.log_softmax(model(x).logits[:, :-1], dim=-1)
set_peft_enabled(model, True)

ft_lps = torch.log_softmax(logits, dim=-1)
ft_lps = torch.log_softmax(logits, dim=-1)

# Compute average KL across all unmasked tokens
kls = torch.sum(ft_lps.exp() * (ft_lps - ref_lps), dim=-1)
losses = torch.sum(kls * masks, dim=-1) / denoms
if "advantage" in batch:
losses *= torch.tensor(batch["advantage"], device=losses.device)
# Compute average KL across all unmasked tokens
kls = torch.sum(ft_lps.exp() * (ft_lps - ref_lps), dim=-1)
losses = torch.sum(kls * masks, dim=-1) / denoms
if "advantage" in batch:
losses *= torch.tensor(batch["advantage"], device=losses.device)

else:
losses = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y[:, 1:].flatten(),
reduction="none",
label_smoothing=cfg.label_smoothing,
).reshape_as(y[:, 1:])
losses = losses.sum(1) / denoms
if "advantage" in batch:
losses *= torch.tensor(batch["advantage"], device=losses.device)
else:
losses = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y[:, 1:].flatten(),
reduction="none",
label_smoothing=cfg.label_smoothing,
).reshape_as(y[:, 1:])
losses = losses.sum(1) / denoms
if "advantage" in batch:
losses *= torch.tensor(batch["advantage"], device=losses.device)

losses.sum().backward()
model.zero_grad()
Expand All @@ -796,37 +813,47 @@ def fwd_bwd(model, x: Tensor, y: Tensor, batch: dict):
def fwd_bwd_hessian_factory(
index_cfg: IndexConfig, hessian_cfg: HessianConfig
) -> Callable:
_AUTOCAST_DTYPES = {"bf16": torch.bfloat16, "fp16": torch.float16}
autocast_dtype = _AUTOCAST_DTYPES.get(index_cfg.precision)

def fwd_bwd_hessian(model, x: Tensor, y: Tensor, batch: dict):
logits = model(x).logits[:, :-1]
masks = y[:, 1:] != -100
denoms = (
masks.sum(dim=1, dtype=model.dtype)
if index_cfg.loss_reduction == "mean"
else 1.0
device_type = "cuda" if x.is_cuda else "cpu"
amp = (
torch.autocast(device_type, dtype=autocast_dtype)
if autocast_dtype
else nullcontext()
)
if hessian_cfg.use_dataset_labels:
losses = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y[:, 1:].flatten(),
reduction="none",
).reshape_as(y[:, 1:])
losses = losses.sum(1) / denoms
else:
with torch.no_grad():
probs = F.softmax(logits, dim=-1)
sampled_tokens = torch.multinomial(
probs.reshape(-1, probs.size(-1)),
num_samples=1,
replacement=True,
with amp:
logits = model(x).logits[:, :-1]
masks = y[:, 1:] != -100
denoms = (
masks.sum(dim=1, dtype=model.dtype)
if index_cfg.loss_reduction == "mean"
else 1.0
)
if hessian_cfg.use_dataset_labels:
losses = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y[:, 1:].flatten(),
reduction="none",
).reshape_as(y[:, 1:])
losses = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
sampled_tokens.flatten(),
reduction="none",
).reshape_as(y[:, 1:])
losses = losses.sum(1) / denoms
losses = losses.sum(1) / denoms
else:
with torch.no_grad():
probs = F.softmax(logits, dim=-1)
sampled_tokens = torch.multinomial(
probs.reshape(-1, probs.size(-1)),
num_samples=1,
replacement=True,
).reshape_as(y[:, 1:])
losses = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
sampled_tokens.flatten(),
reduction="none",
).reshape_as(y[:, 1:])
losses = losses.sum(1) / denoms

losses.sum().backward()
losses.sum().backward()
model.zero_grad()

return losses
Expand Down
15 changes: 10 additions & 5 deletions bergson/collector/gradient_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]):
P = self._compute_gradient(module, g)

if not self.cfg.skip_preconditioners:
P = P.float()
if name in self.processor.preconditioners:
self.processor.preconditioners[name].addmm_(P.mT, P)
else:
self.processor.preconditioners[name] = P.mT @ P
# Disable autocast so preconditioner accumulation stays in fp32
with torch.autocast(P.device.type, enabled=False):
P = P.float()
if name in self.processor.preconditioners:
self.processor.preconditioners[name].addmm_(P.mT, P)
else:
self.processor.preconditioners[name] = P.mT @ P

if self.save_index and self.preprocess_cfg.aggregation == "none":
# Asynchronously move the gradient to CPU and convert to the final
Expand All @@ -109,6 +111,9 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]):
device="cpu", dtype=self.save_dtype, non_blocking=True
)
else:
# TODO we probably only want this for build not score so it
# should happen in builder.
# Benchmark and add disk save precision flag
self.mod_grads[name] = P.to(dtype=self.save_dtype)

def process_batch(self, indices: list[int], **kwargs):
Expand Down
13 changes: 9 additions & 4 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ class FaissConfig:


@dataclass
class TrackstarConfig:
class TrackStarConfig:
"""Config for the trackstar pipeline query dataset."""

query: DataConfig = field(default_factory=DataConfig)
Expand All @@ -386,9 +386,14 @@ class TrackstarConfig:
index preconditioners intersect at this component. Typical value is
~1000 out of ~65K total components."""

num_stats_sample_preconditioner: bool = True
"""Whether to use num_stats_sample items or the full dataset to
compute preconditioners."""
sample_preconditioners: bool = True
"""Whether to use num_stats_sample items to compute preconditioners.
If False, uses the full dataset."""

stats_token_batch_size: int | None = None
"""Token batch size for normalizer/preconditioner estimation steps (1-2).
These components are always collected in fp32, so use ~2x the VRAM when
in half precision. If None, uses the main token_batch_size."""

resume: bool = False
"""Skip pipeline steps whose output directory already exists."""
6 changes: 4 additions & 2 deletions bergson/normalizer/fit_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
@dataclass(kw_only=True)
class NormalizerCollector(HookCollectorBase):
"""
Collects per-sample gradients from model layers and writes them to disk.
Collects per-sample gradients from model layers and uses them to fit
an optimizer-based second moment normalizer.

- For each forward/backward hook, we compute the the gradient or a low-rank
approximation via random projections, if cfg.projection_dim is set.
- Supports normalization via Adam or Adafactor normalizers.
- Fit Adam or Adafactor normalizers.
"""

data: Dataset
Expand Down Expand Up @@ -142,6 +143,7 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]):
a = module._inputs # [N, S, I]

assert isinstance(a, torch.Tensor), "Activation cache missing for module"

name = assert_type(str, module._name)

P = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I]
Expand Down
Loading
Loading