Skip to content

Commit 24cc037

Browse files
committed
Multivariate models now implement a diagonal covariance matrix.
1 parent 53c9dbc commit 24cc037

File tree

11 files changed

+75
-37
lines changed

11 files changed

+75
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MixtureDensityNetworks"
22
uuid = "521d8788-cab4-41cb-a05a-da376f16ad79"
33
authors = ["Joshua Billson"]
4-
version = "0.2.1"
4+
version = "0.2.2"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

README.md

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,39 @@ with the MLJ ecosystem. Below is an example demonstrating the use of this packag
1717
# Example (Native Interface)
1818

1919
```julia
20-
using MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
20+
using Flux, MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
2121

2222
const n_samples = 1000
2323
const epochs = 1000
24-
const mixtures = 6
24+
const batchsize = 128
25+
const mixtures = 8
2526
const layers = [128, 128]
2627

2728
function main()
2829
# Generate Data
2930
X, Y = generate_data(n_samples)
3031

3132
# Create Model
32-
machine = MixtureDensityNetworks.Machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers))
33+
model = MixtureDensityNetwork(1, 1, layers, mixtures)
3334

3435
# Fit Model
35-
report = with_logger(TerminalLogger()) do
36-
fit!(machine, X, Y)
36+
model, report = with_logger(TerminalLogger()) do
37+
MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
3738
end
3839

3940
# Plot Learning Curve
4041
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
4142
save("LearningCurve.png", fig)
4243

4344
# Plot Learned Distribution
44-
= predict(machine, X)
45+
= model(X)
4546
fig, ax, plt = scatter(X[1,:], rand.(Ŷ), markersize=4, label="Predicted Distribution")
4647
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
4748
axislegend(ax, position=:lt)
4849
save("PredictedDistribution.png", fig)
4950

5051
# Plot Conditional Distribution
51-
cond = predict(machine, reshape([-2.0], (1,1)))[1]
52+
cond = model(reshape([-2.1], (1,1)))[1]
5253
fig = Figure(resolution=(1000, 500))
5354
density(fig[1,1], rand(cond, 10000), npoints=10000)
5455
save("ConditionalDistribution.png", fig)
@@ -60,21 +61,22 @@ main()
6061
# Example (MLJ Interface)
6162

6263
```julia
63-
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ
64+
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ, Random
6465

6566
const n_samples = 1000
66-
const epochs = 1000
67-
const mixtures = 6
67+
const epochs = 500
68+
const batchsize = 128
69+
const mixtures = 8
6870
const layers = [128, 128]
6971

7072
function main()
7173
# Generate Data
7274
X, Y = generate_data(n_samples)
7375

7476
# Create Model
75-
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers), MLJ.table(X'), Y[1,:])
77+
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers, batchsize=batchsize), MLJ.table(X'), Y[1,:])
7678
77-
# Evaluate Model
79+
# Fit Model on Training Data, Then Evaluate on Test
7880
with_logger(TerminalLogger()) do
7981
@info "Evaluating..."
8082
evaluation = MLJ.evaluate!(
@@ -88,7 +90,7 @@ function main()
8890
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
8991
end
9092
91-
# Fit Model
93+
# Fit Model on Entire Dataset
9294
with_logger(TerminalLogger()) do
9395
@info "Training..."
9496
MLJ.fit!(mach)

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ density(fig[1,1], rand(cond, 10000), npoints=10000)
6060

6161
Below is a script for running the complete example.
6262
```julia
63-
using MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
63+
using Flux, MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
6464

6565
const n_samples = 1000
6666
const epochs = 1000

docs/src/mlj.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ with the MLJ ecosystem. Below is an example demonstrating the use of this packag
1010
```julia
1111
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ, Random
1212

13-
Random.seed!(123)
14-
1513
const n_samples = 1000
1614
const epochs = 500
1715
const batchsize = 128

examples/mlj_example.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ, Random
22

3-
Random.seed!(123)
4-
53
const n_samples = 1000
64
const epochs = 500
75
const batchsize = 128

examples/multivariate_example.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using Flux, MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
2+
3+
const n_samples = 1000
4+
const epochs = 250
5+
const batchsize = 256
6+
const mixtures = 12
7+
const layers = [256, 512, 1024]
8+
9+
function main()
10+
# Generate Data
11+
Y = rand(Uniform(-10.5, 10.5), 1, n_samples)
12+
μ_x = (7sin.(0.75 .* Y) + 0.5 .* Y)
13+
X = rand.(Normal.(μ_x, 0.5))
14+
μ_z = (-0.5 .* X) .+ 2.0
15+
Z = rand.(Normal.(μ_z, 0.6))
16+
Y = cat(Y, Z, dims=1)
17+
18+
# Normalize Features
19+
= (X .- mean(X, dims=2)) ./ std(X, dims=2)
20+
21+
# Create Model
22+
model = MixtureDensityNetwork(1, 2, layers, mixtures)
23+
24+
# Fit Model
25+
model, report = with_logger(TerminalLogger()) do
26+
MixtureDensityNetworks.fit!(model, X̄, Y; batchsize=batchsize, epochs=epochs)
27+
end
28+
29+
# Plot Learning Curve
30+
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
31+
save("MultivariateLearningCurve.png", fig)
32+
33+
# Plot Learned Distribution
34+
= model(X̄) .|> rand
35+
fig = Figure(resolution=(2000,1000), figure_padding=100)
36+
ax1 = Axis3(fig[1,1], title="True Distribution", elevation=0.2π, azimuth=0.25π, titlesize=48, titlegap=50)
37+
ax2 = Axis3(fig[1,2], title="Predicted Distribution", elevation=0.2π, azimuth=0.25π, titlesize=48, titlegap=50)
38+
scatter!(ax1, X[1,:], Y[1,:], Y[2,:], markersize=3.0)
39+
scatter!(ax2, X[1,:], [x[1] for x in Ŷ], [x[2] for x in Ŷ], markersize=3.0)
40+
xlims!(ax1, -15, 15); zlims!(ax1, -7, 10); ylims!(ax1, -13, 13)
41+
xlims!(ax2, -15, 15); zlims!(ax2, -7, 10); ylims!(ax2, -13, 13)
42+
save("MultivariateDistributions.png", fig)
43+
end
44+
45+
main()

src/layers.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function MultivariateGMM(input::Int, output::Int, mixtures::Int)
8484

8585
# Construct Output Layer
8686
μ = Flux.Dense(input=>(output * mixtures), init=init)
87-
Σ = Flux.Dense(input=>(output * output * mixtures), init=init)
87+
Σ = Flux.Dense(input=>(output * mixtures), exp, init=init)
8888
w = Flux.Chain(Flux.Dense(input=>mixtures, init=init), x -> Flux.softmax(x; dims=1))
8989

9090
# Return Layer
@@ -98,16 +98,11 @@ end
9898
function (m::MultivariateGMM)(X::AbstractMatrix{Float64})
9999
# Forward Pass
100100
μ = reshape(m.μ(X), (m.mixtures, m.outputs, :))
101-
Σ = reshape(m.Σ(X), (m.mixtures, m.outputs, m.outputs, :))
101+
D = reshape(m.Σ(X), (m.mixtures, m.outputs, :))
102102
w = reshape(m.w(X), (m.mixtures, :))
103103

104-
# Get Cholesky Decomposition Of Σ
105-
d_mask = [b == c ? 1.0 : 0.0 for a in 1:1, b in 1:m.outputs, c in 1:m.outputs, d in 1:1]
106-
u_mask = [b < c ? 1.0 : 0.0 for a in 1:1, b in 1:m.outputs, c in 1:m.outputs, d in 1:1]
107-
U = exp.(Σ .* d_mask) .+.* u_mask)
108-
109104
# Return Distributions
110105
return map(eachindex(w[1,:])) do obs
111-
MixtureModel([MultivariateNormal(μ[mixture,:,obs], U[mixture,:,:,obs]' * U[mixture,:,:,obs] + 1e-9I) for mixture in eachindex(μ[:,1,1])], w[:,obs])
106+
MixtureModel([MultivariateNormal(μ[mixture,:,obs], Diagonal(D[mixture,:,obs])) for mixture in eachindex(μ[:,1,1])], w[:,obs])
112107
end
113108
end

src/losses.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Conpute the negative log-likelihood loss for a set of labels `y` under a set of
2222
2323
# Parameters
2424
- `distributions`: A vector of multivariate Gaussian Mixture Model distributions.
25-
- `y`: A kxn matrix of labels where k is the dimension of each label and n is the number of samples.
25+
- `y`: A dxn matrix of labels where d is the dimension of each label and n is the number of samples.
2626
"""
2727
function likelihood_loss(distributions::Vector{<:MixtureModel{Multivariate}}, y::Matrix{<:Real})
2828
return likelihood_loss(distributions, Float64.(y))

src/mlj_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ MLJModelInterface.metadata_model(
129129
input_scitype=MMI.Table(MMI.Continuous),
130130
target_scitype=AbstractVector{<:MMI.Continuous},
131131
load_path="MixtureDensityNetworks.MDN",
132-
human_name="MDN",
132+
human_name="Mixture Density Network",
133133
)
134134

135135
"""

src/model.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
$(TYPEDEF)
33
4-
A custom Flux model whose predictions paramaterize a Gaussian Mixture Model.
4+
A Flux model for implementing a standard Mixture Density Network.
55
66
# Parameters
77
$(TYPEDFIELDS)
@@ -19,8 +19,8 @@ $(TYPEDSIGNATURES)
1919
Construct a standard Mixture Density Network.
2020
2121
# Parameters
22-
- `input`: The length of the input feature vectors.
23-
- `output`: The length of the output feature vectors.
22+
- `input`: The dimension of the input features.
23+
- `output`: The dimension of the output. Setting output = 1 indicates a univariate model, whereas output > 1 indicates a multivariate model.
2424
- `layers`: The topolgy of the hidden layers, starting from the first layer.
2525
- `mixtures`: The number of Gaussian mixtures to use in the predicted distribution.
2626
"""

0 commit comments

Comments
 (0)