forked from EleutherAI/bergson
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_gradients.py
More file actions
106 lines (87 loc) · 4.31 KB
/
test_gradients.py
File metadata and controls
106 lines (87 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import tempfile
from pathlib import Path
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from bergson.gradients import (
AdafactorNormalizer,
AdamNormalizer,
GradientCollector,
GradientProcessor,
LayerAdapter,
)
def test_phi3():
temp_dir = Path(tempfile.mkdtemp())
config = AutoConfig.from_pretrained("trl-internal-testing/tiny-Phi3ForCausalLM")
model = AutoModelForCausalLM.from_config(config)
# It's important that we use a batch size of one so that we can simply use the
# aggregate gradients from the backward itself and compare against those
tokens = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=model.device)
inputs = dict(input_ids=tokens, labels=tokens)
collected_grads = {}
def closure(name: str, g: torch.Tensor):
"""Store the gradients in a dictionary for later comparison."""
collected_grads[name] = g
# Test with 16 x 16 random projection as well as with no projection
for p in (16, None):
processor = GradientProcessor(projection_dim=p)
collector = GradientCollector(model, closure, processor)
with collector:
model.zero_grad()
model(**inputs).loss.backward()
adafactors: dict[str, AdafactorNormalizer] = {}
adams: dict[str, AdamNormalizer] = {}
# Go through the motions of what GradientCollector does, but after the fact
for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)
i = getattr(layer, LayerAdapter.in_attr(layer))
o = getattr(layer, LayerAdapter.out_attr(layer))
g = layer.weight.grad
assert g is not None
moments = g.square()
if p is not None:
A = collector.projection(name, p, o, "left", g.device, g.dtype)
B = collector.projection(name, p, i, "right", g.device, g.dtype)
g = A @ g @ B.T
torch.testing.assert_close(g, collected_grad.squeeze(0))
# Store normalizers for this layer
adams[name] = AdamNormalizer(moments)
adafactors[name] = adams[name].to_adafactor()
# Now do it again but this time use the normalizers
for normalizers in (adafactors, adams):
previous_collected_grads = {}
for do_load in (False, True):
if do_load:
processor = GradientProcessor.load(str(temp_dir / "processor"))
else:
processor = GradientProcessor(
normalizers=normalizers, projection_dim=p
)
processor.save(str(temp_dir / "processor"))
collector = GradientCollector(model, closure, processor)
with collector:
model.zero_grad()
model(**inputs).loss.backward()
for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)
i = getattr(layer, LayerAdapter.in_attr(layer))
o = getattr(layer, LayerAdapter.out_attr(layer))
g = layer.weight.grad
assert g is not None
g = normalizers[name].normalize_(g)
if p is not None:
A = collector.projection(name, p, o, "left", g.device, g.dtype)
B = collector.projection(name, p, i, "right", g.device, g.dtype)
g = A @ g @ B.T
# Compare the normalized gradient with the collected gradient. We
# use a higher tolerance than the default because there seems to be
# some non-negligible numerical error that accumulates due to the
# different order of operations. Maybe we should look into this
torch.testing.assert_close(
g, collected_grad.squeeze(0), atol=1e-4, rtol=1e-4
)
# Check gradients are the same after loading and restoring
if do_load:
torch.testing.assert_close(
collected_grad, previous_collected_grads[name]
)
previous_collected_grads = collected_grads.copy()