Skip to content

Commit cfca53b

Browse files
committed
feat: support input transformations for features
1 parent d89fe57 commit cfca53b

7 files changed

+76
-13
lines changed

R/SurrogateLearner.R

+20-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
#' Can be `"mean"` to use mean imputation or `"random"` to sample values uniformly at random between the empirical minimum and maximum.
1616
#' Default is `"random"`.
1717
#' }
18+
#' \item{`input_trafo`}{`character(1)`\cr
19+
#' Which input transformation should be applied to numeric and integer features?
20+
#' Can be `"none"` for no transformation or `"unitcube"` to perform for each feature a min-max scaling to `[0, 1]` based on the boundaries of the search space.
21+
#' Default is `"none"`.
22+
#' }
1823
#' }
1924
#'
2025
#' @export
@@ -74,9 +79,10 @@ SurrogateLearner = R6Class("SurrogateLearner",
7479

7580
ps = ps(
7681
catch_errors = p_lgl(),
77-
impute_method = p_fct(c("mean", "random"), default = "random")
82+
impute_method = p_fct(c("mean", "random"), default = "random"),
83+
input_trafo = p_fct(c("none", "unitcube"), default = "none")
7884
)
79-
ps$values = list(catch_errors = TRUE, impute_method = "random")
85+
ps$values = list(catch_errors = TRUE, impute_method = "random", input_trafo = "none")
8086

8187
super$initialize(learner = learner, archive = archive, cols_x = cols_x, cols_y = col_y, param_set = ps)
8288
},
@@ -90,7 +96,10 @@ SurrogateLearner = R6Class("SurrogateLearner",
9096
#' @return [data.table::data.table()] with the columns `mean` and `se`.
9197
predict = function(xdt) {
9298
assert_xdt(xdt)
93-
xdt = fix_xdt_missing(xdt, cols_x = self$cols_x, archive = self$archive)
99+
xdt = fix_xdt_missing(copy(xdt), cols_x = self$cols_x, archive = self$archive)
100+
if (self$param_set$values$input_trafo == "unitcube") {
101+
xdt = input_trafo_unitcube(xdt, search_space = self$archive$search_space)
102+
}
94103

95104
pred = self$learner$predict_newdata(newdata = xdt)
96105
if (self$learner$predict_type == "se") {
@@ -157,7 +166,10 @@ SurrogateLearner = R6Class("SurrogateLearner",
157166
private = list(
158167
# Train learner with new data.
159168
.update = function() {
160-
xydt = self$archive$data[, c(self$cols_x, self$cols_y), with = FALSE]
169+
xydt = copy(self$archive$data[, c(self$cols_x, self$cols_y), with = FALSE])
170+
if (self$param_set$values$input_trafo == "unitcube") {
171+
xydt = input_trafo_unitcube(xydt, search_space = self$archive$search_space)
172+
}
161173
task = TaskRegr$new(id = "surrogate_task", backend = xydt, target = self$cols_y)
162174
assert_learnable(task, learner = self$learner)
163175
self$learner$train(task)
@@ -166,7 +178,10 @@ SurrogateLearner = R6Class("SurrogateLearner",
166178
# Train learner with new data.
167179
# Operates on an asynchronous archive and performs imputation as needed.
168180
.update_async = function() {
169-
xydt = self$archive$rush$fetch_tasks_with_state(states = c("queued", "running", "finished"))[, c(self$cols_x, self$cols_y, "state"), with = FALSE]
181+
xydt = copy(self$archive$rush$fetch_tasks_with_state(states = c("queued", "running", "finished"))[, c(self$cols_x, self$cols_y, "state"), with = FALSE])
182+
if (self$param_set$values$input_trafo == "unitcube") {
183+
xydt = input_trafo_unitcube(xydt, search_space = self$archive$search_space)
184+
}
170185
if (self$param_set$values$impute_method == "mean") {
171186
mean_y = mean(xydt[[self$cols_y]], na.rm = TRUE)
172187
xydt[c("queued", "running"), (self$cols_y) := mean_y, on = "state"]

R/SurrogateLearnerCollection.R

+20-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
#' Can be `"mean"` to use mean imputation or `"random"` to sample values uniformly at random between the empirical minimum and maximum.
1818
#' Default is `"random"`.
1919
#' }
20+
#' \item{`input_trafo`}{`character(1)`\cr
21+
#' Which input transformation should be applied to numeric and integer features?
22+
#' Can be `"none"` for no transformation or `"unitcube"` to perform for each feature a min-max scaling to `[0, 1]` based on the boundaries of the search space.
23+
#' Default is `"none"`.
24+
#' }
2025
#' }
2126
#'
2227
#' @export
@@ -89,9 +94,10 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
8994

9095
ps = ps(
9196
catch_errors = p_lgl(),
92-
impute_method = p_fct(c("mean", "random"), default = "random")
97+
impute_method = p_fct(c("mean", "random"), default = "random"),
98+
input_trafo = p_fct(c("none", "unitcube"), default = "none")
9399
)
94-
ps$values = list(catch_errors = TRUE, impute_method = "random")
100+
ps$values = list(catch_errors = TRUE, impute_method = "random", input_trafo = "none")
95101

96102
super$initialize(learner = learners, archive = archive, cols_x = cols_x, cols_y = cols_y, param_set = ps)
97103
},
@@ -107,7 +113,10 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
107113
#' @return list of [data.table::data.table()]s with the columns `mean` and `se`.
108114
predict = function(xdt) {
109115
assert_xdt(xdt)
110-
xdt = fix_xdt_missing(xdt, cols_x = self$cols_x, archive = self$archive)
116+
xdt = fix_xdt_missing(copy(xdt), cols_x = self$cols_x, archive = self$archive)
117+
if (self$param_set$values$input_trafo == "unitcube") {
118+
xdt = input_trafo_unitcube(xdt, search_space = self$archive$search_space)
119+
}
111120

112121
preds = lapply(self$learner, function(learner) {
113122
pred = learner$predict_newdata(newdata = xdt)
@@ -185,7 +194,10 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
185194
.update = function() {
186195
assert_true((length(self$cols_y) == length(self$learner)) || length(self$cols_y) == 1L) # either as many cols_y as learner or only one
187196
one_to_multiple = length(self$cols_y) == 1L
188-
xydt = self$archive$data[, c(self$cols_x, self$cols_y), with = FALSE]
197+
xydt = copy(self$archive$data[, c(self$cols_x, self$cols_y), with = FALSE])
198+
if (self$param_set$values$input_trafo == "unitcube") {
199+
xydt = input_trafo_unitcube(xydt, search_space = self$archive$search_space)
200+
}
189201
features = setdiff(names(xydt), self$cols_y)
190202
tasks = lapply(self$cols_y, function(col_y) {
191203
# if this turns out to be a bottleneck, we can also operate on a single task here
@@ -214,7 +226,10 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
214226
assert_true((length(self$cols_y) == length(self$learner)) || length(self$cols_y) == 1L) # either as many cols_y as learner or only one
215227
one_to_multiple = length(self$cols_y) == 1L
216228

217-
xydt = self$archive$rush$fetch_tasks_with_state(states = c("queued", "running", "finished"))[, c(self$cols_x, self$cols_y, "state"), with = FALSE]
229+
xydt = copy(self$archive$rush$fetch_tasks_with_state(states = c("queued", "running", "finished"))[, c(self$cols_x, self$cols_y, "state"), with = FALSE])
230+
if (self$param_set$values$input_trafo == "unitcube") {
231+
xydt = input_trafo_unitcube(xydt, search_space = self$archive$search_space)
232+
}
218233
if (self$param_set$values$impute_method == "mean") {
219234
walk(self$cols_y, function(col) {
220235
mean_y = mean(xydt[[col]], na.rm = TRUE)

R/helper.R

+8-1
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,17 @@ assert_xdt = function(xdt) {
172172
assert_learner_surrogate = function(x, .var.name = vname(x)) {
173173
# NOTE: this is buggy in checkmate; assert should always return x invisible not TRUE as is the case here
174174
assert(check_learner_surrogate(x), .var.name = .var.name)
175-
176175
x
177176
}
178177

178+
input_trafo_unitcube = function(xydt, search_space) {
179+
parameters = names(which(search_space$is_number)) # numeric or integer
180+
for (parameter in parameters) {
181+
set(xydt, j = parameter, value = (xydt[[parameter]] - search_space$lower[[parameter]]) / (search_space$upper[[parameter]] - search_space$lower[[parameter]]))
182+
}
183+
xydt
184+
}
185+
179186
#' Check if Redis Server is Available
180187
#'
181188
#' Attempts to establish a connection to a Redis server using the \CRANpkg{redux} package

man/SurrogateLearner.Rd

+5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/SurrogateLearnerCollection.Rd

+5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_SurrogateLearner.R

+9-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ test_that("SurrogateLearner API works", {
2626
surrogate$learner$predict_type = "response"
2727
expect_equal(surrogate$predict_type, surrogate$learner$predict_type)
2828
expect_error({surrogate$predict_type = "response"}, "is read-only")
29+
30+
# unitcube input transformation for numeric and integer features
31+
surrogate = SurrogateLearner$new(learner = REGR_FEATURELESS, archive = inst$archive)
32+
surrogate$param_set$values$input_trafo = "unitcube"
33+
surrogate$update()
34+
expect_learner(surrogate$learner)
35+
expect_data_table(surrogate$predict(xdt), col.names = "named", nrows = 5, ncols = 2, any.missing = FALSE)
2936
})
3037

3138
test_that("predict_types are recognized", {
@@ -50,9 +57,10 @@ test_that("param_set", {
5057
inst = MAKE_INST_1D()
5158
surrogate = SurrogateLearner$new(learner = REGR_FEATURELESS, archive = inst$archive)
5259
expect_r6(surrogate$param_set, "ParamSet")
53-
expect_setequal(surrogate$param_set$ids(), c("catch_errors", "impute_method"))
60+
expect_setequal(surrogate$param_set$ids(), c("catch_errors", "impute_method", "input_trafo"))
5461
expect_equal(surrogate$param_set$class[["catch_errors"]], "ParamLgl")
5562
expect_equal(surrogate$param_set$class[["impute_method"]], "ParamFct")
63+
expect_equal(surrogate$param_set$class[["input_trafo"]], "ParamFct")
5664
expect_error({surrogate$param_set = list()}, regexp = "param_set is read-only.")
5765
})
5866

tests/testthat/test_SurrogateLearnerCollection.R

+9-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ test_that("SurrogateLearnerCollection API works", {
3535
expect_equal(surrogate$predict_type, surrogate$learner[[2L]]$predict_type)
3636
expect_error({surrogate$predict_type = "response"}, "is read-only")
3737

38+
# unitcube input transformation for numeric and integer features
39+
surrogate = SurrogateLearnerCollection$new(learners = list(REGR_FEATURELESS, REGR_FEATURELESS$clone(deep = TRUE)), archive = inst$archive)
40+
surrogate$param_set$values$input_trafo = "unitcube"
41+
surrogate$update()
42+
expect_learner(surrogate$learner[[1L]])
43+
expect_learner(surrogate$learner[[2L]])
44+
expect_list(surrogate$predict(xdt), len = 2L)
3845
})
3946

4047
test_that("predict_types are recognized", {
@@ -60,9 +67,10 @@ test_that("param_set", {
6067
inst = MAKE_INST(OBJ_1D_2, PS_1D, trm("evals", n_evals = 5L))
6168
surrogate = SurrogateLearnerCollection$new(learners = list(REGR_FEATURELESS, REGR_FEATURELESS$clone(deep = TRUE)), archive = inst$archive)
6269
expect_r6(surrogate$param_set, "ParamSet")
63-
expect_setequal(surrogate$param_set$ids(), c("catch_errors", "impute_method"))
70+
expect_setequal(surrogate$param_set$ids(), c("catch_errors", "impute_method", "input_trafo"))
6471
expect_equal(surrogate$param_set$class[["catch_errors"]], "ParamLgl")
6572
expect_equal(surrogate$param_set$class[["impute_method"]], "ParamFct")
73+
expect_equal(surrogate$param_set$class[["input_trafo"]], "ParamFct")
6674
expect_error({surrogate$param_set = list()}, regexp = "param_set is read-only.")
6775
})
6876

0 commit comments

Comments
 (0)