diff --git a/examples/domain_decomp/heat_equation.jl b/examples/domain_decomp/heat_equation.jl index e1224f5..e45a896 100644 --- a/examples/domain_decomp/heat_equation.jl +++ b/examples/domain_decomp/heat_equation.jl @@ -12,9 +12,7 @@ using BlockArrays include("fnc_utils.jl") - -# How many points in our mesh -N = 200 +N = 100 # Number of nodes in the global problem. # GLOBAL PROBLEM SETUP AND SOLVE ################################ @@ -26,8 +24,8 @@ bump(x, mu=0, sigma=10) = begin z = exp.(-(x .- mu) .^ 2 ./ sigma) end -# Construct the global source term. -rhs_func(x) = 3bump(x, 1 / 2, 1 / 20) - 3bump(x, -1 / 2, 1 / 20) + +rhs_func(x) = 2bump(x, 1 / 2, 1 / 20) - 4bump(x, -2 / 3, 1 / 30) # x, u = bvplin(1 / 20, x -> 0, x -> 0, x -> -rhs_func(x), [-1, 1], -1 / 2, -1 / 2, N) # x, u = bvplin(1 / 20, x -> 0, x -> 0, x -> -rhs_func(x), [-1, 1], 0,0, N) @@ -35,23 +33,27 @@ rhs_func(x) = 3bump(x, 1 / 2, 1 / 20) - 3bump(x, -1 / 2, 1 / 20) solveN = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-1, 1], N) x, u = solveN(0, 0) -p = plot(x, rhs_func.(x), label="b") -p = plot!(p, x, u, label="bvplin_soln") +p = plot(x, rhs_func.(x), label="b", lw=3, title="Solution n=$N") +p = plot!(p, x, u, label="bvplin_soln", lw=3, xlabel="x", ylabel="u", legend=:bottomright) +plt2 = deepcopy(p) + # LOCAL SOLVER SETUP AND SOLVE ############################## - # Index arithmetic to divide the mesh with AB amount of overlap. Nhalf = ceil(Int, N / 2) -AB = 10 + +xAright = 0.1 +xBleft = -0.1 +AB = length(findall(xBleft .<= x .<= xAright)) A = Nhalf + ceil(Int, AB / 2) B = Nhalf + ceil(Int, AB / 2) # Create the local solvers for each subdomain. -solveA = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-1, 0.1], A) -solveB = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-0.1, 1], B) +solveA = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [-1, xAright], A) +solveB = bvplin_solver(1 / 20, zero, zero, x -> -rhs_func(x), [xBleft, 1], B) # "wrap" a solver to just take in a u and return the projection of that u to the nearest solution. wrap_solver(solver) = u -> begin @@ -67,40 +69,97 @@ f1 = wrap_solver(solveA) f2 = wrap_solver(solveB) -# CELLULAR SHEAF SETUP -###################### +""" +Make restriction maps. In this case, they are both projections. +""" +function restriction_matrices(A, B, AB) + p1 = zeros(AB, A) + p1[1:AB, end-AB+1:end] .= I(AB) + + p2 = zeros(AB, B) + p2[1:AB, 1:AB] .= I(AB) + return p1, p2 +end + +p1, p2 = restriction_matrices(A, B, AB) + +function alternating_projection(u₀, niter=1) + function update(u1, u2) + u1, u2 = f1(u1), f2(u2) + mid = (p1 * u1 + p2 * u2) / 2 + u1[end-AB+1:end] = mid + u2[1:AB] = mid + return u1, u2 + end + u1 = u₀[1:A] + u2 = u₀[N-B:end] + for i in 1:niter + u1, u2 = update(u1, u2) + end + return vcat(u1, u2[AB+1:end]) +end +# Plot both the solutions and the error over iterations + +begin + rplt = plot(xlabel="x", ylabel="error", title="Error 1:$A, $(N-B):$N") + rplt_tail = plot(xlabel="x", ylabel="error", title="Error 1:$A, $(N-B):$N") + iters = [1, 5, 10, 50, 100, 200] + for i in iters + ualt = alternating_projection(zeros(N), i) + plot!(p, x, ualt, label="ualt_$i", linestyle=:dash, lw=2) + if i < 200 + plot!(rplt, x, ualt - u, label="resid_$i", lw=2, ls=:dash) + else + scatter!(rplt_tail, x[2:end], ((ualt - u) ./ u)[2:end], label="resid_$i", lw=2, ls=:dash) + end + println("Residual of ualt_$i: ", norm(ualt - u)) + end + vline!(p, [xBleft, xAright], linestyle=:dash) + vline!(rplt, [xBleft, xAright], linestyle=:dash) + plt = plot(p, rplt, rplt_tail, layout=[1; 1; 1], size=(800, 700)) +end +plt -# make restriction maps. In this case, they are both projections. -p1 = zeros(AB, A) -p1[1:AB, end-AB+1:end] .= I(AB) -p1 -p2 = zeros(AB, B) -p2[1:AB, 1:AB] = I(AB) -p2 +# CELLULAR SHEAF SETUP +###################### # Make cellular sheaf. -s = CellularSheaf([A, B], [AB]) -set_edge_maps!(s, 1, 2, 1, p1, p2) +#s = CellularSheaf([A, B], [AB]) +#set_edge_maps!(s, 1, 2, 1, p1, p2) # Set up homological program using the local solvers we defined earlier. -hp = CollocationHP([f1, f2], s) +# hp = CollocationHP([f1, f2], s) +# CELLULAR SHEAF SETUP +###################### # Use ADMM to solve the HP. -primal_sol, dual_sol = solve(hp, ADMM(2.0, 1)) +#primal_sol, dual_sol = solve(hp, ADMM(2.0, 1)) + +#=function lift_matching_family(primal_sol) + u1 = primal_sol[Block(1)] + u2 = primal_sol[Block(2)] + u_solA = vcat(u1[1:end-AB], u2) + u_solB = vcat(u1, u2[AB+1:end]) + return u_solA, u_solB +end =# -# Solution analysis and visualization. -u1 = primal_sol[Block(1)] -u2 = primal_sol[Block(2)] +#u_solA, u_solB = lift_matching_family(primal_sol) +#@show norm(u_solA - u_solB) -u_solA = vcat(u1[1:end-AB], u2) -u_solB = vcat(u1, u2[AB+1:end]) +# p = scatter!(p, x, u_solA, label="hpA") +# p = scatter!(p, x, u_solB, label="hpB") +# plot!(p, u, color=:teal, label="reference") +# plot!(p, rhs_func.(x), color=:purple, label="b") + +#=u₀ = vcat(u[1:A], u[N-B+1:end]) +primal_sol, dual_sol = solve(hp, ADMM(2.0, 1), u₀) +u_solA, u_solB = lift_matching_family(primal_sol) @show norm(u_solA - u_solB) +p = scatter!(p, x, u_solB, label="hp-fp")=# -p = scatter!(p, x, u_solA, label="hpA") -p = scatter!(p, x, u_solB, label="hpB") -# plot!(p, u, color=:teal, label="reference") -# plot!(p, rhs_func.(x), color=:purple, label="b") \ No newline at end of file +#u1 = f1(primal_sol[Block(1)]) +#u2 = f2(primal_sol[Block(2)]) \ No newline at end of file diff --git a/src/homological_programming/HomologicalPrograms.jl b/src/homological_programming/HomologicalPrograms.jl index 17bf73a..722333a 100644 --- a/src/homological_programming/HomologicalPrograms.jl +++ b/src/homological_programming/HomologicalPrograms.jl @@ -41,18 +41,31 @@ struct CollocationHP <: AbstractHomologicalProgram sheaf::AbstractCellularSheaf end + # ADMM solver for domain decomposition methods. -function solve(h::CollocationHP, alg::ADMM) - # Initialize storage for algorithm state. - y = BlockArray(zeros(sum(h.sheaf.vertex_stalks)), h.sheaf.vertex_stalks) # Dual variable - z = BlockArray(zeros(sum(h.sheaf.vertex_stalks)), h.sheaf.vertex_stalks) # Primal variable - x_star = BlockArray(zeros(sum(h.sheaf.vertex_stalks)), h.sheaf.vertex_stalks) # Temporary variable +function solve(h::CollocationHP, alg::ADMM, x₀=nothing) + stalks = h.sheaf.vertex_stalks + n = sum(stalks) + # Dual variable + y = BlockArray(zeros(n), stalks) + # Primal variable + z = BlockArray(zeros(n), stalks) + # Temporary variable + x_star = !isnothing(x₀) ? + BlockArray(x₀, stalks) : + BlockArray(zeros(n), stalks) + + #regularized_objectives = [(z, y) -> (x -> f(x) + alg.step_size / 2 * (x - z + y)' * (x - z + y)) for f in h.objectives] + # Run the ADMM iteration. for k in 1:alg.num_iters # Run local solver for each node of the cellular sheaf for (i, f) in enumerate(h.node_solvers) res_x = f(z - y) + #res_x = optimize(f(z[Block(i)], y[Block(i)]), zeros(stalks[i]), LBFGS(); autodiff=:forward) + println(length(x_star[Block(i)])) + println(length(res_x)) x_star[Block(i)] = res_x end # Project to the nearest global section of the cellular sheaf