-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathconftest.py
More file actions
90 lines (68 loc) · 2.46 KB
/
conftest.py
File metadata and controls
90 lines (68 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Global pytest configuration for memory management and test optimization."""
import gc
import random
import tempfile
from pathlib import Path
import numpy as np
import pytest
import torch
@pytest.fixture(autouse=True, scope="function")
def cleanup_memory():
"""Automatically clean up memory after each test."""
yield
# Clear torch cache for all accelerators
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
# Force garbage collection for cleanup
gc.collect()
@pytest.fixture(autouse=True, scope="class")
def cleanup_class_memory():
"""Clean up memory after each test class."""
yield
# More aggressive cleanup after test classes
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
gc.collect()
# Configure pytest to be more memory-efficient
def pytest_configure(config):
"""Configure pytest for better memory usage and reproducible randomness."""
# Configure garbage collection to be more aggressive
gc.set_threshold(700, 10, 10)
# Set random seeds for consistent test parametrization across parallel workers
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
@pytest.fixture(autouse=True, scope="session")
def _enable_hf_retry_for_tests():
"""Deferred to fixture (not pytest_configure) so jaxtyping installs first."""
from transformer_lens.utilities.hf_utils import enable_hf_retry
enable_hf_retry()
yield
@pytest.fixture(scope="session")
def gpt2_tokenizer():
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained("gpt2")
@pytest.fixture(scope="session")
def gpt2_hooked_processed():
"""Read-only use only — mutations leak across the session."""
from transformer_lens import HookedTransformer
return HookedTransformer.from_pretrained("gpt2", device="cpu")
def pytest_sessionfinish(session, exitstatus):
"""Clean up at the end of test session."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
gc.collect()
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)