Skip to content

Commit b5d38c4

Browse files
authored
Merge pull request #117 from EleutherAI/convert-dtype
Use dtype utils
2 parents 689f17b + 10362d3 commit b5d38c4

File tree

12 files changed

+184
-91
lines changed

12 files changed

+184
-91
lines changed

bergson/collector/dist_preconditioners_gradient_collector.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from bergson.process_preconditioners import process_preconditioners
2020
from bergson.score.scorer import Scorer
21-
from bergson.utils.utils import assert_type
21+
from bergson.utils.utils import assert_type, get_gradient_dtype
2222

2323

2424
@dataclass(kw_only=True)
@@ -98,18 +98,14 @@ def setup(self) -> None:
9898
self.owned_modules: set[str] = set()
9999
self.module_to_rank: dict[str, int] = {}
100100

101-
# TODO: handle more elegantly?
102-
self.save_dtype = (
103-
torch.float32 if self.model.dtype == torch.float32 else torch.float16
104-
)
105-
101+
self.save_dtype = get_gradient_dtype(self.model)
106102
self.lo = torch.finfo(self.save_dtype).min
107103
self.hi = torch.finfo(self.save_dtype).max
108104

109105
self.per_doc_losses = torch.full(
110106
(len(self.data),),
111107
device=self.model.device,
112-
dtype=self.save_dtype,
108+
dtype=torch.float32,
113109
fill_value=0.0,
114110
)
115111

@@ -298,11 +294,7 @@ def teardown(self):
298294
self.data = self.data.add_column(
299295
"loss",
300296
self.per_doc_losses.cpu().numpy(),
301-
feature=Value(
302-
"float16"
303-
if self.save_dtype == torch.float16
304-
else "float32" # TODO: This is not robust
305-
),
297+
feature=Value("float32"),
306298
new_fingerprint="loss",
307299
)
308300

bergson/collector/gradient_collectors.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from bergson.process_preconditioners import process_preconditioners
2121
from bergson.score.scorer import Scorer
22-
from bergson.utils.utils import assert_type
22+
from bergson.utils.utils import assert_type, get_gradient_dtype
2323

2424

2525
@dataclass(kw_only=True)
@@ -93,18 +93,14 @@ def setup(self) -> None:
9393
"consider disabling bias inclusion for now."
9494
)
9595

96-
# TODO: handle more elegantly?
97-
self.save_dtype = (
98-
torch.float32 if self.model.dtype == torch.float32 else torch.float16
99-
)
100-
96+
self.save_dtype = get_gradient_dtype(self.model)
10197
self.lo = torch.finfo(self.save_dtype).min
10298
self.hi = torch.finfo(self.save_dtype).max
10399

104100
self.per_doc_losses = torch.full(
105101
(len(self.data),),
106102
device=self.model.device,
107-
dtype=self.save_dtype,
103+
dtype=torch.float32,
108104
fill_value=0.0,
109105
)
110106

@@ -263,11 +259,7 @@ def teardown(self):
263259
self.data = self.data.add_column(
264260
"loss",
265261
self.per_doc_losses.cpu().numpy(),
266-
feature=Value(
267-
"float16"
268-
if self.save_dtype == torch.float16
269-
else "float32" # TODO: This is not robust
270-
),
262+
feature=Value("float32"),
271263
new_fingerprint="loss",
272264
)
273265

@@ -302,11 +294,7 @@ class TraceCollector(HookCollectorBase):
302294
"""Dtype for stored gradients."""
303295

304296
def setup(self) -> None:
305-
# TODO: handle more elegantly?
306-
self.save_dtype = (
307-
torch.float32 if self.model.dtype == torch.float32 else torch.float16
308-
)
309-
297+
self.save_dtype = get_gradient_dtype(self.model)
310298
self.lo = torch.finfo(self.save_dtype).min
311299
self.hi = torch.finfo(self.save_dtype).max
312300

bergson/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ class ScoreConfig:
282282
batch_size: int = 1024
283283
"""Batch size for processing the query dataset."""
284284

285+
precision: Literal["auto", "bf16", "fp16", "fp32"] = "auto"
286+
"""Precision (dtype) to convert the query and index gradients to before
287+
computing the scores. If "auto", the model's gradient dtype is used."""
288+
285289

286290
@dataclass
287291
class ReduceConfig:

bergson/data.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import Any, Sequence, cast, overload
77

8+
import ml_dtypes # noqa: F401 # registers bfloat16 dtype with numpy
89
import numpy as np
910
import pyarrow as pa
1011
import torch
@@ -20,7 +21,12 @@
2021
from numpy.typing import DTypeLike
2122

2223
from .config import DataConfig, ReduceConfig
23-
from .utils.utils import assert_type, simple_parse_args_string
24+
from .utils.utils import (
25+
assert_type,
26+
convert_dtype_to_np,
27+
simple_parse_args_string,
28+
tensor_to_numpy,
29+
)
2430

2531

2632
def ceildiv(a: int, b: int) -> int:
@@ -202,7 +208,7 @@ def create_index(
202208
"num_grads": num_grads,
203209
"dtype": struct_dtype,
204210
"grad_sizes": grad_sizes,
205-
"base_dtype": np.dtype(dtype).str,
211+
"base_dtype": np.dtype(dtype).name,
206212
},
207213
f,
208214
indent=2,
@@ -367,17 +373,16 @@ def __init__(
367373
self.rank = dist.get_rank() if dist.is_initialized() else 0
368374
if reduce_cfg is not None:
369375
num_grads = 1
376+
np_dtype = np.float32
370377
self.in_memory_grad_buffer = torch.zeros(
371378
(num_grads, sum(self.grad_sizes.values())),
372379
dtype=torch.float32,
373380
device=f"cuda:{self.rank}",
374381
)
375-
np_dtype = np.float32
376382
else:
377383
num_grads = self.num_items
384+
np_dtype = convert_dtype_to_np(dtype)
378385
self.in_memory_grad_buffer = None
379-
# TODO: Handle this more elegantly
380-
np_dtype = np.float32 if dtype == torch.float32 else np.float16
381386

382387
self.grad_buffer = create_index(
383388
path,
@@ -423,7 +428,7 @@ def __call__(self, indices: list[int], mod_grads: dict[str, torch.Tensor]):
423428
for module_name in self.grad_sizes.keys():
424429
self.grad_buffer[
425430
indices, offset : offset + mod_grads[module_name].shape[1]
426-
] = mod_grads[module_name].numpy()
431+
] = tensor_to_numpy(mod_grads[module_name])
427432
offset += mod_grads[module_name].shape[1]
428433

429434
def flush(self):
@@ -447,7 +452,7 @@ def dist_reduce(self):
447452

448453
rank = dist.get_rank() if dist.is_initialized() else 0
449454
if rank == 0:
450-
self.grad_buffer[:] = self.in_memory_grad_buffer.numpy().astype(
455+
self.grad_buffer[:] = tensor_to_numpy(self.in_memory_grad_buffer).astype(
451456
self.grad_buffer.dtype
452457
)
453458

bergson/huggingface.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
import torch.distributed as dist
1111
from datasets import Dataset
12-
from numpy.typing import DTypeLike
1312
from peft import PeftModel
1413
from torch import Tensor
1514
from torch.utils.data import DataLoader
@@ -22,6 +21,7 @@
2221
from bergson.data import create_index
2322
from bergson.gradients import AdafactorNormalizer, AdamNormalizer
2423
from bergson.utils.peft import detect_peft_modules
24+
from bergson.utils.utils import convert_dtype_to_torch
2525

2626

2727
class GradientCollectorCallback(TrainerCallback):
@@ -34,7 +34,7 @@ def __init__(
3434
attention_cfgs: dict[str, AttentionConfig] = {},
3535
projection_dim: int = 16,
3636
include_bias: bool = False,
37-
dtype: DTypeLike = np.float16,
37+
dtype: np.dtype = np.dtype(np.float16),
3838
accumulate_grads: bool = False,
3939
use_optimizer_state: bool = True,
4040
track_order: bool = False,
@@ -77,8 +77,7 @@ def __init__(
7777
self.mod_grads = {}
7878
self.batch_indices: Tensor | None = None
7979

80-
# TODO: Handle this more elegantly
81-
self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16
80+
self.torch_dtype = convert_dtype_to_torch(self.dtype)
8281

8382
def write_grads(self, grad_buffer: np.memmap):
8483
torch.cuda.synchronize()

bergson/normalizer/fit_normalizers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Normalizer,
2020
)
2121
from bergson.process_preconditioners import process_preconditioners
22-
from bergson.utils.utils import assert_type
22+
from bergson.utils.utils import assert_type, get_gradient_dtype
2323

2424

2525
@dataclass(kw_only=True)
@@ -123,11 +123,7 @@ def setup(self) -> None:
123123
"consider disabling bias inclusion for now."
124124
)
125125

126-
# TODO: handle more elegantly?
127-
self.save_dtype = (
128-
torch.float32 if self.model.dtype == torch.float32 else torch.float16
129-
)
130-
126+
self.save_dtype = get_gradient_dtype(self.model)
131127
self.lo = torch.finfo(self.save_dtype).min
132128
self.hi = torch.finfo(self.save_dtype).max
133129

bergson/score/score.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from bergson.distributed import launch_distributed_run
1818
from bergson.gradients import GradientProcessor
1919
from bergson.score.scorer import Scorer
20-
from bergson.utils.utils import assert_type
20+
from bergson.utils.utils import (
21+
assert_type,
22+
convert_precision_to_torch,
23+
get_gradient_dtype,
24+
)
2125
from bergson.utils.worker_utils import (
2226
create_processor,
2327
setup_data_pipeline,
@@ -277,6 +281,12 @@ def score_worker(
277281
"attention_cfgs": attention_cfgs,
278282
}
279283

284+
score_dtype = (
285+
convert_precision_to_torch(score_cfg.precision)
286+
if score_cfg.precision != "auto"
287+
else get_gradient_dtype(model)
288+
)
289+
280290
if isinstance(ds, Dataset):
281291
kwargs["batches"] = allocate_batches(ds["length"], index_cfg.token_batch_size)
282292
kwargs["scorer"] = Scorer(
@@ -285,7 +295,7 @@ def score_worker(
285295
query_grads,
286296
score_cfg,
287297
device=torch.device(f"cuda:{rank}"),
288-
dtype=torch.float32 if model.dtype == torch.float32 else torch.float16,
298+
dtype=score_dtype,
289299
)
290300

291301
collect_gradients(**kwargs)
@@ -310,7 +320,7 @@ def flush(kwargs):
310320
query_grads,
311321
score_cfg,
312322
torch.device(f"cuda:{rank}"),
313-
model.dtype if model.dtype != "auto" else torch.float32,
323+
score_dtype,
314324
)
315325

316326
collect_gradients(**kwargs)

bergson/score/scorer.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from pathlib import Path
2-
from typing import Callable
32

43
import torch
54

@@ -8,14 +7,8 @@
87

98

109
class Scorer:
11-
scorer_callback: Callable
12-
13-
num_scores: int
14-
1510
writer: ScoreWriter
1611

17-
device: torch.device
18-
1912
def __init__(
2013
self,
2114
path: Path,
@@ -29,10 +22,14 @@ def __init__(
2922
self.dtype = dtype
3023
self.num_items = num_items
3124

32-
self.scorer_callback = self.build_scorer_callback(
33-
query_grads,
34-
score_cfg,
25+
self.query_tensor = torch.cat(
26+
[
27+
query_grads[m].to(device=self.device, dtype=self.dtype)
28+
for m in score_cfg.modules
29+
],
30+
dim=1,
3531
)
32+
self.score_cfg = score_cfg
3633

3734
num_scores = len(query_grads[score_cfg.modules[0]])
3835

@@ -47,37 +44,22 @@ def __call__(
4744
indices: list[int],
4845
mod_grads: dict[str, torch.Tensor],
4946
):
50-
first_grad = next(iter(mod_grads.values()))
51-
if first_grad.dtype != self.dtype:
47+
# Convert the gradients to the scoring dtype
48+
if next(iter(mod_grads.values())).dtype != self.dtype:
5249
mod_grads = {name: grad.to(self.dtype) for name, grad in mod_grads.items()}
5350

54-
scores = self.scorer_callback(mod_grads)
55-
self.writer(indices, scores)
56-
57-
def build_scorer_callback(
58-
self,
59-
query_grads: dict[str, torch.Tensor],
60-
score_cfg: ScoreConfig,
61-
) -> Callable:
62-
"""Unified scorer builder for all scorer types."""
63-
query_tensor = torch.cat(
64-
[
65-
query_grads[m].to(device=self.device, dtype=self.dtype)
66-
for m in score_cfg.modules
67-
],
68-
dim=1,
69-
)
51+
scores = self.score(mod_grads)
7052

71-
@torch.inference_mode()
72-
def callback(mod_grads: dict[str, torch.Tensor]):
73-
grads = torch.cat([mod_grads[m] for m in score_cfg.modules], dim=1)
74-
if score_cfg.unit_normalize:
75-
grads /= grads.norm(dim=1, keepdim=True)
53+
self.writer(indices, scores)
7654

77-
if score_cfg.score == "nearest":
78-
all_scores = grads @ query_tensor.T
79-
return all_scores.max(dim=-1).values
55+
@torch.inference_mode()
56+
def score(self, mod_grads: dict[str, torch.Tensor]):
57+
grads = torch.cat([mod_grads[m] for m in self.score_cfg.modules], dim=1)
58+
if self.score_cfg.unit_normalize:
59+
grads /= grads.norm(dim=1, keepdim=True)
8060

81-
return grads @ query_tensor.T
61+
if self.score_cfg.score == "nearest":
62+
all_scores = grads @ self.query_tensor.T
63+
return all_scores.max(dim=-1).values
8264

83-
return callback
65+
return grads @ self.query_tensor.T

0 commit comments

Comments
 (0)