Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MMD_GAN/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "blue"
166 changes: 166 additions & 0 deletions MMD_GAN/mmd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
using LinearAlgebra

const min_var_est = 1e-8

# Helper function to compute linear time MMD with a linear kernel
function linear_mmd2(fₓ, fᵧ)
Δ = fₓ - fᵧ
loss = mean(sum(Δ[1:end-1, :] .* Δ[2:end, :], dims=2))
return loss
end

# Helper function to compute linear time MMD with a polynomial kernel
function poly_mmd2(fₓ, fᵧ; d=2, α=1.0, c=2.0)
Kₓₓ = α * sum(fₓ[1:end-1, :] .* fₓ[2:end, :], dims=2) .+ c
Kᵧᵧ = α * sum(fᵧ[1:end-1, :] .* fᵧ[2:end, :], dims=2) .+ c
Kₓᵧ = α * sum(fₓ[1:end-1, :] .* fᵧ[2:end, :], dims=2) .+ c
Kᵧₓ = α * sum(fᵧ[1:end-1, :] .* fₓ[2:end, :], dims=2) .+ c

K̃ₓₓ = mean(Kₓₓ .^ d)
K̃ᵧᵧ = mean(Kᵧᵧ .^ d)
K̃ₓᵧ = mean(Kₓᵧ .^ d)
K̃ᵧₓ = mean(Kᵧₓ .^ d)

return K̃ₓₓ + Kᵧᵧ - Kₓᵧ - Kᵧₓ
end

# Helper function to compute mixed radial basis function kernel
function _mix_rbf_kernel(X, Y, sigma_list)
m = size(X, 1)

Z = vcat(X, Y)
ZZₜ = Z * Z'
diag_ZZₜ = diag(ZZₜ)
Z_norm_sqr = broadcast(+, diag_ZZₜ, zeros(size(ZZₜ)))
exponent = Z_norm_sqr .- 2 * ZZₜ .+ Z_norm_sqr'

K = zeros(size(exponent))
for σ in sigma_list
γ = 1.0 / (2 * σ^2)
K += exp.(-γ * exponent)
end

return K[1:m, 1:m], K[1:m, m+1:end], K[m+1:end, m+1:end], length(sigma_list)
end

# Mixed Radial Basis Function Maximum Mean Discrepancy
function mix_rbf_mmd2(X, Y, sigma_list, biased=true)
@assert size(X, 1) == size(Y, 1) "X and Y must have the same number of rows"
Kₓₓ, Kₓᵧ, Kᵧᵧ, d = _mix_rbf_kernel(X, Y, sigma_list)
return _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ, false, biased)
end

function mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=true)
Kₓₓ, Kₓᵧ, Kᵧᵧ, d = _mix_rbf_kernel(X, Y, sigma_list)
return _mmd2_and_ratio(Kₓₓ, Kₓᵧ, Kᵧᵧ, false, biased)
end

# Helper function to compute variance based on kernel matrices
function _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ, const_diagonal=false, biased=false)
m = size(Kₓₓ, 1)

# Get the various sums of kernels that we'll use
if const_diagonal !== false
diagₓ = diagᵧ = const_diagonal
sum_diagₓ = sum_diagᵧ = m * const_diagonal
else
diagₓ = diag(Kₓₓ)
diagᵧ = diag(Kᵧᵧ)
sum_diagₓ = sum(diagₓ)
sum_diagᵧ = sum(diagᵧ)
end

Kₜₓₓ_sums = sum(Kₓₓ, dims=2) .- diagₓ
Kₜᵧᵧ_sums = sum(Kᵧᵧ, dims=2) .- diagᵧ
Kₓᵧ_sums₁ = sum(Kₓᵧ, dims=1)

Kₜₓₓ_sum = sum(Kₜₓₓ_sums)
Kₜᵧᵧ_sum = sum(Kₜᵧᵧ_sums)
Kₓᵧ_sum = sum(Kₓᵧ_sums₁)

if biased
mmd2 = ((Kₜₓₓ_sum + sum_diagₓ) / (m * m)
+
(Kₜᵧᵧ_sum + sum_diagᵧ) / (m * m)
-
2.0 * Kₓᵧ_sum / (m * m))
else
mmd2 = (Kₜₓₓ_sum / (m * (m - 1))
+
Kₜᵧᵧ_sum / (m * (m - 1))
-
2.0 * Kₓᵧ_sum / (m * m))
end

return mmd2
end

function mmd2_and_ratio(K_XX, K_XY, K_YY; const_diagonal::Bool=false, biased::Bool=false)
mmd2, var_est = mmd2_and_variance(Kₓₓ, Kₓᵧ, Kᵧᵧ, const_diagonal=const_diagonal, biased=biased)
loss = mmd2 / √(max(var_est, min_var_est))
return loss, mmd2, var_est
end

function _mmd2_and_variance(Kₓₓ, Kₓᵧ, Kᵧᵧ; const_diagonal=false, biased=false)
m = size(Kₓₓ, 1) # assume X, Y are the same shape

# Get the various sums of kernels that we'll use
if const_diagonal !== false
diagₓ = diagᵧ = const_diagonal
sum_diagₓ = sum_diagᵧ = m * const_diagonal
sum_diag2ₓ = sum_diag2ᵧ = m * const_diagonal^2
else
diagₓ = diagm(Kₓₓ) # (m,)
diagᵧ = diagm(Kᵧᵧ) # (m,)
sum_diagₓ = sum(diagₓ)
sum_diagᵧ = sum(diagᵧ)
sum_diag2ₓ = dot(diagₓ, diagₓ)
sum_diag2ᵧ = dot(diagᵧ, diagᵧ)
end

Kₜₓₓ_sums = sum(Kₓₓ, dims=2) .- diagₓ # \tilde{K}_XX * e = Kₓₓ * e - diagₓ
Kₜᵧᵧ_sums = sum(Kᵧᵧ, dims=2) .- diagᵧ # \tilde{K}_YY * e = Kᵧᵧ * e - diagᵧ
Kₓᵧ_sums₁ = sum(Kₓᵧ, dims=1) # K_{XY}^T * e
Kₓᵧ_sums₂ = sum(Kₓᵧ, dims=2) # K_{XY} * e

Kₜₓₓ_sum = sum(Kₜₓₓ_sums)
Kₜᵧᵧ_sum = sum(Kₜᵧᵧ_sums)
Kₓᵧ_sum = sum(Kₓᵧ_sums₁)

Kₜₓₓ_2_sum = sum(Kₓₓ .^ 2) .- sum_diag2ₓ # \| \tilde{K}_XX \|_F^2
Kₜᵧᵧ_2_sum = sum(Kᵧᵧ .^ 2) .- sum_diag2ᵧ # \| \tilde{K}_YY \|_F^2
Kₓᵧ_2_sum = sum(Kₓᵧ .^ 2) # \| K_{XY} \|_F^2

if biased
mmd2 = ((Kₜₓₓ_sum + sum_diagₓ) / (m * m)
+
(Kₜᵧᵧ_sum + sum_diagᵧ) / (m * m)
-
2.0 * Kₓᵧ_sum / (m * m))
else
mmd2 = (Kₜₓₓ_sum / (m * (m - 1))
+
Kₜᵧᵧ_sum / (m * (m - 1))
-
2.0 * Kₓᵧ_sum / (m * m))
end

var_est = (
2.0 / (m^2 * (m - 1.0)^2) * (2 * dot(Kₜₓₓ_sums, Kₜₓₓ_sums) - Kₜₓₓ_2_sum + 2 * dot(Kₜᵧᵧ_sums, Kₜᵧᵧ_sums) - Kₜᵧᵧ_2_sum)
-
(4.0 * m - 6.0) / (m^3 * (m - 1.0)^3) * (Kₜₓₓ_sum^2 + Kₜᵧᵧ_sum^2)
+
4.0 * (m - 2.0) / (m^3 * (m - 1.0)^2) * (dot(Kₓᵧ_sums₂, Kₓᵧ_sums₂) + dot(Kₓᵧ_sums₁, Kₓᵧ_sums₁))
-
4.0 * (m - 3.0) / (m^3 * (m - 1.0)^2) * (Kₓᵧ_2_sum) - (8 * m - 12) / (m^5 * (m - 1)) * Kₓᵧ_sum^2
+
8.0 / (m^3 * (m - 1.0)) * (
1.0 / m * (Kₜₓₓ_sum + Kₜᵧᵧ_sum) * Kₓᵧ_sum
-
dot(Kₜₓₓ_sums, Kₓᵧ_sums₂)
-
dot(Kₜᵧᵧ_sums, Kₓᵧ_sums₁))
)

return mmd2, var_est
end
126 changes: 126 additions & 0 deletions MMD_GAN/mmd_gan_1d.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
using Base.Iterators: partition
using Flux
using Flux.Optimise: update!
using Flux: logitbinarycrossentropy, binarycrossentropy
using Statistics
using Parameters: @with_kw
using Random
using Printf
using CUDA
using Zygote
using Distributions

include("./mmd.jl")

@with_kw struct HyperParams
data_size::Int = 10000
batch_size::Int = 128
latent_dim::Int = 1
epochs::Int = 1000
verbose_freq::Int = 1000
num_gen::Int = 1
num_enc_dec::Int = 5
lr_enc::Float64 = 1.0e-4
lr_dec::Float64 = 1.0e-4
lr_gen::Float64 = 1.0e-4

lambda_AE::Float64 = 8.0
target_param::Tuple{Float64,Float64} = (23.0, 1.0)
noise_param::Tuple{Float64,Float64} = (0.0, 1.0)
base::Float64 = 1.0
sigma_list::Array{Float64,1} = [1.0, 2.0, 4.0, 8.0, 16.0] ./ base
end

function generator()
return Chain(
Dense(1, 7),
elu,
Dense(7, 13),
elu,
Dense(13, 7),
elu,
Dense(7, 1)
)
end

function encoder()
return Chain(Dense(1, 11), elu, Dense(11, 29), elu)
end

function decoder()
return Chain(Dense(29, 11), elu, Dense(11, 1))
end

function data_sampler(hparams, target)
return rand(Normal(target[1], target[2]), (hparams.batch_size, 1))
end

# Initialize models and optimizers
function train()
hparams = HyperParams()
mse = Flux.mse

gen = generator()
enc = encoder()
dec = decoder()

# Optimizers
gen_opt = ADAM(hparams.lr_gen)
enc_opt = ADAM(hparams.lr_enc)
dec_opt = ADAM(hparams.lr_dec)

cum_dis_loss = 0.0
cum_gen_loss = 0.0

# Training
losses_gen = []
losses_dscr = []
train_steps = 0
# Training loop
gen_ps = Flux.params(gen)
enc_ps = Flux.params(enc)
dec_ps = Flux.params(dec)
@showprogress for ep in 1:hparams.epochs
for _ in 1:hparams.num_enc_dec
loss, back = Zygote.pullback(Flux.params(enc, dec)) do
target = data_sampler(hparams, hparams.target_param)
noise = data_sampler(hparams, hparams.noise_param)
encoded_target = enc(target')
decoded_target = dec(encoded_target)
L2_AE_target = Flux.mse(decoded_target, target)
transformed_noise = gen(noise')
encoded_noise = enc(transformed_noise)
decoded_noise = dec(encoded_noise)
L2_AE_noise = Flux.mse(decoded_noise, transformed_noise)
MMD = mix_rbf_mmd2(encoded_target, encoded_noise, hparams.sigma_list)
MMD = relu(MMD)
L_MMD_AE = -1.0 * (sqrt(MMD) - hparams.lambda_AE * (L2_AE_noise + L2_AE_target))
end
grads = back(1.0f0)
update!(enc_opt, enc_ps, grads)
update!(dec_opt, dec_ps, grads)
push!(losses_dscr, loss)
end
for _ in 1:hparams.num_gen
loss, back = Zygote.pullback(gen_ps) do
target = data_sampler(hparams, hparams.target_param)
noise = data_sampler(hparams, hparams.noise_param)
encoded_target = enc(target')
encoded_noise = enc(gen(noise'))
MMD = sqrt(relu(mix_rbf_mmd2(encoded_target, encoded_noise, hparams.sigma_list)))
end
grads = back(1.0f0)
update!(gen_opt, gen_ps, grads)
push!(losses_gen, loss)
end
end
end

function plot_results(n_samples, range)
target = [data_sampler(hparams, hparams.target_param) for _ in 1:n_samples]
target = collect(Iterators.flatten(target))
transformed_noise = [gen(data_sampler(hparams, hparams.noise_param)')' for _ in 1:n_samples]
transformed_noise = collect(Iterators.flatten(transformed_noise))
histogram(target, bins=range)
histogram!(transformed_noise, bins=range)
end
29 changes: 29 additions & 0 deletions MMD_GAN/test_mmd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Test

include("./mmd.jl")

tol = 1e-5

@testset "mix_rbf_mmd2" begin
X = [0.4126 0.334]
Y = [-0.3432 0.45]
sigma_list = [1.0, 2.0, 4.0, 8.0, 16.0] ./ 1.0
@test mix_rbf_mmd2(X, Y, sigma_list) == 0.6955453052850462

X = [-0.9964 1.6757 -1.0000 -0.9621 3.0564; -0.9966 1.6849 -1.0000 -0.9630 3.0814]
Y = [-0.4433 0.1915 0.2270 0.0989 -0.1531; -0.4603 0.2200 0.1087 0.0844 -0.1827]
@test mix_rbf_mmd2(X, Y, sigma_list) == 4.754166275960963

X = [0.0588 24.6208 0.6140 1.9435 21.2479; 0.0780 23.8120 0.5976 1.8904 20.5289]
Y = [0.2057 18.4525 0.4892 1.5384 15.7635; 0.1529 20.6706 0.5340 1.6841 17.7357]
@test mix_rbf_mmd2(X, Y, sigma_list) == 4.659662358316032

end;

@testset "_mmd2" begin
Kₓₓ = [5.0 4.999526869857079; 4.999526869857079 5.0]
Kₓᵧ = [2.6194483555739208 2.629665928762427; 2.610644548009624 2.6206924815501655]
Kᵧᵧ = [5.0 4.989256995960984; 4.989256995960984 5.0]
@test _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ) == 4.748558208869993
@test _mmd2(Kₓₓ, Kₓᵧ, Kᵧᵧ, false, true) == 4.754166275960963
end;