Skip to content

Commit f147881

Browse files
committed
fix handling of cmdstan in tests for remote CI runs
1 parent 1a79e0e commit f147881

12 files changed

+146
-41
lines changed

R/adaptive_btl_refit.R

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,9 +1944,9 @@
19441944
cmdstan,
19451945
seed,
19461946
model_fn = NULL) {
1947-
.btl_mcmc_require_cmdstanr()
19481947
resolved_cmdstan <- .btl_mcmc_resolve_cmdstan_config(cmdstan %||% list())
19491948
if (is.null(model_fn)) {
1949+
.btl_mcmc_require_cmdstanr()
19501950
model_fn <- cmdstanr::cmdstan_model
19511951
}
19521952
if (!is.function(model_fn)) {
@@ -2245,6 +2245,10 @@
22452245
diagnostics <- NULL
22462246
mcmc_config_used <- NULL
22472247
cmdstan_schedule_used <- NULL
2248+
cmdstan_fit_fn <- refit_contract_ctx[["cmdstan_fit_fn"]] %||% .adaptive_link_fit_transform_cmdstan
2249+
if (!is.function(cmdstan_fit_fn)) {
2250+
rlang::abort("`refit_contract$cmdstan_fit_fn` must be a function when provided.")
2251+
}
22482252
repair_attempts <- 0L
22492253
max_attempts <- 3L
22502254
for (attempt in seq_len(max_attempts)) {
@@ -2259,19 +2263,19 @@
22592263
),
22602264
joint_used = joint_used
22612265
)
2262-
cmdstan_fit <- .adaptive_link_fit_transform_cmdstan(
2266+
cmdstan_fit <- cmdstan_fit_fn(
22632267
stan_data = stan_data_base,
22642268
variable_names = variable_names,
22652269
cmdstan = utils::modifyList(
2266-
refit_contract_ctx$cmdstan %||% list(),
2270+
refit_contract_ctx[["cmdstan"]] %||% list(),
22672271
list(
22682272
chains = as.integer(cmdstan_schedule_used$chains),
22692273
iter_warmup = as.integer(cmdstan_schedule_used$iter_warmup),
22702274
iter_sampling = as.integer(cmdstan_schedule_used$iter_sampling)
22712275
)
22722276
),
22732277
seed = as.integer((seed + attempt * 1009L) %% .Machine$integer.max),
2274-
model_fn = refit_contract_ctx$cmdstan_model_fn %||% NULL
2278+
model_fn = refit_contract_ctx[["cmdstan_model_fn"]] %||% NULL
22752279
)
22762280
draws_matrix <- as.matrix(cmdstan_fit$draws_matrix)
22772281
diagnostics <- .adaptive_link_cmdstan_validate_diagnostics(
@@ -3286,6 +3290,7 @@
32863290
}
32873291
}
32883292

3293+
btl_config <- out$config$btl_config %||% list()
32893294
cross_all <- .adaptive_link_cross_edges(out, spoke_id = spoke_id, last_refit_step = NULL)
32903295
cross_since <- .adaptive_link_cross_edges(out, spoke_id = spoke_id, last_refit_step = last_step)
32913296
startup_gap <- .adaptive_link_phase_b_startup_gap_for_spoke(out, spoke_id = spoke_id)
@@ -3306,11 +3311,13 @@
33063311
hub_lock_mode = lock_mode,
33073312
hub_lock_kappa = kappa,
33083313
shift_only_theta_treatment = theta_treatment,
3309-
cmdstan = out$config$btl_config$cmdstan %||% list(),
3314+
cmdstan = btl_config[["cmdstan"]] %||% list(),
3315+
cmdstan_fit_fn = btl_config[["cmdstan_fit_fn"]] %||% NULL,
3316+
cmdstan_model_fn = btl_config[["cmdstan_model_fn"]] %||% NULL,
33103317
link_diagnostics_thresholds = list(
3311-
divergences_max = as.integer(out$config$btl_config$divergences_max %||% 0L),
3312-
max_rhat = as.double(out$config$btl_config$max_rhat %||% 1.01),
3313-
min_ess_bulk = as.double(out$config$btl_config$ess_bulk_min %||% 400)
3318+
divergences_max = as.integer(btl_config$divergences_max %||% 0L),
3319+
max_rhat = as.double(btl_config$max_rhat %||% 1.01),
3320+
min_ess_bulk = as.double(btl_config$ess_bulk_min %||% 400)
33143321
)
33153322
)
33163323
hub_theta_init <- if (identical(refit_mode, "joint_refit") && length(hub_current) > 0L) {
@@ -5345,7 +5352,7 @@ default_btl_fit_fn <- function(state, config) {
53455352
results = results,
53465353
ids = ids_fit,
53475354
model_variant = config$model_variant %||% "btl_e_b",
5348-
cmdstan = config$cmdstan %||% list()
5355+
cmdstan = config[["cmdstan"]] %||% list()
53495356
)
53505357

53515358
fit_contract <- .adaptive_btl_extract_fit_contract(fit_out)

R/bayes_btl_mcmc_adaptive.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ as_btl_fit_contract_from_mcmc <- function(mcmc_fit, ids) {
535535
Y = as.integer(bt_data$Y)
536536
)
537537

538-
cmdstan <- config$cmdstan %||% list()
538+
cmdstan <- config[["cmdstan"]] %||% list()
539539
if (!is.list(cmdstan)) {
540540
rlang::abort("`config$cmdstan` must be a list when provided.")
541541
}

R/btl_mcmc_contracts.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,11 @@ validate_btl_mcmc_config <- function(config) {
259259
"`keep_draws` must be logical.")
260260
.btl_mcmc_check(.btl_mcmc_intish(config$thin_draws) && config$thin_draws >= 1L,
261261
"`thin_draws` must be >= 1.")
262-
if (!is.list(config$cmdstan)) {
262+
cmdstan <- config[["cmdstan"]]
263+
if (!is.list(cmdstan)) {
263264
rlang::abort("`config$cmdstan` must be a list when provided.")
264265
}
265-
cmdstan_output_dir <- config$cmdstan$output_dir %||% NULL
266+
cmdstan_output_dir <- cmdstan[["output_dir"]] %||% NULL
266267
if (!is.null(cmdstan_output_dir)) {
267268
.btl_mcmc_check(is.character(cmdstan_output_dir) && length(cmdstan_output_dir) == 1L,
268269
"`config$cmdstan$output_dir` must be a length-1 character path.")
@@ -804,7 +805,7 @@ build_round_log_row <- function(state,
804805
row$mcmc_cores_detected_physical <- as.integer(mcmc_config_used$cores_detected_physical %||% NA_integer_)
805806
row$mcmc_cores_detected_logical <- as.integer(mcmc_config_used$cores_detected_logical %||% NA_integer_)
806807
threads_per_chain <- mcmc_config_used$threads_per_chain %||%
807-
config$cmdstan$threads_per_chain %||% 1L
808+
config[["cmdstan"]][["threads_per_chain"]] %||% 1L
808809
row$mcmc_threads_per_chain <- as.integer(threads_per_chain %||% NA_integer_)
809810
row$mcmc_cmdstanr_version <- as.character(mcmc_config_used$cmdstanr_version %||% NA_character_)
810811
row

tests/testthat/helper-fixtures.R

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,94 @@ make_deterministic_fit_fn <- function(ids, fit = NULL) {
135135
get_calls = function() env$calls
136136
)
137137
}
138+
139+
make_test_link_cmdstan_fit_fn <- function() {
140+
function(stan_data, variable_names, cmdstan, seed, model_fn = NULL) {
141+
n_draws <- 4L
142+
draw_offsets <- c(-0.03, -0.01, 0.01, 0.03)
143+
delta_center <- if (is.numeric(stan_data$hub_ref_cross) && is.numeric(stan_data$spoke_ref_cross)) {
144+
mean(as.double(stan_data$hub_ref_cross) - as.double(stan_data$spoke_ref_cross), na.rm = TRUE)
145+
} else {
146+
0
147+
}
148+
if (!is.finite(delta_center)) {
149+
delta_center <- 0
150+
}
151+
hub_prior_signal <- mean(as.double(stan_data$hub_prior_sd %||% numeric()), na.rm = TRUE)
152+
if (is.finite(hub_prior_signal)) {
153+
delta_center <- delta_center + (hub_prior_signal * 0.01)
154+
}
155+
156+
build_theta_draws <- function(base_vals, prefix) {
157+
base_vals <- as.double(base_vals %||% numeric())
158+
if (length(base_vals) < 1L) {
159+
return(NULL)
160+
}
161+
out <- vapply(
162+
seq_along(base_vals),
163+
function(idx) base_vals[[idx]] + draw_offsets + ((idx - 1L) * 0.005),
164+
numeric(n_draws)
165+
)
166+
colnames(out) <- paste0(prefix, "[", seq_along(base_vals), "]")
167+
out
168+
}
169+
170+
draws <- matrix(nrow = n_draws, ncol = 0L)
171+
if ("delta" %in% variable_names) {
172+
draws <- cbind(draws, delta = delta_center + draw_offsets)
173+
}
174+
if ("log_alpha" %in% variable_names) {
175+
draws <- cbind(draws, log_alpha = c(-0.04, -0.01, 0.01, 0.04))
176+
}
177+
178+
theta_hub_draws <- build_theta_draws(stan_data$hub_ref, "theta_hub")
179+
if (!is.null(theta_hub_draws) &&
180+
("theta_hub" %in% variable_names || any(grepl("^theta_hub\\[", variable_names)))) {
181+
keep <- if ("theta_hub" %in% variable_names) {
182+
rep(TRUE, ncol(theta_hub_draws))
183+
} else {
184+
colnames(theta_hub_draws) %in% variable_names
185+
}
186+
draws <- cbind(draws, theta_hub_draws[, keep, drop = FALSE])
187+
}
188+
189+
theta_spoke_draws <- build_theta_draws(stan_data$spoke_ref, "theta_spoke")
190+
if (!is.null(theta_spoke_draws) &&
191+
("theta_spoke" %in% variable_names || any(grepl("^theta_spoke\\[", variable_names)))) {
192+
keep <- if ("theta_spoke" %in% variable_names) {
193+
rep(TRUE, ncol(theta_spoke_draws))
194+
} else {
195+
colnames(theta_spoke_draws) %in% variable_names
196+
}
197+
draws <- cbind(draws, theta_spoke_draws[, keep, drop = FALSE])
198+
}
199+
200+
if (ncol(draws) < 1L) {
201+
draws <- matrix(delta_center + draw_offsets, ncol = 1L)
202+
colnames(draws) <- "delta"
203+
}
204+
205+
list(
206+
fit = NULL,
207+
draws_matrix = draws,
208+
diagnostics = list(
209+
divergences = 0L,
210+
max_rhat = 1.0,
211+
min_ess_bulk = 1000
212+
),
213+
mcmc_config_used = list(
214+
chains = as.integer(cmdstan$chains %||% 4L),
215+
parallel_chains = as.integer(cmdstan$parallel_chains %||% cmdstan$chains %||% 4L),
216+
threads_per_chain = as.integer(cmdstan$threads_per_chain %||% 1L),
217+
cmdstanr_version = "test"
218+
)
219+
)
220+
}
221+
}
222+
223+
test_link_btl_config <- function(x = list()) {
224+
utils::modifyList(
225+
list(cmdstan_fit_fn = make_test_link_cmdstan_fit_fn()),
226+
x %||% list()
227+
)
228+
}

tests/testthat/test-5019-console-progress-and-refit-block.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ test_that("adaptive_rank_run_live prints linking-specific refit summary lines",
8181
phase_a_mode = "import",
8282
phase_a_artifacts = artifacts
8383
),
84-
btl_config = list(refit_pairs_target = 2L, stability_lag = 1L),
84+
btl_config = test_link_btl_config(list(refit_pairs_target = 2L, stability_lag = 1L)),
8585
progress = "all",
8686
progress_redraw_every = 1L
8787
)

tests/testthat/test-5026-adaptive-rank-wrapper.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ test_that("adaptive_rank wrapper supports link_one_spoke import flow", {
599599
phase_a_mode = "import",
600600
phase_a_artifacts = artifacts[c("1", "2")]
601601
),
602-
btl_config = list(refit_pairs_target = 2L),
602+
btl_config = test_link_btl_config(list(refit_pairs_target = 2L)),
603603
progress = "none",
604604
seed = 13L
605605
)
@@ -643,7 +643,7 @@ test_that("adaptive_rank wrapper supports link_multi_spoke concurrent flow", {
643643
phase_a_mode = "import",
644644
phase_a_artifacts = artifacts
645645
),
646-
btl_config = list(refit_pairs_target = 2L),
646+
btl_config = test_link_btl_config(list(refit_pairs_target = 2L)),
647647
progress = "none",
648648
seed = 17L
649649
)
@@ -681,7 +681,7 @@ test_that("adaptive_rank wrapper falls back to rank_raw when linked ranks are un
681681
hub_id = 1L,
682682
phase_a_mode = "run"
683683
),
684-
btl_config = list(refit_pairs_target = 1L),
684+
btl_config = test_link_btl_config(list(refit_pairs_target = 1L)),
685685
progress = "none",
686686
seed = 31L
687687
)

tests/testthat/test-5049-linking-candidates-round-routing.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,7 @@ test_that("coverage source propagates through selection and linking stage rows",
714714
state$warm_start_done <- TRUE
715715
state <- mark_link_phase_b_ready(state)
716716
state$round$staged_active <- TRUE
717+
state$config$btl_config <- test_link_btl_config(state$config$btl_config %||% list())
717718
draws <- matrix(
718719
seq_along(state$item_ids),
719720
nrow = 4L,

tests/testthat/test-5050-linking-refit-transforms.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ make_linking_refit_state <- function(adaptive_config = list()) {
1010
seed = 123L,
1111
adaptive_config = utils::modifyList(base_cfg, adaptive_config)
1212
)
13+
state$config$btl_config <- test_link_btl_config(state$config$btl_config %||% list())
1314

1415
draws <- matrix(
1516
c(
@@ -3786,6 +3787,7 @@ test_that("linking CmdStan schedule and refit seed are stable under fixed inputs
37863787

37873788
test_that("linking refit retries CmdStan effort until diagnostics pass", {
37883789
state <- make_linking_refit_state(list(link_refit_mode = "shift_only"))
3790+
state$config$btl_config$cmdstan_fit_fn <- NULL
37893791
state <- append_cross_step(state, 1L, "s21", "h1", 1L, spoke_id = 2L)
37903792
state <- append_cross_step(state, 2L, "h2", "s22", 0L, spoke_id = 2L)
37913793

tests/testthat/test-5052-linking-calibration-harness.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ test_that("offline linking calibration is deterministic and writes required arti
1010
judge_b = 0.05,
1111
judge_eps = 0.03,
1212
n_steps = 30L,
13+
btl_config = test_link_btl_config(),
1314
output_dir = out_dir,
1415
progress = "none"
1516
)
@@ -22,6 +23,7 @@ test_that("offline linking calibration is deterministic and writes required arti
2223
judge_b = 0.05,
2324
judge_eps = 0.03,
2425
n_steps = 30L,
26+
btl_config = test_link_btl_config(),
2527
output_dir = withr::local_tempdir(),
2628
progress = "none"
2729
)
@@ -81,6 +83,7 @@ test_that("offline calibration reuses canonical production selection utilities",
8183
seed = 17L,
8284
set_sizes = c(3L, 3L),
8385
n_steps = 50L,
86+
btl_config = test_link_btl_config(),
8487
progress = "none"
8588
)
8689

0 commit comments

Comments
 (0)