Skip to content

Commit 8a3e0ad

Browse files
committed
add test
1 parent 7303d64 commit 8a3e0ad

1 file changed

Lines changed: 80 additions & 0 deletions

File tree

src/test/unit/services/sample/hmc_nuts_diag_e_adapt_test.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,83 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_0) {
235235
}
236236
}
237237
}
238+
239+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_1) {
240+
unsigned int random_seed = 0;
241+
unsigned int chain = 1;
242+
double init_radius = 0;
243+
int num_warmup = 150;
244+
int num_samples = 10;
245+
int num_thin = 1;
246+
bool save_warmup = true;
247+
int refresh = 0;
248+
double stepsize = 1.0;
249+
double stepsize_jitter = 0.0;
250+
int max_depth = 10;
251+
double delta = .8;
252+
double gamma = .05;
253+
double kappa = .75;
254+
double t0 = 10;
255+
unsigned int init_buffer = 49;
256+
unsigned int term_buffer = 1;
257+
unsigned int window = 100;
258+
stan::test::unit::instrumented_interrupt interrupt;
259+
EXPECT_EQ(interrupt.call_count(), 0);
260+
261+
stan::services::sample::hmc_nuts_diag_e_adapt(
262+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
263+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
264+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
265+
logger, init, parameter, diagnostic);
266+
267+
EXPECT_EQ(0, logger.call_count_error());
268+
int num_output_lines = (num_warmup + num_samples) / num_thin;
269+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
270+
271+
std::vector<std::string> messages = parameter.string_values();
272+
for (auto msg : messages) {
273+
if (msg.find("Step size") != std::string::npos) {
274+
EXPECT_NE("Step size = 1", msg);
275+
}
276+
}
277+
}
278+
279+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, no_stepsize_adapt) {
280+
unsigned int random_seed = 0;
281+
unsigned int chain = 1;
282+
double init_radius = 0;
283+
int num_warmup = 150;
284+
int num_samples = 10;
285+
int num_thin = 1;
286+
bool save_warmup = true;
287+
int refresh = 0;
288+
double stepsize = 1.0;
289+
double stepsize_jitter = 0.0;
290+
int max_depth = 10;
291+
double delta = .8;
292+
double gamma = .05;
293+
double kappa = .75;
294+
double t0 = 10;
295+
unsigned int init_buffer = 0;
296+
unsigned int term_buffer = 0;
297+
unsigned int window = 50;
298+
stan::test::unit::instrumented_interrupt interrupt;
299+
EXPECT_EQ(interrupt.call_count(), 0);
300+
301+
stan::services::sample::hmc_nuts_diag_e_adapt(
302+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
303+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
304+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
305+
logger, init, parameter, diagnostic);
306+
307+
EXPECT_EQ(0, logger.call_count_error());
308+
int num_output_lines = (num_warmup + num_samples) / num_thin;
309+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
310+
311+
std::vector<std::string> messages = parameter.string_values();
312+
for (auto msg : messages) {
313+
if (msg.find("Step size") != std::string::npos) {
314+
EXPECT_NE("Step size = 1", msg);
315+
}
316+
}
317+
}

0 commit comments

Comments
 (0)