diff --git a/MMD_GAN/.JuliaFormatter.toml b/MMD_GAN/.JuliaFormatter.toml new file mode 100644 index 0000000..323237b --- /dev/null +++ b/MMD_GAN/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" diff --git a/MMD_GAN/mmd.jl b/MMD_GAN/mmd.jl new file mode 100644 index 0000000..e865452 --- /dev/null +++ b/MMD_GAN/mmd.jl @@ -0,0 +1,166 @@ +using LinearAlgebra + +const min_var_est = 1e-8 + +# Helper function to compute linear time MMD with a linear kernel +function linear_mmd2(fₓ, fᵧ) + Δ = fₓ - fᵧ + loss = mean(sum(Δ[1:end-1, :] .* Δ[2:end, :], dims=2)) + return loss +end + +# Helper function to compute linear time MMD with a polynomial kernel +function poly_mmd2(fₓ, fᵧ; d=2, α=1.0, c=2.0) + Kₓₓ = α * sum(fₓ[1:end-1, :] .* fₓ[2:end, :], dims=2) .+ c + Kᵧᵧ = α * sum(fᵧ[1:end-1, :] .* fᵧ[2:end, :], dims=2) .+ c + Kₓᵧ = α * sum(fₓ[1:end-1, :] .* fᵧ[2:end, :], dims=2) .+ c + Kᵧₓ = α * sum(fᵧ[1:end-1, :] .* fₓ[2:end, :], dims=2) .+ c + + K̃ₓₓ = mean(Kₓₓ .^ d) + K̃ᵧᵧ = mean(Kᵧᵧ .^ d) + K̃ₓᵧ = mean(Kₓᵧ .^ d) + K̃ᵧₓ = mean(Kᵧₓ .^ d) + + return K̃ₓₓ + Kᵧᵧ - Kₓᵧ - Kᵧₓ +end + +# Helper function to compute mixed radial basis function kernel +function _mix_rbf_kernel(X, Y, sigma_list) + m = size(X, 1) + + Z = vcat(X, Y) + ZZₜ = Z * Z' + diag_ZZₜ = diag(ZZₜ) + Z_norm_sqr = broadcast(+, diag_ZZₜ, zeros(size(ZZₜ))) + exponent = Z_norm_sqr .- 2 * ZZₜ .+ Z_norm_sqr' + + K = zeros(size(exponent)) + for σ in sigma_list + γ = 1.0 / (2 * σ^2) + K += exp.(-γ * exponent) + end + + return K[1:m, 1:m], K[1:m, m+1:end], K[m+1:end, m+1:end], length(sigma_list) +end + +# Mixed Radial Basis Function Maximum Mean Discrepancy +function mix_rbf_mmd2(X, Y, sigma_list, biased=true) + @assert size(X, 1) == size(Y, 1) "X and Y must have the same number of rows" + Kₓₓ, Kₓᵧ, Kᵧᵧ, d = _mix_rbf_kernel(X, Y, sigma_list) + return _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ, false, biased) +end + +function mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=true) + Kₓₓ, Kₓᵧ, Kᵧᵧ, d = _mix_rbf_kernel(X, Y, sigma_list) + return _mmd2_and_ratio(Kₓₓ, Kₓᵧ, Kᵧᵧ, false, biased) +end + +# Helper function to compute variance based on kernel matrices +function _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ, const_diagonal=false, biased=false) + m = size(Kₓₓ, 1) + + # Get the various sums of kernels that we'll use + if const_diagonal !== false + diagₓ = diagᵧ = const_diagonal + sum_diagₓ = sum_diagᵧ = m * const_diagonal + else + diagₓ = diag(Kₓₓ) + diagᵧ = diag(Kᵧᵧ) + sum_diagₓ = sum(diagₓ) + sum_diagᵧ = sum(diagᵧ) + end + + Kₜₓₓ_sums = sum(Kₓₓ, dims=2) .- diagₓ + Kₜᵧᵧ_sums = sum(Kᵧᵧ, dims=2) .- diagᵧ + Kₓᵧ_sums₁ = sum(Kₓᵧ, dims=1) + + Kₜₓₓ_sum = sum(Kₜₓₓ_sums) + Kₜᵧᵧ_sum = sum(Kₜᵧᵧ_sums) + Kₓᵧ_sum = sum(Kₓᵧ_sums₁) + + if biased + mmd2 = ((Kₜₓₓ_sum + sum_diagₓ) / (m * m) + + + (Kₜᵧᵧ_sum + sum_diagᵧ) / (m * m) + - + 2.0 * Kₓᵧ_sum / (m * m)) + else + mmd2 = (Kₜₓₓ_sum / (m * (m - 1)) + + + Kₜᵧᵧ_sum / (m * (m - 1)) + - + 2.0 * Kₓᵧ_sum / (m * m)) + end + + return mmd2 +end + +function mmd2_and_ratio(K_XX, K_XY, K_YY; const_diagonal::Bool=false, biased::Bool=false) + mmd2, var_est = mmd2_and_variance(Kₓₓ, Kₓᵧ, Kᵧᵧ, const_diagonal=const_diagonal, biased=biased) + loss = mmd2 / √(max(var_est, min_var_est)) + return loss, mmd2, var_est +end + +function _mmd2_and_variance(Kₓₓ, Kₓᵧ, Kᵧᵧ; const_diagonal=false, biased=false) + m = size(Kₓₓ, 1) # assume X, Y are the same shape + + # Get the various sums of kernels that we'll use + if const_diagonal !== false + diagₓ = diagᵧ = const_diagonal + sum_diagₓ = sum_diagᵧ = m * const_diagonal + sum_diag2ₓ = sum_diag2ᵧ = m * const_diagonal^2 + else + diagₓ = diagm(Kₓₓ) # (m,) + diagᵧ = diagm(Kᵧᵧ) # (m,) + sum_diagₓ = sum(diagₓ) + sum_diagᵧ = sum(diagᵧ) + sum_diag2ₓ = dot(diagₓ, diagₓ) + sum_diag2ᵧ = dot(diagᵧ, diagᵧ) + end + + Kₜₓₓ_sums = sum(Kₓₓ, dims=2) .- diagₓ # \tilde{K}_XX * e = Kₓₓ * e - diagₓ + Kₜᵧᵧ_sums = sum(Kᵧᵧ, dims=2) .- diagᵧ # \tilde{K}_YY * e = Kᵧᵧ * e - diagᵧ + Kₓᵧ_sums₁ = sum(Kₓᵧ, dims=1) # K_{XY}^T * e + Kₓᵧ_sums₂ = sum(Kₓᵧ, dims=2) # K_{XY} * e + + Kₜₓₓ_sum = sum(Kₜₓₓ_sums) + Kₜᵧᵧ_sum = sum(Kₜᵧᵧ_sums) + Kₓᵧ_sum = sum(Kₓᵧ_sums₁) + + Kₜₓₓ_2_sum = sum(Kₓₓ .^ 2) .- sum_diag2ₓ # \| \tilde{K}_XX \|_F^2 + Kₜᵧᵧ_2_sum = sum(Kᵧᵧ .^ 2) .- sum_diag2ᵧ # \| \tilde{K}_YY \|_F^2 + Kₓᵧ_2_sum = sum(Kₓᵧ .^ 2) # \| K_{XY} \|_F^2 + + if biased + mmd2 = ((Kₜₓₓ_sum + sum_diagₓ) / (m * m) + + + (Kₜᵧᵧ_sum + sum_diagᵧ) / (m * m) + - + 2.0 * Kₓᵧ_sum / (m * m)) + else + mmd2 = (Kₜₓₓ_sum / (m * (m - 1)) + + + Kₜᵧᵧ_sum / (m * (m - 1)) + - + 2.0 * Kₓᵧ_sum / (m * m)) + end + + var_est = ( + 2.0 / (m^2 * (m - 1.0)^2) * (2 * dot(Kₜₓₓ_sums, Kₜₓₓ_sums) - Kₜₓₓ_2_sum + 2 * dot(Kₜᵧᵧ_sums, Kₜᵧᵧ_sums) - Kₜᵧᵧ_2_sum) + - + (4.0 * m - 6.0) / (m^3 * (m - 1.0)^3) * (Kₜₓₓ_sum^2 + Kₜᵧᵧ_sum^2) + + + 4.0 * (m - 2.0) / (m^3 * (m - 1.0)^2) * (dot(Kₓᵧ_sums₂, Kₓᵧ_sums₂) + dot(Kₓᵧ_sums₁, Kₓᵧ_sums₁)) + - + 4.0 * (m - 3.0) / (m^3 * (m - 1.0)^2) * (Kₓᵧ_2_sum) - (8 * m - 12) / (m^5 * (m - 1)) * Kₓᵧ_sum^2 + + + 8.0 / (m^3 * (m - 1.0)) * ( + 1.0 / m * (Kₜₓₓ_sum + Kₜᵧᵧ_sum) * Kₓᵧ_sum + - + dot(Kₜₓₓ_sums, Kₓᵧ_sums₂) + - + dot(Kₜᵧᵧ_sums, Kₓᵧ_sums₁)) + ) + + return mmd2, var_est +end diff --git a/MMD_GAN/mmd_gan_1d.jl b/MMD_GAN/mmd_gan_1d.jl new file mode 100644 index 0000000..83b07f1 --- /dev/null +++ b/MMD_GAN/mmd_gan_1d.jl @@ -0,0 +1,126 @@ +using Base.Iterators: partition +using Flux +using Flux.Optimise: update! +using Flux: logitbinarycrossentropy, binarycrossentropy +using Statistics +using Parameters: @with_kw +using Random +using Printf +using CUDA +using Zygote +using Distributions + +include("./mmd.jl") + +@with_kw struct HyperParams + data_size::Int = 10000 + batch_size::Int = 128 + latent_dim::Int = 1 + epochs::Int = 1000 + verbose_freq::Int = 1000 + num_gen::Int = 1 + num_enc_dec::Int = 5 + lr_enc::Float64 = 1.0e-4 + lr_dec::Float64 = 1.0e-4 + lr_gen::Float64 = 1.0e-4 + + lambda_AE::Float64 = 8.0 + target_param::Tuple{Float64,Float64} = (23.0, 1.0) + noise_param::Tuple{Float64,Float64} = (0.0, 1.0) + base::Float64 = 1.0 + sigma_list::Array{Float64,1} = [1.0, 2.0, 4.0, 8.0, 16.0] ./ base +end + +function generator() + return Chain( + Dense(1, 7), + elu, + Dense(7, 13), + elu, + Dense(13, 7), + elu, + Dense(7, 1) + ) +end + +function encoder() + return Chain(Dense(1, 11), elu, Dense(11, 29), elu) +end + +function decoder() + return Chain(Dense(29, 11), elu, Dense(11, 1)) +end + +function data_sampler(hparams, target) + return rand(Normal(target[1], target[2]), (hparams.batch_size, 1)) +end + +# Initialize models and optimizers +function train() + hparams = HyperParams() + mse = Flux.mse + + gen = generator() + enc = encoder() + dec = decoder() + + # Optimizers + gen_opt = ADAM(hparams.lr_gen) + enc_opt = ADAM(hparams.lr_enc) + dec_opt = ADAM(hparams.lr_dec) + + cum_dis_loss = 0.0 + cum_gen_loss = 0.0 + + # Training + losses_gen = [] + losses_dscr = [] + train_steps = 0 + # Training loop + gen_ps = Flux.params(gen) + enc_ps = Flux.params(enc) + dec_ps = Flux.params(dec) + @showprogress for ep in 1:hparams.epochs + for _ in 1:hparams.num_enc_dec + loss, back = Zygote.pullback(Flux.params(enc, dec)) do + target = data_sampler(hparams, hparams.target_param) + noise = data_sampler(hparams, hparams.noise_param) + encoded_target = enc(target') + decoded_target = dec(encoded_target) + L2_AE_target = Flux.mse(decoded_target, target) + transformed_noise = gen(noise') + encoded_noise = enc(transformed_noise) + decoded_noise = dec(encoded_noise) + L2_AE_noise = Flux.mse(decoded_noise, transformed_noise) + MMD = mix_rbf_mmd2(encoded_target, encoded_noise, hparams.sigma_list) + MMD = relu(MMD) + L_MMD_AE = -1.0 * (sqrt(MMD) - hparams.lambda_AE * (L2_AE_noise + L2_AE_target)) + end + grads = back(1.0f0) + update!(enc_opt, enc_ps, grads) + update!(dec_opt, dec_ps, grads) + push!(losses_dscr, loss) + end + for _ in 1:hparams.num_gen + loss, back = Zygote.pullback(gen_ps) do + target = data_sampler(hparams, hparams.target_param) + noise = data_sampler(hparams, hparams.noise_param) + encoded_target = enc(target') + encoded_noise = enc(gen(noise')) + MMD = sqrt(relu(mix_rbf_mmd2(encoded_target, encoded_noise, hparams.sigma_list))) + end + grads = back(1.0f0) + update!(gen_opt, gen_ps, grads) + push!(losses_gen, loss) + end + end +end + +function plot_results(n_samples, range) + target = [data_sampler(hparams, hparams.target_param) for _ in 1:n_samples] + target = collect(Iterators.flatten(target)) + transformed_noise = [gen(data_sampler(hparams, hparams.noise_param)')' for _ in 1:n_samples] + transformed_noise = collect(Iterators.flatten(transformed_noise)) + histogram(target, bins=range) + histogram!(transformed_noise, bins=range) +end diff --git a/MMD_GAN/test_mmd.jl b/MMD_GAN/test_mmd.jl new file mode 100644 index 0000000..3374e61 --- /dev/null +++ b/MMD_GAN/test_mmd.jl @@ -0,0 +1,29 @@ +using Test + +include("./mmd.jl") + +tol = 1e-5 + +@testset "mix_rbf_mmd2" begin + X = [0.4126 0.334] + Y = [-0.3432 0.45] + sigma_list = [1.0, 2.0, 4.0, 8.0, 16.0] ./ 1.0 + @test mix_rbf_mmd2(X, Y, sigma_list) == 0.6955453052850462 + + X = [-0.9964 1.6757 -1.0000 -0.9621 3.0564; -0.9966 1.6849 -1.0000 -0.9630 3.0814] + Y = [-0.4433 0.1915 0.2270 0.0989 -0.1531; -0.4603 0.2200 0.1087 0.0844 -0.1827] + @test mix_rbf_mmd2(X, Y, sigma_list) == 4.754166275960963 + + X = [0.0588 24.6208 0.6140 1.9435 21.2479; 0.0780 23.8120 0.5976 1.8904 20.5289] + Y = [0.2057 18.4525 0.4892 1.5384 15.7635; 0.1529 20.6706 0.5340 1.6841 17.7357] + @test mix_rbf_mmd2(X, Y, sigma_list) == 4.659662358316032 + +end; + +@testset "_mmd2" begin + Kₓₓ = [5.0 4.999526869857079; 4.999526869857079 5.0] + Kₓᵧ = [2.6194483555739208 2.629665928762427; 2.610644548009624 2.6206924815501655] + Kᵧᵧ = [5.0 4.989256995960984; 4.989256995960984 5.0] + @test _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ) == 4.748558208869993 + @test _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ, false, true) == 4.754166275960963 +end;