Skip to content

Commit 1bc5fcf

Browse files
eliascarvjuliohm
andauthored
Use the InverseFunctions.inverse function in Functional transform (#227)
* Use the 'InverseFunctions.inverse' function in Functional transform * Update tests * Update docstring * Update code --------- Co-authored-by: Júlio Hoffimann <[email protected]>
1 parent 33e8f01 commit 1bc5fcf

File tree

6 files changed

+94
-147
lines changed

6 files changed

+94
-147
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ CoDa = "5900dafe-f573-5c72-b367-76665857777b"
1010
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
1111
DataScienceTraits = "6cb2f572-2d2b-4ba6-bdb3-e710fa044d6c"
1212
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
13+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
NelderMead = "2f6b4ddb-b4ff-44c0-b59b-2ab99302f970"
1516
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
@@ -28,6 +29,7 @@ CoDa = "1.2"
2829
ColumnSelectors = "0.1"
2930
DataScienceTraits = "0.1"
3031
Distributions = "0.25"
32+
InverseFunctions = "0.1"
3133
LinearAlgebra = "1.9"
3234
NelderMead = "0.4"
3335
PrettyTables = "2"
@@ -36,6 +38,6 @@ Statistics = "1.9"
3638
StatsBase = "0.33, 0.34"
3739
Tables = "1.6"
3840
Transducers = "0.4"
39-
TransformsBase = "1.2"
41+
TransformsBase = "1.3"
4042
Unitful = "1.17"
4143
julia = "1.9"

src/TableTransforms.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ using ColumnSelectors: AllSelector, Column, selector, selectsingle
2121
using DataScienceTraits: SciType, Continuous, Categorical, coerce
2222
using Unitful: AbstractQuantity, AffineQuantity, AffineUnits, Units
2323
using Distributions: ContinuousUnivariateDistribution, Normal
24+
using InverseFunctions: NoInverse, inverse as invfun
2425
using StatsBase: AbstractWeights, Weights, sample
2526
using Transducers: tcollect
2627
using NelderMead: optimise
2728

2829
import Distributions: quantile, cdf
29-
import TransformsBase: assertions, isrevertible, preprocess
30-
import TransformsBase: apply, revert, reapply
30+
import TransformsBase: assertions, isrevertible, isinvertible
31+
import TransformsBase: apply, revert, reapply, preprocess, inverse
3132

3233
include("assertions.jl")
3334
include("tabletraits.jl")

src/transforms/functional.jl

+30-57
Original file line numberDiff line numberDiff line change
@@ -3,98 +3,71 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
Functional(func)
6+
Functional(fun)
77
8-
The transform that applies a `func` elementwise.
8+
The transform that applies a `fun` elementwise.
99
10-
Functional(col₁ => func₁, col₂ => func₂, ..., colₙ => funcₙ)
10+
Functional(col₁ => fun₁, col₂ => fun₂, ..., colₙ => funₙ)
1111
12-
Apply the corresponding `funcᵢ` function to each `colᵢ` column.
12+
Apply the corresponding `funᵢ` function to each `colᵢ` column.
1313
1414
# Examples
1515
1616
```julia
17-
Functional(cos)
18-
Functional(sin)
19-
Functional(1 => cos, 2 => sin)
20-
Functional(:a => cos, :b => sin)
21-
Functional("a" => cos, "b" => sin)
17+
Functional(exp)
18+
Functional(log)
19+
Functional(1 => exp, 2 => log)
20+
Functional(:a => exp, :b => log)
21+
Functional("a" => exp, "b" => log)
2222
```
2323
"""
2424
struct Functional{S<:ColumnSelector,F} <: StatelessFeatureTransform
2525
selector::S
26-
func::F
26+
fun::F
2727
end
2828

29-
Functional(func) = Functional(AllSelector(), func)
29+
Functional(fun) = Functional(AllSelector(), fun)
3030

3131
Functional(pairs::Pair{C}...) where {C<:Column} = Functional(selector(first.(pairs)), last.(pairs))
3232

3333
Functional() = throw(ArgumentError("cannot create Functional transform without arguments"))
3434

35-
# known invertible functions
36-
inverse(::typeof(log)) = exp
37-
inverse(::typeof(exp)) = log
38-
inverse(::typeof(cos)) = acos
39-
inverse(::typeof(acos)) = cos
40-
inverse(::typeof(sin)) = asin
41-
inverse(::typeof(asin)) = sin
42-
inverse(::typeof(cosd)) = acosd
43-
inverse(::typeof(acosd)) = cosd
44-
inverse(::typeof(sind)) = asind
45-
inverse(::typeof(asind)) = sind
46-
inverse(::typeof(identity)) = identity
35+
isrevertible(transform::Functional) = isinvertible(transform)
4736

48-
# fallback to nothing
49-
inverse(::Any) = nothing
37+
_hasinverse(f) = !(invfun(f) isa NoInverse)
5038

51-
isrevertible(transform::Functional{AllSelector}) = !isnothing(inverse(transform.func))
39+
isinvertible(transform::Functional{AllSelector}) = _hasinverse(transform.fun)
40+
isinvertible(transform::Functional) = all(_hasinverse, transform.fun)
5241

53-
isrevertible(transform::Functional) = all(!isnothing, inverse.(transform.func))
42+
inverse(transform::Functional{AllSelector}) = Functional(transform.selector, invfun(transform.fun))
43+
inverse(transform::Functional) = Functional(transform.selector, invfun.(transform.fun))
5444

55-
_funcdict(func, names) = Dict(nm => func for nm in names)
56-
_funcdict(func::Tuple, names) = Dict(names .=> func)
45+
_fundict(transform::Functional{AllSelector}, names) = Dict(nm => transform.fun for nm in names)
46+
_fundict(transform::Functional, names) = Dict(zip(names, transform.fun))
5747

5848
function applyfeat(transform::Functional, feat, prep)
5949
cols = Tables.columns(feat)
6050
names = Tables.columnnames(cols)
6151
snames = transform.selector(names)
62-
funcs = _funcdict(transform.func, snames)
52+
fundict = _fundict(transform, snames)
6353

64-
columns = map(names) do nm
65-
x = Tables.getcolumn(cols, nm)
66-
if nm snames
67-
func = funcs[nm]
68-
y = func.(x)
54+
columns = map(names) do name
55+
x = Tables.getcolumn(cols, name)
56+
if name snames
57+
fun = fundict[name]
58+
map(fun, x)
6959
else
70-
y = x
60+
x
7161
end
72-
y
7362
end
7463

7564
𝒯 = (; zip(names, columns)...)
7665
newfeat = 𝒯 |> Tables.materializer(feat)
77-
return newfeat, (snames, funcs)
66+
67+
newfeat, nothing
7868
end
7969

8070
function revertfeat(transform::Functional, newfeat, fcache)
81-
cols = Tables.columns(newfeat)
82-
names = Tables.columnnames(cols)
83-
84-
snames, funcs = fcache
85-
86-
columns = map(names) do nm
87-
y = Tables.getcolumn(cols, nm)
88-
if nm snames
89-
func = funcs[nm]
90-
invfunc = inverse(func)
91-
x = invfunc.(y)
92-
else
93-
x = y
94-
end
95-
x
96-
end
97-
98-
𝒯 = (; zip(names, columns)...)
99-
𝒯 |> Tables.materializer(newfeat)
71+
ofeat, _ = applyfeat(inverse(transform), newfeat, nothing)
72+
ofeat
10073
end

test/metadata.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
mtₒ = revert(T, mn, mc)
2525
@test mtₒ == mt
2626

27-
T = Functional(sin)
27+
T = Functional(exp)
2828
mn, mc = apply(T, mt)
2929
tn, tc = apply(T, t)
3030
@test mn.meta == m
@@ -33,7 +33,7 @@
3333
@test mtₒ.meta == mt.meta
3434
@test Tables.matrix(mtₒ.table) Tables.matrix(mt.table)
3535

36-
T = (Functional(sin) MinMax()) Center()
36+
T = (Functional(exp) MinMax()) Center()
3737
mn, mc = apply(T, mt)
3838
tn, tc = apply(T, t)
3939
@test mn.meta == m
@@ -68,7 +68,7 @@
6868
mtₒ = revert(T, mn, mc)
6969
@test mtₒ == mt
7070

71-
T = Functional(cos)
71+
T = Functional(exp)
7272
mn, mc = apply(T, mt)
7373
tn, tc = apply(T, t)
7474
@test mn.meta == VarMeta(m.data .+ 2)
@@ -79,7 +79,7 @@
7979

8080
# first revertible branch has two transforms,
8181
# so metadata is increased by 2 + 2 = 4
82-
T = (Functional(sin) MinMax()) Center()
82+
T = (Functional(exp) MinMax()) Center()
8383
mn, mc = apply(T, mt)
8484
tn, tc = apply(T, t)
8585
@test mn.meta == VarMeta(m.data .+ 4)
@@ -90,7 +90,7 @@
9090

9191
# first revertible branch has one transform,
9292
# so metadata is increased by 2
93-
T = Center() (Functional(sin) MinMax())
93+
T = Center() (Functional(exp) MinMax())
9494
mn, mc = apply(T, mt)
9595
tn, tc = apply(T, t)
9696
@test mn.meta == VarMeta(m.data .+ 2)

test/shows.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -317,18 +317,18 @@
317317
end
318318

319319
@testset "Functional" begin
320-
T = Functional(sin)
320+
T = Functional(log)
321321

322322
# compact mode
323323
iostr = sprint(show, T)
324-
@test iostr == "Functional(all, sin)"
324+
@test iostr == "Functional(all, log)"
325325

326326
# full mode
327327
iostr = sprint(show, MIME("text/plain"), T)
328328
@test iostr == """
329329
Functional transform
330330
├─ selector = all
331-
└─ func = sin"""
331+
└─ fun = log"""
332332
end
333333

334334
@testset "EigenAnalysis" begin
@@ -419,31 +419,31 @@
419419
@testset "ParallelTableTransform" begin
420420
t1 = Scale(low=0.3, high=0.6)
421421
t2 = EigenAnalysis(:VDV)
422-
t3 = Functional(cos)
422+
t3 = Functional(exp)
423423
pipeline = t1 t2 t3
424424

425425
# compact mode
426426
iostr = sprint(show, pipeline)
427-
@test iostr == "Scale(all, 0.3, 0.6) ⊔ EigenAnalysis(:VDV, nothing, 1.0) ⊔ Functional(all, cos)"
427+
@test iostr == "Scale(all, 0.3, 0.6) ⊔ EigenAnalysis(:VDV, nothing, 1.0) ⊔ Functional(all, exp)"
428428

429429
# full mode
430430
iostr = sprint(show, MIME("text/plain"), pipeline)
431431
@test iostr == """
432432
ParallelTableTransform
433433
├─ Scale(all, 0.3, 0.6)
434434
├─ EigenAnalysis(:VDV, nothing, 1.0)
435-
└─ Functional(all, cos)"""
435+
└─ Functional(all, exp)"""
436436

437437
# parallel and sequential
438438
f1 = ZScore()
439439
f2 = Scale()
440-
f3 = Functional(cos)
440+
f3 = Functional(exp)
441441
f4 = Interquartile()
442442
pipeline = (f1 f2) (f3 f4)
443443

444444
# compact mode
445445
iostr = sprint(show, pipeline)
446-
@test iostr == "ZScore(all) → Scale(all, 0.25, 0.75) ⊔ Functional(all, cos) → Scale(all, 0.25, 0.75)"
446+
@test iostr == "ZScore(all) → Scale(all, 0.25, 0.75) ⊔ Functional(all, exp) → Scale(all, 0.25, 0.75)"
447447

448448
# full mode
449449
iostr = sprint(show, MIME("text/plain"), pipeline)
@@ -453,7 +453,7 @@
453453
│ ├─ ZScore(all)
454454
│ └─ Scale(all, 0.25, 0.75)
455455
└─ SequentialTransform
456-
├─ Functional(all, cos)
456+
├─ Functional(all, exp)
457457
└─ Scale(all, 0.25, 0.75)"""
458458
end
459459
end

0 commit comments

Comments
 (0)