Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion profiling/mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Profile.clear()
x_bin,
QARBoM.CD();
n_epochs = 20,
gibbs_steps = 3,
cd_steps = 3,
learning_rate = 0.01,
)

Expand Down
1 change: 1 addition & 0 deletions src/QARBoM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ abstract type AbstractDBN end
abstract type AbstractRBM end
abstract type DBNLayer end


export RBM, RBMClassifier, GRBM, GRBMClassifier
export QSampling, PCD, CD, FastPCD

Expand Down
88 changes: 87 additions & 1 deletion src/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ abstract type EvaluationMethod end
abstract type Accuracy <: EvaluationMethod end
abstract type CrossEntropy <: EvaluationMethod end
abstract type MeanSquaredError <: EvaluationMethod end
abstract type TruePositives <: EvaluationMethod end
abstract type TrueNegative <: EvaluationMethod end
abstract type FalsePositive <: EvaluationMethod end
abstract type FalseNegative <: EvaluationMethod end
abstract type Precision <: EvaluationMethod end
abstract type Recall <: EvaluationMethod end
abstract type F1Score <: EvaluationMethod end

export Accuracy, MeanSquaredError, CrossEntropy
export Accuracy, MeanSquaredError, CrossEntropy, TruePositives, TrueNegative, FalseNegative, FalsePositive, Precision, Recall, F1Score

function _evaluate(::Type{Accuracy}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
Expand All @@ -24,6 +31,71 @@ function _evaluate(::Type{CrossEntropy}, metrics_dict::Dict{String, Vector{Float
predicted = kwargs[:y_pred]
return metrics_dict["cross_entropy"][epoch] += sum(sample .* log.(predicted) .+ (1 .- sample) .* log.(1 .- predicted)) / dataset_size
end
function _evaluate(::Type{TruePositives}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
predicted = kwargs[:y_pred]
tp = 0
if sample[1] == 1.0 && predicted[1] == 1
tp += 1
end
return metrics_dict["true_positives"][epoch] += tp
end

function _evaluate(::Type{TrueNegative}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
predicted = kwargs[:y_pred]
tn = 0
if sample[1] == 0.0 && predicted[1] == 0
tn += 1
end
return metrics_dict["true_negative"][epoch] += tn
end

function _evaluate(::Type{Precision}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
predicted = kwargs[:y_pred]
tn = 0
if sample[1] == 0.0 && predicted[1] == 0
tn += 1
end
return metrics_dict["precision"][epoch] = metrics_dict["true_positives"][epoch]/(metrics_dict["true_positives"][epoch] + metrics_dict["false_positives"][epoch])
end

function _evaluate(::Type{Recall}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
predicted = kwargs[:y_pred]
tn = 0
if sample[1] == 0.0 && predicted[1] == 0
tn += 1
end
return metrics_dict["recall"][epoch] = metrics_dict["true_positives"][epoch]/(metrics_dict["true_positives"][epoch] + metrics_dict["false_negatives"][epoch])
end

function _evaluate(::Type{FalsePositive}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
predicted = kwargs[:y_pred]
fp = 0
if sample[1] == 0.0 && predicted[1] == 1
fp += 1
end
return metrics_dict["false_positives"][epoch] += fp
end

function _evaluate(::Type{FalseNegative}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
predicted = kwargs[:y_pred]
fn = 0
if sample[1] == 1.0 && predicted[1] == 0
fn += 1
end
return metrics_dict["false_negatives"][epoch] += fn
end

function _evaluate(::Type{F1Score}, metrics_dict::Dict{String, Vector{Float64}}, epoch::Int, dataset_size::Int; kwargs...)
sample = kwargs[:y_sample]
predicted = kwargs[:y_pred]
return metrics_dict["F1"][epoch] = metrics_dict["true_positives"][epoch]/(metrics_dict["true_positives"][epoch] + 1/2*(metrics_dict["false_positives"][epoch] + metrics_dict["false_negatives"][epoch]))
end

function evaluate(
rbm::Union{RBMClassifier, GRBMClassifier},
Expand Down Expand Up @@ -74,6 +146,20 @@ function _initialize_metrics(metrics::Vector{<:DataType})
metrics_dict["mse"] = Float64[]
elseif metric == CrossEntropy
metrics_dict["cross_entropy"] = Float64[]
elseif metric == TruePositives
metrics_dict["true_positives"] = Float64[]
elseif metric == TrueNegative
metrics_dict["true_negative"] = Float64[]
elseif metric == FalsePositive
metrics_dict["false_positives"] = Float64[]
elseif metric == FalseNegative
metrics_dict["false_negatives"] = Float64[]
elseif metric == Precision
metrics_dict["precision"] = Float64[]
elseif metric == Recall
metrics_dict["recall"] = Float64[]
elseif metric == F1Score
metrics_dict["F1"] = Float64[]
end
end
return metrics_dict
Expand Down
42 changes: 16 additions & 26 deletions src/fantasy_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,18 @@ mutable struct FantasyDataClassifier
y::Vector{Float64}
end

function _update_fantasy_data!(rbm::AbstractRBM, fantasy_data::Vector{FantasyData}, steps::Int)
for _ in 1:steps
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h)
end
function _update_fantasy_data!(rbm::AbstractRBM, fantasy_data::Vector{FantasyData})
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h)
end
end

function _update_fantasy_data!(rbm::Union{RBMClassifier, GRBMClassifier}, fantasy_data::Vector{FantasyDataClassifier}, steps::Int)
for _ in 1:steps
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v, fantasy_data[i].y)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h)
fantasy_data[i].y = gibbs_sample_label(rbm, fantasy_data[i].h)
end
function _update_fantasy_data!(rbm::Union{RBMClassifier, GRBMClassifier}, fantasy_data::Vector{FantasyDataClassifier})
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v, fantasy_data[i].y)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h)
fantasy_data[i].y = gibbs_sample_label(rbm, fantasy_data[i].h)
end
end

Expand All @@ -34,13 +30,10 @@ function _update_fantasy_data!(
W_fast::Matrix{Float64},
a_fast::Vector{Float64},
b_fast::Vector{Float64},
steps::Int,
)
for _ in 1:steps
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v, W_fast, b_fast)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h, W_fast, a_fast)
end
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v, W_fast, b_fast)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h, W_fast, a_fast)
end
end

Expand All @@ -52,14 +45,11 @@ function _update_fantasy_data!(
a_fast::Vector{Float64},
b_fast::Vector{Float64},
c_fast::Vector{Float64},
steps::Int,
)
for _ in 1:steps
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v, fantasy_data[i].y, W_fast, U_fast, b_fast)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h, W_fast, a_fast)
fantasy_data[i].y = gibbs_sample_label(rbm, fantasy_data[i].h, U_fast, c_fast)
end
for i in 1:length(fantasy_data)
fantasy_data[i].h = gibbs_sample_hidden(rbm, fantasy_data[i].v, fantasy_data[i].y, W_fast, U_fast, b_fast)
fantasy_data[i].v = gibbs_sample_visible(rbm, fantasy_data[i].h, W_fast, a_fast)
fantasy_data[i].y = gibbs_sample_label(rbm, fantasy_data[i].h, U_fast, c_fast)
end
end

Expand Down
9 changes: 3 additions & 6 deletions src/rbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ mutable struct GRBMClassifier <: AbstractRBM
n_classifiers::Int # number of classifier bits
end

const RBMClassifiers = Union{RBMClassifier, GRBMClassifier}

function RBM(n_visible::Int, n_hidden::Int)
W = randn(n_visible, n_hidden)
a = zeros(n_visible)
Expand Down Expand Up @@ -167,8 +165,7 @@ conditional_prob_h(rbm::AbstractRBM, v::Vector{<:Number}) = _sigmoid.(rbm.b .+ r
conditional_prob_h(rbm::AbstractRBM, v::Vector{<:Number}, W_fast::Matrix{Float64}, b_fast::Vector{Float64}) =
_sigmoid.(rbm.b .+ b_fast .+ (rbm.W .+ W_fast)' * v)

conditional_prob_h(rbm::Union{RBMClassifier, GRBMClassifier}, v::Vector{<:Number}, y::Vector{<:Number}) =
_sigmoid.(rbm.b .+ rbm.W' * v .+ rbm.U' * y)
conditional_prob_h(rbm::Union{RBMClassifier, GRBMClassifier}, v::Vector{<:Number}, y::Vector{<:Number}) = _sigmoid.(rbm.b .+ rbm.W' * v .+ rbm.U' * y)

conditional_prob_h(
rbm::Union{RBMClassifier, GRBMClassifier},
Expand Down Expand Up @@ -230,7 +227,7 @@ function reconstruct(rbm::AbstractRBM, v::Vector{<:Number})
return v_reconstructed
end

function classify(rbm::RBMClassifiers, v::Vector{<:Number})
function classify(rbm::GRBMClassifier, v::Vector{<:Number})
y = conditional_prob_y_given_v(rbm, v)
return y
end
Expand All @@ -252,7 +249,7 @@ function copy_rbm(rbm::RBMClassifier)
return RBMClassifier(copy(rbm.W), copy(rbm.U), copy(rbm.a), copy(rbm.b), copy(rbm.c), rbm.n_visible, rbm.n_hidden, rbm.n_classifiers)
end

function copy_rbm!(rbm_src::RBMClassifiers, rbm_target::RBMClassifier)
function copy_rbm!(rbm_src::GRBMClassifier, rbm_target::RBMClassifier)
rbm_target.W .= rbm_src.W
rbm_target.U .= rbm_src.U
rbm_target.a .= rbm_src.a
Expand Down
12 changes: 6 additions & 6 deletions src/training/train_cd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function contrastive_divergence!(rbm::AbstractRBM, x; steps::Int, learning_rate:
return total_t_sample, total_t_gibbs, total_t_update
end

function contrastive_divergence!(rbm::RBMClassifiers, x, y; steps::Int, learning_rate::Float64 = 0.1, label_learning_rate::Float64 = 0.1)
function contrastive_divergence!(rbm::GRBMClassifier, x, y; steps::Int, learning_rate::Float64 = 0.1, label_learning_rate::Float64 = 0.1)
total_t_sample, total_t_gibbs, total_t_update = 0.0, 0.0, 0.0
for sample_i in eachindex(x)
v_data = x[sample_i]
Expand All @@ -47,7 +47,7 @@ function train!(
x_train,
::Type{CD};
n_epochs::Int,
gibbs_steps::Int = 3,
cd_steps::Int = 3,
learning_rate::Vector{Float64},
metrics::Vector{<:DataType} = [MeanSquaredError],
early_stopping::Bool = false,
Expand All @@ -71,7 +71,7 @@ function train!(
t_sample, t_gibbs, t_update = contrastive_divergence!(
rbm,
x_train;
steps = gibbs_steps,
steps = cd_steps,
learning_rate = learning_rate[epoch],
)
total_t_sample += t_sample
Expand Down Expand Up @@ -114,12 +114,12 @@ function train!(
end

function train!(
rbm::RBMClassifiers,
rbm::GRBMClassifier,
x_train,
label_train,
::Type{CD};
n_epochs::Int,
gibbs_steps::Int = 3,
cd_steps::Int = 3,
learning_rate::Vector{Float64},
label_learning_rate::Vector{Float64},
metrics::Vector{<:DataType} = [Accuracy],
Expand All @@ -146,7 +146,7 @@ function train!(
rbm,
x_train,
label_train;
steps = gibbs_steps,
steps = cd_steps,
learning_rate = learning_rate[epoch],
label_learning_rate = label_learning_rate[epoch],
)
Expand Down
18 changes: 4 additions & 14 deletions src/training/train_fast_pcd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ function fast_persistent_contrastive_divergence!(
x,
mini_batches::Vector{UnitRange{Int}},
fantasy_data::Vector{FantasyData};
steps::Int = 1,
learning_rate::Float64 = 0.1,
fast_learning_rate::Float64 = 0.1,
)
Expand Down Expand Up @@ -49,19 +48,18 @@ function fast_persistent_contrastive_divergence!(

# Update fantasy data
t_gibbs = time()
_update_fantasy_data!(rbm, fantasy_data, W_fast, a_fast, b_fast, steps)
_update_fantasy_data!(rbm, fantasy_data, W_fast, a_fast, b_fast)
total_t_gibbs += time() - t_gibbs
end
return total_t_sample, total_t_gibbs, total_t_update
end

function fast_persistent_contrastive_divergence!(
rbm::RBMClassifiers,
rbm::GRBMClassifier,
x,
label,
mini_batches::Vector{UnitRange{Int}},
fantasy_data::Vector{FantasyDataClassifier};
steps::Int = 1,
learning_rate::Float64 = 0.1,
label_learning_rate::Float64 = 0.1,
fast_learning_rate::Float64 = 0.1,
Expand Down Expand Up @@ -120,7 +118,7 @@ function fast_persistent_contrastive_divergence!(

# Update fantasy data
t_gibbs = time()
_update_fantasy_data!(rbm, fantasy_data, W_fast, U_fast, a_fast, b_fast, c_fast, steps)
_update_fantasy_data!(rbm, fantasy_data, W_fast, U_fast, a_fast, b_fast, c_fast)
total_t_gibbs += time() - t_gibbs
end
return total_t_sample, total_t_gibbs, total_t_update
Expand All @@ -132,7 +130,6 @@ end
x_train,
::Type{FastPCD};
n_epochs::Int,
gibbs_steps::Int = 1,
batch_size::Int,
learning_rate::Vector{Float64},
fast_learning_rate::Float64,
Expand All @@ -154,7 +151,6 @@ Tieleman and Hinton (2009) "Using fast weights to improve persistent contrastive
- `rbm::AbstractRBM`: The RBM to train.
- `x_train`: The training data.
- `n_epochs::Int`: The number of epochs to train the RBM.
- `gibbs_steps::Int`: The number of Gibbs Sampling steps to use.
- `batch_size::Int`: The size of the mini-batches.
- `learning_rate::Vector{Float64}`: The learning rate for each epoch.
- `fast_learning_rate::Float64`: The fast learning rate.
Expand All @@ -171,7 +167,6 @@ function train!(
x_train,
::Type{FastPCD};
n_epochs::Int,
gibbs_steps::Int = 1,
batch_size::Int,
learning_rate::Vector{Float64},
fast_learning_rate::Float64,
Expand Down Expand Up @@ -203,7 +198,6 @@ function train!(
x_train,
mini_batches,
fantasy_data;
steps = gibbs_steps,
learning_rate = learning_rate[epoch],
fast_learning_rate = fast_learning_rate,
)
Expand Down Expand Up @@ -253,7 +247,6 @@ end
label_train,
::Type{FastPCD};
n_epochs::Int,
gibbs_steps::Int = 1,
batch_size::Int,
learning_rate::Vector{Float64},
fast_learning_rate::Float64,
Expand All @@ -278,7 +271,6 @@ Tieleman and Hinton (2009) "Using fast weights to improve persistent contrastive
- `x_train`: The training data.
- `label_train`: The training labels.
- `n_epochs::Int`: The number of epochs to train the RBM.
- `gibbs_steps::Int`: The number of Gibbs Sampling steps to use.
- `batch_size::Int`: The size of the mini-batches.
- `learning_rate::Vector{Float64}`: The learning rate for each epoch.
- `fast_learning_rate::Float64`: The fast learning rate.
Expand All @@ -294,12 +286,11 @@ Tieleman and Hinton (2009) "Using fast weights to improve persistent contrastive
- `file_path`: The file path to save the metrics.
"""
function train!(
rbm::RBMClassifiers,
rbm::GRBMClassifier,
x_train,
label_train,
::Type{FastPCD};
n_epochs::Int,
gibbs_steps::Int = 1,
batch_size::Int,
learning_rate::Vector{Float64},
fast_learning_rate::Float64 = 0.1,
Expand Down Expand Up @@ -335,7 +326,6 @@ function train!(
label_train,
mini_batches,
fantasy_data;
steps = gibbs_steps,
learning_rate = learning_rate[epoch],
label_learning_rate = label_learning_rate[epoch],
fast_learning_rate = fast_learning_rate,
Expand Down
Loading
Loading