Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 92 additions & 33 deletions examples/domain_decomp/heat_equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ using BlockArrays

include("fnc_utils.jl")


# How many points in our mesh
N = 200
N = 100 # Number of nodes in the global problem.

# GLOBAL PROBLEM SETUP AND SOLVE
################################
Expand All @@ -26,32 +24,36 @@ bump(x, mu=0, sigma=10) = begin
z = exp.(-(x .- mu) .^ 2 ./ sigma)
end

# Construct the global source term.
rhs_func(x) = 3bump(x, 1 / 2, 1 / 20) - 3bump(x, -1 / 2, 1 / 20)

rhs_func(x) = 2bump(x, 1 / 2, 1 / 20) - 4bump(x, -2 / 3, 1 / 30)
# x, u = bvplin(1 / 20, x -> 0, x -> 0, x -> -rhs_func(x), [-1, 1], -1 / 2, -1 / 2, N)
# x, u = bvplin(1 / 20, x -> 0, x -> 0, x -> -rhs_func(x), [-1, 1], 0,0, N)

# Compute the correct global solution for comparison.
solveN = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-1, 1], N)
x, u = solveN(0, 0)

p = plot(x, rhs_func.(x), label="b")
p = plot!(p, x, u, label="bvplin_soln")
p = plot(x, rhs_func.(x), label="b", lw=3, title="Solution n=$N")
p = plot!(p, x, u, label="bvplin_soln", lw=3, xlabel="x", ylabel="u", legend=:bottomright)
plt2 = deepcopy(p)


# LOCAL SOLVER SETUP AND SOLVE
##############################


# Index arithmetic to divide the mesh with AB amount of overlap.
Nhalf = ceil(Int, N / 2)
AB = 10

xAright = 0.1
xBleft = -0.1
AB = length(findall(xBleft .<= x .<= xAright))
A = Nhalf + ceil(Int, AB / 2)
B = Nhalf + ceil(Int, AB / 2)


# Create the local solvers for each subdomain.
solveA = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-1, 0.1], A)
solveB = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-0.1, 1], B)
solveA = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-1, xAright], A)
solveB = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [xBleft, 1], B)

# "wrap" a solver to just take in a u and return the projection of that u to the nearest solution.
wrap_solver(solver) = u -> begin
Expand All @@ -67,40 +69,97 @@ f1 = wrap_solver(solveA)
f2 = wrap_solver(solveB)


# CELLULAR SHEAF SETUP
######################
"""
Make restriction maps. In this case, they are both projections.
"""
function restriction_matrices(A, B, AB)
p1 = zeros(AB, A)
p1[1:AB, end-AB+1:end] .= I(AB)

p2 = zeros(AB, B)
p2[1:AB, 1:AB] .= I(AB)
return p1, p2
end

p1, p2 = restriction_matrices(A, B, AB)

function alternating_projection(u₀, niter=1)
function update(u1, u2)
u1, u2 = f1(u1), f2(u2)
mid = (p1 * u1 + p2 * u2) / 2
u1[end-AB+1:end] = mid
u2[1:AB] = mid
return u1, u2
end
u1 = u₀[1:A]
u2 = u₀[N-B:end]
for i in 1:niter
u1, u2 = update(u1, u2)
end
return vcat(u1, u2[AB+1:end])
end

# Plot both the solutions and the error over iterations

begin
rplt = plot(xlabel="x", ylabel="error", title="Error 1:$A, $(N-B):$N")
rplt_tail = plot(xlabel="x", ylabel="error", title="Error 1:$A, $(N-B):$N")
iters = [1, 5, 10, 50, 100, 200]
for i in iters
ualt = alternating_projection(zeros(N), i)
plot!(p, x, ualt, label="ualt_$i", linestyle=:dash, lw=2)
if i < 200
plot!(rplt, x, ualt - u, label="resid_$i", lw=2, ls=:dash)
else
scatter!(rplt_tail, x[2:end], ((ualt - u) ./ u)[2:end], label="resid_$i", lw=2, ls=:dash)
end
println("Residual of ualt_$i: ", norm(ualt - u))
end
vline!(p, [xBleft, xAright], linestyle=:dash)
vline!(rplt, [xBleft, xAright], linestyle=:dash)
plt = plot(p, rplt, rplt_tail, layout=[1; 1; 1], size=(800, 700))
end
plt

# make restriction maps. In this case, they are both projections.
p1 = zeros(AB, A)
p1[1:AB, end-AB+1:end] .= I(AB)
p1

p2 = zeros(AB, B)
p2[1:AB, 1:AB] = I(AB)
p2
# CELLULAR SHEAF SETUP
######################

# Make cellular sheaf.
s = CellularSheaf([A, B], [AB])
set_edge_maps!(s, 1, 2, 1, p1, p2)
#s = CellularSheaf([A, B], [AB])
#set_edge_maps!(s, 1, 2, 1, p1, p2)

# Set up homological program using the local solvers we defined earlier.
hp = CollocationHP([f1, f2], s)
# hp = CollocationHP([f1, f2], s)
# CELLULAR SHEAF SETUP
######################

# Use ADMM to solve the HP.
primal_sol, dual_sol = solve(hp, ADMM(2.0, 1))
#primal_sol, dual_sol = solve(hp, ADMM(2.0, 1))

#=function lift_matching_family(primal_sol)
u1 = primal_sol[Block(1)]
u2 = primal_sol[Block(2)]
u_solA = vcat(u1[1:end-AB], u2)
u_solB = vcat(u1, u2[AB+1:end])
return u_solA, u_solB
end =#

# Solution analysis and visualization.
u1 = primal_sol[Block(1)]
u2 = primal_sol[Block(2)]
#u_solA, u_solB = lift_matching_family(primal_sol)
#@show norm(u_solA - u_solB)


u_solA = vcat(u1[1:end-AB], u2)
u_solB = vcat(u1, u2[AB+1:end])
# p = scatter!(p, x, u_solA, label="hpA")
# p = scatter!(p, x, u_solB, label="hpB")
# plot!(p, u, color=:teal, label="reference")
# plot!(p, rhs_func.(x), color=:purple, label="b")

#=u₀ = vcat(u[1:A], u[N-B+1:end])
primal_sol, dual_sol = solve(hp, ADMM(2.0, 1), u₀)
u_solA, u_solB = lift_matching_family(primal_sol)
@show norm(u_solA - u_solB)
p = scatter!(p, x, u_solB, label="hp-fp")=#


p = scatter!(p, x, u_solA, label="hpA")
p = scatter!(p, x, u_solB, label="hpB")
# plot!(p, u, color=:teal, label="reference")
# plot!(p, rhs_func.(x), color=:purple, label="b")
#u1 = f1(primal_sol[Block(1)])
#u2 = f2(primal_sol[Block(2)])
23 changes: 18 additions & 5 deletions src/homological_programming/HomologicalPrograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,31 @@ struct CollocationHP <: AbstractHomologicalProgram
sheaf::AbstractCellularSheaf
end


# ADMM solver for domain decomposition methods.
function solve(h::CollocationHP, alg::ADMM)
# Initialize storage for algorithm state.
y = BlockArray(zeros(sum(h.sheaf.vertex_stalks)), h.sheaf.vertex_stalks) # Dual variable
z = BlockArray(zeros(sum(h.sheaf.vertex_stalks)), h.sheaf.vertex_stalks) # Primal variable
x_star = BlockArray(zeros(sum(h.sheaf.vertex_stalks)), h.sheaf.vertex_stalks) # Temporary variable
function solve(h::CollocationHP, alg::ADMM, x₀=nothing)
stalks = h.sheaf.vertex_stalks
n = sum(stalks)
# Dual variable
y = BlockArray(zeros(n), stalks)
# Primal variable
z = BlockArray(zeros(n), stalks)
# Temporary variable
x_star = !isnothing(x₀) ?
BlockArray(x₀, stalks) :
BlockArray(zeros(n), stalks)

#regularized_objectives = [(z, y) -> (x -> f(x) + alg.step_size / 2 * (x - z + y)' * (x - z + y)) for f in h.objectives]


# Run the ADMM iteration.
for k in 1:alg.num_iters
# Run local solver for each node of the cellular sheaf
for (i, f) in enumerate(h.node_solvers)
res_x = f(z - y)
#res_x = optimize(f(z[Block(i)], y[Block(i)]), zeros(stalks[i]), LBFGS(); autodiff=:forward)
println(length(x_star[Block(i)]))
println(length(res_x))
x_star[Block(i)] = res_x
end
# Project to the nearest global section of the cellular sheaf
Expand Down
Loading