Skip to content

Commit 49816c2

Browse files
committed
Updated Documentation
1 parent cc8f4a6 commit 49816c2

File tree

4 files changed

+82
-27
lines changed

4 files changed

+82
-27
lines changed

README.md

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
[![Build Status](https://github.com/JoshuaBillson/MixtureDensityNetworks.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JoshuaBillson/MixtureDensityNetworks.jl/actions/workflows/CI.yml?query=branch%3Amain)
66
[![Coverage](https://codecov.io/gh/JoshuaBillson/MixtureDensityNetworks.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JoshuaBillson/MixtureDensityNetworks.jl)
77

8-
Mixture Density Networks (MDNs) were first proposed by [Bishop (1994)](https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf). We can think of them as a specialized type of neural network, which are typically employed when our data has a lot of uncertainty or when the relationship between features and labels is one-to-many. Unlike a traditional neural network, which predicts a point-estimate equal to the mode of the learned conditional distribution P(Y|X), an MDN maintains the full condtional distribution by predicting the parameters of a Gaussian Mixture Model (GMM). The multi-modal nature of GMMs are precisely what makes MDNs so well-suited to modeling one-to-many relationships. This package aims to provide a simple interface for defining, training, and deploying MDNs.
8+
This package provides a simple interface for defining, training, and deploying Mixture Density Networks (MDNs). MDNs were first proposed by [Bishop (1994)](https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf). We can think of an MDN as a specialized type of Artificial Neural Network (ANN), which takes some features `X` and returns a distribution over the labels `Y` under a Gaussian Mixture Model (GMM). Unlike an ANN, MDNs maintain the full conditional distribution P(Y|X). This makes them particularly well-suited for situations where we want to maintain some measure of the uncertainty in our predictions. Moreover, because GMMs can represent multimodal distributions, MDNs are capable of modelling one-to-many relationships, which occurs when each input `X` can be associated with more than one output `Y`.
99

10-
# Example
10+
![](https://github.com/JoshuaBillson/MixtureDensityNetworks.jl/blob/main/docs/src/figures/PredictedDistribution.png?raw=true)
1111

12-
Below is an example of fitting an MDN to the visualized one-to-many distribution.
12+
# MLJ Compatibility
1313

14-
![](https://github.com/JoshuaBillson/MixtureDensityNetworks.jl/blob/main/docs/src/figures/PredictedDistribution.png?raw=true)
14+
This package implements the interface specified by [MLJModelInterface](https://github.com/JuliaAI/MLJModelInterface.jl) and is thus fully compatible
15+
with the MLJ ecosystem. Below is an example demonstrating the use of this package in conjunction with MLJ.
16+
17+
# Example (Native Interface)
1518

1619
```julia
1720
using MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
@@ -21,16 +24,15 @@ const epochs = 1000
2124
const mixtures = 6
2225
const layers = [128, 128]
2326

24-
2527
function main()
2628
# Generate Data
2729
X, Y = generate_data(n_samples)
2830

2931
# Create Model
30-
machine = MDN(epochs=epochs, mixtures=mixtures, layers=layers) |> Machine
32+
machine = MixtureDensityNetworks.Machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers))
3133

3234
# Fit Model
33-
report = with_logger(ConsoleLogger()) do
35+
report = with_logger(TerminalLogger()) do
3436
fit!(machine, X, Y)
3537
end
3638

@@ -50,9 +52,65 @@ function main()
5052
fig = Figure(resolution=(1000, 500))
5153
density(fig[1,1], rand(cond, 10000), npoints=10000)
5254
save("ConditionalDistribution.png", fig)
53-
54-
return machine
5555
end
5656

5757
main()
5858
```
59+
60+
# Example (MLJ Interface)
61+
62+
```julia
63+
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ
64+
65+
const n_samples = 1000
66+
const epochs = 1000
67+
const mixtures = 6
68+
const layers = [128, 128]
69+
70+
function main()
71+
# Generate Data
72+
X, Y = generate_data(n_samples)
73+
74+
# Create Model
75+
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers), MLJ.table(X'), Y[1,:])
76+
77+
# Evaluate Model
78+
with_logger(TerminalLogger()) do
79+
@info "Evaluating..."
80+
evaluation = MLJ.evaluate!(
81+
mach,
82+
resampling=Holdout(shuffle=true),
83+
measure=[rsq, rmse, mae, mape],
84+
operation=MLJ.predict_mean
85+
)
86+
names = ["R²", "RMSE", "MAE", "MAPE"]
87+
metrics = round.(evaluation.measurement, digits=3)
88+
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
89+
end
90+
91+
# Fit Model
92+
with_logger(TerminalLogger()) do
93+
@info "Training..."
94+
MLJ.fit!(mach)
95+
end
96+
97+
# Plot Learning Curve
98+
fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))
99+
save("LearningCurve.png", fig)
100+
101+
# Plot Learned Distribution
102+
Ŷ = MLJ.predict(mach) .|> rand
103+
fig, ax, plt = scatter(X[1,:], Ŷ, markersize=4, label="Predicted Distribution")
104+
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
105+
axislegend(ax, position=:lt)
106+
save("PredictedDistribution.png", fig)
107+
108+
# Plot Conditional Distribution
109+
cond = MLJ.predict(mach, MLJ.table(reshape([-2.1], (1,1))))[1]
110+
fig = Figure(resolution=(1000, 500))
111+
density(fig[1,1], rand(cond, 10000), npoints=10000)
112+
save("ConditionalDistribution.png", fig)
113+
end
114+
115+
main()
116+
```

docs/src/index.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ CurrentModule = MixtureDensityNetworks
44

55
# MixtureDensityNetworks
66

7-
Mixture Density Networks (MDNs) were first proposed by [Bishop (1994)](https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf). We can think of them as a specialized type of neural network, which are typically employed when our data has a lot of uncertainty or when the relationship between features and labels is one-to-many. Unlike a traditional neural network, which predicts a point-estimate equal to the mode of the learned conditional distribution P(Y|X), an MDN maintains the full condtional distribution by predicting the parameters of a Gaussian Mixture Model (GMM). The multi-modal nature of GMMs are precisely what makes MDNs so well-suited to modeling one-to-many relationships. This package aims to provide a simple interface for defining, training, and deploying MDNs.
7+
This package provides a simple interface for defining, training, and deploying Mixture Density Networks (MDNs). MDNs were first proposed by [Bishop (1994)](https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf). We can think of an MDN as a specialized type of Artificial Neural Network (ANN), which takes some features `X` and returns a distribution over the labels `Y` under a Gaussian Mixture Model (GMM). Unlike an ANN, MDNs maintain the full conditional distribution P(Y|X). This makes them particularly well-suited for situations where we want to maintain some measure of the uncertainty in our predictions. Moreover, because GMMs can represent multimodal distributions, MDNs are capable of modelling one-to-many relationships, which occurs when each input `X` can be associated with more than one output `Y`.
88

99
# Example
1010

@@ -14,9 +14,7 @@ using Distributions, CairoMakie, MixtureDensityNetworks
1414

1515
const n_samples = 1000
1616

17-
Y = rand(Uniform(-10.5, 10.5), 1, n_samples)
18-
μ = 7sin.(0.75 .* Y) + 0.5 .* Y
19-
X = rand.(Normal.(μ, 1.0))
17+
X, Y = generate_data(n_samples)
2018

2119
fig, ax, plt = scatter(X[1,:], Y[1,:], markersize=5)
2220
```
@@ -59,36 +57,38 @@ density(fig[1,1], rand(cond, 10000), npoints=10000)
5957

6058
Below is a script for running the complete example.
6159
```julia
62-
using MixtureDensityNetworks, Distributions, CairoMakie
60+
using MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
6361

6462
const n_samples = 1000
6563
const epochs = 1000
66-
const mixtures = 5
64+
const mixtures = 6
6765
const layers = [128, 128]
6866

6967
function main()
7068
# Generate Data
7169
X, Y = generate_data(n_samples)
7270

7371
# Create Model
74-
model = MDN(epochs=epochs, mixtures=mixtures, layers=layers)
72+
machine = MixtureDensityNetworks.Machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers))
7573

7674
# Fit Model
77-
lc = fit!(model, X, Y)
75+
report = with_logger(TerminalLogger()) do
76+
fit!(machine, X, Y)
77+
end
7878

7979
# Plot Learning Curve
80-
fig, _, _ = lines(1:epochs, lc, axis=(;xlabel="Epochs", ylabel="Loss"))
80+
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
8181
save("LearningCurve.png", fig)
8282

8383
# Plot Learned Distribution
84-
= predict(model, X)
84+
= predict(machine, X)
8585
fig, ax, plt = scatter(X[1,:], rand.(Ŷ), markersize=4, label="Predicted Distribution")
8686
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
8787
axislegend(ax, position=:lt)
8888
save("PredictedDistribution.png", fig)
8989

9090
# Plot Conditional Distribution
91-
cond = predict(model, reshape([-2.0], (1,1)))[1]
91+
cond = predict(machine, reshape([-2.0], (1,1)))[1]
9292
fig = Figure(resolution=(1000, 500))
9393
density(fig[1,1], rand(cond, 10000), npoints=10000)
9494
save("ConditionalDistribution.png", fig)

examples/mlj_example.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ function main()
4848
fig = Figure(resolution=(1000, 500))
4949
density(fig[1,1], rand(cond, 10000), npoints=10000)
5050
save("ConditionalDistribution.png", fig)
51+
end
5152
52-
return mach
53-
end
53+
main()

examples/native_example.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@ const epochs = 1000
55
const mixtures = 6
66
const layers = [128, 128]
77

8-
98
function main()
109
# Generate Data
1110
X, Y = generate_data(n_samples)
1211

1312
# Create Model
14-
machine = MDN(epochs=epochs, mixtures=mixtures, layers=layers) |> Machine
13+
machine = MixtureDensityNetworks.Machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers))
1514

1615
# Fit Model
1716
report = with_logger(TerminalLogger()) do
@@ -34,8 +33,6 @@ function main()
3433
fig = Figure(resolution=(1000, 500))
3534
density(fig[1,1], rand(cond, 10000), npoints=10000)
3635
save("ConditionalDistribution.png", fig)
37-
38-
return machine
3936
end
4037

41-
#main()
38+
main()

0 commit comments

Comments
 (0)