-
-
Notifications
You must be signed in to change notification settings - Fork 23
Description
Hi,
I made some effort to train and tune survival SVMs in a small dataset. Using a simple autotune example, I found out that the SVM survival learner can either fail (some fault with the optimization solvers I think) or get stuck (training never ends, CPU at 100%). I used a lrn('surv.kaplan') as a fallback learner and added a learner$timeout to deal with these issues but I think that this instability is a bad sign for a learner. These issues mostly relate to the choice of type: whenever it's not regression there is a high chance that you will face such issues (C-indexes are close to 0.5 in the example below from using the kaplan estimator). I have seen the SVM learner fail also when type=regression (more sparsely).
I post the following tuning example here so that others benefit from this investigation. Commenting the learner$fallback and learner$timeout lines will lead to the issues I mentioned.
library(mlr3verse)
#> Loading required package: mlr3
library(mlr3proba)
library(survivalsvm)
#> Loading required package: survival
set.seed(42)
task = as_task_surv(x = veteran, time = 'time', event = 'status')
poe = po('encode')
task = poe$train(list(task))[[1]]
train_indxs = sample(seq_len(nrow(veteran)), 120)
test_indxs = setdiff(seq_len(nrow(veteran)), train_indxs)
learner = lrn('surv.svm',
type = to_tune(c('regression', 'vanbelle1', 'vanbelle2', 'hybrid')),
diff.meth = to_tune(c('makediff1', 'makediff2', 'makediff3')),
gamma.mu = to_tune(ps(
gamma = p_dbl(1e-03, 10, logscale = TRUE),
mu = p_dbl(1e-03, 10, logscale = TRUE, depends = type == 'hybrid'),
.extra_trafo = function(x, param_set) {
list(gamma.mu = c(x$gamma, x$mu))
},
.allow_dangling_dependencies = TRUE
)),
kernel = to_tune(c('lin_kernel', 'add_kernel', 'rbf_kernel', 'poly_kernel'))
)
# saves you from when the learner crashes
learner$fallback = lrn('surv.kaplan')
# saves you from when the learner is stuck
learner$timeout = c('train' = 1, 'predict' = Inf)
#learner$param_set$values$eig.tol = 1e-03
#learner$param_set$values$conv.tol = 1e-03
#learner$param_set$values$posd.tol = 1e-03
#learner$param_set$values$opt.meth = 'ipop'
#learner$param_set$values$sigf = 2
#generate_design_random(learner$param_set$search_space(), 20)
generate_design_random(learner$param_set$search_space(), 3)$transpose()
#> [[1]]
#> [[1]]$type
#> [1] "hybrid"
#>
#> [[1]]$diff.meth
#> [1] "makediff3"
#>
#> [[1]]$kernel
#> [1] "lin_kernel"
#>
#> [[1]]$gamma.mu
#> [1] 0.01853109 0.97598798
#>
#>
#> [[2]]
#> [[2]]$type
#> [1] "vanbelle2"
#>
#> [[2]]$diff.meth
#> [1] "makediff3"
#>
#> [[2]]$kernel
#> [1] "add_kernel"
#>
#> [[2]]$gamma.mu
#> [1] 0.01089036
#>
#>
#> [[3]]
#> [[3]]$type
#> [1] "hybrid"
#>
#> [[3]]$diff.meth
#> [1] "makediff3"
#>
#> [[3]]$kernel
#> [1] "lin_kernel"
#>
#> [[3]]$gamma.mu
#> [1] 0.931249 1.488555
ssvm_at = AutoTuner$new(
learner = learner,
resampling = rsmp('cv', folds = 5),
measure = msr('surv.cindex'),
terminator = trm('evals', n_evals = 10),
tuner = tnr('random_search'))
ssvm_at$train(task)
#> INFO [15:25:11.388] [bbotk] Starting to optimize 5 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
#> INFO [15:25:11.436] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:11.462] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:11.503] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:11.844] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:12.144] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:12.450] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:12.765] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:13.232] [mlr3] Finished benchmark
#> INFO [15:25:13.261] [bbotk] Result of batch 1:
#> INFO [15:25:13.263] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:13.263] [bbotk] regression <NA> 1.99383 NA lin_kernel 0.6893636 0 0
#> INFO [15:25:13.263] [bbotk] runtime_learners uhash
#> INFO [15:25:13.263] [bbotk] 1.597 16669f80-5de6-4c79-a768-08928c934405
#> INFO [15:25:13.271] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:13.289] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:13.293] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:16.780] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:20.174] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:23.868] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:27.330] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:30.758] [mlr3] Finished benchmark
#> INFO [15:25:30.785] [bbotk] Result of batch 2:
#> INFO [15:25:30.787] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:30.787] [bbotk] vanbelle2 makediff2 -3.545851 NA rbf_kernel 0.5 0 5
#> INFO [15:25:30.787] [bbotk] runtime_learners uhash
#> INFO [15:25:30.787] [bbotk] NA c98a2b5b-9cb6-4b2d-8669-4aea5e396cde
#> INFO [15:25:30.796] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:30.823] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:30.828] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:31.536] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:32.367] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:33.067] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:33.807] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:34.653] [mlr3] Finished benchmark
#> INFO [15:25:34.679] [bbotk] Result of batch 3:
#> INFO [15:25:34.681] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:34.681] [bbotk] hybrid makediff1 -6.114898 1.024288 rbf_kernel 0.5238242 0 0
#> INFO [15:25:34.681] [bbotk] runtime_learners uhash
#> INFO [15:25:34.681] [bbotk] 3.703 3edbcfe6-a153-47d5-b6f2-818d504735b4
#> INFO [15:25:34.692] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:34.714] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:34.719] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:38.152] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:41.555] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:45.237] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:48.701] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:52.225] [mlr3] Finished benchmark
#> INFO [15:25:52.255] [bbotk] Result of batch 4:
#> INFO [15:25:52.256] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:52.256] [bbotk] vanbelle2 makediff2 1.982577 NA rbf_kernel 0.5 0 5
#> INFO [15:25:52.256] [bbotk] runtime_learners uhash
#> INFO [15:25:52.256] [bbotk] NA 0f54c084-42ca-4e08-bc4b-818bc7922e5f
#> INFO [15:25:52.265] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:52.286] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:52.292] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:55.867] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:59.550] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:03.482] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:07.046] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:10.614] [mlr3] Finished benchmark
#> INFO [15:26:10.642] [bbotk] Result of batch 5:
#> INFO [15:26:10.644] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:10.644] [bbotk] vanbelle2 makediff2 -3.050726 NA lin_kernel 0.5 0 5
#> INFO [15:26:10.644] [bbotk] runtime_learners uhash
#> INFO [15:26:10.644] [bbotk] NA 1d365cc0-c18b-42bc-92b0-fecac6aaac4d
#> INFO [15:26:10.653] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:10.670] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:10.676] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:10.932] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:11.186] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:11.435] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:11.676] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:11.925] [mlr3] Finished benchmark
#> INFO [15:26:11.954] [bbotk] Result of batch 6:
#> INFO [15:26:11.955] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:11.955] [bbotk] regression <NA> -5.757422 NA lin_kernel 0.6854107 0 0
#> INFO [15:26:11.955] [bbotk] runtime_learners uhash
#> INFO [15:26:11.955] [bbotk] 1.127 cde21793-0b12-4566-8e8b-2bb756563a27
#> INFO [15:26:11.965] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:11.988] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:11.996] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:12.304] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:12.608] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:12.900] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:13.192] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:13.475] [mlr3] Finished benchmark
#> INFO [15:26:13.503] [bbotk] Result of batch 7:
#> INFO [15:26:13.504] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:13.504] [bbotk] regression <NA> 0.2568419 NA lin_kernel 0.6893636 0 0
#> INFO [15:26:13.504] [bbotk] runtime_learners uhash
#> INFO [15:26:13.504] [bbotk] 1.352 f4e9fa94-2b73-4d39-8e67-f864d3a7c71b
#> INFO [15:26:13.513] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:13.531] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:13.536] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:14.563] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:15.517] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:16.501] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:17.924] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:19.006] [mlr3] Finished benchmark
#> INFO [15:26:19.041] [bbotk] Result of batch 8:
#> INFO [15:26:19.043] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:19.043] [bbotk] hybrid makediff3 -1.907343 -6.24123 add_kernel 0.5645394 0 1
#> INFO [15:26:19.043] [bbotk] runtime_learners uhash
#> INFO [15:26:19.043] [bbotk] NA 8ae16fcb-7969-420b-af39-56a0ce68a74c
#> INFO [15:26:19.053] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:19.072] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:19.077] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:22.564] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:26.089] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:29.908] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:33.430] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:36.925] [mlr3] Finished benchmark
#> INFO [15:26:36.953] [bbotk] Result of batch 9:
#> INFO [15:26:36.955] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:36.955] [bbotk] vanbelle2 makediff2 -1.883382 NA lin_kernel 0.5 0 5
#> INFO [15:26:36.955] [bbotk] runtime_learners uhash
#> INFO [15:26:36.955] [bbotk] NA 5a765b5a-0741-4f75-95c3-9096c3916b65
#> INFO [15:26:36.965] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:36.983] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:36.988] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:37.058] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:37.137] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:37.210] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:37.284] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:37.357] [mlr3] Finished benchmark
#> INFO [15:26:37.386] [bbotk] Result of batch 10:
#> INFO [15:26:37.388] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:37.388] [bbotk] vanbelle1 makediff1 -5.990032 NA rbf_kernel 0.5337007 0 0
#> INFO [15:26:37.388] [bbotk] runtime_learners uhash
#> INFO [15:26:37.388] [bbotk] 0.242 d249648b-6561-4155-9347-243b96263347
#> INFO [15:26:37.410] [bbotk] Finished optimizing after 10 evaluation(s)
#> INFO [15:26:37.410] [bbotk] Result:
#> INFO [15:26:37.412] [bbotk] type diff.meth gamma mu kernel learner_param_vals x_domain
#> INFO [15:26:37.412] [bbotk] regression <NA> 1.99383 NA lin_kernel <list[3]> <list[3]>
#> INFO [15:26:37.412] [bbotk] surv.cindex
#> INFO [15:26:37.412] [bbotk] 0.6893636Created on 2022-08-15 by the reprex package (v2.0.1)