Skip to content

Commit f17126d

Browse files
author
Camille Touron
committed
add tests on train and solve schedule shapes, devices, bounds
1 parent b851fd7 commit f17126d

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

tests/vf_estimator_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,78 @@ def _build_vector_field_estimator_and_tensors(
185185
)
186186
condition = condition
187187
return estimator, inputs, condition
188+
189+
190+
@pytest.mark.gpu
191+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
192+
@pytest.mark.parametrize(
193+
"estimator_type,sde_type",
194+
[
195+
("score", "vp"),
196+
("score", "subvp"),
197+
("score", "ve"),
198+
("flow", None),
199+
],
200+
)
201+
def test_train_schedule(device, estimator_type, sde_type):
202+
"""Test on shapes, bounds and devices for train and solve schedules
203+
of vector field estimators (flow or score)
204+
"""
205+
embedding_net = torch.nn.Identity()
206+
t_min = torch.tensor([0.0], device=device)
207+
t_max = torch.tensor([1.0], device=device)
208+
209+
if estimator_type == "flow":
210+
estimator = build_flow_matching_estimator(
211+
torch.randn(100, 1),
212+
torch.randn(100, 1),
213+
embedding_net=embedding_net,
214+
)
215+
estimator.to(device)
216+
217+
else:
218+
estimator = build_score_matching_estimator(
219+
torch.randn(100, 1),
220+
torch.randn(100, 1),
221+
embedding_net=embedding_net,
222+
sde_type=sde_type,
223+
)
224+
estimator.to(device)
225+
# Train schedule only defined for score estimators
226+
# Schedule with default bounds
227+
train_schedule_default = estimator.train_schedule(300)
228+
assert train_schedule_default.shape == torch.Size((300,))
229+
assert train_schedule_default.max() <= estimator.t_max
230+
assert train_schedule_default.min() >= estimator.t_min
231+
assert str(train_schedule_default.device).split(":")[0] == device.split(":")[0]
232+
233+
# Schedule with given bounds
234+
train_schedule = estimator.train_schedule(300, t_min, t_max)
235+
assert train_schedule.shape == torch.Size((300,))
236+
assert train_schedule.max() <= t_max.item()
237+
assert train_schedule.min() >= t_min.item()
238+
assert str(train_schedule.device).split(":")[0] == device.split(":")[0]
239+
240+
# Solve schedule with default bounds
241+
solve_schedule_default = estimator.solve_schedule(
242+
300, t_max=estimator.t_max, t_min=estimator.t_min
243+
)
244+
assert torch.allclose(
245+
solve_schedule_default[0], torch.tensor([estimator.t_max], device=device)
246+
)
247+
assert torch.allclose(
248+
solve_schedule_default[-1], torch.tensor([estimator.t_min], device=device)
249+
)
250+
assert solve_schedule_default.shape == torch.Size((300,))
251+
assert torch.all(solve_schedule_default[:-1] - solve_schedule_default[1:] >= 0)
252+
assert str(solve_schedule_default.device).split(":")[0] == device.split(":")[0]
253+
254+
# Solve schedule with given bounds
255+
solve_schedule = estimator.solve_schedule(
256+
300, t_max=t_max.item(), t_min=t_min.item()
257+
)
258+
assert torch.allclose(solve_schedule[0], t_max)
259+
assert torch.allclose(solve_schedule[-1], t_min)
260+
assert solve_schedule_default.shape == torch.Size((300,))
261+
assert torch.all(solve_schedule[:-1] - solve_schedule[1:] >= 0)
262+
assert str(solve_schedule.device).split(":")[0] == device.split(":")[0]

0 commit comments

Comments
 (0)