Skip to content

Commit cfbf7ef

Browse files
eliascarvjuliohm
andauthored
Add Indicator transform (#181)
* Add 'Indicator' transform * Update code * Fix typo * Apply suggestions * Add tests * [WIP] Add docstring * Apply suggestions from code review Co-authored-by: Júlio Hoffimann <[email protected]> * Apply suggestions * Update the default value of 'k' * Update docstring * Apply suggestions from code review Co-authored-by: Júlio Hoffimann <[email protected]> * Add Indicator to docs * Update src/transforms/indicator.jl --------- Co-authored-by: Júlio Hoffimann <[email protected]>
1 parent 10e00ef commit cfbf7ef

File tree

7 files changed

+219
-1
lines changed

7 files changed

+219
-1
lines changed

docs/src/transforms.md

+6
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ Coerce
7474
Levels
7575
```
7676

77+
## Indicator
78+
79+
```@docs
80+
Indicator
81+
```
82+
7783
## OneHot
7884

7985
```@docs

src/TableTransforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ export
5757
Coalesce,
5858
Coerce,
5959
Levels,
60+
Indicator,
6061
OneHot,
6162
Identity,
6263
Center,

src/transforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ include("transforms/replace.jl")
284284
include("transforms/coalesce.jl")
285285
include("transforms/coerce.jl")
286286
include("transforms/levels.jl")
287+
include("transforms/indicator.jl")
287288
include("transforms/onehot.jl")
288289
include("transforms/center.jl")
289290
include("transforms/scale.jl")

src/transforms/indicator.jl

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
const SCALES = [:quantile, :linear]
2+
3+
"""
4+
Indicator(col; k=10, scale=:quantile, categ=false)
5+
6+
Transforms continuous variable into `k` indicator variables defined by
7+
half-intervals of `col` values in a given `scale`. Optionally, specify the `categ`
8+
option to return binary categorical values as opposed to raw 1s and 0s.
9+
10+
Given a sequence of increasing threshold values `t1 < t2 < ... < tk`, the indicator
11+
transform converts a continuous variable `Z` into a sequence of `k` variables
12+
`Z_1 = Z <= t1`, `Z_2 = Z <= t2`, ..., `Z_k = Z <= tk`.
13+
14+
## Scales:
15+
16+
* `:quantile` - threshold values are calculated using the `quantile(Z, p)` function
17+
with a linear range of probabilities.
18+
* `:linear` - threshold values are calculated using a linear range.
19+
20+
# Examples
21+
22+
```julia
23+
Indicator(1, k=3)
24+
Indicator(:a, k=6, scale=:linear)
25+
Indicator("a", k=9, scale=:linear, categ=true)
26+
```
27+
"""
28+
struct Indicator{S<:ColSpec} <: StatelessFeatureTransform
29+
colspec::S
30+
k::Int
31+
scale::Symbol
32+
categ::Bool
33+
34+
function Indicator(col, k, scale, categ)
35+
if k < 1
36+
throw(ArgumentError("`k` must be greater than or equal to 1"))
37+
end
38+
39+
if scale SCALES
40+
throw(ArgumentError("invalid `scale` option, use `:quantile` or `:linear`"))
41+
end
42+
43+
cs = colspec([col])
44+
new{typeof(cs)}(cs, k, scale, categ)
45+
end
46+
end
47+
48+
Indicator(col; k=10, scale=:quantile, categ=false) = Indicator(col, k, scale, categ)
49+
50+
assertions(transform::Indicator) = [SciTypeAssertion{Continuous}(transform.colspec)]
51+
52+
isrevertible(::Type{<:Indicator}) = true
53+
54+
function _intervals(transform::Indicator, x)
55+
k = transform.k
56+
ts = if transform.scale === :quantile
57+
quantile(x, range(0, 1, k + 1))
58+
else
59+
range(extrema(x)..., k + 1)
60+
end
61+
ts[(begin + 1):end]
62+
end
63+
64+
function applyfeat(transform::Indicator, feat, prep)
65+
cols = Tables.columns(feat)
66+
names = Tables.columnnames(cols) |> collect
67+
columns = Any[Tables.getcolumn(cols, nm) for nm in names]
68+
69+
name = choose(transform.colspec, names) |> first
70+
ind = findfirst(==(name), names)
71+
x = columns[ind]
72+
73+
k = transform.k
74+
ts = _intervals(transform, x)
75+
tuples = map(1:k) do i
76+
nm = Symbol("$(name)_$i")
77+
while nm names
78+
nm = Symbol("$(nm)_")
79+
end
80+
(nm, x .≤ ts[i])
81+
end
82+
83+
newnames = first.(tuples)
84+
newcolumns = last.(tuples)
85+
86+
# convert to categorical arrays if necessary
87+
newcolumns = transform.categ ? categorical.(newcolumns, levels=[false, true]) : newcolumns
88+
89+
splice!(names, ind, newnames)
90+
splice!(columns, ind, newcolumns)
91+
92+
inds = ind:(ind + length(newnames) - 1)
93+
94+
𝒯 = (; zip(names, columns)...)
95+
newfeat = 𝒯 |> Tables.materializer(feat)
96+
newfeat, (name, x, inds)
97+
end
98+
99+
function revertfeat(::Indicator, newfeat, fcache)
100+
cols = Tables.columns(newfeat)
101+
names = Tables.columnnames(cols) |> collect
102+
columns = Any[Tables.getcolumn(cols, nm) for nm in names]
103+
104+
oname, ocolumn, inds = fcache
105+
106+
splice!(names, inds, [oname])
107+
splice!(columns, inds, [ocolumn])
108+
109+
𝒯 = (; zip(names, columns)...)
110+
𝒯 |> Tables.materializer(newfeat)
111+
end

src/transforms/onehot.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function applyfeat(transform::OneHot, feat, prep)
3939
names = Tables.columnnames(cols) |> collect
4040
columns = Any[Tables.getcolumn(cols, nm) for nm in names]
4141

42-
name = choose(transform.colspec, names)[1]
42+
name = choose(transform.colspec, names) |> first
4343
ind = findfirst(==(name), names)
4444
x = columns[ind]
4545

test/transforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ transformfiles = [
99
"coalesce.jl",
1010
"coerce.jl",
1111
"levels.jl",
12+
"indicator.jl",
1213
"onehot.jl",
1314
"identity.jl",
1415
"center.jl",

test/transforms/indicator.jl

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
@testset "Indicator" begin
2+
a = [5.8, 6.4, 6.4, 9.8, 7.6, 8.2, 4.5, 2.5, 1.7, 2.3]
3+
b = [8.4, 1.4, 7.2, 1.8, 9.4, 1.0, 2.0, 5.2, 9.4, 6.2]
4+
c = [4.1, 5.6, 7.1, 9.1, 5.9, 9.5, 5.7, 9.0, 6.6, 9.9]
5+
d = [7.5, 2.2, 1.6, 2.8, 1.2, 1.5, 3.7, 2.0, 8.3, 8.2]
6+
t = Table(; a, b, c, d)
7+
8+
T = Indicator(:a, k=1, scale=:quantile)
9+
n, c = apply(T, t)
10+
@test Tables.columnnames(n) == (:a_1, :b, :c, :d)
11+
@test n.a_1 == Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
12+
@test n.a_1 isa BitVector
13+
tₒ = revert(T, n, c)
14+
@test t == tₒ
15+
16+
T = Indicator(:b, k=2, scale=:quantile)
17+
n, c = apply(T, t)
18+
@test Tables.columnnames(n) == (:a, :b_1, :b_2, :c, :d)
19+
@test n.b_1 == Bool[0, 1, 0, 1, 0, 1, 1, 1, 0, 0]
20+
@test n.b_2 == Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
21+
@test n.b_1 isa BitVector
22+
@test n.b_2 isa BitVector
23+
tₒ = revert(T, n, c)
24+
@test t == tₒ
25+
26+
T = Indicator(:c, k=3, scale=:quantile, categ=true)
27+
n, c = apply(T, t)
28+
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3, :d)
29+
@test n.c_1 == categorical(Bool[1, 1, 0, 0, 1, 0, 1, 0, 0, 0])
30+
@test n.c_2 == categorical(Bool[1, 1, 1, 0, 1, 0, 1, 0, 1, 0])
31+
@test n.c_3 == categorical(Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
32+
@test n.c_1 isa CategoricalVector{Bool}
33+
@test n.c_2 isa CategoricalVector{Bool}
34+
@test n.c_3 isa CategoricalVector{Bool}
35+
tₒ = revert(T, n, c)
36+
@test t == tₒ
37+
38+
T = Indicator(:d, k=4, scale=:quantile, categ=true)
39+
n, c = apply(T, t)
40+
@test Tables.columnnames(n) == (:a, :b, :c, :d_1, :d_2, :d_3, :d_4)
41+
@test n.d_1 == categorical(Bool[0, 0, 1, 0, 1, 1, 0, 0, 0, 0])
42+
@test n.d_2 == categorical(Bool[0, 1, 1, 0, 1, 1, 0, 1, 0, 0])
43+
@test n.d_3 == categorical(Bool[0, 1, 1, 1, 1, 1, 1, 1, 0, 0])
44+
@test n.d_4 == categorical(Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
45+
@test n.d_1 isa CategoricalVector{Bool}
46+
@test n.d_2 isa CategoricalVector{Bool}
47+
@test n.d_3 isa CategoricalVector{Bool}
48+
@test n.d_4 isa CategoricalVector{Bool}
49+
tₒ = revert(T, n, c)
50+
@test t == tₒ
51+
52+
T = Indicator(:a, k=1, scale=:linear)
53+
n, c = apply(T, t)
54+
@test Tables.columnnames(n) == (:a_1, :b, :c, :d)
55+
@test n.a_1 == Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
56+
@test n.a_1 isa BitVector
57+
tₒ = revert(T, n, c)
58+
@test t == tₒ
59+
60+
T = Indicator(:b, k=2, scale=:linear)
61+
n, c = apply(T, t)
62+
@test Tables.columnnames(n) == (:a, :b_1, :b_2, :c, :d)
63+
@test n.b_1 == Bool[0, 1, 0, 1, 0, 1, 1, 1, 0, 0]
64+
@test n.b_2 == Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
65+
@test n.b_1 isa BitVector
66+
@test n.b_2 isa BitVector
67+
tₒ = revert(T, n, c)
68+
@test t == tₒ
69+
70+
T = Indicator(:c, k=3, scale=:linear, categ=true)
71+
n, c = apply(T, t)
72+
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3, :d)
73+
@test n.c_1 == categorical(Bool[1, 1, 0, 0, 1, 0, 1, 0, 0, 0])
74+
@test n.c_2 == categorical(Bool[1, 1, 1, 0, 1, 0, 1, 0, 1, 0])
75+
@test n.c_3 == categorical(Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
76+
@test n.c_1 isa CategoricalVector{Bool}
77+
@test n.c_2 isa CategoricalVector{Bool}
78+
@test n.c_3 isa CategoricalVector{Bool}
79+
tₒ = revert(T, n, c)
80+
@test t == tₒ
81+
82+
T = Indicator(:d, k=4, scale=:linear, categ=true)
83+
n, c = apply(T, t)
84+
@test Tables.columnnames(n) == (:a, :b, :c, :d_1, :d_2, :d_3, :d_4)
85+
@test n.d_1 == categorical(Bool[0, 1, 1, 1, 1, 1, 0, 1, 0, 0])
86+
@test n.d_2 == categorical(Bool[0, 1, 1, 1, 1, 1, 1, 1, 0, 0])
87+
@test n.d_3 == categorical(Bool[0, 1, 1, 1, 1, 1, 1, 1, 0, 0])
88+
@test n.d_4 == categorical(Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
89+
@test n.d_1 isa CategoricalVector{Bool}
90+
@test n.d_2 isa CategoricalVector{Bool}
91+
@test n.d_3 isa CategoricalVector{Bool}
92+
@test n.d_4 isa CategoricalVector{Bool}
93+
tₒ = revert(T, n, c)
94+
@test t == tₒ
95+
96+
@test_throws ArgumentError Indicator(:a, k=0)
97+
@test_throws ArgumentError Indicator(:a, scale=:test)
98+
end

0 commit comments

Comments
 (0)