Skip to content

Commit 1ab1fe9

Browse files
authored
Fix revert of logratio transforms (#296)
1 parent 8758da8 commit 1ab1fe9

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

src/transforms/logratio.jl

+29-19
Original file line numberDiff line numberDiff line change
@@ -23,56 +23,66 @@ assertions(::LogRatio) = [scitypeassert(Continuous)]
2323

2424
function applyfeat(transform::LogRatio, feat, prep)
2525
cols = Tables.columns(feat)
26-
onames = Tables.columnnames(cols)
27-
varnames = collect(onames)
26+
names = Tables.columnnames(cols)
27+
vars = collect(names)
28+
29+
# perform closure for full revertibility
30+
cfeat, ccache = apply(Closure(), feat)
2831

2932
# reference variable
30-
rvar = refvar(transform, varnames)
31-
_assert(rvar varnames, "invalid reference variable")
32-
rind = findfirst(==(rvar), varnames)
33+
rvar = refvar(transform, vars)
34+
_assert(rvar vars, "invalid reference variable")
35+
36+
# reference index
37+
rind = findfirst(==(rvar), vars)
3338

3439
# permute columns if necessary
35-
perm = rind lastindex(varnames)
40+
perm = rind lastindex(vars)
3641
pfeat = if perm
37-
popat!(varnames, rind)
38-
push!(varnames, rvar)
39-
feat |> Select(varnames)
42+
popat!(vars, rind)
43+
push!(vars, rvar)
44+
cfeat |> Select(vars)
4045
else
41-
feat
46+
cfeat
4247
end
4348

4449
# apply transform
4550
X = Tables.matrix(pfeat)
4651
Y = applymatrix(transform, X)
4752

4853
# new variable names
49-
newnames = newvars(transform, varnames)
54+
newnames = newvars(transform, vars)
5055

5156
# return same table type
5257
𝒯 = (; zip(newnames, eachcol(Y))...)
5358
newfeat = 𝒯 |> Tables.materializer(feat)
5459

55-
newfeat, (rind, perm, onames)
60+
newfeat, (ccache, perm, rind, vars)
5661
end
5762

5863
function revertfeat(transform::LogRatio, newfeat, fcache)
64+
# retrieve cache
65+
ccache, perm, rind, vars = fcache
66+
5967
# revert transform
6068
Y = Tables.matrix(newfeat)
6169
X = revertmatrix(transform, Y)
62-
63-
# retrieve cache
64-
rind, perm, onames = fcache
70+
pfeat = (; zip(vars, eachcol(X))...)
6571

6672
# revert the permutation if necessary
67-
if perm
68-
n = length(onames)
73+
cfeat = if perm
74+
n = length(vars)
6975
inds = collect(1:(n - 1))
7076
insert!(inds, rind, n)
71-
X = X[:, inds]
77+
pfeat |> Select(inds)
78+
else
79+
pfeat
7280
end
7381

82+
# revert closure for full revertibility
83+
𝒯 = revert(Closure(), cfeat, ccache)
84+
7485
# return same table type
75-
𝒯 = (; zip(onames, eachcol(X))...)
7686
𝒯 |> Tables.materializer(newfeat)
7787
end
7888

test/transforms/logratio.jl

+25-11
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,22 @@
1212
n, c = apply(T, t)
1313
@test Tables.schema(n).names == (:ARL1, :ARL2)
1414
@test n == t |> ALR(:c)
15-
talr = revert(T, n, c)
15+
r = revert(T, n, c)
16+
@test Tables.matrix(r) Tables.matrix(t)
17+
1618
T = CLR()
1719
n, c = apply(T, t)
1820
@test Tables.schema(n).names == (:CLR1, :CLR2, :CLR3)
19-
tclr = revert(T, n, c)
21+
r = revert(T, n, c)
22+
@test Tables.matrix(r) Tables.matrix(t)
23+
2024
T = ILR()
2125
n, c = apply(T, t)
2226
@test Tables.schema(n).names == (:ILR1, :ILR2)
2327
@test n == t |> ILR(:c)
24-
tilr = revert(T, n, c)
25-
@test Tables.matrix(talr) Tables.matrix(tclr)
26-
@test Tables.matrix(tclr) Tables.matrix(tilr)
27-
@test Tables.matrix(talr) Tables.matrix(tilr)
28+
r = revert(T, n, c)
29+
@test Tables.matrix(r) Tables.matrix(t)
2830

29-
# permute columns
3031
a = [1.0, 0.0, 1.0]
3132
b = [2.0, 2.0, 2.0]
3233
c = [3.0, 3.0, 0.0]
@@ -35,10 +36,23 @@
3536

3637
T = ALR(:c)
3738
n1, c1 = apply(T, t1)
39+
r1 = revert(T, n1, c1)
40+
n2, c2 = apply(T, t2)
41+
r2 = revert(T, n2, c2)
42+
@test n1 == n2
43+
@test Tables.matrix(r1) Tables.matrix(t1)
44+
@test Tables.schema(r1).names == (:a, :c, :b)
45+
@test Tables.matrix(r2) Tables.matrix(t2)
46+
@test Tables.schema(r2).names == (:c, :a, :b)
47+
48+
T = ILR(:c)
49+
n1, c1 = apply(T, t1)
50+
r1 = revert(T, n1, c1)
3851
n2, c2 = apply(T, t2)
52+
r2 = revert(T, n2, c2)
3953
@test n1 == n2
40-
tₒ = revert(T, n1, c1)
41-
@test Tables.schema(tₒ).names == (:a, :c, :b)
42-
tₒ = revert(T, n2, c2)
43-
@test Tables.schema(tₒ).names == (:c, :a, :b)
54+
@test Tables.matrix(r1) Tables.matrix(t1)
55+
@test Tables.schema(r1).names == (:a, :c, :b)
56+
@test Tables.matrix(r2) Tables.matrix(t2)
57+
@test Tables.schema(r2).names == (:c, :a, :b)
4458
end

0 commit comments

Comments
 (0)