@@ -235,3 +235,83 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_0) {
235235 }
236236 }
237237}
238+
239+ TEST_F (ServicesSampleHmcNutsDiagEAdapt, term_buffer_1) {
240+ unsigned int random_seed = 0 ;
241+ unsigned int chain = 1 ;
242+ double init_radius = 0 ;
243+ int num_warmup = 150 ;
244+ int num_samples = 10 ;
245+ int num_thin = 1 ;
246+ bool save_warmup = true ;
247+ int refresh = 0 ;
248+ double stepsize = 1.0 ;
249+ double stepsize_jitter = 0.0 ;
250+ int max_depth = 10 ;
251+ double delta = .8 ;
252+ double gamma = .05 ;
253+ double kappa = .75 ;
254+ double t0 = 10 ;
255+ unsigned int init_buffer = 49 ;
256+ unsigned int term_buffer = 1 ;
257+ unsigned int window = 100 ;
258+ stan::test::unit::instrumented_interrupt interrupt;
259+ EXPECT_EQ (interrupt.call_count (), 0 );
260+
261+ stan::services::sample::hmc_nuts_diag_e_adapt (
262+ model, context, random_seed, chain, init_radius, num_warmup, num_samples,
263+ num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
264+ delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
265+ logger, init, parameter, diagnostic);
266+
267+ EXPECT_EQ (0 , logger.call_count_error ());
268+ int num_output_lines = (num_warmup + num_samples) / num_thin;
269+ EXPECT_EQ (num_output_lines, parameter.call_count (" vector_double" ));
270+
271+ std::vector<std::string> messages = parameter.string_values ();
272+ for (auto msg : messages) {
273+ if (msg.find (" Step size" ) != std::string::npos) {
274+ EXPECT_NE (" Step size = 1" , msg);
275+ }
276+ }
277+ }
278+
279+ TEST_F (ServicesSampleHmcNutsDiagEAdapt, no_stepsize_adapt) {
280+ unsigned int random_seed = 0 ;
281+ unsigned int chain = 1 ;
282+ double init_radius = 0 ;
283+ int num_warmup = 150 ;
284+ int num_samples = 10 ;
285+ int num_thin = 1 ;
286+ bool save_warmup = true ;
287+ int refresh = 0 ;
288+ double stepsize = 1.0 ;
289+ double stepsize_jitter = 0.0 ;
290+ int max_depth = 10 ;
291+ double delta = .8 ;
292+ double gamma = .05 ;
293+ double kappa = .75 ;
294+ double t0 = 10 ;
295+ unsigned int init_buffer = 0 ;
296+ unsigned int term_buffer = 0 ;
297+ unsigned int window = 50 ;
298+ stan::test::unit::instrumented_interrupt interrupt;
299+ EXPECT_EQ (interrupt.call_count (), 0 );
300+
301+ stan::services::sample::hmc_nuts_diag_e_adapt (
302+ model, context, random_seed, chain, init_radius, num_warmup, num_samples,
303+ num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
304+ delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
305+ logger, init, parameter, diagnostic);
306+
307+ EXPECT_EQ (0 , logger.call_count_error ());
308+ int num_output_lines = (num_warmup + num_samples) / num_thin;
309+ EXPECT_EQ (num_output_lines, parameter.call_count (" vector_double" ));
310+
311+ std::vector<std::string> messages = parameter.string_values ();
312+ for (auto msg : messages) {
313+ if (msg.find (" Step size" ) != std::string::npos) {
314+ EXPECT_NE (" Step size = 1" , msg);
315+ }
316+ }
317+ }
0 commit comments