-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathconftest.py
More file actions
34 lines (29 loc) · 796 Bytes
/
conftest.py
File metadata and controls
34 lines (29 loc) · 796 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pytest
import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM
@pytest.fixture
def model():
"""Randomly initialize a small test model."""
torch.manual_seed(42)
torch.cuda.manual_seed(42)
config = AutoConfig.from_pretrained("trl-internal-testing/tiny-Phi3ForCausalLM")
return AutoModelForCausalLM.from_config(config)
@pytest.fixture
def dataset():
"""Create a small test dataset."""
data = {
"input_ids": [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
],
"labels": [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
],
"attention_mask": [
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
],
}
return Dataset.from_dict(data)