Skip to content

Commit c6a5dae

Browse files
committed
Merge branch 'test' into main
2 parents 53e667f + 33e88dc commit c6a5dae

13 files changed

+3538
-0
lines changed

test/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Run all tests
2+
cd /home/yangx0i/deeplens_proj/debug/DeepLens
3+
pytest test/ -v
4+
5+
# Run specific test file
6+
pytest test/test_ray.py -v
7+
8+
# Run with coverage
9+
pytest test/ --cov=deeplens --cov-report=term-missing

test/conftest.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""
2+
Shared pytest fixtures for DeepLens test suite.
3+
"""
4+
5+
import os
6+
import sys
7+
8+
import pytest
9+
import torch
10+
11+
# Add deeplens to path
12+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13+
14+
15+
# =============================================================================
16+
# Device fixtures
17+
# =============================================================================
18+
@pytest.fixture(scope="session")
19+
def device():
20+
"""Return CUDA device if available, otherwise CPU."""
21+
if torch.cuda.is_available():
22+
return torch.device("cuda")
23+
else:
24+
pytest.skip("CUDA not available, skipping GPU tests")
25+
26+
27+
@pytest.fixture(scope="session")
28+
def device_cpu():
29+
"""Return CPU device."""
30+
return torch.device("cpu")
31+
32+
33+
@pytest.fixture(scope="session")
34+
def device_auto():
35+
"""Return CUDA if available, otherwise CPU (for tests that should run on either)."""
36+
if torch.cuda.is_available():
37+
return torch.device("cuda")
38+
return torch.device("cpu")
39+
40+
41+
# =============================================================================
42+
# Lens fixtures
43+
# =============================================================================
44+
@pytest.fixture(scope="function")
45+
def sample_singlet_lens(device_auto):
46+
"""Load a simple singlet lens for testing."""
47+
from deeplens import GeoLens
48+
49+
lens_path = os.path.join(
50+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
51+
"datasets/lenses/singlet/example1.json",
52+
)
53+
lens = GeoLens(filename=lens_path)
54+
lens.to(device_auto)
55+
return lens
56+
57+
58+
@pytest.fixture(scope="function")
59+
def sample_cellphone_lens(device_auto):
60+
"""Load a cellphone lens with aspheric surfaces for testing."""
61+
from deeplens import GeoLens
62+
63+
lens_path = os.path.join(
64+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
65+
"datasets/lenses/cellphone/cellphone68deg.json",
66+
)
67+
lens = GeoLens(filename=lens_path)
68+
lens.to(device_auto)
69+
return lens
70+
71+
72+
@pytest.fixture(scope="function")
73+
def sample_camera_lens(device_auto):
74+
"""Load a camera lens for testing."""
75+
from deeplens import GeoLens
76+
77+
lens_path = os.path.join(
78+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
79+
"datasets/lenses/camera/ef50mm_f1.8.json",
80+
)
81+
lens = GeoLens(filename=lens_path)
82+
lens.to(device_auto)
83+
return lens
84+
85+
86+
# =============================================================================
87+
# Image fixtures
88+
# =============================================================================
89+
@pytest.fixture(scope="function")
90+
def sample_image(device_auto):
91+
"""Create a simple test image tensor [B, C, H, W]."""
92+
# Create a gradient image for testing
93+
H, W = 256, 256
94+
x = torch.linspace(0, 1, W, device=device_auto)
95+
y = torch.linspace(0, 1, H, device=device_auto)
96+
yy, xx = torch.meshgrid(y, x, indexing="ij")
97+
98+
img = torch.stack([xx, yy, (xx + yy) / 2], dim=0) # [3, H, W]
99+
img = img.unsqueeze(0) # [1, 3, H, W]
100+
return img
101+
102+
103+
@pytest.fixture(scope="function")
104+
def sample_image_small(device_auto):
105+
"""Create a small test image tensor for fast tests."""
106+
H, W = 64, 64
107+
img = torch.rand(1, 3, H, W, device=device_auto)
108+
return img
109+
110+
111+
# =============================================================================
112+
# Ray fixtures
113+
# =============================================================================
114+
@pytest.fixture(scope="function")
115+
def sample_ray(device_auto):
116+
"""Create a sample ray for testing."""
117+
from deeplens.optics.ray import Ray
118+
119+
o = torch.tensor([[0.0, 0.0, -100.0]], device=device_auto)
120+
d = torch.tensor([[0.0, 0.0, 1.0]], device=device_auto)
121+
ray = Ray(o, d, wvln=0.55, device=device_auto)
122+
return ray
123+
124+
125+
@pytest.fixture(scope="function")
126+
def sample_rays_batch(device_auto):
127+
"""Create a batch of rays for testing."""
128+
from deeplens.optics.ray import Ray
129+
130+
# Create 100 rays in a grid pattern
131+
n = 10
132+
x = torch.linspace(-1, 1, n, device=device_auto)
133+
y = torch.linspace(-1, 1, n, device=device_auto)
134+
yy, xx = torch.meshgrid(y, x, indexing="ij")
135+
136+
o = torch.stack([xx.flatten(), yy.flatten(), torch.full((n*n,), -100.0, device=device_auto)], dim=-1)
137+
d = torch.zeros_like(o)
138+
d[..., 2] = 1.0
139+
140+
ray = Ray(o, d, wvln=0.55, device=device_auto)
141+
return ray
142+
143+
144+
# =============================================================================
145+
# Path helpers
146+
# =============================================================================
147+
@pytest.fixture(scope="session")
148+
def project_root():
149+
"""Return the project root directory."""
150+
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
151+
152+
153+
@pytest.fixture(scope="session")
154+
def lenses_dir(project_root):
155+
"""Return the lenses dataset directory."""
156+
return os.path.join(project_root, "datasets/lenses")
157+
158+
159+
@pytest.fixture(scope="session")
160+
def test_output_dir(project_root):
161+
"""Return a directory for test outputs."""
162+
output_dir = os.path.join(project_root, "test/test_outputs")
163+
os.makedirs(output_dir, exist_ok=True)
164+
return output_dir

test/test_basics.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
Tests for deeplens/basics.py - Basic utilities and constants.
3+
"""
4+
5+
import pytest
6+
import torch
7+
8+
9+
from deeplens.basics import (
10+
DEPTH,
11+
DEFAULT_WAVE,
12+
EPSILON,
13+
PSF_KS,
14+
SPP_PSF,
15+
WAVE_RGB,
16+
DeepObj,
17+
init_device,
18+
wave_rgb,
19+
)
20+
21+
22+
class TestConstants:
23+
"""Test default constants are properly defined."""
24+
25+
def test_depth_constant(self):
26+
"""DEPTH should be a large negative value representing infinity."""
27+
assert DEPTH == -20000.0
28+
assert DEPTH < 0
29+
30+
def test_wave_rgb(self):
31+
"""WAVE_RGB should contain R, G, B wavelengths in micrometers."""
32+
assert len(WAVE_RGB) == 3
33+
assert WAVE_RGB[0] > WAVE_RGB[1] > WAVE_RGB[2] # R > G > B
34+
# All wavelengths should be in visible range (0.38 - 0.78 um)
35+
for wvln in WAVE_RGB:
36+
assert 0.38 < wvln < 0.78
37+
38+
def test_default_wave(self):
39+
"""DEFAULT_WAVE should be green wavelength."""
40+
assert 0.5 < DEFAULT_WAVE < 0.6 # Green light
41+
42+
def test_spp_psf(self):
43+
"""SPP_PSF should be a power of 2."""
44+
assert SPP_PSF > 0
45+
assert (SPP_PSF & (SPP_PSF - 1)) == 0 # Check power of 2
46+
47+
def test_psf_ks(self):
48+
"""PSF_KS should be a reasonable kernel size."""
49+
assert PSF_KS > 0
50+
assert PSF_KS < 256
51+
52+
def test_epsilon(self):
53+
"""EPSILON should be a small positive value."""
54+
assert EPSILON > 0
55+
assert EPSILON < 1e-6
56+
57+
58+
class TestInitDevice:
59+
"""Test device initialization."""
60+
61+
def test_init_device_returns_device(self):
62+
"""init_device should return a torch device."""
63+
device = init_device()
64+
assert isinstance(device, torch.device)
65+
66+
def test_init_device_cuda_or_cpu(self):
67+
"""init_device should return cuda or cpu."""
68+
device = init_device()
69+
assert device.type in ["cuda", "cpu"]
70+
71+
def test_init_device_matches_availability(self):
72+
"""init_device result should match CUDA availability."""
73+
device = init_device()
74+
if torch.cuda.is_available():
75+
assert device.type == "cuda"
76+
else:
77+
assert device.type == "cpu"
78+
79+
80+
class TestWaveRgb:
81+
"""Test random wavelength sampling."""
82+
83+
def test_wave_rgb_returns_three(self):
84+
"""wave_rgb should return 3 wavelengths."""
85+
waves = wave_rgb()
86+
assert len(waves) == 3
87+
88+
def test_wave_rgb_order(self):
89+
"""wave_rgb should return [R, G, B] in decreasing wavelength order."""
90+
waves = wave_rgb()
91+
assert waves[0] > waves[1] > waves[2]
92+
93+
def test_wave_rgb_range(self):
94+
"""All wavelengths should be in visible range."""
95+
waves = wave_rgb()
96+
for w in waves:
97+
assert 0.4 < w < 0.75
98+
99+
def test_wave_rgb_randomness(self):
100+
"""wave_rgb should produce different results (probabilistic)."""
101+
results = [tuple(wave_rgb()) for _ in range(10)]
102+
# At least some should be different
103+
assert len(set(results)) > 1
104+
105+
106+
class TestDeepObj:
107+
"""Test DeepObj base class functionality."""
108+
109+
def test_deep_obj_init(self):
110+
"""DeepObj should initialize with default dtype."""
111+
obj = DeepObj()
112+
assert obj.dtype == torch.get_default_dtype()
113+
114+
def test_deep_obj_init_custom_dtype(self):
115+
"""DeepObj should accept custom dtype."""
116+
obj = DeepObj(dtype=torch.float64)
117+
assert obj.dtype == torch.float64
118+
119+
def test_deep_obj_str(self):
120+
"""DeepObj should have string representation."""
121+
obj = DeepObj()
122+
s = str(obj)
123+
assert "DeepObj" in s
124+
125+
def test_deep_obj_clone(self):
126+
"""DeepObj clone should create independent copy."""
127+
obj = DeepObj()
128+
obj.test_attr = torch.tensor([1.0, 2.0, 3.0])
129+
cloned = obj.clone()
130+
131+
# Modify original
132+
obj.test_attr[0] = 999.0
133+
134+
# Clone should be unchanged
135+
assert cloned.test_attr[0] != 999.0
136+
137+
def test_deep_obj_to_device(self, device_auto):
138+
"""DeepObj.to() should move tensors to device."""
139+
obj = DeepObj()
140+
obj.tensor_attr = torch.tensor([1.0, 2.0, 3.0])
141+
142+
obj.to(device_auto)
143+
144+
assert obj.device.type == device_auto.type
145+
assert obj.tensor_attr.device.type == device_auto.type
146+
147+
def test_deep_obj_to_device_nested(self, device_auto):
148+
"""DeepObj.to() should handle nested DeepObj."""
149+
outer = DeepObj()
150+
inner = DeepObj()
151+
inner.data = torch.tensor([1.0, 2.0])
152+
outer.child = inner
153+
154+
outer.to(device_auto)
155+
156+
assert inner.data.device.type == device_auto.type
157+
158+
def test_deep_obj_to_device_list(self, device_auto):
159+
"""DeepObj.to() should handle tensor lists."""
160+
obj = DeepObj()
161+
obj.tensor_list = [torch.tensor([1.0]), torch.tensor([2.0])]
162+
163+
obj.to(device_auto)
164+
165+
for t in obj.tensor_list:
166+
assert t.device.type == device_auto.type
167+
168+
def test_deep_obj_astype_float32(self):
169+
"""DeepObj.astype() should convert to float32."""
170+
obj = DeepObj(dtype=torch.float64)
171+
obj.data = torch.tensor([1.0, 2.0], dtype=torch.float64)
172+
173+
obj.astype(torch.float32)
174+
175+
assert obj.dtype == torch.float32
176+
assert obj.data.dtype == torch.float32
177+
178+
def test_deep_obj_astype_float64(self):
179+
"""DeepObj.astype() should convert to float64."""
180+
obj = DeepObj(dtype=torch.float32)
181+
obj.data = torch.tensor([1.0, 2.0], dtype=torch.float32)
182+
183+
obj.astype(torch.float64)
184+
185+
assert obj.dtype == torch.float64
186+
assert obj.data.dtype == torch.float64
187+
188+
def test_deep_obj_astype_none(self):
189+
"""DeepObj.astype(None) should be no-op."""
190+
obj = DeepObj(dtype=torch.float32)
191+
original_dtype = obj.dtype
192+
193+
result = obj.astype(None)
194+
195+
assert obj.dtype == original_dtype
196+
assert result is obj
197+
198+
def test_deep_obj_astype_invalid(self):
199+
"""DeepObj.astype() should reject invalid dtypes."""
200+
obj = DeepObj()
201+
202+
with pytest.raises(AssertionError):
203+
obj.astype(torch.int32)
204+
205+
def test_deep_obj_call_raises(self):
206+
"""DeepObj.__call__() should raise if forward not implemented."""
207+
obj = DeepObj()
208+
209+
with pytest.raises(AttributeError):
210+
obj(torch.tensor([1.0]))

0 commit comments

Comments
 (0)