Skip to content

Commit e9f97db

Browse files
authored
Add support for 'Coerce(SciType)' (#271)
1 parent 1edfa1b commit e9f97db

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

src/transforms/coerce.jl

+12-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
88
Return a copy of the table, ensuring that the scientific types of the columns match the new specification.
99
10+
Coerce(S)
11+
12+
Coerce all columns of the table with scientific type `S`.
13+
1014
This transform uses the `DataScienceTraits.coerce` function. Please see their docstring for more details.
1115
1216
# Examples
@@ -18,23 +22,28 @@ Coerce(:a => DST.Continuous, :b => DST.Continuous)
1822
Coerce("a" => DST.Continuous, "b" => DST.Continuous)
1923
```
2024
"""
21-
struct Coerce{S<:ColumnSelector} <: StatelessFeatureTransform
25+
struct Coerce{S<:ColumnSelector,T} <: StatelessFeatureTransform
2226
selector::S
23-
scitypes::Vector{DataType}
27+
scitypes::T
2428
end
2529

2630
Coerce() = throw(ArgumentError("cannot create Coerce transform without arguments"))
2731

32+
Coerce(scitype::Type{<:SciType}) = Coerce(AllSelector(), scitype)
33+
2834
Coerce(pairs::Pair{C,DataType}...) where {C<:Column} = Coerce(selector(first.(pairs)), collect(last.(pairs)))
2935

3036
isrevertible(::Type{<:Coerce}) = true
3137

38+
_typedict(scitype::Type{<:SciType}, snames) = Dict(nm => scitype for nm in snames)
39+
_typedict(scitypes::AbstractVector, snames) = Dict(zip(snames, scitypes))
40+
3241
function applyfeat(transform::Coerce, feat, prep)
3342
cols = Tables.columns(feat)
3443
names = Tables.columnnames(cols)
3544
types = Tables.schema(feat).types
3645
snames = transform.selector(names)
37-
typedict = Dict(zip(snames, transform.scitypes))
46+
typedict = _typedict(transform.scitypes, snames)
3847

3948
columns = map(names) do name
4049
x = Tables.getcolumn(cols, name)

src/transforms/rename.jl

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct Rename{S<:ColumnSelector,N} <: StatelessFeatureTransform
3636
end
3737
end
3838

39+
Rename() = throw(ArgumentError("cannot create Rename transform without arguments"))
40+
3941
Rename(fun) = Rename(AllSelector(), fun)
4042

4143
Rename(pairs::Pair{C,Symbol}...) where {C<:Column} = Rename(selector(first.(pairs)), collect(last.(pairs)))

test/transforms/coerce.jl

+21
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,32 @@
3030
@test eltype(tₒ.a) == eltype(t.a)
3131
@test eltype(tₒ.b) == eltype(t.b)
3232

33+
T = Coerce(DST.Continuous)
34+
n, c = apply(T, t)
35+
@test eltype(n.a) <: Float64
36+
@test eltype(n.b) <: Float64
37+
n, c = apply(T, t)
38+
tₒ = revert(T, n, c)
39+
@test eltype(tₒ.a) == eltype(t.a)
40+
@test eltype(tₒ.b) == eltype(t.b)
41+
42+
T = Coerce(DST.Categorical)
43+
n, c = apply(T, t)
44+
@test eltype(n.a) <: Int
45+
@test eltype(n.b) <: Int
46+
n, c = apply(T, t)
47+
tₒ = revert(T, n, c)
48+
@test eltype(tₒ.a) == eltype(t.a)
49+
@test eltype(tₒ.b) == eltype(t.b)
50+
3351
# row table
3452
rt = Tables.rowtable(t)
3553
T = Coerce(:a => DST.Continuous, :b => DST.Categorical)
3654
n, c = apply(T, rt)
3755
@test Tables.isrowtable(n)
3856
rtₒ = revert(T, n, c)
3957
@test rt == rtₒ
58+
59+
# error: cannot create Coerce transform without arguments
60+
@test_throws ArgumentError Coerce()
4061
end

test/transforms/rename.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@
152152
tₒ = revert(T, n, c)
153153
@test t == tₒ
154154

155-
# throws
155+
# error: cannot create Rename transform without arguments
156+
@test_throws ArgumentError Rename()
157+
# error: new names must be unique
156158
@test_throws AssertionError Rename(:a => :x, :b => :x)
157159
@test_throws AssertionError apply(Rename(:a => :c, :b => :d), t)
158160
end

0 commit comments

Comments
 (0)