Skip to content

Commit c2cf150

Browse files
authored
Merge pull request #259 from JuliaML/parameters
Implement the `parameters` function for some transforms
2 parents f44e433 + e925807 commit c2cf150

18 files changed

+34
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,6 @@ Statistics = "1.9"
3838
StatsBase = "0.33, 0.34"
3939
Tables = "1.6"
4040
Transducers = "0.4"
41-
TransformsBase = "1.3"
41+
TransformsBase = "1.4"
4242
Unitful = "1.17"
4343
julia = "1.9"

src/TableTransforms.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ using Transducers: tcollect
2727
using NelderMead: optimise
2828

2929
import Distributions: quantile, cdf
30-
import TransformsBase: assertions, isrevertible, isinvertible
30+
import TransformsBase: assertions, parameters, isrevertible, isinvertible
3131
import TransformsBase: apply, revert, reapply, preprocess, inverse
3232

3333
include("tabletraits.jl")

src/transforms/coalesce.jl

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ Coalesce(; value) = Coalesce(AllSelector(), value)
4242
Coalesce(cols; value) = Coalesce(selector(cols), value)
4343
Coalesce(cols::C...; value) where {C<:Column} = Coalesce(selector(cols), value)
4444

45+
parameters(transform::Coalesce) = (; value=transform.value)
46+
4547
isrevertible(::Type{<:Coalesce}) = false
4648

4749
colcache(::Coalesce, x) = nothing

src/transforms/dropextrema.jl

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ DropExtrema(; low=0.25, high=0.75) = DropExtrema(AllSelector(), low, high)
5050
DropExtrema(cols; low=0.25, high=0.75) = DropExtrema(selector(cols), low, high)
5151
DropExtrema(cols::C...; low=0.25, high=0.75) where {C<:Column} = DropExtrema(selector(cols), low, high)
5252

53+
parameters(transform::DropExtrema) = (low=transform.low, high=transform.high)
54+
5355
isrevertible(::Type{<:DropExtrema}) = false
5456

5557
function preprocess(transform::DropExtrema, feat)

src/transforms/eigenanalysis.jl

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ EigenAnalysis(proj; maxdim=nothing, pratio=1.0) = EigenAnalysis(proj, maxdim, pr
5454

5555
assertions(::EigenAnalysis) = [scitypeassert(Continuous)]
5656

57+
parameters(transform::EigenAnalysis) = (proj=transform.proj, maxdim=transform.maxdim, pratio=transform.pratio)
58+
5759
isrevertible(::Type{EigenAnalysis}) = true
5860

5961
function applyfeat(transform::EigenAnalysis, feat, prep)

src/transforms/indicator.jl

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ Indicator(col::Column; k=10, scale=:quantile, categ=false) = Indicator(selector(
4848

4949
assertions(transform::Indicator) = [scitypeassert(Continuous, transform.selector)]
5050

51+
parameters(transform::Indicator) = (k=transform.k, scale=transform.scale)
52+
5153
isrevertible(::Type{<:Indicator}) = true
5254

5355
function _intervals(transform::Indicator, x)

src/transforms/projectionpursuit.jl

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ ProjectionPursuit(; tol=1e-6, maxiter=100, deg=5, perc=0.9, n=100, rng=Random.GL
4949

5050
assertions(::ProjectionPursuit) = [scitypeassert(Continuous)]
5151

52+
parameters(transform::ProjectionPursuit) = (tol=transform.tol, deg=transform.deg, perc=transform.perc, n=transform.n)
53+
5254
isrevertible(::Type{<:ProjectionPursuit}) = true
5355

5456
# transforms a row of random variables into a convex combination

src/transforms/quantile.jl

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ Quantile(cols::C...; dist=Normal()) where {C<:Column} = Quantile(selector(cols),
4141

4242
assertions(transform::Quantile) = [scitypeassert(Continuous, transform.selector)]
4343

44+
parameters(transform::Quantile) = (; dist=transform.dist)
45+
4446
isrevertible(::Type{<:Quantile}) = true
4547

4648
colcache(::Quantile, x) = EmpiricalDistribution(x)

src/transforms/remainder.jl

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ isrevertible(::Type{<:Remainder}) = true
2222

2323
assertions(::Remainder) = [scitypeassert(Continuous)]
2424

25+
parameters(transform::Remainder) = (; total=transform.total)
26+
2527
function applyfeat(transform::Remainder, feat, prep)
2628
cols = Tables.columns(feat)
2729
names = Tables.columnnames(cols) |> collect

src/transforms/scale.jl

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Scale(cols::C...; low=0.25, high=0.75) where {C<:Column} = Scale(selector(cols),
5050

5151
assertions(transform::Scale) = [scitypeassert(Continuous, transform.selector)]
5252

53+
parameters(transform::Scale) = (low=transform.low, high=transform.high)
54+
5355
isrevertible(::Type{<:Scale}) = true
5456

5557
function colcache(transform::Scale, x)

test/transforms/coalesce.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
@testset "Coalesce" begin
22
@test !isrevertible(Coalesce(value=0))
33

4+
@test TT.parameters(Coalesce(value=0)) == (; value=0)
5+
46
a = [3, 2, missing, 4, 5, 3]
57
b = [missing, 4, 4, 5, 8, 5]
68
c = [1, 1, 6, 2, 4, missing]

test/transforms/dropextrema.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
@testset "DropExtrema" begin
22
@test !isrevertible(DropExtrema(:a))
33

4+
@test TT.parameters(DropExtrema(:a)) == (low=0.25, high=0.75)
5+
46
a = [6.9, 9.0, 7.8, 0.0, 5.1, 4.8, 1.1, 8.0, 5.4, 7.9]
57
b = [7.7, 4.2, 6.3, 1.4, 4.4, 0.5, 3.0, 6.1, 1.9, 1.5]
68
c = [6.1, 7.7, 5.7, 2.8, 2.8, 6.7, 8.4, 5.0, 8.9, 1.0]

test/transforms/eigenanalysis.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
@testset "EigenAnalysis" begin
2+
@test TT.parameters(EigenAnalysis(:V)) == (proj=:V, maxdim=nothing, pratio=1.0)
3+
24
# PCA test
35
x = rand(Normal(0, 10), 1500)
46
y = x + rand(Normal(0, 2), 1500)

test/transforms/indicator.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
@testset "Indicator" begin
2+
@test TT.parameters(Indicator(:a)) == (k=10, scale=:quantile)
3+
24
a = [5.8, 6.4, 6.4, 9.8, 7.6, 8.2, 4.5, 2.5, 1.7, 2.3]
35
b = [8.4, 1.4, 7.2, 1.8, 9.4, 1.0, 2.0, 5.2, 9.4, 6.2]
46
c = [4.1, 5.6, 7.1, 9.1, 5.9, 9.5, 5.7, 9.0, 6.6, 9.9]

test/transforms/projectionpursuit.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
@testset "ProjectionPursuit" begin
2+
@test TT.parameters(ProjectionPursuit()) == (tol=1.0e-6, deg=5, perc=0.9, n=100)
3+
24
rng = MersenneTwister(42)
35
N = 10_000
46
a = [2randn(rng, N ÷ 2) .+ 6; randn(rng, N ÷ 2)]

test/transforms/quantile.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
@testset "Quantile" begin
2+
@test TT.parameters(Quantile()) == (; dist=Normal())
3+
24
t = Table(z=rand(100))
35
T = Quantile()
46
n, c = apply(T, t)

test/transforms/remainder.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
@testset "Remainder" begin
22
@test isrevertible(Remainder())
33

4+
@test TT.parameters(Remainder()) == (; total=nothing)
5+
46
a = [2.0, 66.0, 0.0]
57
b = [4.0, 22.0, 2.0]
68
c = [4.0, 12.0, 98.0]

test/transforms/scale.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
@testset "Scale" begin
2+
@test TT.parameters(Scale(:a)) == (low=0.25, high=0.75)
3+
24
# constant column
35
x = fill(3.0, 10)
46
y = rand(10)

0 commit comments

Comments
 (0)