Skip to content

Commit 61e411c

Browse files
committed
Update documentation.
1 parent 15a4e0b commit 61e411c

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ We can fit our model to our data by calling `fit!(m, X, Y; opt=Flux.Adam(), batc
3232
containing the learning curve, best epoch, and lowest loss observed during training as its second value. We can use Makie's `lines` method to visualize
3333
the learning curve.
3434
```julia
35-
model, report = tMixtureDensityNetworks.fit!(model, X, Y; epochs=500, opt=Flux.Adam(1e-3), batchsize=128)
35+
model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=500, opt=Flux.Adam(1e-3), batchsize=128)
3636
fig, _, _ = lines(1:500, lc, axis=(;xlabel="Epochs", ylabel="Loss"))
3737
```
3838

@@ -49,7 +49,7 @@ axislegend(ax, position=:lt)
4949

5050
![](figures/PredictedDistribution.png)
5151

52-
We can also visualize the conditional distribution predicted by our model at x = -2.0.
52+
We can also visualize the conditional distribution predicted by our model at x = -2.1.
5353
```julia
5454
cond = model(reshape([-2.1], (1,1)))[1]
5555
fig = Figure(resolution=(1000, 500))

src/layers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
$(TYPEDEF)
33
4-
A layer which returns a univriate Gaussian Mixture Model as its output.
4+
A layer which produces a univariate Gaussian Mixture Model as its output.
55
66
# Parameters
77
$(TYPEDFIELDS)
@@ -20,7 +20,7 @@ $(TYPEDSIGNATURES)
2020
Construct a layer which returns a univariate Gaussian Mixture Model as its output.
2121
2222
# Parameters
23-
- `input`: Specifies the size of the input tensor, which should have the dimensions `input x N`.
23+
- `input`: Specifies the length of the feature vectors. The layer expects a matrix with the dimensions `input x N` as input.
2424
- `mixtures`: The number of mixtures to use in the GMM.
2525
"""
2626
function UnivariateGMM(input::Int, mixtures::Int)
@@ -53,7 +53,7 @@ end
5353
"""
5454
$(TYPEDEF)
5555
56-
A layer which returns a multivariate Gaussian Mixture Model as its output.
56+
A layer which produces a multivariate Gaussian Mixture Model as its output.
5757
5858
# Parameters
5959
$(TYPEDFIELDS)
@@ -74,8 +74,8 @@ $(TYPEDSIGNATURES)
7474
Construct a layer which returns a multivariate Gaussian Mixture Model as its output.
7575
7676
# Parameters
77-
- `input`: Specifies the size of the input tensor, which should have the dimensions `input x N`.
78-
- `output`: Specifies the dimension of the labels, which should have the dimensions `output x N`.
77+
- `input`: Specifies the length of the feature vectors. The layer expects a matrix with the dimensions `input x N` as input.
78+
- `output`: Specifies the length of the label vectors. The layer returns a matrix with dimensions `output x N` as output.
7979
- `mixtures`: The number of mixtures to use in the GMM.
8080
"""
8181
function MultivariateGMM(input::Int, output::Int, mixtures::Int)

src/model.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ Flux.@functor MixtureDensityNetwork (hidden, output)
1616
"""
1717
$(TYPEDSIGNATURES)
1818
19-
Construct a new MixtureDensityNetwork.
19+
Construct a standard Mixture Density Network.
2020
2121
# Parameters
22-
- `input`: The dimension of the input features.
22+
- `input`: The length of the input feature vectors.
23+
- `output`: The length of the output feature vectors.
2324
- `layers`: The topolgy of the hidden layers, starting from the first layer.
24-
- `mixtures`: The number of Gaussian mixtures to use in the conditional distribution.
25+
- `mixtures`: The number of Gaussian mixtures to use in the predicted distribution.
2526
"""
2627
function MixtureDensityNetwork(input::Int, output::Int, layers::Vector{Int}, mixtures::Int)
2728
# Define Weight Initializer

0 commit comments

Comments
 (0)