Skip to content

Commit 1ba7a90

Browse files
DianaPatjuliohm
andauthored
Add ProjectionPursuit (new PR) (#141)
* projectionpursuit related files * Update src/transforms/projectionpursuit.jl Co-authored-by: Júlio Hoffimann <[email protected]> * Update src/transforms/projectionpursuit.jl Co-authored-by: Júlio Hoffimann <[email protected]> * Change the distributions of the visual test * Update src/transforms/projectionpursuit.jl * Apply suggestions from code review * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Apply suggestions from code review * Update src/transforms/projectionpursuit.jl * Apply suggestions from code review * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Update src/transforms/projectionpursuit.jl * Remove outdated notation, change in the visual tests. * Update src/transforms/projectionpursuit.jl Co-authored-by: Júlio Hoffimann <[email protected]> * Update src/transforms/projectionpursuit.jl Co-authored-by: Júlio Hoffimann <[email protected]> * Fix the plot error * Bugs corrected, tests modified. * Move float, change perc -> 1-perc. * Update tests * Change test image * Correction of Test Co-authored-by: Júlio Hoffimann <[email protected]>
1 parent 95d7c58 commit 1ba7a90

21 files changed

+304
-15
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1112
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
@@ -21,6 +22,7 @@ TransformsBase = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8"
2122
AbstractTrees = "0.4"
2223
CategoricalArrays = "0.10"
2324
Distributions = "0.25"
25+
Optim = "1.7"
2426
PrettyTables = "1.3, 2"
2527
ScientificTypes = "2.3, 3.0"
2628
StatsBase = "0.33"

docs/src/transforms/builtin.md

+6
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ DRS
152152
SDS
153153
```
154154

155+
## ProjectionPursuit
156+
157+
```@docs
158+
ProjectionPursuit
159+
```
160+
155161
## RowTable
156162

157163
```@docs

src/TableTransforms.jl

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using PrettyTables
1616
using AbstractTrees
1717
using CategoricalArrays
1818
using Random
19+
using Optim: optimize, minimizer
1920

2021
import Distributions: ContinuousUnivariateDistribution
2122
import Distributions: quantile, cdf
@@ -67,6 +68,7 @@ export
6768
Functional,
6869
EigenAnalysis,
6970
PCA, DRS, SDS,
71+
ProjectionPursuit,
7072
RowTable,
7173
ColTable,
7274
,

src/transforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ include("transforms/zscore.jl")
285285
include("transforms/quantile.jl")
286286
include("transforms/functional.jl")
287287
include("transforms/eigenanalysis.jl")
288+
include("transforms/projectionpursuit.jl")
288289
include("transforms/rowtable.jl")
289290
include("transforms/coltable.jl")
290291
include("transforms/parallel.jl")

src/transforms/projectionpursuit.jl

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
"""
6+
ProjectionPursuit(;tol=1e-6, maxiter=100, deg=5, perc=.9, n=100)
7+
8+
The projection pursuit multivariate transform converts any multivariate distribution into
9+
the standard multivariate Gaussian distribution.
10+
11+
This iterative algorithm repeatedly finds a direction of projection `α` that maximizes a score of
12+
non-Gaussianity known as the projection index `I(α)`. The samples projected along `α` are then
13+
transformed with the [`Quantile`](@ref) transform to remove the non-Gaussian structure. The
14+
other coordinates in the rotated orthonormal basis `Q = [α ...]` are left untouched.
15+
16+
The non-singularity of Q is controlled by assuring that norm(det(Q)) ≥ `tol`. The iterative
17+
process terminates whenever the transformed samples are "more Gaussian" than `perc`% of `n`
18+
randomly generated samples from the standard multivariate Gaussian distribution, or when the
19+
number of iterations reaches a maximum `maxiter`.
20+
21+
# Examples
22+
23+
```julia
24+
ProjectionPursuit()
25+
ProjectionPursuit(deg=10)
26+
ProjectionPursuit(perc=.85, n=50)
27+
ProjectionPursuit(tol=1e-4, maxiter=250, deg=5, perc=.95, n=100)
28+
```
29+
30+
See [https://doi.org/10.2307/2289161](https://doi.org/10.2307/2289161) for
31+
further details.
32+
"""
33+
34+
struct ProjectionPursuit{T} <: StatelessFeatureTransform
35+
tol::T
36+
maxiter::Int
37+
deg::Int
38+
perc::T
39+
n::Int
40+
end
41+
42+
ProjectionPursuit(;tol=1e-6, maxiter=100, deg=5, perc=.9, n=100) =
43+
ProjectionPursuit{typeof(tol)}(tol, maxiter, deg, perc, n)
44+
45+
isrevertible(::Type{<:ProjectionPursuit}) = true
46+
47+
# transforms a row of random variables into a convex combination
48+
# of random variables with values in [-1,1] and standard normal distribution
49+
rscore(Z, α) = 2 .* cdf.(Normal(), Z * α) .- 1
50+
51+
# projection index of sample along a given direction
52+
function pindex(transform, Z, α)
53+
d = transform.deg
54+
r = rscore(Z, α)
55+
I = (3/2) * mean(r)^2
56+
if d > 1
57+
Pⱼ₋₂, Pⱼ₋₁ = ones(length(r)), r
58+
for j = 2:d
59+
Pⱼ₋₂, Pⱼ₋₁ =
60+
Pⱼ₋₁, (1/j) * ((2j-1) * r .* Pⱼ₋₁ - (j-1) * Pⱼ₋₂)
61+
I += ((2j+1)/2) * (mean(Pⱼ₋₁))^2
62+
end
63+
end
64+
I
65+
end
66+
67+
# j-th element of the canonical basis in ℝᵈ
68+
basis(d, j) = float(1:d .== j)
69+
70+
# index for all vectors in the canonical basis
71+
function pbasis(transform, Z)
72+
q = size(Z, 2)
73+
[pindex(transform, Z, basis(q, j)) for j in 1:q]
74+
end
75+
76+
# projection index of the standard multivariate Gaussian
77+
function gaussquantiles(transform, N, q)
78+
n = transform.n
79+
p = 1.0 - transform.perc
80+
Is = [pbasis(transform, randn(N, q)) for i in 1:n]
81+
I = reduce(hcat, Is)
82+
quantile.(eachrow(I), p)
83+
end
84+
85+
function alphaguess(transform, Z)
86+
q = size(Z, 2)
87+
88+
# objective function
89+
func(α) = pindex(transform, Z, α)
90+
91+
# evaluate objective along axes
92+
j = argmax(j -> func(basis(q, j)), 1:q)
93+
α = basis(q, j)
94+
I = func(α)
95+
96+
# evaluate objective along diagonals
97+
diag(α, s, e) = (1/√(2+2s*αe)) *+ s * e)
98+
for eᵢ in basis.(q, 1:q)
99+
d₊ = diag(α, +1, eᵢ)
100+
d₋ = diag(α, -1, eᵢ)
101+
f₊ = func(d₊)
102+
f₋ = αeᵢ != 1.0 ? func(d₋) : 0.0
103+
f, d = f₊ > f₋ ? (f₊, d₊) : (f₋, d₋)
104+
if f > I
105+
α = d
106+
I = f
107+
end
108+
end
109+
110+
α
111+
end
112+
113+
function neldermead(transform, Z, α₀)
114+
f(α) = -pindex(transform, Z, α ./ norm(α))
115+
op = optimize(f, α₀)
116+
minimizer(op)
117+
end
118+
119+
function alphamax(transform, Z)
120+
α = alphaguess(transform, Z)
121+
neldermead(transform, Z, α)
122+
end
123+
124+
function orthobasis(α, tol)
125+
q = length(α)
126+
Q, R = qr([α rand(q,q-1)])
127+
while norm(diag(R)) < tol
128+
Q, R = qr([α rand(q,q-1)])
129+
end
130+
Q
131+
end
132+
133+
function rmstructure(transform, Z, α)
134+
# find orthonormal basis for rotation
135+
Q = orthobasis(α, transform.tol)
136+
137+
# remove structure of first rotated axis
138+
newtable, qcache = apply(Quantile(1), Tables.table(Z * Q))
139+
140+
# undo rotation, i.e recover original axis-aligned features
141+
Z₊ = Tables.matrix(newtable) * Q'
142+
143+
Z₊, (Q, qcache)
144+
end
145+
146+
sphering() = Quantile() EigenAnalysis(:VDV)
147+
148+
function applyfeat(transform::ProjectionPursuit, table, prep)
149+
# retrieve column names
150+
cols = Tables.columns(table)
151+
names = Tables.columnnames(cols)
152+
153+
# preprocess the data to approximately spherical shape
154+
ptable, pcache = apply(sphering(), table)
155+
156+
# initialize scores and Gaussian quantiles
157+
Z = Tables.matrix(ptable)
158+
I = pbasis(transform, Z)
159+
g = gaussquantiles(transform, size(Z)...)
160+
161+
iter = 0; caches = []
162+
while any(I .> g) && iter transform.maxiter
163+
# choose direction with maximum projection index
164+
α = alphamax(transform, Z)
165+
166+
# remove non-Gaussian structure
167+
Z, cache = rmstructure(transform, Z, α)
168+
169+
# update the scores along original axes
170+
I = pbasis(transform, Z)
171+
172+
# store cache and continue
173+
push!(caches, cache)
174+
iter += 1
175+
end
176+
177+
𝒯 = (; zip(names, eachcol(Z))...)
178+
newtable = 𝒯 |> Tables.materializer(table)
179+
newtable, (pcache, caches)
180+
end
181+
182+
function revertfeat(::ProjectionPursuit, newtable, fcache)
183+
# retrieve column names
184+
cols = Tables.columns(newtable)
185+
names = Tables.columnnames(cols)
186+
187+
# caches to retrieve transform steps
188+
pcache, caches = fcache
189+
190+
Z = Tables.matrix(newtable)
191+
for (Q, qcache) in reverse(caches)
192+
table = revert(Quantile(1), Tables.table(Z * Q), qcache)
193+
Z = Tables.matrix(table) * Q'
194+
end
195+
196+
table = revert(sphering(), Tables.table(Z), pcache)
197+
Z = Tables.matrix(table)
198+
199+
𝒯 = (; zip(names, eachcol(Z))...)
200+
newtable = 𝒯 |> Tables.materializer(newtable)
201+
end

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
44
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
55
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
78
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
89
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
@@ -16,4 +17,5 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
1617

1718
[compat]
1819
GR = "=0.59.0"
20+
PairPlots = "=0.6.0"
1921
Plots = "=1.22.4"

test/data/center.png

26.1 KB
Loading

test/data/eigenanalysis-1.png

17.9 KB
Loading

test/data/eigenanalysis-2.png

28 KB
Loading

test/data/projectionpursuit-1.png

258 KB
Loading

test/data/projectionpursuit-2.png

485 KB
Loading

test/data/projectionpursuit-3.png

96.1 KB
Loading

test/data/scale.png

26.9 KB
Loading

test/data/zscore.png

15.2 KB
Loading

test/runtests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ using Statistics
99
using Test, Random, Plots
1010
using ReferenceTests, ImageIO
1111
using StatsBase
12+
using PairPlots
1213

1314
const TT = TableTransforms
1415

1516
# set default configurations for plots
16-
gr(ms=1, mc=:black, aspectratio=:equal,
17-
label=false, size=(600,400))
17+
gr(ms=1, mc=:black, label=false, size=(600,400))
1818

1919
# workaround GR warnings
2020
ENV["GKSwstype"] = "100"

test/transforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ transformfiles = [
1717
"quantile.jl",
1818
"functional.jl",
1919
"eigenanalysis.jl",
20+
"projectionpursuit.jl",
2021
"rowtable.jl",
2122
"coltable.jl",
2223
"sequential.jl",

test/transforms/center.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
# visual tests
1414
if visualtests
15-
p₁ = scatter(t.x, t.y, label="Original")
16-
p₂ = scatter(n.x, n.y, label="Center")
15+
p₁ = scatter(t.x, t.y, label="Original", aspectratio=:equal)
16+
p₂ = scatter(n.x, n.y, label="Center", aspectratio=:equal)
1717
p = plot(p₁, p₂, layout=(1,2))
1818

1919
@test_reference joinpath(datadir, "center.png") p

test/transforms/eigenanalysis.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@
5353

5454
# visual tests
5555
if visualtests
56-
p₁ = scatter(t₁.x, t₁.y, label="Original")
57-
p₂ = scatter(t₂.PC1, t₂.PC2, label="V")
58-
p₃ = scatter(t₃.PC1, t₃.PC2, label="VD")
59-
p₄ = scatter(t₄.PC1, t₄.PC2, label="VDV")
60-
p₅ = scatter(t₅.PC1, t₅.PC2, label="PCA")
61-
p₆ = scatter(t₆.PC1, t₆.PC2, label="DRS")
62-
p₇ = scatter(t₇.PC1, t₇.PC2, label="SDS")
56+
p₁ = scatter(t₁.x, t₁.y, label="Original", aspectratio=:equal)
57+
p₂ = scatter(t₂.PC1, t₂.PC2, label="V", aspectratio=:equal)
58+
p₃ = scatter(t₃.PC1, t₃.PC2, label="VD", aspectratio=:equal)
59+
p₄ = scatter(t₄.PC1, t₄.PC2, label="VDV", aspectratio=:equal)
60+
p₅ = scatter(t₅.PC1, t₅.PC2, label="PCA", aspectratio=:equal)
61+
p₆ = scatter(t₆.PC1, t₆.PC2, label="DRS", aspectratio=:equal)
62+
p₇ = scatter(t₇.PC1, t₇.PC2, label="SDS", aspectratio=:equal)
6363
p = plot(p₁, p₂, p₃, p₄, layout=(2,2))
6464
q = plot(p₂, p₃, p₄, p₅, p₆, p₇, layout=(2,3))
6565

0 commit comments

Comments
 (0)