Skip to content

Commit 38a5745

Browse files
feat: fixed number of steps
1 parent 3441b70 commit 38a5745

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

R/ols-stepwise-hierarchical.R

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,36 @@
1313
#' @param progress Logical; if \code{TRUE}, will display variable selection progress.
1414
#' @param details Logical; if \code{TRUE}, will print the regression result at
1515
#' each step.
16+
#' @param steps Number of steps after which the stepwise procedures should stop.
1617
#'
1718
#' @examples
19+
#' # forward hierarchical selection
1820
#' model <- lm(y ~ ., data = surgical)
21+
#' ols_step_hierarchical(model)
22+
#'
23+
#' # backward hierarchical selection
1924
#' model <- lm(y ~ bcs + alc_heavy + pindex + enzyme_test + liver_test + alc_mod + age + gender, data = surgical)
25+
#' ols_step_hierarchical(model, forward = FALSE)
26+
#'
27+
#' # steps
2028
#' model <- lm(y ~ bcs + alc_heavy + pindex + enzyme_test + liver_test + age + gender + alc_mod, data = surgical)
29+
#' ols_step_hierarchical(model, steps = 2)
2130
#'
2231
#' @keywords internal
2332
#'
2433
#' @noRd
2534
#'
26-
ols_step_hierarchical <- function(model, p_value = 0.1, forward = TRUE, progress = FALSE, details = FALSE) {
35+
ols_step_hierarchical <- function(model, p_value = 0.1, forward = TRUE, progress = FALSE, details = FALSE, steps = NULL) {
2736

2837
if (forward) {
29-
ols_step_hierarchical_forward(model, p_value, progress, details)
38+
ols_step_hierarchical_forward(model, p_value, progress, details, steps)
3039
} else {
31-
ols_step_hierarchical_backward(model, p_value, progress, details)
40+
ols_step_hierarchical_backward(model, p_value, progress, details, steps)
3241
}
3342

3443
}
3544

36-
ols_step_hierarchical_forward <- function(model, p_value = 0.1, progress = FALSE, details = FALSE) {
45+
ols_step_hierarchical_forward <- function(model, p_value = 0.1, progress = FALSE, details = FALSE, steps = NULL) {
3746

3847
if (details) {
3948
progress <- FALSE
@@ -56,7 +65,7 @@ ols_step_hierarchical_forward <- function(model, p_value = 0.1, progress = FALSE
5665
if (progress || details) {
5766
ols_candidate_terms(nam, "forward")
5867
}
59-
68+
6069
step <- 0
6170
rsq <- c()
6271
adjrsq <- c()
@@ -68,7 +77,7 @@ ols_step_hierarchical_forward <- function(model, p_value = 0.1, progress = FALSE
6877

6978
base_model <- lm(paste(response, "~", 1), data = l)
7079
rsq_base <- summary(base_model)$r.squared
71-
80+
7281
if (details) {
7382
ols_rsquared_init(NULL, "r2", response, rsq_base)
7483
}
@@ -77,6 +86,10 @@ ols_step_hierarchical_forward <- function(model, p_value = 0.1, progress = FALSE
7786
ols_progress_init("forward")
7887
}
7988

89+
if (!is.null(steps)) {
90+
mlen_p <- steps
91+
}
92+
8093
for (i in seq_len(mlen_p)) {
8194
predictors <- c(preds, all_pred[i])
8295
m <- lm(paste(response, "~", paste(predictors, collapse = " + ")), l)
@@ -151,7 +164,7 @@ ols_step_hierarchical_forward <- function(model, p_value = 0.1, progress = FALSE
151164

152165
}
153166

154-
ols_step_hierarchical_backward <- function(model, p_value = 0.1, progress = FALSE, details = FALSE) {
167+
ols_step_hierarchical_backward <- function(model, p_value = 0.1, progress = FALSE, details = FALSE, steps = NULL) {
155168

156169
if (details) {
157170
progress <- FALSE
@@ -166,9 +179,9 @@ ols_step_hierarchical_backward <- function(model, p_value = 0.1, progress = FALS
166179
cterms <- nam
167180

168181
if (progress || details) {
169-
ols_candidate_terms(nam, "backward")
182+
ols_candidate_terms(nam, "backward")
170183
}
171-
184+
172185
step <- 0
173186
rsq <- c()
174187
adjrsq <- c()
@@ -179,7 +192,7 @@ ols_step_hierarchical_backward <- function(model, p_value = 0.1, progress = FALS
179192
rmse <- c()
180193

181194
rsq_base <- summary(model)$r.squared
182-
195+
183196
if (details) {
184197
ols_rsquared_init(NULL, "r2", response, rsq_base)
185198
}
@@ -194,7 +207,7 @@ ols_step_hierarchical_backward <- function(model, p_value = 0.1, progress = FALS
194207
m_sum <- Anova(m)
195208
pvals <- m_sum$`Pr(>F)`[1:i]
196209
p_vals <- pvals[i]
197-
210+
198211
if (details) {
199212
d <- data.frame(predictors = predictors, p_val = pvals)
200213
ols_stepwise_table_p(d, predictors, pvals)
@@ -222,6 +235,12 @@ ols_step_hierarchical_backward <- function(model, p_value = 0.1, progress = FALS
222235
ols_stepwise_details(step, preds, rpred, response, rsq1, "removed", "rsq")
223236
}
224237

238+
if (!is.null(steps)) {
239+
if (step == steps) {
240+
break
241+
}
242+
}
243+
225244
} else {
226245

227246
if (progress || details) {

tests/testthat/test-step-backward.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ test_that("backward hierarchical selection output matches the expected result",
2929
expect_equal(k$metrics$variable, c("alc_mod", "gender", "age", "liver_test"), ignore_attr = TRUE)
3030
})
3131

32+
test_that("backward hierarchical selection stops after a specific number of steps", {
33+
model <- lm(y ~ bcs + alc_heavy + pindex + enzyme_test + liver_test + age + gender + alc_mod, data = surgical)
34+
k <- ols_step_backward_p(model, 0.1, hierarchical = TRUE, steps = 2)
35+
expect_equal(k$metrics$step, 1:2)
36+
expect_equal(k$metrics$variable, c("alc_mod", "gender"), ignore_attr = TRUE)
37+
})
38+
3239
test_that("backward elimination output matches the expected result when steps in specified", {
3340
model <- lm(y ~ bcs + alc_heavy + pindex + enzyme_test + liver_test + age + gender + alc_mod, data = surgical)
3441
k <- ols_step_backward_p(model, steps = 2)

tests/testthat/test-step-forward.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ test_that("forward hierarchical selection output matches the expected result", {
2828
expect_equal(k$metrics$variable, c("bcs", "alc_heavy", "pindex", "enzyme_test"), ignore_attr = TRUE)
2929
})
3030

31+
test_that("forward hierarchical selection stops after a specific number of steps", {
32+
model <- lm(y ~ bcs + alc_heavy + pindex + enzyme_test + liver_test + age + gender + alc_mod, data = surgical)
33+
k <- ols_step_forward_p(model, 0.1, hierarchical = TRUE, steps = 2)
34+
expect_equal(k$metrics$step, 1:2)
35+
expect_equal(k$metrics$variable, c("bcs", "alc_heavy"), ignore_attr = TRUE)
36+
})
37+
3138
test_that("stepwise forward regression stops after a specific number of steps", {
3239
model <- lm(y ~ x1 + x2 + x3 + x4, data = cement)
3340
k <- ols_step_forward_p(model, steps = 2)

0 commit comments

Comments
 (0)