Skip to content

Commit c5be198

Browse files
rajatrc1705juliohmeliascarv
authored
feat: adds optional categorical paramater as option to onehot trans… (#161)
* feat: adds optional `categorical` paramater as option to onehot transform * fix: fixes variable name spelling on line 53 * adds tests for categorical=false category * refactor: modifications for the new OneHot function * refactor: incorporates reviews on previous push * doc: summarises description of categorical parameter * refactor: categorical as keyword parameter * modifies existing tests to accomodate change in the OneHot function (only for categorical=false) * doc: summarises description of categorical parameter * Update src/transforms/onehot.jl * Update src/transforms/onehot.jl * Update src/transforms/onehot.jl * Update src/transforms/onehot.jl * Update src/transforms/onehot.jl * Update test/transforms/onehot.jl * Update test/transforms/onehot.jl * Add kwarg constructor * Fix show test * Add 'categorical=true' tests * Add levels and fix revert bug when 'categorical=true' * Change order of levels * Rename the 'categorical' keyword to 'categ'; Check if onehot columns are of type CategoricalArray{Bool} when 'categ=true' * Fix show test Co-authored-by: Júlio Hoffimann <[email protected]> Co-authored-by: Elias Carvalho <[email protected]>
1 parent b130aeb commit c5be198

File tree

3 files changed

+56
-14
lines changed

3 files changed

+56
-14
lines changed

src/transforms/onehot.jl

+15-6
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,40 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
OneHot(col)
7-
6+
OneHot(col; categ=true)
7+
88
Transforms categorical column `col` into one-hot columns of levels
99
returned by the `levels` function of CategoricalArrays.jl.
10+
The `categ` option can be used to convert resulting
11+
columns to categorical arrays as opposed to boolean vectors.
1012
1113
# Examples
1214
1315
```julia
1416
OneHot(1)
1517
OneHot(:a)
1618
OneHot("a")
19+
OneHot("a", categ=false)
1720
```
1821
"""
1922
struct OneHot{S<:ColSpec} <: StatelessFeatureTransform
2023
colspec::S
21-
function OneHot(col::Col)
24+
categ::Bool
25+
function OneHot(col, categ)
2226
cs = colspec([col])
23-
new{typeof(cs)}(cs)
27+
new{typeof(cs)}(cs, categ)
2428
end
2529
end
2630

31+
OneHot(col; categ=true) = OneHot(col, categ)
32+
2733
isrevertible(::Type{<:OneHot}) = true
2834

2935
function applyfeat(transform::OneHot, feat, prep)
3036
cols = Tables.columns(feat)
3137
names = Tables.columnnames(cols) |> collect
3238
columns = Any[Tables.getcolumn(cols, nm) for nm in names]
33-
39+
3440
name = choose(transform.colspec, names)[1]
3541
ind = findfirst(==(name), names)
3642
x = columns[ind]
@@ -48,6 +54,9 @@ function applyfeat(transform::OneHot, feat, prep)
4854

4955
newnms, newcols = first.(onehot), last.(onehot)
5056

57+
# convert to categorical arrays if necessary
58+
newcols = transform.categ ? categorical.(newcols, levels=[false, true]) : newcols
59+
5160
splice!(names, ind, newnms)
5261
splice!(columns, ind, newcols)
5362

@@ -65,7 +74,7 @@ function revertfeat(::OneHot, newfeat, fcache)
6574

6675
oname, inds, levels, ordered = fcache
6776
x = map(zip(columns[inds]...)) do row
68-
levels[findfirst(row)]
77+
levels[findfirst(==(true), row)]
6978
end
7079

7180
ocolumn = categorical(x; levels, ordered)

test/shows.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,14 @@
200200

201201
# compact mode
202202
iostr = sprint(show, T)
203-
@test iostr == "OneHot([:a])"
203+
@test iostr == "OneHot([:a], true)"
204204

205205
# full mode
206206
iostr = sprint(show, MIME("text/plain"), T)
207207
@test iostr == """
208208
OneHot transform
209-
└─ colspec = [:a]"""
209+
├─ colspec = [:a]
210+
└─ categ = true"""
210211
end
211212

212213
@testset "Identity" begin

test/transforms/onehot.jl

+38-6
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,55 @@
44
c = categorical([3, 2, 2, 1, 1, 3])
55
t = Table(; a, b, c)
66

7-
T = OneHot(1)
7+
T = OneHot(1; categ=true)
8+
n, c = apply(T, t)
9+
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c)
10+
@test n.a_false == categorical(Bool[1, 0, 0, 1, 0, 0])
11+
@test n.a_true == categorical(Bool[0, 1, 1, 0, 1, 1])
12+
@test n.a_false isa CategoricalVector{Bool}
13+
@test n.a_true isa CategoricalVector{Bool}
14+
tₒ = revert(T, n, c)
15+
@test t == tₒ
16+
17+
T = OneHot(:b; categ=true)
18+
n, c = apply(T, t)
19+
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c)
20+
@test n.b_f == categorical(Bool[0, 1, 0, 0, 0, 1])
21+
@test n.b_m == categorical(Bool[1, 0, 1, 1, 1, 0])
22+
@test n.b_f isa CategoricalVector{Bool}
23+
@test n.b_m isa CategoricalVector{Bool}
24+
tₒ = revert(T, n, c)
25+
@test t == tₒ
26+
27+
T = OneHot("c"; categ=true)
28+
n, c = apply(T, t)
29+
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3)
30+
@test n.c_1 == categorical(Bool[0, 0, 0, 1, 1, 0])
31+
@test n.c_2 == categorical(Bool[0, 1, 1, 0, 0, 0])
32+
@test n.c_3 == categorical(Bool[1, 0, 0, 0, 0, 1])
33+
@test n.c_1 isa CategoricalVector{Bool}
34+
@test n.c_2 isa CategoricalVector{Bool}
35+
@test n.c_3 isa CategoricalVector{Bool}
36+
tₒ = revert(T, n, c)
37+
@test t == tₒ
38+
39+
T = OneHot(1; categ=false)
840
n, c = apply(T, t)
941
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c)
1042
@test n.a_false == Bool[1, 0, 0, 1, 0, 0]
1143
@test n.a_true == Bool[0, 1, 1, 0, 1, 1]
1244
tₒ = revert(T, n, c)
1345
@test t == tₒ
1446

15-
T = OneHot(:b)
47+
T = OneHot(:b; categ=false)
1648
n, c = apply(T, t)
1749
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c)
1850
@test n.b_f == Bool[0, 1, 0, 0, 0, 1]
1951
@test n.b_m == Bool[1, 0, 1, 1, 1, 0]
2052
tₒ = revert(T, n, c)
2153
@test t == tₒ
2254

23-
T = OneHot("c")
55+
T = OneHot("c"; categ=false)
2456
n, c = apply(T, t)
2557
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3)
2658
@test n.c_1 == Bool[0, 0, 0, 1, 1, 0]
@@ -35,7 +67,7 @@
3567
b_m = rand(10)
3668
t = Table(; b, b_f, b_m)
3769

38-
T = OneHot(:b)
70+
T = OneHot(:b; categ=false)
3971
n, c = apply(T, t)
4072
@test Tables.columnnames(n) == (:b_f_, :b_m_, :b_f, :b_m)
4173
@test n.b_f_ == Bool[0, 1, 0, 0, 0, 1]
@@ -50,7 +82,7 @@
5082
b_m_ = rand(10)
5183
t = Table(; b, b_f, b_m, b_f_, b_m_)
5284

53-
T = OneHot(:b)
85+
T = OneHot(:b; categ=false)
5486
n, c = apply(T, t)
5587
@test Tables.columnnames(n) == (:b_f__, :b_m__, :b_f, :b_m, :b_f_, :b_m_)
5688
@test n.b_f__ == Bool[0, 1, 0, 0, 0, 1]
@@ -70,4 +102,4 @@
70102
# invalid column selection
71103
@test_throws AssertionError apply(OneHot(:c), t)
72104
@test_throws AssertionError apply(OneHot("c"), t)
73-
end
105+
end

0 commit comments

Comments
 (0)