Skip to content

Commit 1aea958

Browse files
refactor: enforce line length and use cli for latent prior checks (#580)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: Sam Abbott <s.e.abbott12@gmail.com>
1 parent 1d551c6 commit 1aea958

2 files changed

Lines changed: 90 additions & 0 deletions

File tree

R/prior.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ epidist_prior <- function(
4444
enforce_presence = FALSE
4545
) {
4646
assert_epidist(data)
47+
.check_latent_priors(data, prior)
4748
default <- brms::default_prior(formula, data = data)
4849
model <- epidist_model_prior(data, formula)
4950
if (!is.null(model)) {
@@ -133,3 +134,45 @@ epidist_family_prior.lognormal <- function(family, formula, ...) {
133134
prior <- prior + sigma_prior
134135
return(prior)
135136
}
137+
138+
.check_latent_priors <- function(data, prior) {
139+
if (is_epidist_latent_model(data) && !is.null(prior)) {
140+
# Define parameters to check with their severity and messages
141+
params <- list(
142+
list(
143+
name = "swindow_raw",
144+
severity = "stop",
145+
msg = "Priors for the secondary event window (swindow_raw) must be uniform(0, 1)." # nolint
146+
),
147+
list(
148+
name = "pwindow_raw",
149+
severity = "warning",
150+
msg = "Non-uniform priors for the primary event window (pwindow_raw) are not fully supported and may lead to misleading posterior predictions and log-likelihoods." # nolint
151+
)
152+
)
153+
154+
for (param in params) {
155+
rows <- which(
156+
prior$dpar == param$name |
157+
grepl(param$name, prior$prior, fixed = TRUE)
158+
)
159+
160+
if (length(rows) > 0) {
161+
for (i in rows) {
162+
p_str <- prior$prior[i]
163+
p_clean <- gsub("\\s+", "", p_str)
164+
165+
is_uniform <- grepl("uniform(0,1)", p_clean, fixed = TRUE)
166+
if (!is_uniform) {
167+
if (param$severity == "stop") {
168+
cli::cli_abort(param$msg, call = NULL)
169+
} else {
170+
cli::cli_warn(param$msg, call = NULL)
171+
}
172+
}
173+
}
174+
}
175+
}
176+
}
177+
return(invisible(NULL))
178+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
test_that("epidist_prior warns or errors for latent model non-uniform priors", {
2+
# Mock data object
3+
data <- list(
4+
relative_obs_time = 1,
5+
pwindow = 1,
6+
woverlap = 1,
7+
swindow = 1,
8+
delay = 1,
9+
.row_id = 1
10+
)
11+
class(data) <- c("epidist_latent_model", "data.frame")
12+
13+
# Mock brmsprior object creator
14+
mock_prior <- function(prior_str, class = "b", dpar = "", ...) {
15+
return(data.frame(
16+
prior = prior_str,
17+
class = class,
18+
dpar = dpar,
19+
stringsAsFactors = FALSE
20+
))
21+
}
22+
class(mock_prior) <- "brmsprior"
23+
24+
# Since we cannot easily run epidist_prior because it depends on
25+
# brms::default_prior and other things we will test the checking function
26+
# directly.
27+
28+
# Case 1: swindow_raw non-uniform -> Error
29+
p1 <- mock_prior("normal(0,1)", dpar = "swindow_raw")
30+
expect_error(.check_latent_priors(data, p1), "secondary event")
31+
32+
# Case 2: swindow_raw uniform -> No error
33+
p2 <- mock_prior("uniform(0,1)", dpar = "swindow_raw")
34+
expect_no_error(.check_latent_priors(data, p2))
35+
36+
# Case 3: pwindow_raw non-uniform -> Warning
37+
p3 <- mock_prior("normal(0,1)", dpar = "pwindow_raw")
38+
expect_warning(.check_latent_priors(data, p3), "primary event")
39+
40+
# Case 4: pwindow_raw uniform -> No warning
41+
p4 <- mock_prior("uniform(0,1)", dpar = "pwindow_raw")
42+
expect_no_warning(.check_latent_priors(data, p4))
43+
44+
# Case 5: mixed
45+
p5 <- rbind(p1, p3)
46+
expect_error(.check_latent_priors(data, p5), "secondary event")
47+
})

0 commit comments

Comments
 (0)