-
Notifications
You must be signed in to change notification settings - Fork 156
Expand file tree
/
Copy pathconftest.py
More file actions
131 lines (103 loc) · 3.03 KB
/
conftest.py
File metadata and controls
131 lines (103 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Shared pytest fixtures and configuration."""
import os
import sys
import tempfile
from pathlib import Path
from typing import Generator
import pytest
# Add the project root to the Python path
sys.path.insert(0, str(Path(__file__).parent.parent))
@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for testing.
Yields:
Path: Path to the temporary directory
"""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def mock_config() -> dict:
"""Provide a mock configuration dictionary for testing.
Returns:
dict: Mock configuration with common settings
"""
return {
"batch_size": 32,
"num_workers": 2,
"learning_rate": 0.001,
"epochs": 10,
"augmentation": {
"severity": 3,
"width": 3,
"depth": -1,
"alpha": 1.0
}
}
@pytest.fixture
def sample_image_path(temp_dir: Path) -> Path:
"""Create a dummy image file for testing.
Args:
temp_dir: Temporary directory fixture
Returns:
Path: Path to the created dummy image file
"""
image_path = temp_dir / "test_image.jpg"
# Create a dummy file (actual image content not needed for most tests)
image_path.write_bytes(b"dummy image content")
return image_path
@pytest.fixture
def mock_dataset_config() -> dict:
"""Provide mock dataset configuration.
Returns:
dict: Mock dataset configuration
"""
return {
"name": "cifar10",
"data_dir": "./data",
"num_classes": 10,
"image_size": 32,
"mean": [0.4914, 0.4822, 0.4465],
"std": [0.2023, 0.1994, 0.2010]
}
@pytest.fixture(autouse=True)
def reset_random_seeds():
"""Reset random seeds before each test for reproducibility."""
import random
import numpy as np
random.seed(42)
np.random.seed(42)
# Only set torch seed if torch is available
try:
import torch
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
except ImportError:
pass
@pytest.fixture
def capture_output():
"""Capture stdout and stderr for testing print statements.
Yields:
tuple: (stdout, stderr) StringIO objects
"""
from io import StringIO
import sys
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = StringIO()
sys.stderr = StringIO()
yield sys.stdout, sys.stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
# Markers for different test types
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers", "unit: Unit tests that test individual components"
)
config.addinivalue_line(
"markers", "integration: Integration tests that test component interactions"
)
config.addinivalue_line(
"markers", "slow: Tests that take a long time to run"
)