Skip to content

Commit 604ed83

Browse files
authored
Refactor Scale (#140)
* Refactor Scale * Update Scale tests * Fix tests and code style * Make low and high more generic
1 parent a075bb8 commit 604ed83

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/transforms/scale.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@ Scale(r"[ace]", low=0.3, high=0.7)
3535
3636
* The `low` and `high` values are restricted to the interval [0, 1].
3737
"""
38-
struct Scale{S<:ColSpec,T<:Real} <: ColwiseFeatureTransform
38+
struct Scale{S<:ColSpec,T} <: ColwiseFeatureTransform
3939
colspec::S
4040
low::T
4141
high::T
4242

43-
function Scale(colspec::S, low::T, high::T) where {S<:ColSpec,T<:Real}
43+
function Scale(colspec::S, low::T, high::T) where {S<:ColSpec,T}
4444
@assert 0 low high 1 "invalid quantiles"
4545
new{S,T}(colspec, low, high)
4646
end
4747
end
4848

49-
Scale(colspec::ColSpec, low::Real, high::Real) =
49+
Scale(colspec::ColSpec, low, high) =
5050
Scale(colspec, promote(low, high)...)
5151

5252
Scale(; low=0.25, high=0.75) = Scale(AllSpec(), low, high)
@@ -59,10 +59,11 @@ assertions(::Type{<:Scale}) = [assert_continuous]
5959
isrevertible(::Type{<:Scale}) = true
6060

6161
function colcache(transform::Scale, x)
62-
levels = (transform.low, transform.high)
63-
xl, xh = quantile(x, levels)
62+
low = convert(eltype(x), transform.low)
63+
high = convert(eltype(x), transform.high)
64+
xl, xh = quantile(x, (low, high))
6465
xl == xh && ((xl, xh) = (zero(xl), one(xh)))
65-
(xl=xl, xh=xh)
66+
(; xl, xh)
6667
end
6768

6869
colapply(::Scale, x, c) = @. (x - c.xl) / (c.xh - c.xl)

test/transforms/scale.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# columntype does not change
4343
for FT in (Float16, Float32)
4444
t = Table(; x=rand(FT, 10))
45-
for T in (MinMax(), Scale(low=FT(0), high=FT(0.5)))
45+
for T in (MinMax(), Interquartile(), Scale(low=0, high=0.5))
4646
n, c = apply(T, t)
4747
@test Tables.columntype(t, :x) == Tables.columntype(n, :x)
4848
tₒ = revert(T, n, c)

0 commit comments

Comments
 (0)