-
Notifications
You must be signed in to change notification settings - Fork 163
Expand file tree
/
Copy pathnormal.jl
More file actions
114 lines (93 loc) · 3.79 KB
/
Copy pathnormal.jl
File metadata and controls
114 lines (93 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
struct Normal <: Distribution{Float64} end
struct BroadcastedNormal <: Distribution{Array{Float64}} end
function broadcast_shapes_or_crash(args...)
# This expression produces a DimensionMismatch exception if the shapes are
# incompatible. There seems to currently be no ready-made exception-free
# way to check this, and I think it makes more sense to wait until that
# capability exists in Julia proper than to re-implement it here.
s = Base.Broadcast.broadcast_shape(map(size, args)...)
end
function assert_has_shape(x, expected_shape; msg="Shape assertion failed")
if size(x) != expected_shape
throw(DimensionMismatch(string(msg,
" Expected shape: $expected_shape",
" Actual shape: $(size(x))")))
end
nothing
end
"""
normal(mu::Real, std::Real)
Samples a `Float64` value from a normal distribution.
"""
const normal = Normal()
"""
broadcasted_normal(mu::AbstractArray{<:Real, N1},
std::AbstractArray{<:Real, N2}) where {N1, N2}
Samples an `Array{Float64, max(N1, N2)}` of shape
`Broadcast.broadcast_shapes(size(mu), size(std))` where each element is
independently normally distributed. This is equivalent to (a reshape of) a
multivariate normal with diagonal covariance matrix, but its implementation is
more efficient than that of the more general `mvnormal` for this case.
The shapes of `mu` and `std` must be broadcast-compatible.
If all args are 0-dimensional arrays, then sampling via
`broadcasted_normal(...)` returns a `Float64` rather than properly returning an
`Array{Float64, 0}`. This is consistent with Julia's own inconsistency on the
matter:
```jldoctest
julia> typeof(ones())
Array{Float64,0}
julia> typeof(ones() .* ones())
Float64
```
"""
const broadcasted_normal = BroadcastedNormal()
function logpdf(::Normal, x::Real, mu::Real, std::Real)
z = (x - mu) / std
- (abs2(z) + log(2π))/2 - log(std)
end
function logpdf(::BroadcastedNormal,
x::Union{AbstractArray{<:Real}, Real},
mu::Union{AbstractArray{<:Real}, Real},
std::Union{AbstractArray{<:Real}, Real})
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
z = (x .- mu) ./ std
var = std .* std
diff = x .- mu
sum(- (abs2.(z) .+ log(2π)) / 2 .- log.(std))
end
function logpdf_grad(::Normal, x::Real, mu::Real, std::Real)
z = (x - mu) / std
deriv_x = - z / std
deriv_mu = -deriv_x
deriv_std = -1. / std + abs2(z) / std
(deriv_x, deriv_mu, deriv_std)
end
function logpdf_grad(::BroadcastedNormal,
x::Union{AbstractArray{<:Real}, Real},
mu::Union{AbstractArray{<:Real}, Real},
std::Union{AbstractArray{<:Real}, Real})
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
z = (x .- mu) ./ std
deriv_x = sum(- z ./ std)
deriv_mu = -deriv_x
deriv_std = sum(-1. ./ std .+ abs2.(z) ./ std)
(deriv_x, deriv_mu, deriv_std)
end
random(::Normal, mu::Real, std::Real) = mu + std * randn()
is_discrete(::Normal) = false
function random(::BroadcastedNormal,
mu::Union{AbstractArray{<:Real}, Real},
std::Union{AbstractArray{<:Real}, Real})
broadcast_shape = broadcast_shapes_or_crash(mu, std)
mu .+ std .* randn(broadcast_shape)
end
(::Normal)(mu, std) = random(Normal(), mu, std)
(::BroadcastedNormal)(mu, std) = random(BroadcastedNormal(), mu, std)
has_output_grad(::Normal) = true
has_argument_grads(::Normal) = (true, true)
has_output_grad(::BroadcastedNormal) = true
has_argument_grads(::BroadcastedNormal) = (true, true)
export normal
export broadcasted_normal