-
Notifications
You must be signed in to change notification settings - Fork 5
Update categoricaldistributions #142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
026f317
f349168
b145902
3333383
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| ``` julia | ||
| using Pkg; Pkg.activate("docs") | ||
| # Import libraries | ||
| using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux | ||
| using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, CategoricalDistributions | ||
| theme(:lime) | ||
| ``` | ||
|
|
||
|
|
@@ -17,7 +17,7 @@ using LaplaceRedux.Data | |
| seed = 1234 | ||
| x, y = Data.toy_data_multi(seed=seed) | ||
| X = hcat(x...) | ||
| y_onehot = Flux.onehotbatch(y, unique(y)) | ||
| y_onehot = Flux.onehotbatch(y, unwrap.(unique(y))) | ||
| y_onehot = Flux.unstack(y_onehot',1) | ||
| ``` | ||
|
|
||
|
|
@@ -59,7 +59,7 @@ We set up a model | |
| ``` julia | ||
| n_hidden = 3 | ||
| D = size(X,1) | ||
| out_dim = length(unique(y)) | ||
| out_dim = length(unwrap.(unique(y))) | ||
|
||
| nn = Chain( | ||
| Dense(D, n_hidden, σ), | ||
| Dense(n_hidden, out_dim) | ||
|
|
@@ -103,7 +103,7 @@ optimize_prior!(la; verbosity=1, n_steps=100) | |
| with either the probit approximation: | ||
|
|
||
| ``` julia | ||
| _labels = sort(unique(y)) | ||
| _labels = sort(unwrap.(unique(y))) | ||
|
||
| plt_list = [] | ||
| for target in _labels | ||
| plt = plot(la, X_test, y_test; target=target, clim=(0,1)) | ||
|
|
@@ -117,7 +117,7 @@ plot(plt_list...) | |
| or the plugin approximation: | ||
|
|
||
| ``` julia | ||
| _labels = sort(unique(y)) | ||
| _labels = sort(unwrap.(unique(y))) | ||
|
||
| plt_list = [] | ||
| for target in _labels | ||
| plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ using LaplaceRedux.Data | |
| seed = 1234 | ||
| x, y = Data.toy_data_multi(seed=seed) | ||
| X = hcat(x...) | ||
| y_onehot = Flux.onehotbatch(y, unique(y)) | ||
| y_onehot = Flux.onehotbatch(y, unwrap.(unique(y))) | ||
|
||
| y_onehot = Flux.unstack(y_onehot',1) | ||
| ``` | ||
|
|
||
|
|
@@ -59,7 +59,7 @@ We set up a model | |
| ```{julia} | ||
| n_hidden = 3 | ||
| D = size(X,1) | ||
| out_dim = length(unique(y)) | ||
| out_dim = length(unwrap.(unique(y))) | ||
|
||
| nn = Chain( | ||
| Dense(D, n_hidden, σ), | ||
| Dense(n_hidden, out_dim) | ||
|
|
@@ -105,7 +105,7 @@ with either the probit approximation: | |
| ```{julia} | ||
| #| output: true | ||
|
|
||
| _labels = sort(unique(y)) | ||
| _labels = sort(unwrap.(unique(y))) | ||
|
||
| plt_list = [] | ||
| for target in _labels | ||
| plt = plot(la, X_test, y_test; target=target, clim=(0,1)) | ||
|
|
@@ -119,7 +119,7 @@ plot(plt_list...) | |
| ```{julia} | ||
| #| output: true | ||
|
|
||
| _labels = sort(unique(y)) | ||
| _labels = sort(unwrap.(unique(y))) | ||
|
||
| plt_list = [] | ||
| for target in _labels | ||
| plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ include("testutils.jl") | |
| y = df[:, 3] | ||
|
|
||
| X = hcat(x...) | ||
| y_train = Flux.onehotbatch(y, unique(y)) | ||
| y_train = Flux.onehotbatch(y, unwrap.(unique(y))) | ||
|
||
| y_train = Flux.unstack(y_train'; dims=1) | ||
|
|
||
| data = zip(x, y_train) | ||
|
|
@@ -28,7 +28,7 @@ include("testutils.jl") | |
| nn = deserialize(joinpath(@__DIR__, "datafiles", "nn-binary_multi.jlb")) | ||
|
|
||
| @testset "LA - full weights - full hessian - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:full, | ||
|
|
@@ -43,7 +43,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - full weights - full hessian - empfisher" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:full, | ||
|
|
@@ -58,7 +58,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - last layer - full hessian - empfisher" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:full, | ||
|
|
@@ -75,7 +75,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - last layer - full hessian - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:full, | ||
|
|
@@ -92,7 +92,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - subnetwork - full hessian - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:full, | ||
|
|
@@ -115,7 +115,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - subnetwork - full hessian - empfisher" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:full, | ||
|
|
@@ -138,7 +138,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - full weights - kron - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:kron, | ||
|
|
@@ -151,7 +151,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - last layer - kron - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:classification, | ||
| hessian_structure=:kron, | ||
|
|
@@ -179,7 +179,7 @@ include("testutils.jl") | |
| nn = deserialize(joinpath(@__DIR__, "datafiles", "nn-binary_regression.jlb")) | ||
|
|
||
| @testset "LA - full weights - full hessian - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:regression, | ||
| hessian_structure=:full, | ||
|
|
@@ -192,7 +192,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - last layer - full hessian - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:regression, | ||
| hessian_structure=:full, | ||
|
|
@@ -205,7 +205,7 @@ include("testutils.jl") | |
| end | ||
|
|
||
| @testset "LA - subnetwork - full hessian - ggn" begin | ||
| la = Laplace( | ||
| la = LaplaceRedux.Laplace( | ||
| nn; | ||
| likelihood=:regression, | ||
| hessian_structure=:full, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing import:
unwrapis used here butCategoricalDistributionsis not imported in this file. Addusing CategoricalDistributionsto the imports section to makeunwrapavailable.