Skip to content

Commit 2905f55

Browse files
committed
Added support for Flux 0.14. Progress bars now display appropriately in both terminal and Pluto. Moved to Julia 1.9. Released version 0.3.0
1 parent 24cc037 commit 2905f55

File tree

12 files changed

+86
-81
lines changed

12 files changed

+86
-81
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.7'
21+
- '1.9'
2222
os:
2323
- ubuntu-latest
2424
arch:

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
name = "MixtureDensityNetworks"
22
uuid = "521d8788-cab4-41cb-a05a-da376f16ad79"
33
authors = ["Joshua Billson"]
4-
version = "0.2.2"
4+
version = "0.3.0"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1112
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1213
Pipe = "b98c9c47-44ae-5843-9183-064241ee97a0"
1314
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1415
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16+
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
1517

1618
[compat]
1719
Distributions = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
1820
DocStringExtensions = "0.8, 0.9"
19-
Flux = "0.13"
21+
Flux = "0.14"
2022
MLJModelInterface = "1"
2123
Pipe = "1.3"
2224
ProgressLogging = "0.1"
23-
julia = "1.6"
25+
TerminalLoggers = "0.1"
26+
julia = "1.9"

README.md

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ function main()
3333
model = MixtureDensityNetwork(1, 1, layers, mixtures)
3434

3535
# Fit Model
36-
model, report = with_logger(TerminalLogger()) do
37-
MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
38-
end
36+
model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
3937

4038
# Plot Learning Curve
4139
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
@@ -61,7 +59,7 @@ main()
6159
# Example (MLJ Interface)
6260

6361
```julia
64-
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ, Random
62+
using MixtureDensityNetworks, Distributions, CairoMakie, MLJ
6563

6664
const n_samples = 1000
6765
const epochs = 500
@@ -77,24 +75,21 @@ function main()
7775
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers, batchsize=batchsize), MLJ.table(X'), Y[1,:])
7876
7977
# Fit Model on Training Data, Then Evaluate on Test
80-
with_logger(TerminalLogger()) do
81-
@info "Evaluating..."
82-
evaluation = MLJ.evaluate!(
83-
mach,
84-
resampling=Holdout(shuffle=true),
85-
measure=[rsq, rmse, mae, mape],
86-
operation=MLJ.predict_mean
87-
)
88-
names = ["R²", "RMSE", "MAE", "MAPE"]
89-
metrics = round.(evaluation.measurement, digits=3)
90-
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
91-
end
78+
@info "Evaluating..."
79+
evaluation = MLJ.evaluate!(
80+
mach,
81+
resampling=Holdout(shuffle=true),
82+
measure=[rsq, rmse, mae, mape],
83+
operation=MLJ.predict_mean,
84+
verbosity=2 # Need to set verbosity=2 to show training progress during evaluation
85+
)
86+
names = ["R²", "RMSE", "MAE", "MAPE"]
87+
metrics = round.(evaluation.measurement, digits=3)
88+
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
9289
9390
# Fit Model on Entire Dataset
94-
with_logger(TerminalLogger()) do
95-
@info "Training..."
96-
MLJ.fit!(mach)
97-
end
91+
@info "Training..."
92+
MLJ.fit!(mach)
9893
9994
# Plot Learning Curve
10095
fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))

docs/src/index.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ First, let's create our dataset. To properly demonstrate the power of MDNs, we'l
1313
using Flux, Distributions, CairoMakie, MixtureDensityNetworks
1414

1515
const n_samples = 1000
16+
const epochs = 1000
17+
const batchsize = 128
18+
const mixtures = 8
19+
const layers = [128, 128]
1620

1721
X, Y = generate_data(n_samples)
1822

@@ -24,16 +28,16 @@ fig, ax, plt = scatter(X[1,:], Y[1,:], markersize=5)
2428
Now we'll define a standard univariate MDN. For this example, we construct a network with 2 hidden layers of size 128, which outputs a distribution
2529
with 5 Gaussian mixtures.
2630
```julia
27-
model = MixtureDensityNetwork(1, 1, [128, 128], 5)
31+
model = MixtureDensityNetwork(1, 1, layers, mixtures)
2832
```
2933

3034
We can fit our model to our data by calling `fit!(m, X, Y; opt=Flux.Adam(), batchsize=32, epochs=100)`. We specify that we want to train our model for
3135
500 epochs with the Adam optimiser and a batch size of 128. This method returns the model with the lowest loss as its first value and a named tuple
3236
containing the learning curve, best epoch, and lowest loss observed during training as its second value. We can use Makie's `lines` method to visualize
3337
the learning curve.
3438
```julia
35-
model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=500, opt=Flux.Adam(1e-3), batchsize=128)
36-
fig, _, _ = lines(1:500, lc, axis=(;xlabel="Epochs", ylabel="Loss"))
39+
model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
40+
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
3741
```
3842

3943
![](figures/LearningCurve.png)
@@ -60,7 +64,7 @@ density(fig[1,1], rand(cond, 10000), npoints=10000)
6064

6165
Below is a script for running the complete example.
6266
```julia
63-
using Flux, MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
67+
using Flux, MixtureDensityNetworks, Distributions, CairoMakie
6468

6569
const n_samples = 1000
6670
const epochs = 1000
@@ -76,9 +80,7 @@ function main()
7680
model = MixtureDensityNetwork(1, 1, layers, mixtures)
7781

7882
# Fit Model
79-
model, report = with_logger(TerminalLogger()) do
80-
MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
81-
end
83+
model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
8284

8385
# Plot Learning Curve
8486
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))

docs/src/mlj.md

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This package implements the interface specified by [MLJModelInterface](https://g
88
with the MLJ ecosystem. Below is an example demonstrating the use of this package in conjunction with MLJ.
99

1010
```julia
11-
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ, Random
11+
using MixtureDensityNetworks, Distributions, CairoMakie, MLJ
1212

1313
const n_samples = 1000
1414
const epochs = 500
@@ -24,24 +24,21 @@ function main()
2424
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers, batchsize=batchsize), MLJ.table(X'), Y[1,:])
2525
2626
# Fit Model on Training Data, Then Evaluate on Test
27-
with_logger(TerminalLogger()) do
28-
@info "Evaluating..."
29-
evaluation = MLJ.evaluate!(
30-
mach,
31-
resampling=Holdout(shuffle=true),
32-
measure=[rsq, rmse, mae, mape],
33-
operation=MLJ.predict_mean
34-
)
35-
names = ["R²", "RMSE", "MAE", "MAPE"]
36-
metrics = round.(evaluation.measurement, digits=3)
37-
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
38-
end
27+
@info "Evaluating..."
28+
evaluation = MLJ.evaluate!(
29+
mach,
30+
resampling=Holdout(shuffle=true),
31+
measure=[rsq, rmse, mae, mape],
32+
operation=MLJ.predict_mean,
33+
verbosity=2 # Need to set verbosity=2 to show training progress during evaluation
34+
)
35+
names = ["R²", "RMSE", "MAE", "MAPE"]
36+
metrics = round.(evaluation.measurement, digits=3)
37+
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
3938
4039
# Fit Model on Entire Dataset
41-
with_logger(TerminalLogger()) do
42-
@info "Training..."
43-
MLJ.fit!(mach)
44-
end
40+
@info "Training..."
41+
MLJ.fit!(mach)
4542
4643
# Plot Learning Curve
4744
fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))

examples/mlj_example.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ, Random
1+
using MixtureDensityNetworks, Distributions, CairoMakie, MLJ
22

33
const n_samples = 1000
44
const epochs = 500
@@ -14,24 +14,21 @@ function main()
1414
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers, batchsize=batchsize), MLJ.table(X'), Y[1,:])
1515
1616
# Fit Model on Training Data, Then Evaluate on Test
17-
with_logger(TerminalLogger()) do
18-
@info "Evaluating..."
19-
evaluation = MLJ.evaluate!(
20-
mach,
21-
resampling=Holdout(shuffle=true),
22-
measure=[rsq, rmse, mae, mape],
23-
operation=MLJ.predict_mean
24-
)
25-
names = ["R²", "RMSE", "MAE", "MAPE"]
26-
metrics = round.(evaluation.measurement, digits=3)
27-
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
28-
end
17+
@info "Evaluating..."
18+
evaluation = MLJ.evaluate!(
19+
mach,
20+
resampling=Holdout(shuffle=true),
21+
measure=[rsq, rmse, mae, mape],
22+
operation=MLJ.predict_mean,
23+
verbosity=2 # Need to set verbosity=2 to show training progress during evaluation
24+
)
25+
names = ["R²", "RMSE", "MAE", "MAPE"]
26+
metrics = round.(evaluation.measurement, digits=3)
27+
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
2928
3029
# Fit Model on Entire Dataset
31-
with_logger(TerminalLogger()) do
32-
@info "Training..."
33-
MLJ.fit!(mach)
34-
end
30+
@info "Training..."
31+
MLJ.fit!(mach)
3532
3633
# Plot Learning Curve
3734
fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))

examples/native_example.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ function main()
1414
model = MixtureDensityNetwork(1, 1, layers, mixtures)
1515

1616
# Fit Model
17-
model, report = with_logger(TerminalLogger()) do
18-
MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
19-
end
17+
model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
2018

2119
# Plot Learning Curve
2220
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))

src/MixtureDensityNetworks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using LinearAlgebra
77
using ProgressLogging
88
using MLJModelInterface
99
using DocStringExtensions
10+
using Logging
11+
using TerminalLoggers
1012
using Pipe: @pipe
1113

1214
const MMI = MLJModelInterface

src/mlj_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ end
6565

6666
function MLJModelInterface.fit(model::MDN, verbosity, X, y)
6767
m = MixtureDensityNetwork(size(X, 1), 1, model.layers, model.mixtures)
68-
fitresult, report = MixtureDensityNetworks.fit!(m, X, y, opt=Flux.Adam(model.η), batchsize=model.batchsize, epochs=model.epochs)
68+
fitresult, report = MixtureDensityNetworks.fit!(m, X, y, opt=Flux.Adam(model.η), batchsize=model.batchsize, epochs=model.epochs, verbosity=verbosity)
6969
cache = (;learning_curve=report.learning_curve[1:report.best_epoch])
7070
return fitresult, cache, report
7171
end
7272

7373
function MLJModelInterface.update(model::MDN, verbosity, old_fitresult, old_cache, X, y)
7474
# Update Fitresult
75-
fitresult, report = MixtureDensityNetworks.fit!(old_fitresult, X, y, opt=Flux.Adam(model.η), batchsize=model.batchsize, epochs=model.epochs)
75+
fitresult, report = MixtureDensityNetworks.fit!(old_fitresult, X, y, opt=Flux.Adam(model.η), batchsize=model.batchsize, epochs=model.epochs, verbosity=verbosity)
7676

7777
# Update Report
7878
learning_curve=vcat(old_cache.learning_curve, report.learning_curve)

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function MixtureDensityNetwork(input::Int, output::Int, layers::Vector{Int}, mix
3333
layers = vcat([input], layers)
3434
for (dim_in, dim_out) in zip(layers, layers[2:end])
3535
push!(hidden, Flux.Dense(dim_in=>dim_out, init=init))
36-
push!(hidden, Flux.BatchNorm(dim_out, Flux.relu, initβ=zeros, initγ=ones, ϵ=1e-5, momentum=0.1))
36+
push!(hidden, Flux.BatchNorm(dim_out, Flux.relu, initβ=zeros, initγ=ones, eps=1e-5, momentum=0.1))
3737
end
3838
hidden_layer = Flux.Chain(hidden...)
3939

0 commit comments

Comments
 (0)