Skip to content

Commit a2e01f9

Browse files
committed
perf(scalerange): replace map_dtc() with direct data.table := calls
1 parent 1674233 commit a2e01f9

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

R/PipeOpFDAScaleRange.R

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,36 +51,40 @@ PipeOpFDAScaleRange = R6Class("PipeOpFDAScaleRange",
5151
.train_dt = function(dt, levels, target) {
5252
pars = self$param_set$get_values(tags = "train")
5353

54-
imap_dtc(dt, function(x, nm) {
55-
domain = tf::tf_domain(x)
56-
scale = (pars$upper - pars$lower) / (domain[2L] - domain[1L])
57-
offset = -domain[1L] * scale + pars$lower
58-
self$state[[nm]] = list(domain = domain, scale = scale, offset = offset)
54+
dt[,
55+
names(.SD) := imap(.SD, function(x, nm) {
56+
domain = tf::tf_domain(x)
57+
scale = (pars$upper - pars$lower) / (domain[2L] - domain[1L])
58+
offset = -domain[1L] * scale + pars$lower
59+
self$state[[nm]] = list(domain = domain, scale = scale, offset = offset)
5960

60-
args = tf::tf_arg(x)
61-
if (tf::is_reg(x)) {
62-
new_args = offset + args * scale
63-
} else {
64-
new_args = map(args, function(arg) offset + arg * scale)
65-
}
66-
invoke(tf::tfd, data = tf::tf_evaluations(x), arg = new_args)
67-
})
61+
args = tf::tf_arg(x)
62+
if (tf::is_reg(x)) {
63+
new_args = offset + args * scale
64+
} else {
65+
new_args = map(args, function(arg) offset + arg * scale)
66+
}
67+
invoke(tf::tfd, data = tf::tf_evaluations(x), arg = new_args)
68+
})
69+
]
6870
},
6971

7072
.predict_dt = function(dt, levels) {
71-
imap_dtc(dt, function(x, nm) {
72-
trafo = self$state[[nm]]
73-
if (!all(trafo$domain == tf::tf_domain(x))) {
74-
stopf("Domain of new data does not match the domain of the training data.")
75-
}
76-
args = tf::tf_arg(x)
77-
if (tf::is_reg(x)) {
78-
new_args = trafo$offset + args * trafo$scale
79-
} else {
80-
new_args = map(args, function(arg) trafo$offset + arg * trafo$scale)
81-
}
82-
invoke(tf::tfd, data = tf::tf_evaluations(x), arg = new_args)
83-
})
73+
dt[,
74+
names(.SD) := imap(.SD, function(x, nm) {
75+
trafo = self$state[[nm]]
76+
if (!all(trafo$domain == tf::tf_domain(x))) {
77+
stopf("Domain of new data does not match the domain of the training data.")
78+
}
79+
args = tf::tf_arg(x)
80+
if (tf::is_reg(x)) {
81+
new_args = trafo$offset + args * trafo$scale
82+
} else {
83+
new_args = map(args, function(arg) trafo$offset + arg * trafo$scale)
84+
}
85+
invoke(tf::tfd, data = tf::tf_evaluations(x), arg = new_args)
86+
})
87+
]
8488
}
8589
)
8690
)

0 commit comments

Comments
 (0)