Skip to content

Commit 365d281

Browse files
authored
Replace and homogenize messaging with {cli} (#182)
* improve hardhat messages * switch messages to cli_ messages * add cli to wordlist
1 parent a3107f2 commit 365d281

15 files changed

Lines changed: 379 additions & 308 deletions

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## New features
44

5+
* messaging improved with {cli}
56
* add optimal threshold and support size into new 1.5 alpha `entmax15()` and `sparsemax15()`
67
`mask_types`. Add an optional `mask_topk` config parameter. (#180)
78
* tabnet is now using the `torch_ignite_adam` when available.

R/dials.R

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
check_dials <- function() {
22
if (!requireNamespace("dials", quietly = TRUE))
3-
stop("Package \"dials\" needed for this function to work. Please install it.", call. = FALSE)
3+
runtime_error("Package {.pkg dials} is needed for this function to work. Please install it.")
44
}
55

6-
check_cli <- function() {
7-
if (!requireNamespace("cli", quietly = TRUE))
8-
stop("Package \"cli\" needed for this function to work. Please install it.", call. = FALSE)
9-
}
10-
11-
126

137
#' Parameters for the tabnet model
148
#'
@@ -91,7 +85,7 @@ mask_type <- function(values = c("sparsemax", "entmax")) {
9185
dials::new_qual_param(
9286
type = "character",
9387
values = values,
94-
label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"),
88+
label = c(mask_type = "Final layer of feature selector, either 'sparsemax' or 'entmax'"),
9589
finalize = NULL
9690
)
9791
}
@@ -145,7 +139,6 @@ num_steps <- function(range = c(3L, 10L), trans = NULL) {
145139
#' @rdname tabnet_non_tunable
146140
#' @export
147141
cat_emb_dim <- function(range = NULL, trans = NULL) {
148-
check_cli()
149142
cli::cli_abort("{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet.")
150143
}
151144

R/explain.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ tabnet_explain <- function(object, new_data) {
4444
#' @export
4545
#' @rdname tabnet_explain
4646
tabnet_explain.default <- function(object, new_data) {
47-
stop(domain=NA,
48-
gettextf("`tabnet_explain()` is not defined for a '%s'.", class(object)[1]),
49-
call. = FALSE)
47+
type_error("{.fn tabnet_explain} is not defined for a {.type {class(object)[1]}}.")
5048
}
5149

5250
#' @export

R/hardhat.R

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ tabnet_fit <- function(x, ...) {
108108
#' @export
109109
#' @rdname tabnet_fit
110110
tabnet_fit.default <- function(x, ...) {
111-
stop(domain=NA,
112-
gettextf("`tabnet_fit()` is not defined for a '%s'.", class(x)[1]),
113-
call. = FALSE)
111+
type_error("{.fn tabnet_fit} is not defined for a {.type {class(x)[1])}}.")
114112
}
115113

116114
#' @export
@@ -267,9 +265,7 @@ tabnet_pretrain <- function(x, ...) {
267265
#' @export
268266
#' @rdname tabnet_pretrain
269267
tabnet_pretrain.default <- function(x, ...) {
270-
stop(domain=NA,
271-
gettextf("`tabnet_pretrain()` is not defined for a '%s'.", class(x)[1]),
272-
call. = FALSE)
268+
type_error("{.fn tabnet_pretrain} is not defined for a {.type {class(x)[1])}}.")
273269
}
274270

275271

@@ -348,14 +344,13 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
348344
epoch_shift <- 0L
349345

350346
if (!(is.null(tabnet_model) || inherits(tabnet_model, "tabnet_fit") || inherits(tabnet_model, "tabnet_pretrain")))
351-
stop(gettextf("'%s' is not recognised as a proper TabNet model", tabnet_model),
352-
call. = FALSE)
347+
type_error("{.var {tabnet_model}} is not recognised as a proper TabNet model")
353348

354349
if (!is.null(from_epoch) && !is.null(tabnet_model)) {
355350
# model must be loaded from checkpoint
356351

357352
if (from_epoch > (length(tabnet_model$fit$checkpoints) * tabnet_model$fit$config$checkpoint_epoch))
358-
stop(gettextf("The model was trained for less than '%s' epochs", from_epoch), call. = FALSE)
353+
value_error("The model was trained for less than {.val {from_epoch}} epochs")
359354

360355
# find closest checkpoint for that epoch
361356
closest_checkpoint <- from_epoch %/% tabnet_model$fit$config$checkpoint_epoch
@@ -367,7 +362,7 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
367362
}
368363
if (task == "supervised") {
369364
if (sum(is.na(outcomes)) > 0) {
370-
stop(gettextf("Found missing values in the `%s` outcome column.", names(outcomes)), call. = FALSE)
365+
value_error("Found missing values in the {.var {names(outcomes)}} outcome column.")
371366
}
372367
if (is.null(tabnet_model)) {
373368
# new supervised model needs network initialization
@@ -377,7 +372,7 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
377372
} else if (!check_net_is_empty_ptr(tabnet_model) && inherits(tabnet_model, "tabnet_fit")) {
378373
# resume training from supervised
379374
if (!identical(processed$blueprint, tabnet_model$blueprint))
380-
stop("Model dimensions don't match.", call. = FALSE)
375+
runtime_error("Model dimensions don't match.")
381376

382377
# model is available from tabnet_model$serialized_net
383378
m <- reload_model(tabnet_model$serialized_net)
@@ -402,15 +397,16 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
402397
tabnet_model$fit$network <- reload_model(tabnet_model$fit$checkpoints[[last_checkpoint]])
403398
epoch_shift <- last_checkpoint * tabnet_model$fit$config$checkpoint_epoch
404399

405-
} else stop(gettextf("No model serialized weight can be found in `%s`, check the model history", tabnet_model), call. = FALSE)
400+
} else runtime_error("No model serialized weight can be found in {.var {tabnet_model}}, check the model history")
406401

407402
fit_lst <- tabnet_train_supervised(tabnet_model, predictors, outcomes, config = config, epoch_shift)
408403
return(new_tabnet_fit(fit_lst, blueprint = processed$blueprint))
409404

410405
} else if (task == "unsupervised") {
411406

412407
if (!is.null(tabnet_model)) {
413-
warning("`tabnet_pretrain()` from a model is not currently supported.\nThe pretraining here will start with a network initialization")
408+
warn("Using {.fn tabnet_pretrain} from a model is not currently supported.",
409+
"Pretraining will start from a new network initialization")
414410
}
415411
pretrain_lst <- tabnet_train_unsupervised( predictors, config = config, epoch_shift)
416412
return(new_tabnet_pretrain(pretrain_lst, blueprint = processed$blueprint))
@@ -447,7 +443,7 @@ predict_tabnet_bridge <- function(type, object, predictors, epoch, batch_size) {
447443
if (!is.null(epoch)) {
448444

449445
if (epoch > (length(object$fit$checkpoints) * object$fit$config$checkpoint_epoch))
450-
stop(gettextf("The model was trained for less than `%s` epochs", epoch), call. = FALSE)
446+
value_error("The model was trained for less than {.val {epoch}} epochs")
451447

452448
# find closest checkpoint for that epoch
453449
ind <- epoch %/% object$fit$config$checkpoint_epoch
@@ -485,7 +481,7 @@ model_pretrain_to_fit <- function(obj, x, y, config = tabnet_config()) {
485481
m <- reload_model(obj$serialized_net)
486482

487483
if (m$input_dim != tabnet_model_lst$network$input_dim)
488-
stop("Model dimensions don't match.", call. = FALSE)
484+
runtime_error("Model dimensions don't match.")
489485

490486
# perform update of selected weights into new tabnet_model
491487
m_stat_dict <- m$state_dict()
@@ -523,25 +519,25 @@ check_type <- function(outcome_ptype, type = NULL) {
523519
outcome_all_numeric <- all(purrr::map_lgl(outcome_ptype, is.numeric))
524520

525521
if (!outcome_all_numeric && !outcome_all_factor)
526-
stop(gettextf("Mixed multi-outcome type '%s' is not supported", unique(purrr::map_chr(outcome_ptype, ~class(.x)[[1]]))), call. = FALSE)
522+
not_implemented_error("Mixed multi-outcome type {.type {unique(purrr::map_chr(outcome_ptype, ~class(.x)[[1]]))}} is not supported")
527523

528524
if (is.null(type)) {
529525
if (outcome_all_factor)
530526
type <- "class"
531527
else if (outcome_all_numeric)
532528
type <- "numeric"
533529
else if (ncol(outcome_ptype) == 1)
534-
stop(gettextf("Unknown outcome type '%s'", class(outcome_ptype)), call. = FALSE)
530+
type_error("Unknown outcome type {.type {class(outcome_ptype)}}")
535531
}
536532

537533
type <- rlang::arg_match(type, c("numeric", "prob", "class"))
538534

539535
if (outcome_all_factor) {
540536
if (!type %in% c("prob", "class"))
541-
stop(gettextf("Outcome is factor and the prediction type is '%s'.", type), call. = FALSE)
537+
type_error("Outcome is factor and the prediction type is {.type {type}}.")
542538
} else if (outcome_all_numeric) {
543539
if (type != "numeric")
544-
stop(gettextf("Outcome is numeric and the prediction type is '%s'.", type), call. = FALSE)
540+
type_error("Outcome is numeric and the prediction type is {.type {type}}.")
545541
}
546542

547543
invisible(type)

R/model.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ resolve_loss <- function(config, dtype) {
253253
# cross entropy loss is required
254254
loss_fn <- torch::nn_cross_entropy_loss()
255255
else
256-
stop(gettextf("`%s` is not a valid loss for outcome of type %s", loss, dtype), call. = FALSE)
256+
value_error("{.val {loss}} is not a valid loss for outcome of type {.type {dtype}}")
257257

258258
loss_fn
259259
}
@@ -264,7 +264,7 @@ resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) {
264264
else if (early_stopping_monitor %in% c("train_loss", "auto"))
265265
early_stopping_monitor <- "train_loss"
266266
else
267-
stop(gettextf("%s is not a valid early-stopping metric to monitor with `valid_split` = %s", early_stopping_monitor, valid_split), call. = FALSE)
267+
value_error("{.val {early_stopping_monitor}} is not a valid early-stopping metric to monitor with {.val valid_split = {valid_split}}")
268268

269269
early_stopping_monitor
270270
}
@@ -516,11 +516,11 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
516516
config$ancestor_tt <- torch::torch_tensor(config$ancestor)$to(torch::torch_bool(), device = device)
517517
}
518518

519-
# instanciate optimizer
519+
# instantiate optimizer
520520
if (is_optim_generator(config$optimizer)) {
521521
optimizer <- config$optimizer(network$parameters, config$learn_rate)
522522
} else {
523-
stop("`optimizer` must be resolved into a torch optimizer generator.", call. = FALSE)
523+
type_error("{.var optimizer} must be resolved into a torch optimizer generator.")
524524
}
525525

526526
# define scheduler
@@ -533,7 +533,7 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
533533
} else if (config$lr_scheduler == "step") {
534534
scheduler <- torch::lr_step(optimizer, config$step_size, config$lr_decay)
535535
} else {
536-
stop("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.", call. = FALSE)
536+
not_implemented_error("Currently only the {.str step} and {.str reduce_on_plateau} scheduler are supported.", call. = FALSE)
537537
}
538538

539539
# restore previous metrics & checkpoints
@@ -598,7 +598,7 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
598598
patience_counter <- patience_counter + 1
599599
if (patience_counter >= config$early_stopping_patience){
600600
if (config$verbose)
601-
message(gettextf("Early stopping at epoch %03d", epoch))
601+
cli::cli_alert_success(gettextf("Early-stopping at epoch {.val epoch}"))
602602
break
603603
}
604604
} else {
@@ -623,10 +623,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
623623
if(!config$skip_importance) {
624624
importance_sample_size <- config$importance_sample_size
625625
if (is.null(config$importance_sample_size) && train_ds$.length() > 1e5) {
626-
warning(
627-
gettextf(
628-
"Computing importances for a dataset with size %s. This can consume too much memory. We are going to use a sample of size 1e5, You can disable this message by using the `importance_sample_size` argument.",
629-
train_ds$.length()))
626+
warn("Computing importances for a dataset with size {.val {train_ds$.length()}}.
627+
This can consume too much memory. We are going to use a sample of size 1e5.
628+
You can disable this message by using the `importance_sample_size` argument.")
630629
importance_sample_size <- 1e5
631630
}
632631
indexes <- as.numeric(torch::torch_randint(
@@ -643,6 +642,7 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
643642
} else {
644643
importances <- NULL
645644
}
645+
646646
list(
647647
network = network,
648648
metrics = metrics,

R/parsnip.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL,
471471
) {
472472

473473
if (!requireNamespace("parsnip", quietly = TRUE))
474-
stop("Package \"parsnip\" needed for this function to work. Please install it.", call. = FALSE)
475-
474+
runtime_error("Package {.pkg parsnip} is needed for this function to work. Please install it.")
475+
476476
if (parsnip_is_missing_tabnet(tabnet_env)) {
477477
add_parsnip_tabnet()
478478
tabnet_env$parsnip_added <- TRUE

R/pretraining.R

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,13 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
114114

115115
network$to(device = device)
116116

117-
# instanciate optimizer
117+
# instantiate optimizer
118118
if (is_optim_generator(config$optimizer)) {
119119
optimizer <- config$optimizer(network$parameters, config$learn_rate)
120-
} else
121-
stop("`optimizer` must be resolved into a torch optimizer generator.", call. = FALSE)
122-
120+
} else {
121+
type_error("{.var optimizer} must be resolved into a torch optimizer generator.")
122+
}
123+
123124

124125
# define scheduler
125126
if (is.null(config$lr_scheduler)) {
@@ -131,7 +132,7 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
131132
} else if (config$lr_scheduler == "step") {
132133
scheduler <- torch::lr_step(optimizer, config$step_size, config$lr_decay)
133134
} else {
134-
stop("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.", call. = FALSE)
135+
not_implemented_error("Currently only the {.str step} and {.str reduce_on_plateau} scheduler are supported.", call. = FALSE)
135136
}
136137

137138
# initialize metrics & checkpoints
@@ -195,7 +196,7 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
195196
patience_counter <- patience_counter + 1
196197
if (patience_counter >= config$early_stopping_patience) {
197198
if (config$verbose)
198-
rlang::inform(sprintf("Early stopping at epoch %03d", epoch))
199+
cli::cli_alert_success(gettextf("Early-stopping at epoch {.val epoch}"))
199200
break
200201
}
201202
} else {
@@ -217,26 +218,28 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
217218
}
218219

219220
network$to(device = "cpu")
220-
221-
importance_sample_size <- config$importance_sample_size
222-
if (is.null(config$importance_sample_size) && train_ds$.length() > 1e5) {
223-
warning(domain=NA,
224-
gettextf("Computing importances for a dataset with size %s. This can consume too much memory. We are going to use a sample of size 1e5. You can disable this message by using the `importance_sample_size` argument.", train_ds$.length()),
225-
call. = FALSE)
226-
importance_sample_size <- 1e5
227-
}
228-
indexes <- as.numeric(torch::torch_randint(
229-
1, train_ds$.length(), min(importance_sample_size, train_ds$.length()),
230-
dtype = torch::torch_long()
231-
))
232-
importances <- tibble::tibble(
233-
variables = colnames(x),
234-
importance = compute_feature_importance(
235-
network,
236-
train_ds$.getbatch(batch =indexes)$x$to(device = "cpu"),
237-
train_ds$.getbatch(batch =indexes)$x_na_mask$to(device = "cpu")
221+
if(!config$skip_importance) {
222+
importance_sample_size <- config$importance_sample_size
223+
if (is.null(config$importance_sample_size) && train_ds$.length() > 1e5) {
224+
warn("Computing importances for a dataset with size {.val {train_ds$.length()}}.
225+
This can consume too much memory. We are going to use a sample of size 1e5.
226+
You can disable this message by using the `importance_sample_size` argument.")
227+
importance_sample_size <- 1e5
228+
}
229+
indexes <- as.numeric(torch::torch_randint(
230+
1, train_ds$.length(), min(importance_sample_size, train_ds$.length()),
231+
dtype = torch::torch_long()
232+
))
233+
importances <- tibble::tibble(
234+
variables = colnames(x),
235+
importance = compute_feature_importance(
236+
network,
237+
train_ds$.getbatch(batch =indexes)$x$to(device = "cpu"),
238+
train_ds$.getbatch(batch =indexes)$x_na_mask$to(device = "cpu"))
238239
)
239-
)
240+
} else {
241+
importances <- NULL
242+
}
240243

241244
list(
242245
network = network,

R/tab-network.R

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ tabnet_pretrainer <- torch::nn_module(
235235
self$initial_bn <- torch::nn_batch_norm1d(input_dim, momentum = momentum)
236236

237237
if (n_steps <= 0)
238-
stop("'n_steps' should be a positive integer.")
238+
value_error("{.var n_steps} should be a positive integer.")
239239
if (n_independent == 0 && n_shared == 0)
240-
stop("'n_shared' and 'n_independant' can't be both zero.")
240+
value_error("{.var n_shared} and {.var n_independant} can't be both zero.")
241241

242242
# self$virtual_batch_size <- virtual_batch_size
243243
self$embedder <- embedding_generator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
@@ -402,10 +402,10 @@ tabnet_nn <- torch::nn_module(
402402
self$cat_emb_dim <- cat_emb_dim
403403

404404
if (n_steps <= 0)
405-
stop("'n_steps' should be a positive integer.")
405+
value_error("{.var n_steps} should be a positive integer.")
406406
if (n_independent == 0 && n_shared == 0)
407-
stop("'n_shared' and 'n_independant' can't be both zero.")
408-
407+
value_error("{.var n_shared} and {.var n_independant} can't be both zero.")
408+
409409
self$virtual_batch_size <- virtual_batch_size
410410
self$embedder <- embedding_generator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
411411
self$embedder_na <- na_embedding_generator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
@@ -460,8 +460,7 @@ attentive_transformer <- torch::nn_module(
460460
else if (mask_type == "sparsemax")
461461
self$selector <- sparsemax(dim = -1L)
462462
else
463-
stop("Please choose either 'sparsemax', 'sparsemax15', 'entmax' or 'entmax15' as 'mask_type'")
464-
463+
value_error("Please choose either {.val sparsemax}, {.val sparsemax15}, {.val entmax} or {.val entmax15} as {.var mask_type}")
465464
},
466465
forward = function(priors, processed_feat) {
467466
x <- self$fc(processed_feat)
@@ -625,8 +624,9 @@ embedding_generator <- torch::nn_module(
625624

626625
# check that all embeddings dimensions are provided
627626
if (length(self$cat_emb_dims) != length(cat_dims)){
628-
msg = paste0("`cat_emb_dim` length must be 1 or the number of categorical predictors, got length ",length(self$cat_emb_dims)," for ",length(cat_dims)," categorical predictors")
629-
stop(msg)
627+
value_error("{.var cat_emb_dim} length must be 1 or the number of categorical predictors,
628+
got length {.val {length(self$cat_emb_dims)}} for {.val {length(cat_dims)}}
629+
categorical predictors")
630630
}
631631

632632
self$post_embed_dim <- as.integer(input_dim + sum(self$cat_emb_dims) - length(self$cat_emb_dims))

0 commit comments

Comments
 (0)