@@ -17,38 +17,39 @@ with the MLJ ecosystem. Below is an example demonstrating the use of this packag
1717# Example (Native Interface)
1818
1919``` julia
20- using MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
20+ using Flux, MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
2121
2222const n_samples = 1000
2323const epochs = 1000
24- const mixtures = 6
24+ const batchsize = 128
25+ const mixtures = 8
2526const layers = [128 , 128 ]
2627
2728function main()
2829 # Generate Data
2930 X, Y = generate_data(n_samples)
3031
3132 # Create Model
32- machine = MixtureDensityNetworks . Machine(MDN(epochs = epochs, mixtures = mixtures , layers= layers) )
33+ model = MixtureDensityNetwork( 1 , 1 , layers, mixtures )
3334
3435 # Fit Model
35- report = with_logger(TerminalLogger()) do
36- fit!(machine , X, Y)
36+ model, report = with_logger(TerminalLogger()) do
37+ MixtureDensityNetworks . fit!(model , X, Y; epochs = epochs, opt = Flux . Adam( 1e-3 ), batchsize = batchsize )
3738 end
3839
3940 # Plot Learning Curve
4041 fig, _, _ = lines(1 : epochs, report. learning_curve, axis= (;xlabel= " Epochs" , ylabel= " Loss" ))
4142 save(" LearningCurve.png" , fig)
4243
4344 # Plot Learned Distribution
44- Ŷ = predict(machine, X)
45+ Ŷ = model( X)
4546 fig, ax, plt = scatter(X[1 ,:], rand.(Ŷ), markersize= 4 , label= " Predicted Distribution" )
4647 scatter!(ax, X[1 ,:], Y[1 ,:], markersize= 3 , label= " True Distribution" )
4748 axislegend(ax, position= :lt)
4849 save(" PredictedDistribution.png" , fig)
4950
5051 # Plot Conditional Distribution
51- cond = predict(machine, reshape([- 2.0 ], (1 ,1 )))[1 ]
52+ cond = model( reshape([- 2.1 ], (1 ,1 )))[1 ]
5253 fig = Figure(resolution= (1000 , 500 ))
5354 density(fig[1 ,1 ], rand(cond, 10000 ), npoints= 10000 )
5455 save(" ConditionalDistribution.png" , fig)
@@ -60,21 +61,22 @@ main()
6061# Example (MLJ Interface)
6162
6263``` julia
63- using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ
64+ using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ, Random
6465
6566const n_samples = 1000
66- const epochs = 1000
67- const mixtures = 6
67+ const epochs = 500
68+ const batchsize = 128
69+ const mixtures = 8
6870const layers = [128 , 128 ]
6971
7072function main()
7173 # Generate Data
7274 X, Y = generate_data(n_samples)
7375
7476 # Create Model
75- mach = MLJ. machine(MDN(epochs= epochs, mixtures= mixtures, layers= layers), MLJ. table(X' ), Y[1,:])
77+ mach = MLJ. machine(MDN(epochs= epochs, mixtures= mixtures, layers= layers, batchsize = batchsize ), MLJ. table(X' ), Y[1,:])
7678
77- # Evaluate Model
79+ # Fit Model on Training Data, Then Evaluate on Test
7880 with_logger(TerminalLogger()) do
7981 @info "Evaluating..."
8082 evaluation = MLJ.evaluate!(
@@ -88,7 +90,7 @@ function main()
8890 @info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
8991 end
9092
91- # Fit Model
93+ # Fit Model on Entire Dataset
9294 with_logger(TerminalLogger()) do
9395 @info "Training..."
9496 MLJ.fit!(mach)
0 commit comments