@@ -315,3 +315,43 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, no_stepsize_adapt) {
315315 }
316316 }
317317}
318+
319+ TEST_F (ServicesSampleHmcNutsDiagEAdapt, no_stepsize_adapt_stretch) {
320+ unsigned int random_seed = 0 ;
321+ unsigned int chain = 1 ;
322+ double init_radius = 0 ;
323+ int num_warmup = 150 ;
324+ int num_samples = 10 ;
325+ int num_thin = 1 ;
326+ bool save_warmup = true ;
327+ int refresh = 0 ;
328+ double stepsize = 1.0 ;
329+ double stepsize_jitter = 0.0 ;
330+ int max_depth = 10 ;
331+ double delta = .8 ;
332+ double gamma = .05 ;
333+ double kappa = .75 ;
334+ double t0 = 10 ;
335+ unsigned int init_buffer = 0 ;
336+ unsigned int term_buffer = 0 ;
337+ unsigned int window = 40 ;
338+ stan::test::unit::instrumented_interrupt interrupt;
339+ EXPECT_EQ (interrupt.call_count (), 0 );
340+
341+ stan::services::sample::hmc_nuts_diag_e_adapt (
342+ model, context, random_seed, chain, init_radius, num_warmup, num_samples,
343+ num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
344+ delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
345+ logger, init, parameter, diagnostic);
346+
347+ EXPECT_EQ (0 , logger.call_count_error ());
348+ int num_output_lines = (num_warmup + num_samples) / num_thin;
349+ EXPECT_EQ (num_output_lines, parameter.call_count (" vector_double" ));
350+
351+ std::vector<std::string> messages = parameter.string_values ();
352+ for (auto msg : messages) {
353+ if (msg.find (" Step size" ) != std::string::npos) {
354+ EXPECT_NE (" Step size = 1" , msg);
355+ }
356+ }
357+ }
0 commit comments