1
+ # ------------------------------------------------------------------
2
+ # Licensed under the MIT License. See LICENSE in the project root.
3
+ # ------------------------------------------------------------------
4
+
5
+ """
6
+ ProjectionPursuit(;tol=1e-6, maxiter=100, deg=5, perc=.9, n=100)
7
+
8
+ The projection pursuit multivariate transform converts any multivariate distribution into
9
+ the standard multivariate Gaussian distribution.
10
+
11
+ This iterative algorithm repeatedly finds a direction of projection `α` that maximizes a score of
12
+ non-Gaussianity known as the projection index `I(α)`. The samples projected along `α` are then
13
+ transformed with the [`Quantile`](@ref) transform to remove the non-Gaussian structure. The
14
+ other coordinates in the rotated orthonormal basis `Q = [α ...]` are left untouched.
15
+
16
+ The non-singularity of Q is controlled by assuring that norm(det(Q)) ≥ `tol`. The iterative
17
+ process terminates whenever the transformed samples are "more Gaussian" than `perc`% of `n`
18
+ randomly generated samples from the standard multivariate Gaussian distribution, or when the
19
+ number of iterations reaches a maximum `maxiter`.
20
+
21
+ # Examples
22
+
23
+ ```julia
24
+ ProjectionPursuit()
25
+ ProjectionPursuit(deg=10)
26
+ ProjectionPursuit(perc=.85, n=50)
27
+ ProjectionPursuit(tol=1e-4, maxiter=250, deg=5, perc=.95, n=100)
28
+ ```
29
+
30
+ See [https://doi.org/10.2307/2289161](https://doi.org/10.2307/2289161) for
31
+ further details.
32
+ """
33
+
34
+ struct ProjectionPursuit{T} <: StatelessFeatureTransform
35
+ tol:: T
36
+ maxiter:: Int
37
+ deg:: Int
38
+ perc:: T
39
+ n:: Int
40
+ end
41
+
42
+ ProjectionPursuit (;tol= 1e-6 , maxiter= 100 , deg= 5 , perc= .9 , n= 100 ) =
43
+ ProjectionPursuit {typeof(tol)} (tol, maxiter, deg, perc, n)
44
+
45
+ isrevertible (:: Type{<:ProjectionPursuit} ) = true
46
+
47
+ # transforms a row of random variables into a convex combination
48
+ # of random variables with values in [-1,1] and standard normal distribution
49
+ rscore (Z, α) = 2 .* cdf .(Normal (), Z * α) .- 1
50
+
51
+ # projection index of sample along a given direction
52
+ function pindex (transform, Z, α)
53
+ d = transform. deg
54
+ r = rscore (Z, α)
55
+ I = (3 / 2 ) * mean (r)^ 2
56
+ if d > 1
57
+ Pⱼ₋₂, Pⱼ₋₁ = ones (length (r)), r
58
+ for j = 2 : d
59
+ Pⱼ₋₂, Pⱼ₋₁ =
60
+ Pⱼ₋₁, (1 / j) * ((2 j- 1 ) * r .* Pⱼ₋₁ - (j- 1 ) * Pⱼ₋₂)
61
+ I += ((2 j+ 1 )/ 2 ) * (mean (Pⱼ₋₁))^ 2
62
+ end
63
+ end
64
+ I
65
+ end
66
+
67
+ # j-th element of the canonical basis in ℝᵈ
68
+ basis (d, j) = float (1 : d .== j)
69
+
70
+ # index for all vectors in the canonical basis
71
+ function pbasis (transform, Z)
72
+ q = size (Z, 2 )
73
+ [pindex (transform, Z, basis (q, j)) for j in 1 : q]
74
+ end
75
+
76
+ # projection index of the standard multivariate Gaussian
77
+ function gaussquantiles (transform, N, q)
78
+ n = transform. n
79
+ p = 1.0 - transform. perc
80
+ Is = [pbasis (transform, randn (N, q)) for i in 1 : n]
81
+ I = reduce (hcat, Is)
82
+ quantile .(eachrow (I), p)
83
+ end
84
+
85
+ function alphaguess (transform, Z)
86
+ q = size (Z, 2 )
87
+
88
+ # objective function
89
+ func (α) = pindex (transform, Z, α)
90
+
91
+ # evaluate objective along axes
92
+ j = argmax (j -> func (basis (q, j)), 1 : q)
93
+ α = basis (q, j)
94
+ I = func (α)
95
+
96
+ # evaluate objective along diagonals
97
+ diag (α, s, e) = (1 /√ (2 + 2 s* α⋅ e)) * (α + s * e)
98
+ for eᵢ in basis .(q, 1 : q)
99
+ d₊ = diag (α, + 1 , eᵢ)
100
+ d₋ = diag (α, - 1 , eᵢ)
101
+ f₊ = func (d₊)
102
+ f₋ = α⋅ eᵢ != 1.0 ? func (d₋) : 0.0
103
+ f, d = f₊ > f₋ ? (f₊, d₊) : (f₋, d₋)
104
+ if f > I
105
+ α = d
106
+ I = f
107
+ end
108
+ end
109
+
110
+ α
111
+ end
112
+
113
+ function neldermead (transform, Z, α₀)
114
+ f (α) = - pindex (transform, Z, α ./ norm (α))
115
+ op = optimize (f, α₀)
116
+ minimizer (op)
117
+ end
118
+
119
+ function alphamax (transform, Z)
120
+ α = alphaguess (transform, Z)
121
+ neldermead (transform, Z, α)
122
+ end
123
+
124
+ function orthobasis (α, tol)
125
+ q = length (α)
126
+ Q, R = qr ([α rand (q,q- 1 )])
127
+ while norm (diag (R)) < tol
128
+ Q, R = qr ([α rand (q,q- 1 )])
129
+ end
130
+ Q
131
+ end
132
+
133
+ function rmstructure (transform, Z, α)
134
+ # find orthonormal basis for rotation
135
+ Q = orthobasis (α, transform. tol)
136
+
137
+ # remove structure of first rotated axis
138
+ newtable, qcache = apply (Quantile (1 ), Tables. table (Z * Q))
139
+
140
+ # undo rotation, i.e recover original axis-aligned features
141
+ Z₊ = Tables. matrix (newtable) * Q'
142
+
143
+ Z₊, (Q, qcache)
144
+ end
145
+
146
+ sphering () = Quantile () → EigenAnalysis (:VDV )
147
+
148
+ function applyfeat (transform:: ProjectionPursuit , table, prep)
149
+ # retrieve column names
150
+ cols = Tables. columns (table)
151
+ names = Tables. columnnames (cols)
152
+
153
+ # preprocess the data to approximately spherical shape
154
+ ptable, pcache = apply (sphering (), table)
155
+
156
+ # initialize scores and Gaussian quantiles
157
+ Z = Tables. matrix (ptable)
158
+ I = pbasis (transform, Z)
159
+ g = gaussquantiles (transform, size (Z)... )
160
+
161
+ iter = 0 ; caches = []
162
+ while any (I .> g) && iter ≤ transform. maxiter
163
+ # choose direction with maximum projection index
164
+ α = alphamax (transform, Z)
165
+
166
+ # remove non-Gaussian structure
167
+ Z, cache = rmstructure (transform, Z, α)
168
+
169
+ # update the scores along original axes
170
+ I = pbasis (transform, Z)
171
+
172
+ # store cache and continue
173
+ push! (caches, cache)
174
+ iter += 1
175
+ end
176
+
177
+ 𝒯 = (; zip (names, eachcol (Z))... )
178
+ newtable = 𝒯 |> Tables. materializer (table)
179
+ newtable, (pcache, caches)
180
+ end
181
+
182
+ function revertfeat (:: ProjectionPursuit , newtable, fcache)
183
+ # retrieve column names
184
+ cols = Tables. columns (newtable)
185
+ names = Tables. columnnames (cols)
186
+
187
+ # caches to retrieve transform steps
188
+ pcache, caches = fcache
189
+
190
+ Z = Tables. matrix (newtable)
191
+ for (Q, qcache) in reverse (caches)
192
+ table = revert (Quantile (1 ), Tables. table (Z * Q), qcache)
193
+ Z = Tables. matrix (table) * Q'
194
+ end
195
+
196
+ table = revert (sphering (), Tables. table (Z), pcache)
197
+ Z = Tables. matrix (table)
198
+
199
+ 𝒯 = (; zip (names, eachcol (Z))... )
200
+ newtable = 𝒯 |> Tables. materializer (newtable)
201
+ end
0 commit comments