@@ -185,3 +185,78 @@ def _build_vector_field_estimator_and_tensors(
185185 )
186186 condition = condition
187187 return estimator , inputs , condition
188+
189+
190+ @pytest .mark .gpu
191+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
192+ @pytest .mark .parametrize (
193+ "estimator_type,sde_type" ,
194+ [
195+ ("score" , "vp" ),
196+ ("score" , "subvp" ),
197+ ("score" , "ve" ),
198+ ("flow" , None ),
199+ ],
200+ )
201+ def test_train_schedule (device , estimator_type , sde_type ):
202+ """Test on shapes, bounds and devices for train and solve schedules
203+ of vector field estimators (flow or score)
204+ """
205+ embedding_net = torch .nn .Identity ()
206+ t_min = torch .tensor ([0.0 ], device = device )
207+ t_max = torch .tensor ([1.0 ], device = device )
208+
209+ if estimator_type == "flow" :
210+ estimator = build_flow_matching_estimator (
211+ torch .randn (100 , 1 ),
212+ torch .randn (100 , 1 ),
213+ embedding_net = embedding_net ,
214+ )
215+ estimator .to (device )
216+
217+ else :
218+ estimator = build_score_matching_estimator (
219+ torch .randn (100 , 1 ),
220+ torch .randn (100 , 1 ),
221+ embedding_net = embedding_net ,
222+ sde_type = sde_type ,
223+ )
224+ estimator .to (device )
225+ # Train schedule only defined for score estimators
226+ # Schedule with default bounds
227+ train_schedule_default = estimator .train_schedule (300 )
228+ assert train_schedule_default .shape == torch .Size ((300 ,))
229+ assert train_schedule_default .max () <= estimator .t_max
230+ assert train_schedule_default .min () >= estimator .t_min
231+ assert str (train_schedule_default .device ).split (":" )[0 ] == device .split (":" )[0 ]
232+
233+ # Schedule with given bounds
234+ train_schedule = estimator .train_schedule (300 , t_min , t_max )
235+ assert train_schedule .shape == torch .Size ((300 ,))
236+ assert train_schedule .max () <= t_max .item ()
237+ assert train_schedule .min () >= t_min .item ()
238+ assert str (train_schedule .device ).split (":" )[0 ] == device .split (":" )[0 ]
239+
240+ # Solve schedule with default bounds
241+ solve_schedule_default = estimator .solve_schedule (
242+ 300 , t_max = estimator .t_max , t_min = estimator .t_min
243+ )
244+ assert torch .allclose (
245+ solve_schedule_default [0 ], torch .tensor ([estimator .t_max ], device = device )
246+ )
247+ assert torch .allclose (
248+ solve_schedule_default [- 1 ], torch .tensor ([estimator .t_min ], device = device )
249+ )
250+ assert solve_schedule_default .shape == torch .Size ((300 ,))
251+ assert torch .all (solve_schedule_default [:- 1 ] - solve_schedule_default [1 :] >= 0 )
252+ assert str (solve_schedule_default .device ).split (":" )[0 ] == device .split (":" )[0 ]
253+
254+ # Solve schedule with given bounds
255+ solve_schedule = estimator .solve_schedule (
256+ 300 , t_max = t_max .item (), t_min = t_min .item ()
257+ )
258+ assert torch .allclose (solve_schedule [0 ], t_max )
259+ assert torch .allclose (solve_schedule [- 1 ], t_min )
260+ assert solve_schedule_default .shape == torch .Size ((300 ,))
261+ assert torch .all (solve_schedule [:- 1 ] - solve_schedule [1 :] >= 0 )
262+ assert str (solve_schedule .device ).split (":" )[0 ] == device .split (":" )[0 ]
0 commit comments