Skip to content

Commit 654cc98

Browse files
Merge pull request #6 from JoshuaBillson/4-add-support-for-mlj
Add Support for MLJ
2 parents 9b2d7e9 + 9969013 commit 654cc98

19 files changed

+649
-231
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
/Manifest.toml
55
/docs/build/
66
examples/*.png
7+
test/Manifest.toml

Project.toml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.0.2"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
10+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1011
Pipe = "b98c9c47-44ae-5843-9183-064241ee97a0"
1112
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -18,9 +19,3 @@ Flux = "0.13"
1819
Pipe = "1.3"
1920
ProgressLogging = "0.1"
2021
julia = "1.6"
21-
22-
[extras]
23-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
24-
25-
[targets]
26-
test = ["Test"]

README.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,44 @@ Below is an example of fitting an MDN to the visualized one-to-many distribution
1414
![](https://github.com/JoshuaBillson/MixtureDensityNetworks.jl/blob/main/docs/src/figures/PredictedDistribution.png?raw=true)
1515

1616
```julia
17-
using MixtureDensityNetworks, Distributions, CairoMakie
17+
using MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
1818

1919
const n_samples = 1000
2020
const epochs = 1000
21-
const mixtures = 5
21+
const mixtures = 6
2222
const layers = [128, 128]
2323

24+
2425
function main()
2526
# Generate Data
26-
Y = rand(Uniform(-10.5, 10.5), 1, n_samples)
27-
μ = 7sin.(0.75 .* Y) + 0.5 .* Y
28-
X = rand.(Normal.(μ, 1.0))
27+
X, Y = generate_data(n_samples)
2928

3029
# Create Model
31-
model = MDN(epochs=epochs, mixtures=mixtures, layers=layers)
30+
machine = MDN(epochs=epochs, mixtures=mixtures, layers=layers) |> Machine
3231

3332
# Fit Model
34-
lc = fit!(model, X, Y)
33+
report = with_logger(ConsoleLogger()) do
34+
fit!(machine, X, Y)
35+
end
3536

3637
# Plot Learning Curve
37-
fig, _, _ = lines(1:epochs, lc, axis=(;xlabel="Epochs", ylabel="Loss"))
38+
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
3839
save("LearningCurve.png", fig)
3940

4041
# Plot Learned Distribution
41-
= predict(model, X)
42+
= predict(machine, X)
4243
fig, ax, plt = scatter(X[1,:], rand.(Ŷ), markersize=4, label="Predicted Distribution")
4344
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
4445
axislegend(ax, position=:lt)
4546
save("PredictedDistribution.png", fig)
4647

4748
# Plot Conditional Distribution
48-
cond = predict(model, reshape([-2.0], (1,1)))[1]
49+
cond = predict(machine, reshape([-2.0], (1,1)))[1]
4950
fig = Figure(resolution=(1000, 500))
5051
density(fig[1,1], rand(cond, 10000), npoints=10000)
5152
save("ConditionalDistribution.png", fig)
53+
54+
return machine
5255
end
5356

5457
main()

docs/Manifest.toml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,12 @@ git-tree-sha1 = "2ce8695e1e699b68702c03402672a69f54b8aca9"
816816
uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
817817
version = "2022.2.0+0"
818818

819+
[[deps.MLJModelInterface]]
820+
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
821+
git-tree-sha1 = "c8b7e632d6754a5e36c0d94a4b466a5ba3a30128"
822+
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
823+
version = "1.8.0"
824+
819825
[[deps.MLStyle]]
820826
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
821827
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
@@ -888,10 +894,10 @@ uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
888894
version = "1.1.0"
889895

890896
[[deps.MixtureDensityNetworks]]
891-
deps = ["Distributions", "Flux", "Pipe", "ProgressLogging"]
897+
deps = ["Distributions", "DocStringExtensions", "Flux", "MLJModelInterface", "Pipe", "ProgressLogging", "Statistics"]
892898
path = ".."
893899
uuid = "521d8788-cab4-41cb-a05a-da376f16ad79"
894-
version = "1.0.0-DEV"
900+
version = "0.0.2"
895901

896902
[[deps.Mmap]]
897903
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@@ -1214,6 +1220,11 @@ git-tree-sha1 = "2436b15f376005e8790e318329560dcc67188e84"
12141220
uuid = "7b38b023-a4d7-4c5e-8d43-3f3097f304eb"
12151221
version = "0.3.3"
12161222

1223+
[[deps.ScientificTypesBase]]
1224+
git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b"
1225+
uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
1226+
version = "3.0.0"
1227+
12171228
[[deps.Scratch]]
12181229
deps = ["Dates"]
12191230
git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a"
@@ -1316,6 +1327,12 @@ git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
13161327
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
13171328
version = "1.4.0"
13181329

1330+
[[deps.StatisticalTraits]]
1331+
deps = ["ScientificTypesBase"]
1332+
git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782"
1333+
uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9"
1334+
version = "3.2.0"
1335+
13191336
[[deps.Statistics]]
13201337
deps = ["LinearAlgebra", "SparseArrays"]
13211338
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ makedocs(;
1717
assets=String[],
1818
),
1919
pages=[
20-
"Home" => "index.md",
20+
"Introduction" => "index.md",
21+
"MLJ Compatibility" => "mlj.md",
22+
"API (Reference Manual)" => "reference.md",
2123
],
2224
)
2325

-2.48 KB
Loading

docs/src/index.md

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ CurrentModule = MixtureDensityNetworks
66

77
Mixture Density Networks (MDNs) were first proposed by [Bishop (1994)](https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf). We can think of them as a specialized type of neural network, which are typically employed when our data has a lot of uncertainty or when the relationship between features and labels is one-to-many. Unlike a traditional neural network, which predicts a point-estimate equal to the mode of the learned conditional distribution P(Y|X), an MDN maintains the full condtional distribution by predicting the parameters of a Gaussian Mixture Model (GMM). The multi-modal nature of GMMs are precisely what makes MDNs so well-suited to modeling one-to-many relationships. This package aims to provide a simple interface for defining, training, and deploying MDNs.
88

9-
## Example
9+
# Example
1010

1111
First, let's create our dataset. To properly demonstrate the power of MDNs, we'll generate a many-to-one dataset where each x-value can map to more than one y-value.
1212
```julia
@@ -68,9 +68,7 @@ const layers = [128, 128]
6868

6969
function main()
7070
# Generate Data
71-
Y = rand(Uniform(-10.5, 10.5), 1, n_samples)
72-
μ = 7sin.(0.75 .* Y) + 0.5 .* Y
73-
X = rand.(Normal.(μ, 1.0))
71+
X, Y = generate_data(n_samples)
7472

7573
# Create Model
7674
model = MDN(epochs=epochs, mixtures=mixtures, layers=layers)
@@ -97,16 +95,4 @@ function main()
9795
end
9896

9997
main()
100-
```
101-
102-
## Index
103-
104-
```@index
105-
```
106-
107-
## API
108-
109-
```@autodocs
110-
Modules = [MixtureDensityNetworks]
111-
Private = false
11298
```

docs/src/mlj.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
```@meta
2+
CurrentModule = MixtureDensityNetworks
3+
```
4+
5+
# MLJ Compatibility
6+
7+
This package implements the interface specified by [MLJModelInterface](https://github.com/JuliaAI/MLJModelInterface.jl) and is thus fully compatible
8+
with the MLJ ecosystem. Below is an example demonstrating the use of this package in conjunction with MLJ.
9+
10+
```julia
11+
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ
12+
13+
const n_samples = 1000
14+
const epochs = 1000
15+
const mixtures = 6
16+
const layers = [128, 128]
17+
18+
function main()
19+
# Generate Data
20+
X, Y = generate_data(n_samples)
21+
22+
# Create Model
23+
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers), MLJ.table(X'), Y[1,:])
24+
25+
# Evaluate Model
26+
with_logger(TerminalLogger()) do
27+
@info "Evaluating..."
28+
evaluation = MLJ.evaluate!(
29+
mach,
30+
resampling=Holdout(shuffle=true),
31+
measure=[rsq, rmse, mae, mape],
32+
operation=MLJ.predict_mean
33+
)
34+
names = ["R²", "RMSE", "MAE", "MAPE"]
35+
metrics = round.(evaluation.measurement, digits=3)
36+
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
37+
end
38+
39+
# Fit Model
40+
with_logger(TerminalLogger()) do
41+
@info "Training..."
42+
MLJ.fit!(mach)
43+
end
44+
45+
# Plot Learning Curve
46+
fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))
47+
save("LearningCurve.png", fig)
48+
49+
# Plot Learned Distribution
50+
Ŷ = MLJ.predict(mach) .|> rand
51+
fig, ax, plt = scatter(X[1,:], Ŷ, markersize=4, label="Predicted Distribution")
52+
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
53+
axislegend(ax, position=:lt)
54+
save("PredictedDistribution.png", fig)
55+
56+
# Plot Conditional Distribution
57+
cond = MLJ.predict(mach, MLJ.table(reshape([-2.1], (1,1))))[1]
58+
fig = Figure(resolution=(1000, 500))
59+
density(fig[1,1], rand(cond, 10000), npoints=10000)
60+
save("ConditionalDistribution.png", fig)
61+
end
62+
63+
main()
64+
```

docs/src/reference.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
```@meta
2+
CurrentModule = MixtureDensityNetworks
3+
```
4+
# Index
5+
6+
```@index
7+
```
8+
9+
# API
10+
11+
```@autodocs
12+
Modules = [MixtureDensityNetworks]
13+
Private = false
14+
```

examples/mlj_example.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ
2+
3+
const n_samples = 1000
4+
const epochs = 1000
5+
const mixtures = 6
6+
const layers = [128, 128]
7+
8+
function main()
9+
# Generate Data
10+
X, Y = generate_data(n_samples)
11+
12+
# Create Model
13+
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers), MLJ.table(X'), Y[1,:])
14+
15+
# Evaluate Model
16+
with_logger(TerminalLogger()) do
17+
@info "Evaluating..."
18+
evaluation = MLJ.evaluate!(
19+
mach,
20+
resampling=Holdout(shuffle=true),
21+
measure=[rsq, rmse, mae, mape],
22+
operation=MLJ.predict_mean
23+
)
24+
names = ["R²", "RMSE", "MAE", "MAPE"]
25+
metrics = round.(evaluation.measurement, digits=3)
26+
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
27+
end
28+
29+
# Fit Model
30+
with_logger(TerminalLogger()) do
31+
@info "Training..."
32+
MLJ.fit!(mach)
33+
end
34+
35+
# Plot Learning Curve
36+
fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))
37+
save("LearningCurve.png", fig)
38+
39+
# Plot Learned Distribution
40+
Ŷ = MLJ.predict(mach) .|> rand
41+
fig, ax, plt = scatter(X[1,:], Ŷ, markersize=4, label="Predicted Distribution")
42+
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
43+
axislegend(ax, position=:lt)
44+
save("PredictedDistribution.png", fig)
45+
46+
# Plot Conditional Distribution
47+
cond = MLJ.predict(mach, MLJ.table(reshape([-2.1], (1,1))))[1]
48+
fig = Figure(resolution=(1000, 500))
49+
density(fig[1,1], rand(cond, 10000), npoints=10000)
50+
save("ConditionalDistribution.png", fig)
51+
52+
return mach
53+
end

0 commit comments

Comments
 (0)