Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions src/Utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export create_encoder_schedule,
decode_structure_matrix


const StructureMatrix = Union{UniformScaling, AbstractMatrix}

"""
$(DocStringExtensions.TYPEDSIGNATURES)
Expand Down Expand Up @@ -77,6 +78,23 @@ Base.:(==)(a::PDCP, b::PDCP) where {PDCP <: PairedDataContainerProcessor} =

####

function get_structure_mat(structure_mats, name = nothing)
if isnothing(name)
if size(structure_mats) == 1
return only(values(structure_mats))
elseif isempty(structure_mats)
@error "Please provide a structure matrix."
else
@error "Structure matrices $(collect(keys(structure_mats))) are present. Please indicate which to use."
end
else
if haskey(structure_mats, name)
return structure_mats[name]
else
@error "Structure matrix $name not found. Options: $(collect(keys(structure_mats)))."
end
end
end

function _encode_data(proc::P, data, apply_to::AS) where {P <: DataProcessor, AS <: AbstractString}
input_data, output_data = get_data(data)
Expand Down Expand Up @@ -118,12 +136,12 @@ function _initialize_and_encode_data!(
apply_to::AS,
) where {AS <: AbstractString}
input_data, output_data = get_data(data)
input_structure_mat, output_structure_mat = structure_mats
input_structure_mats, output_structure_mats = structure_mats

if apply_to == "in"
initialize_processor!(proc, input_data, input_structure_mat)
initialize_processor!(proc, input_data, input_structure_mats)
elseif apply_to == "out"
initialize_processor!(proc, output_data, output_structure_mat)
initialize_processor!(proc, output_data, output_structure_mats)
else
bad_apply_to(apply_to)
end
Expand Down Expand Up @@ -189,18 +207,26 @@ Takes in the created encoder schedule (See [`create_encoder_schedule`](@ref)), a
"""
function initialize_and_encode_with_schedule!(
encoder_schedule::VV,
io_pairs::PDC,
input_structure_mat::USorMorN1,
output_structure_mat::USorMorN2,
io_pairs::PDC;
input_structure_mats = Dict{Symbol, StructureMatrix}(),
output_structure_mats = Dict{Symbol, StructureMatrix}(),
input_cov::Union{Nothing, StructureMatrix} = nothing,
noise_cov::Union{Nothing, StructureMatrix} = nothing,
) where {
VV <: AbstractVector,
PDC <: PairedDataContainer,
USorMorN1 <: Union{UniformScaling, AbstractMatrix, Nothing},
USorMorN2 <: Union{UniformScaling, AbstractMatrix, Nothing},
}
processed_io_pairs = deepcopy(io_pairs)
processed_input_structure_mat = deepcopy(input_structure_mat)
processed_output_structure_mat = deepcopy(output_structure_mat)

processed_input_structure_mats = deepcopy(input_structure_mats)
if !isnothing(input_cov)
processed_input_structure_mats[:input_cov] = input_cov
end

processed_output_structure_mats = deepcopy(output_structure_mats)
if !isnothing(noise_cov)
processed_output_structure_mats[:noise_cov] = noise_cov
end

# apply_to is the string "in", "out" etc.
for (processor, apply_to) in encoder_schedule
Expand All @@ -209,38 +235,30 @@ function initialize_and_encode_with_schedule!(
processed = _initialize_and_encode_data!(
processor,
processed_io_pairs,
(processed_input_structure_mat, processed_output_structure_mat),
(processed_input_structure_mats, processed_output_structure_mats),
apply_to,
)

if apply_to == "in"
processed_input_structure_mat = encode_structure_matrix(processor, processed_input_structure_mat)
processed_input_structure_mats = Dict(
name => encode_structure_matrix(processor, mat)
for (name, mat) in processed_input_structure_mats
)
processed_io_pairs = PairedDataContainer(processed, get_outputs(processed_io_pairs))
elseif apply_to == "out"
processed_output_structure_mat = encode_structure_matrix(processor, processed_output_structure_mat)
processed_output_structure_mats = Dict(
name => encode_structure_matrix(processor, mat)
for (name, mat) in processed_output_structure_mats
)
processed_io_pairs = PairedDataContainer(get_inputs(processed_io_pairs), processed)
end
end

return processed_io_pairs, processed_input_structure_mat, processed_output_structure_mat
return processed_io_pairs, processed_input_structure_mats, processed_output_structure_mats
end

# Functions to encode/decode with initialized schedule

# cases when structure_matrix is Nothing:
encode_structure_matrix(dp::DP, n::Nothing) where {DP <: DataProcessor} = nothing
decode_structure_matrix(dp::DP, n::Nothing) where {DP <: DataProcessor} = nothing
encode_with_schedule(
encoder_schedule::VV,
n::Nothing,
in_or_out::AS,
) where {VV <: AbstractVector, AS <: AbstractString} = nothing
decode_with_schedule(
encoder_schedule::VV,
n::Nothing,
in_or_out::AS,
) where {VV <: AbstractVector, AS <: AbstractString} = nothing

"""
$TYPEDSIGNATURES

Expand Down
12 changes: 6 additions & 6 deletions src/Utilities/canonical_correlation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ initialize_processor!(
cc::CanonicalCorrelation,
in_data::MM,
out_data::MM,
input_structure_matrix,
output_structure_matrix,
input_structure_matrices,
output_structure_matrices,
apply_to::AS,
) where {MM <: AbstractMatrix, AS <: AbstractString} = initialize_processor!(cc, in_data, out_data, apply_to)

Expand Down Expand Up @@ -205,8 +205,8 @@ Apply the `CanonicalCorrelation` encoder to a provided structure matrix
"""
function encode_structure_matrix(
cc::CanonicalCorrelation,
structure_matrix::USorM,
) where {USorM <: Union{UniformScaling, AbstractMatrix}}
structure_matrix::SM,
) where {SM <: StructureMatrix}
encoder_mat = get_encoder_mat(cc)[1]
return encoder_mat * structure_matrix * encoder_mat'
end
Expand All @@ -218,8 +218,8 @@ Apply the `CanonicalCorrelation` decoder to a provided structure matrix
"""
function decode_structure_matrix(
cc::CanonicalCorrelation,
enc_structure_matrix::USorM,
) where {USorM <: Union{UniformScaling, AbstractMatrix}}
enc_structure_matrix::SM,
) where {SM <: StructureMatrix}
decoder_mat = get_decoder_mat(cc)[1]
return decoder_mat * enc_structure_matrix * decoder_mat'
end
32 changes: 15 additions & 17 deletions src/Utilities/decorrelator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct Decorrelator{VV1, VV2, VV3, FT, AS <: AbstractString} <: DataContainerPro
retain_var::FT
"Switch to choose what form of matrix to use to decorrelate the data"
decorrelate_with::AS
""
structure_mat_name::Union{Nothing, Symbol}
end

"""
Expand All @@ -48,8 +50,8 @@ Constructs the `Decorrelator` struct. Users can add optional keyword arguments:
- `"sample_cov"`, see [`decorrelate_sample_cov`](@ref)
- `"combined"`, sums the `"sample_cov"` and `"structure_mat"` matrices
"""
decorrelate(; retain_var::FT = Float64(1.0), decorrelate_with = "combined") where {FT} =
Decorrelator([], [], [], clamp(retain_var, FT(0), FT(1)), decorrelate_with)
decorrelate(; retain_var::FT = Float64(1.0), decorrelate_with = "combined", structure_mat_name = nothing) where {FT} =
Decorrelator([], [], [], clamp(retain_var, FT(0), FT(1)), decorrelate_with, structure_mat_name)

"""
$(TYPEDSIGNATURES)
Expand All @@ -58,16 +60,16 @@ Constructs the `Decorrelator` struct, setting decorrelate_with = "sample_cov". E
- `retain_var`[=`1.0`]: to project onto the leading singular vectors such that `retain_var` variance is retained
"""
decorrelate_sample_cov(; retain_var::FT = Float64(1.0)) where {FT} =
Decorrelator([], [], [], clamp(retain_var, FT(0), FT(1)), "sample_cov")
Decorrelator([], [], [], clamp(retain_var, FT(0), FT(1)), "sample_cov", nothing)

"""
$(TYPEDSIGNATURES)

Constructs the `Decorrelator` struct, setting decorrelate_with = "structure_mat". This encoding will transform a provided structure matrix into `I`. One can additionally add keywords:
- `retain_var`[=`1.0`]: to project onto the leading singular vectors such that `retain_var` variance is retained
"""
decorrelate_structure_mat(; retain_var::FT = Float64(1.0)) where {FT} =
Decorrelator([], [], [], clamp(retain_var, FT(0), FT(1)), "structure_mat")
decorrelate_structure_mat(; retain_var::FT = Float64(1.0), structure_mat_name = nothing) where {FT} =
Decorrelator([], [], [], clamp(retain_var, FT(0), FT(1)), "structure_mat", structure_mat_name)

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -121,7 +123,7 @@ Computes and populates the `data_mean` and `encoder_mat` and `decoder_mat` field
function initialize_processor!(
dd::Decorrelator,
data::MM,
structure_matrix::USorMorN,
structure_matrices::USorMorN,
) where {MM <: AbstractMatrix, USorMorN <: Union{UniformScaling, AbstractMatrix, Nothing}}
if length(get_data_mean(dd)) == 0
push!(get_data_mean(dd), vec(mean(data, dims = 2)))
Expand All @@ -132,13 +134,8 @@ function initialize_processor!(
# Can do tsvd here for large matrices
decorrelate_with = get_decorrelate_with(dd)
if decorrelate_with == "structure_mat"
if isnothing(structure_matrix)
throw(
ArgumentError(
"DataProcessor `decorrelate_structure_mat` requires a user-provided structure matrix: received `nothing`. \n please provide (for input or output as needed) as a keyword argument `Emulator(...; input_structure_matrix=..., output_structure_matrix=...)` ",
),
)
elseif isa(structure_matrix, UniformScaling)
structure_matrix = get_structure_mat(structure_matrices, dd.structure_mat_name)
if isa(structure_matrix, UniformScaling)
data_dim = size(data, 1)
svdA = svd(structure_matrix(data_dim))
rk = data_dim
Expand All @@ -151,6 +148,7 @@ function initialize_processor!(
svdA = svd(cd)
rk = rank(cd)
elseif decorrelate_with == "combined"
structure_matrix = get_structure_mat(structure_matrices, dd.structure_mat_name)
spluscd = structure_matrix + cov(data, dims = 2)
svdA = svd(spluscd)
rk = rank(spluscd)
Expand Down Expand Up @@ -214,8 +212,8 @@ Apply the `Decorrelator` encoder to a provided structure matrix
"""
function encode_structure_matrix(
dd::Decorrelator,
structure_matrix::USorM,
) where {USorM <: Union{UniformScaling, AbstractMatrix}}
structure_matrix::SM,
) where {SM <: StructureMatrix}
encoder_mat = get_encoder_mat(dd)[1]
return encoder_mat * structure_matrix * encoder_mat'
end
Expand All @@ -227,8 +225,8 @@ Apply the `Decorrelator` decoder to a provided structure matrix
"""
function decode_structure_matrix(
dd::Decorrelator,
enc_structure_matrix::USorM,
) where {USorM <: Union{UniformScaling, AbstractMatrix}}
enc_structure_matrix::SM,
) where {SM <: StructureMatrix}
decoder_mat = get_decoder_mat(dd)[1]
return decoder_mat * enc_structure_matrix * decoder_mat'
end
16 changes: 8 additions & 8 deletions src/Utilities/elementwise_scaler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ end
function initialize_processor!(
es::ElementwiseScaler,
data::MM,
T::Type{QS},
::Type{QS},
) where {MM <: AbstractMatrix, QS <: QuartileScaling}
quartiles_vec = [quantile(dd, [0.25, 0.5, 0.75]) for dd in eachrow(data)]
quartiles_mat = reduce(hcat, quartiles_vec) # 3 rows: Q1, Q2, and Q3
Expand All @@ -97,7 +97,7 @@ end
function initialize_processor!(
es::ElementwiseScaler,
data::MM,
T::Type{MMS},
::Type{MMS},
) where {MM <: AbstractMatrix, MMS <: MinMaxScaling}
minmax_vec = [[minimum(dd), maximum(dd)] for dd in eachrow(data)]
minmax_mat = reduce(hcat, minmax_vec) # 2 rows: min max
Expand All @@ -108,7 +108,7 @@ end
function initialize_processor!(
es::ElementwiseScaler,
data::MM,
T::Type{ZSS},
::Type{ZSS},
) where {MM <: AbstractMatrix, ZSS <: ZScoreScaling}
stat_vec = [[mean(dd), std(dd)] for dd in eachrow(data)]
stat_mat = reduce(hcat, stat_vec) # 2 rows: mean, std
Expand Down Expand Up @@ -156,7 +156,7 @@ $(TYPEDSIGNATURES)

Computes and populates the `shift` and `scale` fields for the `ElementwiseScaler`
"""
initialize_processor!(es::ElementwiseScaler, data::MM, structure_matrix) where {MM <: AbstractMatrix} =
initialize_processor!(es::ElementwiseScaler, data::MM, structure_matrices) where {MM <: AbstractMatrix} =
initialize_processor!(es, data)


Expand All @@ -167,8 +167,8 @@ Apply the `ElementwiseScaler` encoder to a provided structure matrix
"""
function encode_structure_matrix(
es::ElementwiseScaler,
structure_matrix::USorM,
) where {USorM <: Union{UniformScaling, AbstractMatrix}}
structure_matrix::SM,
) where {SM <: StructureMatrix}
return Diagonal(1 ./ get_scale(es)) * structure_matrix * Diagonal(1 ./ get_scale(es))
end

Expand All @@ -179,7 +179,7 @@ Apply the `ElementwiseScaler` decoder to a provided structure matrix
"""
function decode_structure_matrix(
es::ElementwiseScaler,
enc_structure_matrix::USorM,
) where {USorM <: Union{UniformScaling, AbstractMatrix}}
enc_structure_matrix::SM,
) where {SM <: StructureMatrix}
return Diagonal(get_scale(es)) * enc_structure_matrix * Diagonal(get_scale(es))
end
Loading