-
-
Notifications
You must be signed in to change notification settings - Fork 8
feat/callback-lr_schedule #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
574da40
started docs for callback
cxzhang4 b2ec70c
Merge branch 'main' into feat/callback-lr_schedule
cxzhang4 795c0ba
ough outline
cxzhang4 36fa621
add test file
cxzhang4 f03a741
pre-christmas work from laptop
cxzhang4 6664ec4
very simple change
cxzhang4 e25c623
idrk
cxzhang4 58ff44a
fewioj
cxzhang4 f4d46ee
typed out all pre-implemented schedulers. TODO: implement params that…
cxzhang4 31aa858
a comment
cxzhang4 1ff1e88
check class fn
cxzhang4 3702ff6
awejoifeawjoi
cxzhang4 b5913c1
comment
cxzhang4 43994d3
trains for cosine annealing, still need to test
cxzhang4 3899452
modified check_class_or_list to not accept vectors
cxzhang4 71ad99f
some tests from Claude, not run yet
cxzhang4 78f0b8b
LLM tests fail
cxzhang4 5d96fab
sketch tests, incomplete
cxzhang4 0f21732
Update R/CallbackSetLRScheduler.R
cxzhang4 539b36d
named additional arg for all torch-provided schedulers
cxzhang4 34f0845
playing around with stuff
cxzhang4 a94963d
weoiweoij
cxzhang4 d808346
accept incoming
cxzhang4 e1cac87
tests look ok, still need to debug implementation and document
cxzhang4 4d677b8
ewfojiewfoij
cxzhang4 ac1af70
idk how to have a named arg for the scheduler_fn and also support an …
cxzhang4 731b0c6
stuff
cxzhang4 c9432db
tests still don't work, .scheduler seems to be expected and the custo…
cxzhang4 aa74077
passes veryyy simple tests
cxzhang4 210de81
note to self
cxzhang4 3844b6e
change callbacksetlrscheduler to use additional_args$.scheduler inste…
cxzhang4 9acd639
functionally ok based on the two simple tests
cxzhang4 6046725
looks ok?
cxzhang4 216c937
began pr issues
cxzhang4 d548ef7
Merge branch 'main' into feat/callback-lr_schedule
cxzhang4 a6b8467
more revisions from PR
cxzhang4 826f984
tests passing
cxzhang4 d7a834d
delete attic file
cxzhang4 d311c85
allow custom id
cxzhang4 327ada3
chatgpt says the large paramsets look ok
cxzhang4 c222202
Merge branch 'main' into feat/callback-lr_schedule
cxzhang4 668fa6b
test passes but does not allow setting the parameter later
cxzhang4 5542c73
Update R/CallbackSetLRScheduler.R
cxzhang4 7b34465
comment on test:
cxzhang4 254b8cc
removed unnecessary test
cxzhang4 58a20d4
address comments from (hopefully final) review
cxzhang4 53c4002
Update R/CallbackSetLRScheduler.R
sebffischer 6d375ca
docs(news): update changelog
sebffischer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| #' @title Learning Rate Scheduling Callback | ||
| #' | ||
| #' @name mlr_callback_set.lr_scheduler | ||
| #' | ||
| #' @description | ||
| #' Changes the learning rate based on the schedule specified by a `torch::lr_scheduler`. | ||
| #' | ||
| #' As of this writing, the following are available: [torch::lr_cosine_annealing()], [torch::lr_lambda()], [torch::lr_multiplicative()], [torch::lr_one_cycle()], | ||
| #' [torch::lr_reduce_on_plateau()], [torch::lr_step()], and custom schedulers defined with [torch::lr_scheduler()]. | ||
| #' | ||
| #' @param .scheduler (`lr_scheduler_generator`)\cr | ||
| #' The `torch` scheduler generator (e.g. `torch::lr_step`). | ||
| #' @param ... (any)\cr | ||
| #' The scheduler-specific arguments | ||
| #' | ||
| #' @export | ||
| CallbackSetLRScheduler = R6Class("CallbackSetLRScheduler", | ||
| inherit = CallbackSet, | ||
| lock_objects = FALSE, | ||
| public = list( | ||
| #' @field scheduler_fn (`lr_scheduler_generator`)\cr | ||
| #' The `torch` function that creates a learning rate scheduler | ||
| scheduler_fn = NULL, | ||
| #' @field scheduler (`LRScheduler`)\cr | ||
| #' The learning rate scheduler wrapped by this callback | ||
| scheduler = NULL, | ||
| #' @description | ||
| #' Creates a new instance of this [R6][R6::R6Class] class. | ||
| initialize = function(.scheduler, step_on_epoch, ...) { | ||
| assert_class(.scheduler, "lr_scheduler_generator") | ||
| assert_flag(step_on_epoch) | ||
|
|
||
| self$scheduler_fn = .scheduler | ||
| private$.scheduler_args = list(...) | ||
sebffischer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (step_on_epoch) { | ||
| self$on_epoch_end = function() self$scheduler$step() | ||
| } else { | ||
| self$on_batch_end = function() self$scheduler$step() | ||
| } | ||
| }, | ||
| #' @description | ||
| #' Creates the scheduler using the optimizer from the context | ||
| on_begin = function() { | ||
| self$scheduler = invoke(self$scheduler_fn, optimizer = self$ctx$optimizer, .args = private$.scheduler_args) | ||
| } | ||
| ), | ||
| private = list( | ||
| .scheduler_args = NULL | ||
| ) | ||
| ) | ||
|
|
||
| # some of the schedulers accept lists | ||
| # so they can treat different parameter groups differently | ||
| check_class_or_list = function(x, classname) { | ||
| if (is.list(x)) check_list(x, types = classname) else check_class(x, classname) | ||
| } | ||
|
|
||
| #' @include TorchCallback.R | ||
| mlr3torch_callbacks$add("lr_cosine_annealing", function() { | ||
| TorchCallback$new( | ||
| callback_generator = CallbackSetLRScheduler, | ||
| param_set = ps( | ||
| T_max = p_int(tags = c("train", "required")), | ||
| eta_min = p_dbl(default = 0, tags = "train"), | ||
| last_epoch = p_int(default = -1, tags = "train"), | ||
| verbose = p_lgl(default = FALSE, tags = "train") | ||
| ), | ||
| id = "lr_cosine_annealing", | ||
| label = "Cosine Annealing LR Scheduler", | ||
| man = "mlr3torch::mlr_callback_set.lr_scheduler", | ||
| additional_args = list(.scheduler = torch::lr_cosine_annealing, step_on_epoch = TRUE) | ||
| ) | ||
| }) | ||
|
|
||
| #' @include TorchCallback.R | ||
| mlr3torch_callbacks$add("lr_lambda", function() { | ||
| TorchCallback$new( | ||
| callback_generator = CallbackSetLRScheduler, | ||
| param_set = ps( | ||
| lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")), | ||
| last_epoch = p_int(default = -1, tags = "train"), | ||
| verbose = p_lgl(default = FALSE, tags = "train") | ||
| ), | ||
| id = "lr_scheduler", | ||
| label = "Multiplication by Function LR Scheduler", | ||
| man = "mlr3torch::mlr_callback_set.lr_scheduler", | ||
| additional_args = list(.scheduler = torch::lr_lambda, step_on_epoch = TRUE) | ||
| ) | ||
| }) | ||
cxzhang4 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| #' @include TorchCallback.R | ||
| mlr3torch_callbacks$add("lr_multiplicative", function() { | ||
| TorchCallback$new( | ||
| callback_generator = CallbackSetLRScheduler, | ||
| param_set = ps( | ||
| lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")), | ||
| last_epoch = p_int(default = -1, tags = "train"), | ||
| verbose = p_lgl(default = FALSE, tags = "train") | ||
| ), | ||
| id = "lr_multiplicative", | ||
| label = "Multiplication by Factor LR Scheduler", | ||
| man = "mlr3torch::mlr_callback_set.lr_scheduler", | ||
| additional_args = list(.scheduler = torch::lr_multiplicative, step_on_epoch = TRUE) | ||
| ) | ||
| }) | ||
|
|
||
| #' @include TorchCallback.R | ||
| mlr3torch_callbacks$add("lr_one_cycle", function() { | ||
| TorchCallback$new( | ||
| callback_generator = CallbackSetLRScheduler, | ||
| param_set = ps( | ||
| max_lr = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "numeric")), | ||
| total_steps = p_int(default = NULL, special_vals = list(NULL), tags = "train"), | ||
| epochs = p_int(default = NULL, special_vals = list(NULL), tags = "train"), | ||
| steps_per_epoch = p_int(default = NULL, special_vals = list(NULL), tags = "train"), | ||
| pct_start = p_dbl(default = 0.3, tags = "train"), | ||
| anneal_strategy = p_fct(default = "cos", levels = c("cos", "linear")), # this is a string in the torch fn | ||
| cycle_momentum = p_lgl(default = TRUE, tags = "train"), | ||
| base_momentum = p_uty(default = 0.85, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), | ||
| max_momentum = p_uty(default = 0.95, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), | ||
| div_factor = p_dbl(default = 25, tags = "train"), | ||
| final_div_factor = p_dbl(default = 1e4, tags = "train"), | ||
| verbose = p_lgl(default = FALSE, tags = "train") | ||
| ), | ||
| id = "lr_one_cycle", | ||
| label = "1cyle LR Scheduler", | ||
| man = "mlr3torch::mlr_callback_set.lr_scheduler", | ||
| additional_args = list(.scheduler = torch::lr_one_cycle, step_on_epoch = FALSE) | ||
| ) | ||
| }) | ||
|
|
||
| #' @include TorchCallback.R | ||
| mlr3torch_callbacks$add("lr_reduce_on_plateau", function() { | ||
| TorchCallback$new( | ||
| callback_generator = CallbackSetLRScheduler, | ||
| param_set = ps( | ||
| mode = p_fct(default = "min", levels = c("min", "max"), tags = "train"), | ||
| factor = p_dbl(default = 0.1, tags = "train"), | ||
| patience = p_int(default = 10, tags = "train"), | ||
| threshold = p_dbl(default = 1e-04, tags = "train"), | ||
| threshold_mode = p_fct(default = "rel", levels = c("rel", "abs"), tags = "train"), | ||
| cooldown = p_int(default = 0, tags = "train"), | ||
| min_lr = p_uty(default = 0, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), | ||
| eps = p_dbl(default = 1e-08, tags = "train"), | ||
| verbose = p_lgl(default = FALSE, tags = "train") | ||
| ), | ||
| id = "lr_reduce_on_plateau", | ||
| label = "Reduce on Plateau LR Scheduler", | ||
| man = "mlr3torch::mlr_callback_set.lr_scheduler", | ||
| additional_args = list(.scheduler = torch::lr_reduce_on_plateau, step_on_epoch = TRUE) | ||
| ) | ||
| }) | ||
|
|
||
| #' @include TorchCallback.R | ||
| mlr3torch_callbacks$add("lr_step", function() { | ||
| TorchCallback$new( | ||
| callback_generator = CallbackSetLRScheduler, | ||
| param_set = ps( | ||
| step_size = p_int(tags = c("train", "required")), | ||
| gamma = p_dbl(default = 0.1, tags = "train"), | ||
| last_epoch = p_int(default = -1, tags = "train") | ||
| ), | ||
| id = "lr_step", | ||
| label = "Step Decay LR Scheduler", | ||
| man = "mlr3torch::mlr_callback_set.lr_scheduler", | ||
| additional_args = list(.scheduler = torch::lr_step, step_on_epoch = TRUE) | ||
| ) | ||
| }) | ||
|
|
||
| #' @param x (`function`)\cr | ||
| #' The `torch` scheduler generator defined using `torch::lr_scheduler()`. | ||
| #' @param step_on_epoch (`logical(1)`)\cr | ||
| #' Whether the scheduler steps after every epoch | ||
| as_lr_scheduler = function(x, step_on_epoch) { | ||
| assert_class(x, "lr_scheduler_generator") | ||
| assert_flag(step_on_epoch) | ||
|
|
||
| class_name = class(x)[1L] | ||
|
|
||
| TorchCallback$new( | ||
| callback_generator = CallbackSetLRScheduler, | ||
| param_set = inferps(x), | ||
| id = if (class_name == "") "lr_custom" else class_name, | ||
| label = "Custom LR Scheduler", | ||
| man = "mlr3torch::mlr_callback_set.lr_scheduler", | ||
cxzhang4 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| additional_args = list(.scheduler = x, step_on_epoch = step_on_epoch) | ||
| ) | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.