Skip to content

Commit 14e03ab

Browse files
authored
Merge pull request #373 from CliMA/ab/struct-mat-dict
Store structure matrices and vectors with names as dictionaries
2 parents 3e05b7c + 9abb476 commit 14e03ab

31 files changed

+346
-271
lines changed

docs/src/API/GaussianProcess.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ GaussianProcess(
1515
::FT,
1616
::PredictionType,
1717
) where {GPPkg <: GaussianProcessesPackage, K <: GaussianProcesses.Kernel, KPy <: PyObject, AK <:AbstractGPs.Kernel, FT <: AbstractFloat}
18-
build_models!(::GaussianProcess{GPJL}, ::PairedDataContainer{FT}, input_structure_matrix, output_structure_matrix) where {FT <: AbstractFloat}
18+
build_models!(::GaussianProcess{GPJL}, ::PairedDataContainer{FT}, input_structure_mats, output_structure_mats) where {FT <: AbstractFloat}
1919
optimize_hyperparameters!(::GaussianProcess{GPJL})
2020
predict(::GaussianProcess{GPJL}, ::AbstractMatrix{FT}) where {FT <: AbstractFloat}
2121
```

docs/src/API/RandomFeatures.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ build_default_prior
2222
```@docs
2323
ScalarRandomFeatureInterface
2424
ScalarRandomFeatureInterface(::Int,::Int)
25-
build_models!(::ScalarRandomFeatureInterface, ::PairedDataContainer{FT}, input_structure_matrix, output_structure_matrix) where {FT <: AbstractFloat}
25+
build_models!(::ScalarRandomFeatureInterface, ::PairedDataContainer{FT}, input_structure_mats, output_structure_mats) where {FT <: AbstractFloat}
2626
predict(::ScalarRandomFeatureInterface, ::M) where {M <: AbstractMatrix}
2727
```
2828

@@ -31,7 +31,7 @@ predict(::ScalarRandomFeatureInterface, ::M) where {M <: AbstractMatrix}
3131
```@docs
3232
VectorRandomFeatureInterface
3333
VectorRandomFeatureInterface(::Int, ::Int, ::Int)
34-
build_models!(::VectorRandomFeatureInterface, ::PairedDataContainer{FT}, input_structure_matrix, output_structure_matrix) where {FT <: AbstractFloat}
34+
build_models!(::VectorRandomFeatureInterface, ::PairedDataContainer{FT}, input_structure_mats, output_structure_mats) where {FT <: AbstractFloat}
3535
predict(::VectorRandomFeatureInterface, ::M) where {M <: AbstractMatrix}
3636
```
3737

docs/src/data_processing.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ complex_schedule = [
2828
In this (rather unrealistic) chain;
2929
1. The inputs are decorrelated with their sample mean and covariance (and projected to low dimensional subspace if necessary) i.e PCA
3030
2. The scaled inputs are then subject to a "Robust" univariate scaling, mapping 1st-3rd quartiles to [0,1]
31-
3. The outputs are decorrelated using an "output structure matrix" (provided to the emulator `output_structure_matrix=`). Furthermore, apply a dimension-reduction to a space that retains 95% of the total variance.
31+
3. The outputs are decorrelated using an "output structure matrix" (provided to the emulator in the `encoder_kwargs` keyword parameter, e.g. as `(; obs_cov_noise =)`). Furthermore, apply a dimension-reduction to a space that retains 95% of the total variance.
3232
4. In the reduced input-output space, a canonical correlation analysis is performed. Data is oriented and reduced (if necessary) maximize the joint correlation between inputs and outputs.
3333

3434
!!! note "Default Encoder schedule"
@@ -44,12 +44,12 @@ The schedule is then passed into the Emulator, along with the data and desired s
4444
```julia
4545
emulator = Emulator(
4646
machine_learning_tool,
47-
input_output_pairs;
48-
output_structure_matrix = obs_noise_cov,
47+
input_output_pairs;
4948
encoder_schedule = complex_schedule,
49+
encoder_kwargs = (; obs_noise_cov = obs_noise_cov),
5050
)
5151
```
52-
Note that due to the item `(decorrelate_structure_mat(retain_var=0.95), "out")` in the schedule, we must provide the `output_structure_matrix`.
52+
Note that due to the item `(decorrelate_structure_mat(retain_var=0.95), "out")` in the schedule, we must provide an output structure matrix. In this case, we provide `obs_noise_cov`.
5353

5454
# Types of data processors
5555

docs/src/emulate.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ Wrapping a predefined machine learning tool, e.g. a Gaussian process `gauss_proc
1717
emulator = Emulator(
1818
gauss_proc,
1919
input_output_pairs; # optional arguments after this
20-
output_structure_matrix = Γy,
2120
encoder_schedule = encoder_schedule,
21+
encoder_kwargs = (; obs_noise_cov = Γy),
2222
)
2323
```
2424
The optional arguments above relate to the data processing, which is described [here](@ref data-proc)
@@ -44,7 +44,7 @@ Developers may contribute new tools by performing the following
4444
2. Create a struct `MyMLTool <: MachineLearningTool`, containing any arguments or optimizer options
4545
3. Create the following three methods to build, train, and predict with your tool (use `GaussianProcess.jl` as a guide)
4646
```
47-
build_models!(mlt::MyMLTool, iopairs::PairedDataContainer) -> Nothing
47+
build_models!(mlt::MyMLTool, iopairs::PairedDataContainer, input_structure_mats::Dict{Symbol, <:StructureMatrix}, output_structure_mats::Dict{Symbol, <:StructureMatrix}) -> Nothing
4848
optimize_hyperparameters!(mlt::MyMLTool, args...; kwargs...) -> Nothing
4949
function predict(mlt::MyMLTool, new_inputs::Matrix; kwargs...) -> Matrix, Union{Matrix, Array{,3}
5050
```

examples/Cloudy/Cloudy_emulate_sample.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,8 @@ function main()
172172
emulator = Emulator(
173173
mlt,
174174
input_output_pairs;
175-
input_structure_matrix = cov(priors),
176-
output_structure_matrix = Γy,
177175
encoder_schedule = encoder_schedule,
176+
encoder_kwargs = (; prior_cov = cov(priors), obs_noise_cov = Γy),
178177
)
179178

180179
optimize_hyperparameters!(emulator)

examples/Darcy/emulate_sample.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,13 @@ function main()
9292
@save joinpath(data_save_directory, "input_output_pairs.jld2") input_output_pairs
9393

9494
# data processing
95-
encoding_schedule = (decorrelate_structure_mat(), "in_and_out")
95+
encoder_schedule = (decorrelate_structure_mat(), "in_and_out")
9696

9797
emulator = Emulator(
9898
mlt,
9999
input_output_pairs;
100-
input_structure_matrix = cov(prior),
101-
output_structure_matrix = Γy,
102-
encoding_schedule = encoding_schedule,
100+
encoder_schedule = encoder_schedule,
101+
encoder_kwargs = (; prior_cov = cov(prior), obs_noise_cov = Γy),
103102
)
104103
optimize_hyperparameters!(emulator, kernbounds = [fill(-1e2, n_params + 1), fill(1e2, n_params + 1)])
105104

examples/EDMF_data/emulator-rank-test.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,9 @@ function main()
256256
ttt[rank_id, rep_idx] = @elapsed begin
257257
emulator = Emulator(
258258
mlt,
259-
train_pairs;
260-
input_structure_matrix = cov(prior),
261-
output_structure_matrix = truth_cov,
259+
train_pairs,
262260
encoder_schedule = encoder_schedule,
261+
encoder_kwargs = (; prior_cov = cov(prior), obs_noise_cov = truth_cov),
263262
)
264263

265264
# Optimize the GP hyperparameters for better fit

examples/EDMF_data/uq_for_edmf.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,9 @@ function main()
230230
# Fit an emulator to the data
231231
emulator = Emulator(
232232
mlt,
233-
input_output_pairs;
234-
input_structure_matrix = cov(prior),
235-
output_structure_matrix = truth_cov,
233+
input_output_pairs,
236234
encoder_schedule = encoder_schedule,
235+
encoder_kwargs = (; prior_cov = cov(prior), obs_noise_cov = truth_cov),
237236
)
238237

239238
# Optimize the GP hyperparameters for better fit

examples/Emulator/G-function/emulate-test-n-features.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ function main()
154154
emulator = Emulator(
155155
mlt,
156156
iopairs;
157-
output_structure_matrix = Γ * I,
158157
encoder_schedule = deepcopy(encoder_schedule),
158+
encoder_kwargs = (; obs_noise_cov = Γ * I),
159159
)
160160
optimize_hyperparameters!(emulator)
161161
end

examples/Emulator/G-function/emulate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ function main()
145145
emulator = Emulator(
146146
mlt,
147147
iopairs;
148-
output_structure_matrix = Γ * I,
149148
encoder_schedule = deepcopy(encoder_schedule),
149+
encoder_kwargs = (; obs_noise_cov = Γ * I),
150150
)
151151
optimize_hyperparameters!(emulator)
152152
end

0 commit comments

Comments
 (0)