Skip to content

Commit cfb7f46

Browse files
committed
Fix Quantile in the presence of repeated values
1 parent 886de7c commit cfb7f46

File tree

2 files changed

+44
-37
lines changed

2 files changed

+44
-37
lines changed

src/distributions.jl

+1-31
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct EmpiricalDistribution{T} <: ContinuousUnivariateDistribution
1212

1313
function EmpiricalDistribution{T}(values) where {T}
1414
_assert(!isempty(values), "values must be provided")
15-
new(_smooth(values))
15+
new(sort(values))
1616
end
1717
end
1818

@@ -49,33 +49,3 @@ function cdf(d::EmpiricalDistribution{T}, x::T) where {T}
4949
end
5050
end
5151
end
52-
53-
# helper function that replaces repated values
54-
# by an increasing sequence of values between
55-
# the previous and the next non-repated value
56-
function _smooth(values)
57-
sorted = float.(sort(values))
58-
bounds = findall(>(0), diff(sorted))
59-
if !isempty(bounds)
60-
i = 1
61-
j = first(bounds)
62-
_linear!(sorted, i, j, sorted[j], sorted[j + 1])
63-
for k in 1:length(bounds)-1
64-
i = bounds[k] + 1
65-
j = bounds[k + 1]
66-
_linear!(sorted, i, j, sorted[i - 1], sorted[j])
67-
end
68-
i = last(bounds) + 1
69-
j = length(sorted)
70-
_linear!(sorted, i, j, sorted[i - 1], sorted[j])
71-
end
72-
sorted
73-
end
74-
75-
function _linear!(x, i, j, l, u)
76-
if i < j
77-
for k in i:j
78-
x[k] = (u - l) * (k - i) / (j - i) + l
79-
end
80-
end
81-
end

src/transforms/quantile.jl

+43-6
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,26 @@ parameters(transform::Quantile) = (; dist=transform.dist)
4545

4646
isrevertible(::Type{<:Quantile}) = true
4747

48-
colcache(::Quantile, x) = EmpiricalDistribution(x)
48+
function colcache(::Quantile, x)
49+
s = qsmooth(x)
50+
d = EmpiricalDistribution(s)
51+
d, s
52+
end
4953

5054
function colapply(transform::Quantile, x, c)
51-
origin, target = c, transform.dist
52-
qqtransform(x, origin, target)
55+
d, s = c
56+
origin, target = d, transform.dist
57+
qtransform(s, origin, target)
5358
end
5459

5560
function colrevert(transform::Quantile, y, c)
56-
origin, target = transform.dist, c
57-
qqtransform(y, origin, target)
61+
d, _ = c
62+
origin, target = transform.dist, d
63+
qtransform(y, origin, target)
5864
end
5965

6066
# transform samples from original to target distribution
61-
function qqtransform(samples, origin, target)
67+
function qtransform(samples, origin, target)
6268
# avoid evaluating the quantile at 0 or 1
6369
T = eltype(samples)
6470
pmin = T(0) + T(1e-3)
@@ -68,3 +74,34 @@ function qqtransform(samples, origin, target)
6874
quantile(target, clamp(prob, pmin, pmax))
6975
end
7076
end
77+
78+
# helper function that replaces repated values
79+
# by an increasing sequence of values between
80+
# the previous and the next non-repated value
81+
function qsmooth(values)
82+
permut = sortperm(values)
83+
sorted = float.(values[permut])
84+
bounds = findall(>(0), diff(sorted))
85+
if !isempty(bounds)
86+
i = 1
87+
j = first(bounds)
88+
qlinear!(sorted, i, j, sorted[j], sorted[j + 1])
89+
for k in 1:length(bounds)-1
90+
i = bounds[k] + 1
91+
j = bounds[k + 1]
92+
qlinear!(sorted, i, j, sorted[i - 1], sorted[j])
93+
end
94+
i = last(bounds) + 1
95+
j = length(sorted)
96+
qlinear!(sorted, i, j, sorted[i - 1], sorted[j])
97+
end
98+
sorted[sortperm(permut)]
99+
end
100+
101+
function qlinear!(x, i, j, l, u)
102+
if i < j
103+
for k in i:j
104+
x[k] = (u - l) * (k - i) / (j - i) + l
105+
end
106+
end
107+
end

0 commit comments

Comments
 (0)