Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
574da40
started docs for callback
cxzhang4 Dec 6, 2024
b2ec70c
Merge branch 'main' into feat/callback-lr_schedule
cxzhang4 Dec 10, 2024
795c0ba
ough outline
cxzhang4 Dec 10, 2024
36fa621
add test file
cxzhang4 Dec 10, 2024
f03a741
pre-christmas work from laptop
cxzhang4 Dec 23, 2024
6664ec4
very simple change
cxzhang4 Dec 23, 2024
e25c623
idrk
cxzhang4 Dec 24, 2024
58ff44a
fewioj
cxzhang4 Dec 27, 2024
f4d46ee
typed out all pre-implemented schedulers. TODO: implement params that…
cxzhang4 Dec 28, 2024
31aa858
a comment
cxzhang4 Dec 28, 2024
1ff1e88
check class fn
cxzhang4 Dec 29, 2024
3702ff6
awejoifeawjoi
cxzhang4 Dec 29, 2024
b5913c1
comment
cxzhang4 Dec 29, 2024
43994d3
trains for cosine annealing, still need to test
cxzhang4 Dec 30, 2024
3899452
modified check_class_or_list to not accept vectors
cxzhang4 Dec 30, 2024
71ad99f
some tests from Claude, not run yet
cxzhang4 Dec 30, 2024
78f0b8b
LLM tests fail
cxzhang4 Dec 31, 2024
5d96fab
sketch tests, incomplete
cxzhang4 Jan 3, 2025
0f21732
Update R/CallbackSetLRScheduler.R
cxzhang4 Jan 7, 2025
539b36d
named additional arg for all torch-provided schedulers
cxzhang4 Jan 7, 2025
34f0845
playing around with stuff
cxzhang4 Jan 7, 2025
a94963d
weoiweoij
cxzhang4 Jan 7, 2025
d808346
accept incoming
cxzhang4 Jan 8, 2025
e1cac87
tests look ok, still need to debug implementation and document
cxzhang4 Jan 8, 2025
4d677b8
ewfojiewfoij
cxzhang4 Jan 8, 2025
ac1af70
idk how to have a named arg for the scheduler_fn and also support an …
cxzhang4 Jan 8, 2025
731b0c6
stuff
cxzhang4 Jan 8, 2025
c9432db
tests still don't work, .scheduler seems to be expected and the custo…
cxzhang4 Jan 9, 2025
aa74077
passes veryyy simple tests
cxzhang4 Jan 9, 2025
210de81
note to self
cxzhang4 Jan 9, 2025
3844b6e
change callbacksetlrscheduler to use additional_args$.scheduler inste…
cxzhang4 Jan 9, 2025
9acd639
functionally ok based on the two simple tests
cxzhang4 Jan 10, 2025
6046725
looks ok?
cxzhang4 Jan 10, 2025
216c937
began pr issues
cxzhang4 Jan 10, 2025
d548ef7
Merge branch 'main' into feat/callback-lr_schedule
cxzhang4 Jan 10, 2025
a6b8467
more revisions from PR
cxzhang4 Jan 10, 2025
826f984
tests passing
cxzhang4 Jan 10, 2025
d7a834d
delete attic file
cxzhang4 Jan 10, 2025
d311c85
allow custom id
cxzhang4 Jan 10, 2025
327ada3
chatgpt says the large paramsets look ok
cxzhang4 Jan 10, 2025
c222202
Merge branch 'main' into feat/callback-lr_schedule
cxzhang4 Jan 10, 2025
668fa6b
test passes but does not allow setting the parameter later
cxzhang4 Jan 10, 2025
5542c73
Update R/CallbackSetLRScheduler.R
cxzhang4 Jan 11, 2025
7b34465
comment on test:
cxzhang4 Jan 11, 2025
254b8cc
removed unnecessary test
cxzhang4 Jan 13, 2025
58a20d4
address comments from (hopefully final) review
cxzhang4 Jan 14, 2025
53c4002
Update R/CallbackSetLRScheduler.R
sebffischer Jan 15, 2025
6d375ca
docs(news): update changelog
sebffischer Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Collate:
'CallbackSetCheckpoint.R'
'CallbackSetEarlyStopping.R'
'CallbackSetHistory.R'
'CallbackSetLRScheduler.R'
'CallbackSetProgress.R'
'CallbackSetTB.R'
'CallbackSetUnfreeze.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ S3method(unmarshal_model,learner_torch_model_marshaled)
export(CallbackSet)
export(CallbackSetCheckpoint)
export(CallbackSetHistory)
export(CallbackSetLRScheduler)
export(CallbackSetProgress)
export(CallbackSetTB)
export(CallbackSetUnfreeze)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* feat: Added multimodal melanoma example task
* feat: Added a callback to iteratively unfreeze parameters for finetuning
* fix: torch learners can now be used with `AutoTuner`
* feat: Added different learning rate schedulers as callbacks

# mlr3torch 0.1.2

Expand Down
188 changes: 188 additions & 0 deletions R/CallbackSetLRScheduler.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#' @title Learning Rate Scheduling Callback
#'
#' @name mlr_callback_set.lr_scheduler
#'
#' @description
#' Changes the learning rate based on the schedule specified by a `torch::lr_scheduler`.
#'
#' As of this writing, the following are available: [torch::lr_cosine_annealing()], [torch::lr_lambda()], [torch::lr_multiplicative()], [torch::lr_one_cycle()],
#' [torch::lr_reduce_on_plateau()], [torch::lr_step()], and custom schedulers defined with [torch::lr_scheduler()].
#'
#' @param .scheduler (`lr_scheduler_generator`)\cr
#' The `torch` scheduler generator (e.g. `torch::lr_step`).
#' @param ... (any)\cr
#' The scheduler-specific arguments
#'
#' @export
CallbackSetLRScheduler = R6Class("CallbackSetLRScheduler",
inherit = CallbackSet,
lock_objects = FALSE,
public = list(
#' @field scheduler_fn (`lr_scheduler_generator`)\cr
#' The `torch` function that creates a learning rate scheduler
scheduler_fn = NULL,
#' @field scheduler (`LRScheduler`)\cr
#' The learning rate scheduler wrapped by this callback
scheduler = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(.scheduler, step_on_epoch, ...) {
assert_class(.scheduler, "lr_scheduler_generator")
assert_flag(step_on_epoch)

self$scheduler_fn = .scheduler
private$.scheduler_args = list(...)
if (step_on_epoch) {
self$on_epoch_end = function() self$scheduler$step()
} else {
self$on_batch_end = function() self$scheduler$step()
}
},
#' @description
#' Creates the scheduler using the optimizer from the context
on_begin = function() {
self$scheduler = invoke(self$scheduler_fn, optimizer = self$ctx$optimizer, .args = private$.scheduler_args)
}
),
private = list(
.scheduler_args = NULL
)
)

# some of the schedulers accept lists
# so they can treat different parameter groups differently
check_class_or_list = function(x, classname) {
if (is.list(x)) check_list(x, types = classname) else check_class(x, classname)
}

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_cosine_annealing", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
T_max = p_int(tags = c("train", "required")),
eta_min = p_dbl(default = 0, tags = "train"),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_cosine_annealing",
label = "Cosine Annealing LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_cosine_annealing, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_lambda", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_scheduler",
label = "Multiplication by Function LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_lambda, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_multiplicative", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_multiplicative",
label = "Multiplication by Factor LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_multiplicative, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_one_cycle", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
max_lr = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "numeric")),
total_steps = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
epochs = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
steps_per_epoch = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
pct_start = p_dbl(default = 0.3, tags = "train"),
anneal_strategy = p_fct(default = "cos", levels = c("cos", "linear")), # this is a string in the torch fn
cycle_momentum = p_lgl(default = TRUE, tags = "train"),
base_momentum = p_uty(default = 0.85, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
max_momentum = p_uty(default = 0.95, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
div_factor = p_dbl(default = 25, tags = "train"),
final_div_factor = p_dbl(default = 1e4, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_one_cycle",
label = "1cyle LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_one_cycle, step_on_epoch = FALSE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_reduce_on_plateau", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
mode = p_fct(default = "min", levels = c("min", "max"), tags = "train"),
factor = p_dbl(default = 0.1, tags = "train"),
patience = p_int(default = 10, tags = "train"),
threshold = p_dbl(default = 1e-04, tags = "train"),
threshold_mode = p_fct(default = "rel", levels = c("rel", "abs"), tags = "train"),
cooldown = p_int(default = 0, tags = "train"),
min_lr = p_uty(default = 0, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
eps = p_dbl(default = 1e-08, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_reduce_on_plateau",
label = "Reduce on Plateau LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_reduce_on_plateau, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_step", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
step_size = p_int(tags = c("train", "required")),
gamma = p_dbl(default = 0.1, tags = "train"),
last_epoch = p_int(default = -1, tags = "train")
),
id = "lr_step",
label = "Step Decay LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_step, step_on_epoch = TRUE)
)
})

#' @param x (`function`)\cr
#' The `torch` scheduler generator defined using `torch::lr_scheduler()`.
#' @param step_on_epoch (`logical(1)`)\cr
#' Whether the scheduler steps after every epoch
as_lr_scheduler = function(x, step_on_epoch) {
assert_class(x, "lr_scheduler_generator")
assert_flag(step_on_epoch)

class_name = class(x)[1L]

TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = inferps(x),
id = if (class_name == "") "lr_custom" else class_name,
label = "Custom LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = x, step_on_epoch = step_on_epoch)
)
}
9 changes: 6 additions & 3 deletions R/TorchCallback.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ TorchCallback = R6Class("TorchCallback",
#' @template param_label
#' @template param_packages
#' @template param_man
#' @param additional_args (`any`)\cr
#' Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler.
initialize = function(callback_generator, param_set = NULL, id = NULL,
label = NULL, packages = NULL, man = NULL) {
label = NULL, packages = NULL, man = NULL, additional_args = NULL) {
assert_class(callback_generator, "R6ClassGenerator")

param_set = assert_param_set(param_set %??% inferps(callback_generator))
Expand All @@ -206,7 +208,8 @@ TorchCallback = R6Class("TorchCallback",
param_set = param_set,
packages = union(packages, "mlr3torch"),
label = label,
man = man
man = man,
additional_args = additional_args
)
}
),
Expand All @@ -215,7 +218,7 @@ TorchCallback = R6Class("TorchCallback",
)
)

#' @title Create a Callback Desctiptor
#' @title Create a Callback Descriptor
#'
#' @description
#' Convenience function to create a custom [`TorchCallback`].
Expand Down
10 changes: 7 additions & 3 deletions R/TorchDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ TorchDescriptor = R6Class("TorchDescriptor",
#' @template param_packages
#' @template param_label
#' @template param_man
initialize = function(generator, id = NULL, param_set = NULL, packages = NULL, label = NULL, man = NULL) {
#' @param additional_args (`list()`)\cr
#' Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler.
initialize = function(generator, id = NULL, param_set = NULL, packages = NULL, label = NULL, man = NULL, additional_args = NULL) {
assert_true(is.function(generator) || inherits(generator, "R6ClassGenerator"))
self$generator = generator
self$param_set = assert_r6(param_set, "ParamSet", null.ok = TRUE) %??% inferps(generator)
Expand All @@ -63,6 +65,7 @@ TorchDescriptor = R6Class("TorchDescriptor",
self$id = assert_string(id %??% class(generator)[[1L]], min.chars = 1L)
self$label = assert_string(label %??% self$id, min.chars = 1L)
self$packages = assert_names(unique(union(packages, c("torch", "mlr3torch"))), type = "strict")
private$.additional_args = assert_list(additional_args, null.ok = TRUE)
},
#' @description
#' Prints the object
Expand All @@ -86,9 +89,9 @@ TorchDescriptor = R6Class("TorchDescriptor",
# The torch generators could also be constructed with the $new() method, but then the return value
# would be the actual R6 class and not the wrapped function.
if (is.function(self$generator)) {
invoke(self$generator, .args = self$param_set$get_values())
invoke(self$generator, .args = c(self$param_set$get_values(), private$.additional_args))
} else {
invoke(self$generator$new, .args = self$param_set$get_values())
invoke(self$generator$new, .args = c(self$param_set$get_values(), private$.additional_args))
}
},
#' @description
Expand All @@ -107,6 +110,7 @@ TorchDescriptor = R6Class("TorchDescriptor",
}
),
private = list(
.additional_args = NULL,
.additional_phash_input = function() {
stopf("Classes inheriting from TorchDescriptor must implement the .additional_phash_input() method.")
},
Expand Down
6 changes: 5 additions & 1 deletion man/TorchCallback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/TorchDescriptor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading