Skip to content

Commit 9210d50

Browse files
committed
fix claude issues
1 parent 39386aa commit 9210d50

File tree

3 files changed

+27
-20
lines changed

3 files changed

+27
-20
lines changed

bergson/collector/collector.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -610,13 +610,18 @@ def fwd_bwd(model, x: Tensor, y: Tensor, batch: dict):
610610
return fwd_bwd
611611

612612

613-
def fwd_bwd_hessian_factory(cfg: HessianConfig) -> Callable:
613+
def fwd_bwd_hessian_factory(
614+
index_cfg: IndexConfig, hessian_cfg: HessianConfig
615+
) -> Callable:
614616
def fwd_bwd_hessian(model, x: Tensor, y: Tensor, batch: dict):
615617
logits = model(x).logits[:, :-1]
616618
masks = y[:, 1:] != -100
617-
denoms = masks.sum(dim=1, dtype=model.dtype)
618-
619-
if not cfg.use_dataset_labels:
619+
denoms = (
620+
masks.sum(dim=1, dtype=model.dtype)
621+
if index_cfg.loss_reduction == "mean"
622+
else 1.0
623+
)
624+
if not hessian_cfg.use_dataset_labels:
620625
losses = F.cross_entropy(
621626
logits.reshape(-1, logits.size(-1)),
622627
y[:, 1:].flatten(),
@@ -636,6 +641,7 @@ def fwd_bwd_hessian(model, x: Tensor, y: Tensor, batch: dict):
636641
sampled_tokens.flatten(),
637642
reduction="none",
638643
).reshape_as(y[:, 1:])
644+
losses = losses.sum(1) / denoms
639645

640646
losses.sum().backward()
641647
model.zero_grad()

bergson/hessians/eigenvectors.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ def compute_eigendecomposition(
307307
)
308308

309309
gc.collect()
310-
torch.cuda.empty_cache()
311310

312311

313312
def _merge_and_shard_eigenvectors(
@@ -365,6 +364,5 @@ def _merge_and_shard_eigenvectors(
365364

366365
del tensor
367366
gc.collect()
368-
torch.cuda.empty_cache()
369367

370368
return result_dict

bergson/hessians/hessian_approximations.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from bergson.hessians.kfac import CovarianceCollector
2121
from bergson.hessians.tkfac import TraceCovarianceCollector
2222
from bergson.utils.utils import (
23+
convert_precision_to_torch,
2324
setup_reproducibility,
2425
validate_batch_size,
2526
)
@@ -81,7 +82,7 @@ def hessian_worker(
8182
rank: int,
8283
local_rank: int,
8384
world_size: int,
84-
cfg: IndexConfig,
85+
index_cfg: IndexConfig,
8586
hessian_cfg: HessianConfig,
8687
ds: Dataset,
8788
):
@@ -135,35 +136,37 @@ def hessian_worker(
135136
world_size=world_size,
136137
)
137138

138-
model, target_modules = setup_model_and_peft(cfg)
139+
model, target_modules = setup_model_and_peft(index_cfg)
139140

140-
attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}
141+
attention_cfgs = {
142+
module: index_cfg.attention for module in index_cfg.split_attention_modules
143+
}
141144

142145
kwargs = {
143146
"model": model,
144147
"data": ds,
145-
"cfg": cfg,
148+
"cfg": index_cfg,
146149
"hessian_cfg": hessian_cfg,
147150
"target_modules": target_modules,
148151
"attention_cfgs": attention_cfgs,
149152
}
150153

151-
batches = allocate_batches(ds["length"], cfg.token_batch_size)
154+
batches = allocate_batches(ds["length"], index_cfg.token_batch_size)
152155
kwargs["batches"] = batches
153156
collect_hessians(**kwargs)
154157

155158
total_processed = torch.load(
156-
f"{cfg.partial_run_path}/total_processed.pt",
159+
f"{index_cfg.partial_run_path}/total_processed.pt",
157160
map_location="cpu",
158161
weights_only=False,
159162
)
160163

161164
compute_eigendecomposition(
162-
os.path.join(cfg.partial_run_path, "activation_sharded"),
165+
os.path.join(index_cfg.partial_run_path, "activation_sharded"),
163166
total_processed=total_processed,
164167
)
165168
compute_eigendecomposition(
166-
os.path.join(cfg.partial_run_path, "gradient_sharded"),
169+
os.path.join(index_cfg.partial_run_path, "gradient_sharded"),
167170
total_processed=total_processed,
168171
)
169172

@@ -174,7 +177,7 @@ def hessian_worker(
174177
def collect_hessians(
175178
model: PreTrainedModel,
176179
data: Dataset,
177-
cfg: IndexConfig,
180+
index_cfg: IndexConfig,
178181
*,
179182
batches: list[list[int]] | None = None,
180183
target_modules: set[str] | None = None,
@@ -190,14 +193,14 @@ def collect_hessians(
190193
hessian_dtype = (
191194
model.dtype
192195
if hessian_cfg.hessian_dtype == "auto"
193-
else hessian_cfg.hessian_dtype
196+
else convert_precision_to_torch(hessian_cfg.hessian_dtype)
194197
)
195198

196199
collector_args = {
197200
"model": model.base_model, # type: ignore
198201
"target_modules": target_modules,
199202
"attention_cfgs": attention_cfgs or {},
200-
"path": str(cfg.partial_run_path),
203+
"path": str(index_cfg.partial_run_path),
201204
}
202205
desc = f"Approximating Hessians with {hessian_cfg.method}"
203206
if ev_correction:
@@ -207,16 +210,16 @@ def collect_hessians(
207210
collector_args["dtype"] = hessian_dtype
208211
collector = HESSIAN_APPROXIMATIONS[hessian_cfg.method](**collector_args)
209212

210-
validate_batch_size(model, cfg.token_batch_size, collector)
213+
validate_batch_size(model, index_cfg.token_batch_size, collector)
211214

212215
computer = CollectorComputer(
213216
model=model, # type: ignore
214217
data=data,
215218
collector=collector,
216219
batches=batches,
217-
cfg=cfg,
220+
cfg=index_cfg,
218221
)
219222

220-
computer.forward_backward = fwd_bwd_hessian_factory(hessian_cfg)
223+
computer.forward_backward = fwd_bwd_hessian_factory(index_cfg, hessian_cfg)
221224

222225
computer.run_with_collector_hooks(desc=desc)

0 commit comments

Comments
 (0)