Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

*Note*: We try to adhere to these practices as of version [v0.2.1].
## Version [1.3.0] - 2025-12-08
- temporarily removed TaijaData due to issues with CategoricalDistributions 0.2 [#142]
- Docs env now has compatibility issues with TajaPlotting and RData(needs to be fixed). Cannot add CategoricalDistributions 0.2 without conflicts
- updated the package CategoricalDistributions to 0.2 in LaplaceRedux
- Explicitly used LaplaceRedux.Laplace in the pytorch_comparison.jl to avoid name conflicts


## Version [1.2.0] - 2024-12-03
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8"
CategoricalDistributions = "0.1.15"
CategoricalDistributions = "0.2.1"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
Distributions = "0.25.109"
Flux = "0.12, 0.13, 0.14"
LinearAlgebra = "1.7, 1.10"
MLJBase = "1"
MLJBase = "1.11"
MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Optimisers = "0.2, 0.3"
Expand Down
10 changes: 5 additions & 5 deletions docs/src/tutorials/multi.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
``` julia
using Pkg; Pkg.activate("docs")
# Import libraries
using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux
using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, CategoricalDistributions
theme(:lime)
```

Expand All @@ -17,7 +17,7 @@ using LaplaceRedux.Data
seed = 1234
x, y = Data.toy_data_multi(seed=seed)
X = hcat(x...)
y_onehot = Flux.onehotbatch(y, unique(y))
y_onehot = Flux.onehotbatch(y, unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
y_onehot = Flux.unstack(y_onehot',1)
```

Expand Down Expand Up @@ -59,7 +59,7 @@ We set up a model
``` julia
n_hidden = 3
D = size(X,1)
out_dim = length(unique(y))
out_dim = length(unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
nn = Chain(
Dense(D, n_hidden, σ),
Dense(n_hidden, out_dim)
Expand Down Expand Up @@ -103,7 +103,7 @@ optimize_prior!(la; verbosity=1, n_steps=100)
with either the probit approximation:

``` julia
_labels = sort(unique(y))
_labels = sort(unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
plt_list = []
for target in _labels
plt = plot(la, X_test, y_test; target=target, clim=(0,1))
Expand All @@ -117,7 +117,7 @@ plot(plt_list...)
or the plugin approximation:

``` julia
_labels = sort(unique(y))
_labels = sort(unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
plt_list = []
for target in _labels
plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin)
Expand Down
8 changes: 4 additions & 4 deletions docs/src/tutorials/multi.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using LaplaceRedux.Data
seed = 1234
x, y = Data.toy_data_multi(seed=seed)
X = hcat(x...)
y_onehot = Flux.onehotbatch(y, unique(y))
y_onehot = Flux.onehotbatch(y, unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
y_onehot = Flux.unstack(y_onehot',1)
```

Expand Down Expand Up @@ -59,7 +59,7 @@ We set up a model
```{julia}
n_hidden = 3
D = size(X,1)
out_dim = length(unique(y))
out_dim = length(unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
nn = Chain(
Dense(D, n_hidden, σ),
Dense(n_hidden, out_dim)
Expand Down Expand Up @@ -105,7 +105,7 @@ with either the probit approximation:
```{julia}
#| output: true

_labels = sort(unique(y))
_labels = sort(unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
plt_list = []
for target in _labels
plt = plot(la, X_test, y_test; target=target, clim=(0,1))
Expand All @@ -119,7 +119,7 @@ plot(plt_list...)
```{julia}
#| output: true

_labels = sort(unique(y))
_labels = sort(unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports section to make unwrap available.

Copilot uses AI. Check for mistakes.
plt_list = []
for target in _labels
plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -17,7 +18,6 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TaijaData = "9d524318-b4e6-4a65-86d2-b2b72d07866c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Expand Down
8 changes: 4 additions & 4 deletions test/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using MLUtils
using LinearAlgebra
using Distributions: Normal, Bernoulli, Categorical
using Random

using CategoricalDistributions
@testset "Construction" begin

# One layer:
Expand Down Expand Up @@ -132,12 +132,12 @@ end
#Random.seed!(123) # For reproducibility
x, y = Data.toy_data_multi(50)
X = hcat(x...)
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.onehotbatch(y, unwrap.(unique(y)))
y_train = Flux.unstack(y_train'; dims=1)
data = zip(x, y_train)
n_hidden = 3
D = size(X, 1)
out_dim = length(unique(y))
out_dim = length(unwrap.(unique(y)))
# Case: softmax activation function
nn = Chain(Dense(D, n_hidden, σ), Dense(n_hidden, out_dim), softmax)

Expand Down Expand Up @@ -374,7 +374,7 @@ end

# Classification multi:
xs, y = LaplaceRedux.Data.toy_data_multi(n)
ytrain = Flux.onehotbatch(y, unique(y))
ytrain = Flux.onehotbatch(y, unwrap.(unique(y)))
ytrain = Flux.unstack(ytrain'; dims=1)
X = reduce(hcat, xs)
Y = reduce(hcat, ytrain)
Expand Down
24 changes: 12 additions & 12 deletions test/pytorch_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include("testutils.jl")
y = df[:, 3]

X = hcat(x...)
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.onehotbatch(y, unwrap.(unique(y)))
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import: unwrap is used here but CategoricalDistributions is not imported in this file. Add using CategoricalDistributions to the imports at the top of the file to make unwrap available.

Copilot uses AI. Check for mistakes.
y_train = Flux.unstack(y_train'; dims=1)

data = zip(x, y_train)
Expand All @@ -28,7 +28,7 @@ include("testutils.jl")
nn = deserialize(joinpath(@__DIR__, "datafiles", "nn-binary_multi.jlb"))

@testset "LA - full weights - full hessian - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:full,
Expand All @@ -43,7 +43,7 @@ include("testutils.jl")
end

@testset "LA - full weights - full hessian - empfisher" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:full,
Expand All @@ -58,7 +58,7 @@ include("testutils.jl")
end

@testset "LA - last layer - full hessian - empfisher" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:full,
Expand All @@ -75,7 +75,7 @@ include("testutils.jl")
end

@testset "LA - last layer - full hessian - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:full,
Expand All @@ -92,7 +92,7 @@ include("testutils.jl")
end

@testset "LA - subnetwork - full hessian - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:full,
Expand All @@ -115,7 +115,7 @@ include("testutils.jl")
end

@testset "LA - subnetwork - full hessian - empfisher" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:full,
Expand All @@ -138,7 +138,7 @@ include("testutils.jl")
end

@testset "LA - full weights - kron - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:kron,
Expand All @@ -151,7 +151,7 @@ include("testutils.jl")
end

@testset "LA - last layer - kron - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:classification,
hessian_structure=:kron,
Expand Down Expand Up @@ -179,7 +179,7 @@ include("testutils.jl")
nn = deserialize(joinpath(@__DIR__, "datafiles", "nn-binary_regression.jlb"))

@testset "LA - full weights - full hessian - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:regression,
hessian_structure=:full,
Expand All @@ -192,7 +192,7 @@ include("testutils.jl")
end

@testset "LA - last layer - full hessian - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:regression,
hessian_structure=:full,
Expand All @@ -205,7 +205,7 @@ include("testutils.jl")
end

@testset "LA - subnetwork - full hessian - ggn" begin
la = Laplace(
la = LaplaceRedux.Laplace(
nn;
likelihood=:regression,
hessian_structure=:full,
Expand Down
Loading