Skip to content

Commit a2311be

Browse files
committed
significantly reduces allocations for relaxation.
1 parent 2b8d58b commit a2311be

File tree

2 files changed

+14
-114
lines changed

2 files changed

+14
-114
lines changed

src/callbacks_step/relaxation.jl

Lines changed: 10 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -68,159 +68,57 @@ end
6868
gamma_hi = 3 * one(tnew) / 2
6969

7070
@unpack tmp1 = semi.cache # of size N
71-
tmp222 = similar(qold) # of size nvariables * N
71+
tmp2 = similar(qold) # of size nvariables * N
7272
# @unpack tmp222 = semi.cache # of size nvariables * N and ArrayPartition
7373

74-
7574
function relaxation_functional(tmp1, q, semi)
7675
return integrate_quantity!(tmp1, relaxation_callback.invariant, q, semi)
7776
end
7877

79-
function convex_combination!(tmp222, gamma, old, new)
80-
@.. tmp222 = old + gamma * (new - old)
81-
return nothing
78+
function convex_combination(gamma, told, tnew) # for scalars
79+
return @.. told + gamma * (tnew - told)
8280
end
8381

84-
function convex_combination(gamma, old, new)
85-
return @.. old + gamma * (new - old)
86-
end
87-
88-
function root(g)
89-
convex_combination!(tmp222, g, qold, qnew)
90-
return (relaxation_functional(tmp1, tmp222 ,semi) - energy_old)
91-
end
92-
93-
energy_old = relaxation_functional(tmp1, qold, semi)
94-
95-
@trixi_timeit timer() "relaxation" begin
96-
convex_combination!(tmp222, gamma_lo, qold, qnew)
97-
val1 = relaxation_functional(tmp1, tmp222, semi) - energy_old
98-
99-
convex_combination!(tmp222, gamma_hi, qold, qnew)
100-
val2 = relaxation_functional(tmp1, tmp222, semi) - energy_old
101-
102-
if (val1 * val2) > 0
103-
terminate_integration = true
104-
else
105-
gamma = find_zero(root, (gamma_lo, gamma_hi), AlefeldPotraShi())
106-
end
107-
108-
if gamma < eps(typeof(gamma))
109-
terminate_integration = true
110-
end
111-
112-
113-
convex_combination!(tmp222, gamma, qold, qnew)
114-
DiffEqBase.set_u!(integrator, tmp222)
115-
116-
if !isapprox(tnew, first(integrator.opts.tstops))
117-
118-
# convex_combination!(tmp222, gamma, told, tnew)
119-
120-
tgamma = convex_combination(gamma, told, tnew)
121-
DiffEqBase.set_t!(integrator, tgamma)
122-
end
123-
124-
if terminate_integration
125-
terminate!(integrator)
126-
end
127-
end
128-
return nothing
129-
end
130-
131-
132-
133-
#= This method is called as callback during the time integration.
134-
@inline function (relaxation_callback::RelaxationCallback)(integrator)
135-
semi = integrator.p
136-
told = integrator.tprev
137-
qold = integrator.uprev
138-
tnew = integrator.t
139-
qnew = integrator.u
140-
141-
terminate_integration = false
142-
gamma_lo = one(tnew) / 2
143-
gamma_hi = 3 * one(tnew) / 2
144-
145-
@unpack tmp1 = semi.cache # of size N
146-
@unpack tmp222 = semi.cache # of size nvariables * N
147-
tmp2 = tmp222
148-
tmp3 = similar(tmp2)
149-
150-
151-
152-
function relaxation_functional(tmp1, q, semi)
153-
return integrate_quantity!(tmp1, relaxation_callback.invariant, q, semi)
154-
end
155-
156-
function convex_combination!(tmp2, gamma, old, new)
157-
@. tmp2 = old + gamma * (new - old)
82+
function convex_combination!(tmp2, gamma, uold, unew) # for arrays
83+
@.. tmp2 = uold + gamma * (unew - uold)
15884
return nothing
15985
end
16086

161-
function convex_combination(gamma, old, new)
162-
return @.. old + gamma * (new - old)
163-
end
164-
16587
function root(g)
16688
convex_combination!(tmp2, g, qold, qnew)
167-
return (relaxation_functional(tmp1, tmp2 ,semi) - energy_old)
168-
end
169-
170-
function relaxation_functional2(q, semi)
171-
@unpack tmp1 = semi.cache
172-
return integrate_quantity!(tmp1, relaxation_callback.invariant, q, semi)
89+
return (relaxation_functional(tmp1, tmp2, semi) - energy_old)
17390
end
17491

17592
energy_old = relaxation_functional(tmp1, qold, semi)
17693

17794
@trixi_timeit timer() "relaxation" begin
95+
convex_combination!(tmp2, gamma_lo, qold, qnew)
96+
val1 = relaxation_functional(tmp1, tmp2, semi) - energy_old
17897

17998
convex_combination!(tmp2, gamma_hi, qold, qnew)
180-
val2 = relaxation_functional2(tmp2, semi) - energy_old
181-
182-
convex_combination!(tmp3, gamma_lo, qold, qnew)
183-
val1 = relaxation_functional2(tmp3, semi) - energy_old
184-
185-
186-
teststs =
187-
val1_dif = relaxation_functional2(convex_combination(gamma_lo, qold, qnew), semi) - energy_old
188-
val2_dif = relaxation_functional2(convex_combination(gamma_hi, qold, qnew), semi) -energy_old
189-
190-
@show val1
191-
@show val1_dif
192-
@show val2
193-
@show val2_dif
194-
99+
val2 = relaxation_functional(tmp1, tmp2, semi) - energy_old
195100

196101
if (val1 * val2) > 0
197102
terminate_integration = true
198103
else
199104
gamma = find_zero(root, (gamma_lo, gamma_hi), AlefeldPotraShi())
200-
@show gamma
201105
end
202106

203107
if gamma < eps(typeof(gamma))
204108
terminate_integration = true
205109
end
206110

207-
208111
convex_combination!(tmp2, gamma, qold, qnew)
209112
DiffEqBase.set_u!(integrator, tmp2)
210113

211114
if !isapprox(tnew, first(integrator.opts.tstops))
212-
213-
# convex_combination!(tmp2, gamma, told, tnew)
214-
215-
tgamma = convex_combination(gamma, told, tnew)
115+
tgamma = convex_combination(gamma, told, tnew) # scalar combination
216116
DiffEqBase.set_t!(integrator, tgamma)
217117
end
218118

219119
if terminate_integration
220120
terminate!(integrator)
221121
end
222122
end
223-
@show "einmal durch"
224123
return nothing
225124
end
226-
=#

src/semidiscretization.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ function Semidiscretization(mesh, equations, initial_condition, solver;
5858
RealT = real(solver), uEltype = RealT,
5959
# tmp1 is needed for the `RelaxationCallback`
6060
initial_cache = (tmp1 = Array{RealT}(undef, nnodes(mesh)),
61-
tmp222 = ArrayPartition(ntuple(_ -> zeros(real(solver), nnodes(mesh)),
62-
Val(nvariables(equations))))))
61+
#tmp222 = ArrayPartition(ntuple(_ -> zeros(real(solver), nnodes(mesh)),
62+
#tmp222 = Array{RealT}(undef, nvariables(equations)*nnodes(mesh),
63+
# Val(nvariables(equations)))
64+
))
6365
cache = (;
6466
create_cache(mesh, equations, solver, initial_condition, boundary_conditions,
6567
RealT, uEltype)...,

0 commit comments

Comments
 (0)