forked from EleutherAI/bergson
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_gradients.py
More file actions
116 lines (96 loc) · 4.86 KB
/
test_gradients.py
File metadata and controls
116 lines (96 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import tempfile
from pathlib import Path
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from bergson.gradients import (
AdafactorNormalizer,
AdamNormalizer,
GradientCollector,
GradientProcessor,
)
def test_phi3():
temp_dir = Path(tempfile.mkdtemp())
config = AutoConfig.from_pretrained("trl-internal-testing/tiny-Phi3ForCausalLM")
model = AutoModelForCausalLM.from_config(config)
# It's important that we use a batch size of one so that we can simply use the
# aggregate gradients from the backward itself and compare against those
tokens = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=model.device)
inputs = dict(input_ids=tokens, labels=tokens)
collected_grads = {}
def closure(name: str, g: torch.Tensor):
"""Store the gradients in a dictionary for later comparison."""
collected_grads[name] = g
# Test with 16 x 16 random projection as well as with no projection
for p in (16, None):
processor = GradientProcessor(projection_dim=p)
collector = GradientCollector(model, closure, processor)
with collector:
model.zero_grad()
model(**inputs).loss.backward()
adafactors: dict[str, AdafactorNormalizer] = {}
adams: dict[str, AdamNormalizer] = {}
# Go through the motions of what GradientCollector does, but after the fact
for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)
# Handle both Linear and Conv1D modules
if hasattr(layer, 'out_features') and hasattr(layer, 'in_features'):
o, i = layer.out_features, layer.in_features
elif hasattr(layer, 'nf') and hasattr(layer, 'nx'):
o, i = layer.nf, layer.nx
else:
raise ValueError(f"Unsupported layer type: {type(layer)}")
g = layer.weight.grad
assert g is not None
moments = g.square()
if p is not None:
A = collector.projection(name, p, o, "left", g.device, g.dtype)
B = collector.projection(name, p, i, "right", g.device, g.dtype)
g = A @ g @ B.T
torch.testing.assert_close(g, collected_grad.squeeze(0))
# Store normalizers for this layer
adams[name] = AdamNormalizer(moments)
adafactors[name] = adams[name].to_adafactor()
# Now do it again but this time use the normalizers
for normalizers in (adafactors, adams):
previous_collected_grads = {}
for do_load in (False, True):
if do_load:
processor = GradientProcessor.load(str(temp_dir / "processor"))
else:
processor = GradientProcessor(
normalizers=normalizers, projection_dim=p
)
processor.save(str(temp_dir / "processor"))
collector = GradientCollector(model, closure, processor)
with collector:
model.zero_grad()
model(**inputs).loss.backward()
for name, collected_grad in collected_grads.items():
layer = model.get_submodule(name)
# Handle both Linear and Conv1D modules
if hasattr(layer, 'out_features') and hasattr(layer, 'in_features'):
o, i = layer.out_features, layer.in_features
elif hasattr(layer, 'nf') and hasattr(layer, 'nx'):
o, i = layer.nf, layer.nx
else:
raise ValueError(f"Unsupported layer type: {type(layer)}")
g = layer.weight.grad
assert g is not None
g = normalizers[name].normalize_(g)
if p is not None:
A = collector.projection(name, p, o, "left", g.device, g.dtype)
B = collector.projection(name, p, i, "right", g.device, g.dtype)
g = A @ g @ B.T
# Compare the normalized gradient with the collected gradient. We
# use a higher tolerance than the default because there seems to be
# some non-negligible numerical error that accumulates due to the
# different order of operations. Maybe we should look into this
torch.testing.assert_close(
g, collected_grad.squeeze(0), atol=1e-4, rtol=1e-4
)
# Check gradients are the same after loading and restoring
if do_load:
torch.testing.assert_close(
collected_grad, previous_collected_grads[name]
)
previous_collected_grads = collected_grads.copy()