Skip to content

Commit 4256d53

Browse files
authored
Refactor: Accept all types of categorical columns in OneHot and Levels (#218)
1 parent 482ca2e commit 4256d53

File tree

7 files changed

+94
-102
lines changed

7 files changed

+94
-102
lines changed

src/TableTransforms.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ using Random
1616
using CoDa
1717

1818
using TransformsBase: Transform, Identity,
19-
using DataScienceTraits: SciType, Continuous, coerce
2019
using ColumnSelectors: ColumnSelector, SingleColumnSelector
2120
using ColumnSelectors: AllSelector, Column, selector, selectsingle
21+
using DataScienceTraits: SciType, Continuous, Categorical, coerce
2222
using Unitful: AbstractQuantity, AffineQuantity, AffineUnits, Units
2323
using Distributions: ContinuousUnivariateDistribution, Normal
2424
using StatsBase: AbstractWeights, Weights, sample

src/assertions.jl

-24
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,3 @@ function (assertion::SciTypeAssertion{T})(table) where {T}
2525
@assert elscitype(x) <: T "the elements of the column '$nm' are not of scientific type $T"
2626
end
2727
end
28-
29-
"""
30-
ColumnTypeAssertion{T}(selector = AllSelector())
31-
32-
Asserts that the columns in the `selector` have a type `T`.
33-
"""
34-
struct ColumnTypeAssertion{T,S<:ColumnSelector}
35-
selector::S
36-
end
37-
38-
ColumnTypeAssertion{T}(selector::S) where {T,S<:ColumnSelector} = ColumnTypeAssertion{T,S}(selector)
39-
40-
ColumnTypeAssertion{T}() where {T} = ColumnTypeAssertion{T}(AllSelector())
41-
42-
function (assertion::ColumnTypeAssertion{T})(table) where {T}
43-
cols = Tables.columns(table)
44-
names = Tables.columnnames(cols)
45-
snames = assertion.selector(names)
46-
47-
for nm in snames
48-
x = Tables.getcolumn(cols, nm)
49-
@assert typeof(x) <: T "the column '$nm' is not of type $T"
50-
end
51-
end

src/transforms/levels.jl

+19-16
Original file line numberDiff line numberDiff line change
@@ -27,32 +27,35 @@ Levels(pairs::Pair{C}...; ordered=nothing) where {C<:Column} =
2727

2828
Levels(; kwargs...) = throw(ArgumentError("cannot create Levels transform without arguments"))
2929

30-
assertions(transform::Levels) = [ColumnTypeAssertion{CategoricalArray}(transform.selector)]
30+
assertions(transform::Levels) = [SciTypeAssertion{Categorical}(transform.selector)]
3131

3232
isrevertible(::Type{<:Levels}) = true
3333

34+
_revfun(x) = y -> Array(y)
35+
function _revfun(x::CategoricalArray)
36+
l, o = levels(x), isordered(x)
37+
y -> categorical(y, levels=l, ordered=o)
38+
end
39+
3440
function applyfeat(transform::Levels, feat, prep)
3541
cols = Tables.columns(feat)
3642
names = Tables.columnnames(cols)
3743
snames = transform.selector(names)
3844
ordered = transform.ordered(snames)
39-
tlevels = transform.levels
45+
leveldict = Dict(zip(snames, transform.levels))
4046

41-
results = map(names) do nm
42-
x = Tables.getcolumn(cols, nm)
47+
results = map(names) do name
48+
x = Tables.getcolumn(cols, name)
4349

44-
if nm snames
45-
o = nm ordered
46-
l = tlevels[findfirst(==(nm), snames)]
50+
if name snames
51+
o = name ordered
52+
l = leveldict[name]
4753
y = categorical(x, levels=l, ordered=o)
48-
49-
xl, xo = levels(x), isordered(x)
50-
revfunc = y -> categorical(y, levels=xl, ordered=xo)
54+
revfun = _revfun(x)
55+
y, revfun
5156
else
52-
y, revfunc = x, identity
57+
x, identity
5358
end
54-
55-
y, revfunc
5659
end
5760

5861
columns, fcache = first.(results), last.(results)
@@ -67,9 +70,9 @@ function revertfeat(::Levels, newfeat, fcache)
6770
cols = Tables.columns(newfeat)
6871
names = Tables.columnnames(cols)
6972

70-
columns = map(names, fcache) do nm, revfunc
71-
x = Tables.getcolumn(cols, nm)
72-
revfunc(x)
73+
columns = map(names, fcache) do name, revfun
74+
y = Tables.getcolumn(cols, name)
75+
revfun(y)
7376
end
7477

7578
𝒯 = (; zip(names, columns)...)

src/transforms/onehot.jl

+21-13
Original file line numberDiff line numberDiff line change
@@ -26,54 +26,62 @@ end
2626

2727
OneHot(col::Column; categ=false) = OneHot(selector(col), categ)
2828

29-
assertions(transform::OneHot) = [ColumnTypeAssertion{CategoricalArray}(transform.selector)]
29+
assertions(transform::OneHot) = [SciTypeAssertion{Categorical}(transform.selector)]
3030

3131
isrevertible(::Type{<:OneHot}) = true
3232

33+
_categ(x) = categorical(x), identity
34+
function _categ(x::CategoricalArray)
35+
l, o = levels(x), isordered(x)
36+
revfun = y -> categorical(y, levels=l, ordered=o)
37+
x, revfun
38+
end
39+
3340
function applyfeat(transform::OneHot, feat, prep)
3441
cols = Tables.columns(feat)
3542
names = Tables.columnnames(cols) |> collect
3643
columns = Any[Tables.getcolumn(cols, nm) for nm in names]
3744

3845
name = selectsingle(transform.selector, names)
3946
ind = findfirst(==(name), names)
40-
x = columns[ind]
47+
x, revfun = _categ(columns[ind])
4148

42-
xl = levels(x)
43-
onehot = map(xl) do l
49+
xlevels = levels(x)
50+
onehot = map(xlevels) do l
4451
nm = Symbol("$(name)_$l")
4552
while nm names
4653
nm = Symbol("$(nm)_")
4754
end
4855
nm, x .== l
4956
end
5057

51-
newnms, newcols = first.(onehot), last.(onehot)
58+
newnames = first.(onehot)
59+
newcolumns = last.(onehot)
5260

5361
# convert to categorical arrays if necessary
54-
newcols = transform.categ ? categorical.(newcols, levels=[false, true]) : newcols
62+
newcolumns = transform.categ ? categorical.(newcolumns, levels=[false, true]) : newcolumns
5563

56-
splice!(names, ind, newnms)
57-
splice!(columns, ind, newcols)
64+
splice!(names, ind, newnames)
65+
splice!(columns, ind, newcolumns)
5866

59-
inds = ind:(ind + length(newnms) - 1)
67+
inds = ind:(ind + length(newnames) - 1)
6068

6169
𝒯 = (; zip(names, columns)...)
6270
newfeat = 𝒯 |> Tables.materializer(feat)
63-
newfeat, (name, inds, xl, isordered(x))
71+
newfeat, (name, inds, xlevels, revfun)
6472
end
6573

6674
function revertfeat(::OneHot, newfeat, fcache)
6775
cols = Tables.columns(newfeat)
6876
names = Tables.columnnames(cols) |> collect
6977
columns = Any[Tables.getcolumn(cols, nm) for nm in names]
7078

71-
oname, inds, levels, ordered = fcache
72-
x = map(zip(columns[inds]...)) do row
79+
oname, inds, levels, revfun = fcache
80+
y = map(zip(columns[inds]...)) do row
7381
levels[findfirst(==(true), row)]
7482
end
7583

76-
ocolumn = categorical(x; levels, ordered)
84+
ocolumn = revfun(y)
7785

7886
splice!(names, inds, [oname])
7987
splice!(columns, inds, [ocolumn])

test/assertions.jl

-7
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,4 @@
2727
selector = CS.selector([:b, :e, :f])
2828
assertion = TT.SciTypeAssertion{DST.Categorical}(selector)
2929
@test_throws AssertionError assertion(table)
30-
31-
selector = CS.selector([:e, :f])
32-
assertion = TT.ColumnTypeAssertion{CategoricalArray}(selector)
33-
@test isnothing(assertion(table))
34-
selector = CS.selector([:b, :e, :f])
35-
assertion = TT.ColumnTypeAssertion{CategoricalArray}(selector)
36-
@test_throws AssertionError assertion(table)
3730
end

test/transforms/levels.jl

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
@testset "Levels" begin
2-
a = categorical(rand([true, false], 50))
3-
b = categorical(rand(["y", "n"], 50))
4-
c = categorical(rand(1:3, 50))
2+
a = Bool[1, 0, 1, 0, 1, 1]
3+
b = ["n", "y", "n", "y", "y", "y"]
4+
c = [2, 3, 1, 2, 1, 3]
55
t = Table(; a, b, c)
66

77
T = Levels(2 => ["n", "y", "m"])
88
n, c = apply(T, t)
99
@test levels(n.b) == ["n", "y", "m"]
1010
@test isordered(n.b) == false
1111
tₒ = revert(T, n, c)
12+
@test Tables.schema(tₒ) == Tables.schema(t)
1213
@test tₒ == t
1314

1415
T = Levels(:b => ["n", "y", "m"], :c => 1:4, ordered=[:c])
@@ -18,6 +19,7 @@
1819
@test levels(n.c) == [1, 2, 3, 4]
1920
@test isordered(n.c) == true
2021
tₒ = revert(T, n, c)
22+
@test Tables.schema(tₒ) == Tables.schema(t)
2123
@test tₒ == t
2224

2325
T = Levels("b" => ["n", "y", "m"], "c" => 1:4, ordered=["b"])
@@ -27,6 +29,7 @@
2729
@test levels(n.c) == [1, 2, 3, 4]
2830
@test isordered(n.c) == false
2931
tₒ = revert(T, n, c)
32+
@test Tables.schema(tₒ) == Tables.schema(t)
3033
@test tₒ == t
3134

3235
a = categorical(["yes", "no", "no", "no", "yes"])
@@ -87,9 +90,9 @@
8790
tₒ = revert(T, n, c)
8891
@test isordered(tₒ.a) == false
8992

90-
a = rand([true, false], 50)
91-
b = categorical(rand(["y", "n"], 50))
92-
c = categorical(rand(1:3, 50))
93+
a = [0.1, 0.1, 0.2, 0.2, 0.1, 0.2]
94+
b = ["n", "y", "n", "y", "y", "y"]
95+
c = [2, 3, 1, 2, 1, 3]
9396
t = Table(; a, b, c)
9497

9598
# throws: Levels without arguments
@@ -102,7 +105,7 @@
102105
@test_throws AssertionError apply(T, t)
103106

104107
# throws: non categorical column
105-
T = Levels(:a => [true, false], ordered=[:a])
108+
T = Levels(:a => [0.1, 0.2, 0.3], ordered=[:a])
106109
@test_throws AssertionError apply(T, t)
107110

108111
# throws: invalid ordered column selection

test/transforms/onehot.jl

+43-34
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,79 @@
11
@testset "OneHot" begin
2-
a = categorical(Bool[0, 1, 1, 0, 1, 1])
3-
b = categorical(["m", "f", "m", "m", "m", "f"])
4-
c = categorical([3, 2, 2, 1, 1, 3])
5-
t = Table(; a, b, c)
2+
a = Bool[0, 1, 1, 0, 1, 1]
3+
b = ["m", "f", "m", "m", "m", "f"]
4+
c = [3, 2, 2, 1, 1, 3]
5+
d = categorical(a)
6+
e = categorical(b)
7+
f = categorical(c)
8+
t = Table(; a, b, c, d, e, f)
69

710
T = OneHot(1; categ=true)
811
n, c = apply(T, t)
9-
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c)
12+
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c, :d, :e, :f)
1013
@test n.a_false == categorical(Bool[1, 0, 0, 1, 0, 0])
1114
@test n.a_true == categorical(Bool[0, 1, 1, 0, 1, 1])
1215
@test n.a_false isa CategoricalVector{Bool}
1316
@test n.a_true isa CategoricalVector{Bool}
1417
tₒ = revert(T, n, c)
15-
@test t == tₒ
18+
@test Tables.schema(tₒ) == Tables.schema(t)
19+
@test tₒ == t
1620

1721
T = OneHot(:b; categ=true)
1822
n, c = apply(T, t)
19-
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c)
23+
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c, :d, :e, :f)
2024
@test n.b_f == categorical(Bool[0, 1, 0, 0, 0, 1])
2125
@test n.b_m == categorical(Bool[1, 0, 1, 1, 1, 0])
2226
@test n.b_f isa CategoricalVector{Bool}
2327
@test n.b_m isa CategoricalVector{Bool}
2428
tₒ = revert(T, n, c)
25-
@test t == tₒ
29+
@test Tables.schema(tₒ) == Tables.schema(t)
30+
@test tₒ == t
2631

2732
T = OneHot("c"; categ=true)
2833
n, c = apply(T, t)
29-
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3)
34+
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3, :d, :e, :f)
3035
@test n.c_1 == categorical(Bool[0, 0, 0, 1, 1, 0])
3136
@test n.c_2 == categorical(Bool[0, 1, 1, 0, 0, 0])
3237
@test n.c_3 == categorical(Bool[1, 0, 0, 0, 0, 1])
3338
@test n.c_1 isa CategoricalVector{Bool}
3439
@test n.c_2 isa CategoricalVector{Bool}
3540
@test n.c_3 isa CategoricalVector{Bool}
3641
tₒ = revert(T, n, c)
37-
@test t == tₒ
42+
@test Tables.schema(tₒ) == Tables.schema(t)
43+
@test tₒ == t
3844

39-
T = OneHot(1; categ=false)
45+
T = OneHot(4; categ=false)
4046
n, c = apply(T, t)
41-
@test Tables.columnnames(n) == (:a_false, :a_true, :b, :c)
42-
@test n.a_false == Bool[1, 0, 0, 1, 0, 0]
43-
@test n.a_true == Bool[0, 1, 1, 0, 1, 1]
47+
@test Tables.columnnames(n) == (:a, :b, :c, :d_false, :d_true, :e, :f)
48+
@test n.d_false == Bool[1, 0, 0, 1, 0, 0]
49+
@test n.d_true == Bool[0, 1, 1, 0, 1, 1]
4450
tₒ = revert(T, n, c)
45-
@test t == tₒ
51+
@test Tables.schema(tₒ) == Tables.schema(t)
52+
@test tₒ == t
4653

47-
T = OneHot(:b; categ=false)
54+
T = OneHot(:e; categ=false)
4855
n, c = apply(T, t)
49-
@test Tables.columnnames(n) == (:a, :b_f, :b_m, :c)
50-
@test n.b_f == Bool[0, 1, 0, 0, 0, 1]
51-
@test n.b_m == Bool[1, 0, 1, 1, 1, 0]
56+
@test Tables.columnnames(n) == (:a, :b, :c, :d, :e_f, :e_m, :f)
57+
@test n.e_f == Bool[0, 1, 0, 0, 0, 1]
58+
@test n.e_m == Bool[1, 0, 1, 1, 1, 0]
5259
tₒ = revert(T, n, c)
53-
@test t == tₒ
60+
@test Tables.schema(tₒ) == Tables.schema(t)
61+
@test tₒ == t
5462

55-
T = OneHot("c"; categ=false)
63+
T = OneHot("f"; categ=false)
5664
n, c = apply(T, t)
57-
@test Tables.columnnames(n) == (:a, :b, :c_1, :c_2, :c_3)
58-
@test n.c_1 == Bool[0, 0, 0, 1, 1, 0]
59-
@test n.c_2 == Bool[0, 1, 1, 0, 0, 0]
60-
@test n.c_3 == Bool[1, 0, 0, 0, 0, 1]
65+
@test Tables.columnnames(n) == (:a, :b, :c, :d, :e, :f_1, :f_2, :f_3)
66+
@test n.f_1 == Bool[0, 0, 0, 1, 1, 0]
67+
@test n.f_2 == Bool[0, 1, 1, 0, 0, 0]
68+
@test n.f_3 == Bool[1, 0, 0, 0, 0, 1]
6169
tₒ = revert(T, n, c)
62-
@test t == tₒ
70+
@test Tables.schema(tₒ) == Tables.schema(t)
71+
@test tₒ == t
6372

6473
# name formatting
6574
b = categorical(["m", "f", "m", "m", "m", "f"])
66-
b_f = rand(10)
67-
b_m = rand(10)
75+
b_f = rand(6)
76+
b_m = rand(6)
6877
t = Table(; b, b_f, b_m)
6978

7079
T = OneHot(:b; categ=false)
@@ -76,10 +85,10 @@
7685
@test t == tₒ
7786

7887
b = categorical(["m", "f", "m", "m", "m", "f"])
79-
b_f = rand(10)
80-
b_m = rand(10)
81-
b_f_ = rand(10)
82-
b_m_ = rand(10)
88+
b_f = rand(6)
89+
b_m = rand(6)
90+
b_f_ = rand(6)
91+
b_m_ = rand(6)
8392
t = Table(; b, b_f, b_m, b_f_, b_m_)
8493

8594
T = OneHot(:b; categ=false)
@@ -91,8 +100,8 @@
91100
@test t == tₒ
92101

93102
# throws
94-
a = categorical(Bool[0, 1, 1, 0, 1, 1])
95-
b = ["m", "f", "m", "m", "m", "f"]
103+
a = Bool[0, 1, 1, 0, 1, 1]
104+
b = rand(6)
96105
t = Table(; a, b)
97106

98107
# non categorical column

0 commit comments

Comments
 (0)