Skip to content

Commit d274469

Browse files
committed
Updated Documentation
1 parent b307ca5 commit d274469

File tree

4 files changed

+17
-19
lines changed

4 files changed

+17
-19
lines changed

examples/example.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
push!(LOAD_PATH, "../src")
22

3-
using MixtureDensityNetworks, Distributions, CairoMakie
3+
using MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
44

55
const n_samples = 1000
66
const epochs = 1000
@@ -18,7 +18,9 @@ function main()
1818
model = MDN(epochs=epochs, mixtures=mixtures, layers=layers)
1919

2020
# Fit Model
21-
lc = fit!(model, X, Y)
21+
lc = with_logger(TerminalLogger()) do
22+
fit!(model, X, Y)
23+
end
2224

2325
# Plot Learning Curve
2426
fig, _, _ = lines(1:epochs, lc, axis=(;xlabel="Epochs", ylabel="Loss"))

src/MixtureDensityNetworks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,4 @@ include("interface.jl")
1313

1414
export likelihood_loss, MDN, fit!, predict, predict_mean, predict_mode
1515

16-
17-
end
16+
end

src/interface.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ Fit the model to the data given by X and Y.
4141
- `X`: A dxn matrix where d is the number of features and n is the number of samples.
4242
- `Y`: A 1xn matrix where n is the number of samples.
4343
"""
44-
function fit!(model::MDN, X::Matrix{<:Number}, Y::Matrix{<:Number})
45-
fit!(model, Float32.(X), Float32.(Y))
44+
function fit!(model::MDN, X::Matrix{<:Real}, Y::Matrix{<:Real})
45+
fit!(model, Float64.(X), Float64.(Y))
4646
end
4747

4848
function fit!(model::MDN, X::Matrix{Float64}, Y::Matrix{Float64})
@@ -101,8 +101,8 @@ Predict the full conditional distribution P(Y|X).
101101
# Returns
102102
Returns a vector of Distributions.MixtureModel objects representing the conditional distribution for each sample.
103103
"""
104-
function predict(model::MDN, X::Matrix{<:Number})
105-
predict(model, Float32.(X))
104+
function predict(model::MDN, X::Matrix{<:Real})
105+
predict(model, Float64.(X))
106106
end
107107

108108
function predict(model::MDN, X::Matrix{Float64})
@@ -128,8 +128,8 @@ Predict the mean of the conditional distribution P(Y|X).
128128
# Returns
129129
Returns a vector of real numbers representing the mean of the conditional distribution P(Y|X) for each sample.
130130
"""
131-
function predict_mean(model::MDN, X::Matrix{<:Number})
132-
predict_mean(model, Float32.(X))
131+
function predict_mean(model::MDN, X::Matrix{<:Real})
132+
predict_mean(model, Float64.(X))
133133
end
134134

135135
function predict_mean(model::MDN, X::Matrix{Float64})
@@ -149,8 +149,8 @@ Predict the mean of the Gaussian with the largest prior in the conditional distr
149149
# Returns
150150
Returns a vector of real numbers representing the mean of the gaussian with the largest prior for each sample.
151151
"""
152-
function predict_mode(model::MDN, X::Matrix{<:Number})
153-
predict_mode(model, Float32.(X))
152+
function predict_mode(model::MDN, X::Matrix{<:Real})
153+
predict_mode(model, Float64.(X))
154154
end
155155

156156
function predict_mode(model::MDN, X::Matrix{Float64})

src/model.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,14 @@ function MixtureDensityNetwork(input::Int, layers::Vector{Int}, mixtures::Int)
3535
return MixtureDensityNetwork(hidden_layer, μ, Σ, π)
3636
end
3737

38-
"MixtureModel forward pass."
38+
function (m::MixtureDensityNetwork)(X::AbstractMatrix{<:Real})
39+
Float64.(X) |> m
40+
end
41+
3942
function (m::MixtureDensityNetwork)(X::AbstractMatrix{Float64})
40-
# Forward Pass
4143
h = m.hidden(X)
4244
μ = m.μ(h)
4345
Σ = m.Σ(h)
4446
π = m.π(h)
4547
μ, Σ, π
46-
end
47-
48-
"MixtureModel forward pass."
49-
function (m::MixtureDensityNetwork)(X::AbstractMatrix{<:Number})
50-
Float32.(X) |> m
5148
end

0 commit comments

Comments
 (0)