Skip to content

feat: support for wavelets pipe op #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ Imports:
tf (>= 0.3.5)
Suggests:
rpart,
tsfeatures,
testthat (>= 3.2.0),
tsfeatures,
wavelets,
withr
Remotes:
tidyfun/tf
Expand All @@ -55,6 +56,7 @@ Collate:
'PipeOpFDAScaleRange.R'
'PipeOpFDASmooth.R'
'PipeOpFDATsfeatures.R'
'PipeOpFDAWavelets.R'
'PipeOpFPCA.R'
'TaskClassif_phoneme.R'
'TaskRegr_dti.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export(PipeOpFDAInterpol)
export(PipeOpFDAScaleRange)
export(PipeOpFDASmooth)
export(PipeOpFDATsfeatures)
export(PipeOpFDAWavelets)
export(PipeOpFPCA)
import(R6)
import(checkmate)
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# mlr3fda (development version)

* mlr3fda now depends on R 4.1.0 instead of R 3.1.0 to reflect tf requiring 4.1.0
* New PipeOp: `PipeOpFDATsfeatures`
* New PipeOps:
* `PipeOpFDATsfeatures`
* `PipeOpFDAWavelets`

# mlr3fda 0.2.0

Expand Down
100 changes: 100 additions & 0 deletions R/PipeOpFDAWavelets.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#' @title Discrete Wavelet transform features
#' @name mlr_pipeops_fda.wavelets
#'
#' @description
#' This `PipeOp` extracts discrete wavelet transform coefficients from functional columns.
#' For more details, see [wavelets::dwt()], which is called internally.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpTaskPreprocSimple`][mlr3pipelines::PipeOpTaskPreprocSimple],
#' as well as the following parameters:
#' * `filter` :: `character(1)` | `numeric()` | [wavelets::wt.filter()]\cr
#' Specifies which filter should be used. Must be either [wavelets::wt.filter()] object, an even numeric vector or a
#' string. In case of a string must be one of `"d"`|`"la"`|`"bl"`|`"c"` followed by an even number for the level of
#' the filter. The level of the filter needs to be smaller or equal then the time-series length.
#' For more information and acceptable filters see `help(wt.filter)`. Defaults to `"la8"`.
#' * `n.levels` :: `integer(1)`\cr
#' An integer specifying the level of the decomposition.
#' * `boundary` :: `character(1)`\cr
#' Boundary to be used. `"periodic"` assumes circular time series, for `"reflection"` the series is extended to twice
#' its length. Default is `"periodic"`.
#' * `fast` :: `logical(1)`\cr
#' Should the pyramid algorithm be calculated with an internal C function? Default is `TRUE`.
#' @export
#' @examples
#' task = tsk("fuel")
#' po_wavelets = po("fda.wavelets")
#' task_wavelets = po_wavelets$train(list(task))[[1L]]
#' task_wavelets$data()
PipeOpFDAWavelets = R6Class("PipeOpFDAWavelets",
inherit = PipeOpTaskPreprocSimple,
public = list(
#' @description Initializes a new instance of this Class.
#' @param id (`character(1)`)\cr
#' Identifier of resulting object, default is `"fda.wavelets"`.
#' @param param_vals (named `list()`)\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would
#' otherwise be set during construction. Default `list()`.
initialize = function(id = "fda.wavelets", param_vals = list()) {
param_set = ps(
filter = p_uty(
default = "la8", tags = c("train", "predict"), custom_check = crate(function(x) {
if (test_class(x, "wt.filter")) {
return(TRUE)
}
if (test_string(x)) {
choices = c(
paste0("d", c(2, 4, 6, 8, 10, 12, 14, 16, 18, 20)),
paste0("la", c(8, 10, 12, 14, 16, 18, 20)),
paste0("bl", c(14, 18, 20)),
paste0("c", c(6, 12, 18, 24, 30)),
"haar"
)
return(check_choice(x, choices))
}
if (test_numeric(x) && length(x) %% 2L == 0L) {
return(TRUE)
}
"Must be either a string, an even numeric vector or wavelet filter object"
})
),
n.levels = p_int(tags = c("train", "predict")),
boundary = p_fct(default = "periodic", c("periodic", "reflection"), tags = c("train", "predict")),
fast = p_lgl(default = TRUE, tags = c("train", "predict"))
)

super$initialize(
id = id,
param_set = param_set,
param_vals = param_vals,
packages = c("mlr3fda", "mlr3pipelines", "tf", "wavelets"),
feature_types = c("tfd_reg", "tfd_irreg"),
tags = "fda"
)
}
),

private = list(
.transform_dt = function(dt, levels) {
pars = self$param_set$get_values()
filter = pars$filter %??% "la8"

cols = imap(dt, function(x, nm) {
feats = map_dtr(
tf::tf_evaluations(x),
function(x) {
wt = invoke(wavelets::dwt, X = x, .args = pars)
feats = unlist(c(wt@W, wt@V[[wt@level]]), use.names = FALSE)
as.data.table(t(feats))
},
.fill = TRUE
)
setnames(feats, sprintf("%s_wav_%s_%i", nm, filter, seq_len(ncol(feats))))
})
setDT(unlist(unname(cols), recursive = FALSE))
}
)
)

#' @include zzz.R
register_po("fda.wavelets", PipeOpFDAWavelets)
96 changes: 96 additions & 0 deletions man/mlr_pipeops_fda.wavelets.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions tests/testthat/_snaps/PipeOpFDAWavelets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# PipeOpFDAWavelets input validation validation

Code
po("fda.wavelets", filter = "la4")
Condition
Error in `self$assert()`:
! Assertion on 'xs' failed: filter: Must be element of set {'d2','d4','d6','d8','d10','d12','d14','d16','d18','d20','la8','la10','la12','la14','la16','la18','la20','bl14','bl18','bl20','c6','c12','c18','c24','c30','haar'}, but is 'la4'.

---

Code
po("fda.wavelets", filter = "invalid_filter")
Condition
Error in `self$assert()`:
! Assertion on 'xs' failed: filter: Must be element of set {'d2','d4','d6','d8','d10','d12','d14','d16','d18','d20','la8','la10','la12','la14','la16','la18','la20','bl14','bl18','bl20','c6','c12','c18','c24','c30','haar'}, but is 'invalid_filter'.

---

Code
po("fda.wavelets", filter = c(1, 2, 3))
Condition
Error in `self$assert()`:
! Assertion on 'xs' failed: filter: Must be either a string, an even numeric vector or wavelet filter object.

---

Code
po("fda.wavelets", filter = list("la8"))
Condition
Error in `self$assert()`:
! Assertion on 'xs' failed: filter: Must be either a string, an even numeric vector or wavelet filter object.

47 changes: 47 additions & 0 deletions tests/testthat/test_PipeOpFDAWavelets.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
test_that("PipeOpFDAWavelets - basic properties", {
pop = po("fda.wavelets")
expect_pipeop(pop)
expect_identical(pop$id, "fda.wavelets")
})

test_that("PipeOpFDAWavelets input validation validation", {
skip_if_not_installed("wavelets")
expect_no_error(po("fda.wavelets", filter = wavelets::wt.filter()))
expect_no_error(po("fda.wavelets", filter = "la8"))
expect_no_error(po("fda.wavelets", filter = 1:10))
expect_snapshot(po("fda.wavelets", filter = "la4"), error = TRUE)
expect_snapshot(po("fda.wavelets", filter = "invalid_filter"), error = TRUE)
expect_snapshot(po("fda.wavelets", filter = c(1, 2, 3)), error = TRUE)
expect_snapshot(po("fda.wavelets", filter = list("la8")), error = TRUE)
})

test_that("PipeOpFDAWavelets works", {
skip_if_not_installed("wavelets")
task = tsk("fuel")

pop = po("fda.wavelets")
task_wav = train_pipeop(pop, list(task))[[1L]]
new_data = task_wav$data()
expect_task(task_wav)
expect_identical(dim(new_data), c(task$nrow, 362L))
expect_match(setdiff(names(new_data), c("heatan", "h20")), "_wav_la8_[0-9]+$")

pop = po("fda.wavelets", filter = "haar", boundary = "reflection")
task_wav = train_pipeop(pop, list(task))[[1L]]
new_data = task_wav$data()
expect_task(task_wav)
walk(new_data, expect_numeric)
expect_identical(dim(new_data), c(task$nrow, 726L))
expect_match(setdiff(names(new_data), c("heatan", "h20")), "_wav_haar_[0-9]+$")

# irregular data works
task = tsk("dti")
task$select(setdiff(task$feature_names, "sex"))
pop = po("fda.wavelets")
task_wav = train_pipeop(pop, list(task))[[1L]]
new_data = task_wav$data()
expect_task(task_wav)
walk(new_data, expect_numeric)
expect_identical(dim(new_data), c(task$nrow, 144L))
expect_match(setdiff(names(new_data), c("pasat")), "_wav_la8_[0-9]+$")
})