|
| 1 | +```@meta |
| 2 | +CurrentModule = MixtureDensityNetworks |
| 3 | +``` |
| 4 | + |
| 5 | +# MLJ Compatibility |
| 6 | + |
| 7 | +This package implements the interface specified by [MLJModelInterface](https://github.com/JuliaAI/MLJModelInterface.jl) and is thus fully compatible |
| 8 | +with the MLJ ecosystem. Below is an example demonstrating the use of this package in conjunction with MLJ. |
| 9 | + |
| 10 | +```julia |
| 11 | +using MixtureDensityNetworks, Distributions, Logging, TerminalLoggers, CairoMakie, MLJ |
| 12 | + |
| 13 | +const n_samples = 1000 |
| 14 | +const epochs = 1000 |
| 15 | +const mixtures = 6 |
| 16 | +const layers = [128, 128] |
| 17 | + |
| 18 | +function main() |
| 19 | + # Generate Data |
| 20 | + X, Y = generate_data(n_samples) |
| 21 | + |
| 22 | + # Create Model |
| 23 | + mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers), MLJ.table(X'), Y[1,:]) |
| 24 | +
|
| 25 | + # Evaluate Model |
| 26 | + with_logger(TerminalLogger()) do |
| 27 | + @info "Evaluating..." |
| 28 | + evaluation = MLJ.evaluate!( |
| 29 | + mach, |
| 30 | + resampling=Holdout(shuffle=true), |
| 31 | + measure=[rsq, rmse, mae, mape], |
| 32 | + operation=MLJ.predict_mean |
| 33 | + ) |
| 34 | + names = ["R²", "RMSE", "MAE", "MAPE"] |
| 35 | + metrics = round.(evaluation.measurement, digits=3) |
| 36 | + @info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ") |
| 37 | + end |
| 38 | +
|
| 39 | + # Fit Model |
| 40 | + with_logger(TerminalLogger()) do |
| 41 | + @info "Training..." |
| 42 | + MLJ.fit!(mach) |
| 43 | + end |
| 44 | +
|
| 45 | + # Plot Learning Curve |
| 46 | + fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss")) |
| 47 | + save("LearningCurve.png", fig) |
| 48 | +
|
| 49 | + # Plot Learned Distribution |
| 50 | + Ŷ = MLJ.predict(mach) .|> rand |
| 51 | + fig, ax, plt = scatter(X[1,:], Ŷ, markersize=4, label="Predicted Distribution") |
| 52 | + scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution") |
| 53 | + axislegend(ax, position=:lt) |
| 54 | + save("PredictedDistribution.png", fig) |
| 55 | +
|
| 56 | + # Plot Conditional Distribution |
| 57 | + cond = MLJ.predict(mach, MLJ.table(reshape([-2.1], (1,1))))[1] |
| 58 | + fig = Figure(resolution=(1000, 500)) |
| 59 | + density(fig[1,1], rand(cond, 10000), npoints=10000) |
| 60 | + save("ConditionalDistribution.png", fig) |
| 61 | +end |
| 62 | +
|
| 63 | +main() |
| 64 | +``` |
0 commit comments