-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathtest_mps_basic.py
More file actions
246 lines (179 loc) · 8.42 KB
/
test_mps_basic.py
File metadata and controls
246 lines (179 loc) · 8.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""Apple Silicon MPS smoke tests for TransformerLens.
Design principles:
- All tests skip automatically on non-MPS runners (Linux, Windows, CPU-only Macs)
- Only float32 is used (bfloat16 is unsupported on MPS)
- Only small models are loaded (roneneldan/TinyStories-1M, ~50MB)
- torch.mps.empty_cache() + gc.collect() between tests to stay within memory budget
- TRANSFORMERLENS_ALLOW_MPS=1 must be set for get_device() to return "mps"
CI: These tests are run via the `mps-checks` job in .github/workflows/checks.yml
which sets TRANSFORMERLENS_ALLOW_MPS=1 and runs on macos-latest.
"""
import gc
import os
import warnings
import pytest
import torch
# Skip the entire module on non-MPS runners (Linux CI, CPU-only Macs)
pytestmark = pytest.mark.skipif(
not torch.backends.mps.is_available(),
reason="MPS not available on this runner — skipping Apple Silicon tests",
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SMALL_MODEL = "roneneldan/TinyStories-1M" # ~50MB, safe for 1GB runner budget
def _load_tiny_model(device: str = "mps"):
"""Load TinyStories-1M on the given device with float32 (bfloat16 unsupported on MPS)."""
from transformer_lens import HookedTransformer
return HookedTransformer.from_pretrained(SMALL_MODEL, device=device, dtype=torch.float32)
def _cleanup(model=None):
"""Free GPU memory between tests."""
if model is not None:
del model
torch.mps.empty_cache()
gc.collect()
# ---------------------------------------------------------------------------
# 1. Device detection (no model load — instant)
# ---------------------------------------------------------------------------
def test_mps_device_available():
"""Sanity check: MPS backend is present and built on this runner."""
assert torch.backends.mps.is_available(), "MPS not available"
assert torch.backends.mps.is_built(), "MPS not built into this PyTorch"
def test_mps_get_device_returns_mps_with_env_var():
"""get_device() auto-selects MPS when TRANSFORMERLENS_ALLOW_MPS=1 is set."""
from transformer_lens.utilities.devices import get_device
original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "")
try:
os.environ["TRANSFORMERLENS_ALLOW_MPS"] = "1"
device = get_device()
assert isinstance(device, str)
assert device == "mps", f"Expected 'mps', got '{device}'"
finally:
if original:
os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original
else:
os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None)
def test_mps_get_device_falls_back_to_cpu_without_env_var():
"""get_device() falls back to CPU when TRANSFORMERLENS_ALLOW_MPS is unset (safety default)."""
from transformer_lens.utilities.devices import get_device
original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "")
try:
os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None)
device = get_device()
# On a Mac with no CUDA, should return cpu (safe default)
assert isinstance(device, str)
assert (
device == "cpu"
), f"Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not '{device}'"
finally:
if original:
os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original
def test_mps_warn_if_mps_emits_warning_without_env_var():
"""warn_if_mps() emits a UserWarning when MPS is used without the env var."""
import transformer_lens.utilities.devices as devices_module
from transformer_lens.utilities import warn_if_mps
original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "")
original_warned = devices_module._mps_warned
try:
os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None)
devices_module._mps_warned = False # reset so warning fires
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn_if_mps("mps")
assert any(
"MPS backend" in str(warning.message) for warning in w
), "Expected MPS warning but got: " + str([str(x.message) for x in w])
finally:
if original:
os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original
devices_module._mps_warned = original_warned
# ---------------------------------------------------------------------------
# 2. Raw tensor operations on Metal (no model load)
# ---------------------------------------------------------------------------
def test_mps_tensor_basic_operations():
"""Basic tensor arithmetic runs on the Metal GPU without errors."""
x = torch.randn(16, 32, device="mps", dtype=torch.float32)
y = torch.randn(16, 32, device="mps", dtype=torch.float32)
z = x + y
assert z.device.type == "mps"
w = torch.matmul(x, y.T)
assert w.device.type == "mps"
assert w.shape == (16, 16)
# Verify result comes back to CPU correctly
z_cpu = z.cpu()
assert z_cpu.device.type == "cpu"
_cleanup()
def test_mps_softmax_and_layernorm():
"""Softmax and LayerNorm — core transformer ops — work on MPS."""
x = torch.randn(4, 16, 64, device="mps", dtype=torch.float32)
softmax_out = torch.nn.functional.softmax(x, dim=-1)
assert softmax_out.device.type == "mps"
assert torch.allclose(softmax_out.sum(dim=-1), torch.ones(4, 16, device="mps"), atol=1e-5)
ln = torch.nn.LayerNorm(64).to("mps")
ln_out = ln(x)
assert ln_out.device.type == "mps"
_cleanup()
# ---------------------------------------------------------------------------
# 3. Model loading and forward pass on Metal
# ---------------------------------------------------------------------------
def test_mps_model_forward_pass():
"""TinyStories-1M loads and runs a forward pass on the Metal GPU."""
model = _load_tiny_model(device="mps")
tokens = model.to_tokens("Once upon a time")
assert tokens.device.type == "mps", f"Tokens should be on MPS, got {tokens.device}"
logits = model(tokens)
assert logits.device.type == "mps", f"Logits should be on MPS, got {logits.device}"
assert logits.shape[-1] == model.cfg.d_vocab
assert not torch.isnan(logits).any(), "NaN values in logits — possible MPS compute error"
_cleanup(model)
def test_mps_run_with_cache():
"""run_with_cache() returns cache tensors on the Metal GPU."""
model = _load_tiny_model(device="mps")
tokens = model.to_tokens("The quick brown fox")
logits, cache = model.run_with_cache(tokens)
assert logits.device.type == "mps"
# Check a representative set of cache keys
hook_q = cache["blocks.0.attn.hook_q"]
assert hook_q.device.type == "mps", f"Cache tensor not on MPS: {hook_q.device}"
assert not torch.isnan(hook_q).any(), "NaN in attention query cache"
_cleanup(model)
def test_mps_activation_hook_fires_on_metal():
"""run_with_hooks() fires hooks and hook tensors are on the Metal GPU."""
model = _load_tiny_model(device="mps")
tokens = model.to_tokens("Apple Silicon rocks")
hook_devices = []
hook_shapes = []
def capture_hook(value, hook):
hook_devices.append(value.device.type)
hook_shapes.append(value.shape)
return value
model.run_with_hooks(
tokens,
fwd_hooks=[
("blocks.0.attn.hook_q", capture_hook),
("blocks.0.mlp.hook_post", capture_hook),
],
)
assert len(hook_devices) == 2, f"Expected 2 hooks to fire, got {len(hook_devices)}"
for device in hook_devices:
assert device == "mps", f"Hook tensor not on MPS: {device}"
_cleanup(model)
def test_mps_float32_inference():
"""Explicit float32 model loads and infers correctly on MPS."""
model = _load_tiny_model(device="mps")
# Verify all parameters are float32
for name, param in model.named_parameters():
assert param.dtype == torch.float32, f"Parameter {name} has wrong dtype: {param.dtype}"
tokens = model.to_tokens("Testing float32 on Metal")
logits = model(tokens)
assert logits.dtype == torch.float32
_cleanup(model)
def test_mps_loss_computation():
"""Loss computation (return_type='loss') works on MPS."""
model = _load_tiny_model(device="mps")
loss = model("Once upon a time in a land", return_type="loss")
assert isinstance(loss, torch.Tensor)
assert loss.device.type == "mps"
assert not torch.isnan(loss), f"NaN loss — possible MPS compute error: {loss}"
assert loss.item() > 0, "Loss should be positive"
_cleanup(model)