-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathconftest.py
More file actions
38 lines (24 loc) · 1.03 KB
/
conftest.py
File metadata and controls
38 lines (24 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Session fixtures for acceptance tests.
transformer_lens imports stay inside fixture bodies — jaxtyping's pytest_configure
hook must install before the package is first imported.
"""
import pytest
@pytest.fixture(scope="session")
def gpt2_model():
"""Session-scoped HookedTransformer gpt2 with default weight processing."""
from transformer_lens import HookedTransformer
return HookedTransformer.from_pretrained("gpt2", device="cpu")
@pytest.fixture(scope="session")
def bloom_560m_hooked():
from transformer_lens import HookedTransformer
return HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)
@pytest.fixture(scope="session")
def bloom_560m_hf_model():
from transformers import AutoModelForCausalLM
return AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
@pytest.fixture(scope="session")
def bloom_560m_hf_tokenizer():
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained("bigscience/bloom-560m")