171171
172172@inline foreach_system (f, systems) = foreach_noalloc (f, systems)
173173
174+ # This is just for readability to loop over all systems with indices without allocations.
175+ @inline function foreach_system_indexed (f, semi:: Union{NamedTuple, Semidiscretization} )
176+ return foreach_system_indexed (f, semi. systems)
177+ end
178+
179+ @inline function foreach_system_indexed (f, systems:: Tuple )
180+ indices = ntuple (identity, Val (length (systems)))
181+
182+ return foreach_noalloc (indices, systems) do (index, system)
183+ f (index, system)
184+ end
185+ end
186+
187+ # This is just for readability to loop over all systems with wrapped arrays.
188+ @inline function foreach_system_wrapped (f, semi:: Union{NamedTuple, Semidiscretization} ,
189+ v_ode, u_ode)
190+ return foreach_system_indexed (semi) do system_index, system
191+ f (system,
192+ wrap_v (v_ode, system, semi, system_index),
193+ wrap_u (u_ode, system, semi, system_index))
194+ end
195+ end
196+
197+ @inline function foreach_system_wrapped (f, semi:: Union{NamedTuple, Semidiscretization} ,
198+ dv_ode, v_ode, u_ode)
199+ return foreach_system_indexed (semi) do system_index, system
200+ f (system,
201+ wrap_v (dv_ode, system, semi, system_index),
202+ wrap_v (v_ode, system, semi, system_index),
203+ wrap_u (u_ode, system, semi, system_index))
204+ end
205+ end
174206"""
175207 semidiscretize(semi, tspan; reset_threads=true)
176208
@@ -334,24 +366,40 @@ end
334366# We have to pass `system` here for type stability,
335367# since the type of `system` determines the return type.
336368@inline function wrap_v (v_ode, system, semi)
337- (; ranges_v) = semi
369+ return wrap_v (v_ode, system, semi, system_indices (system, semi))
370+ end
338371
339- range = ranges_v[system_indices (system, semi)]
372+ @inline function wrap_v (v_ode, system, semi, system_index:: Integer )
373+ return wrap_v (v_ode, system, semi. ranges_v[system_index])
374+ end
340375
341- @boundscheck @assert length (range) ==
342- v_nvariables (system) * n_integrated_particles (system)
376+ @inline function wrap_v (v_ode, system, range:: AbstractUnitRange )
377+ @boundscheck begin
378+ expected = v_nvariables (system) * n_integrated_particles (system)
379+ range_length = length (range)
380+ range_length == expected ||
381+ throw (DimensionMismatch (" v range length $range_length does not match expected $expected " ))
382+ end
343383
344384 return wrap_array (v_ode, range,
345385 (StaticInt (v_nvariables (system)), n_integrated_particles (system)))
346386end
347387
348388@inline function wrap_u (u_ode, system, semi)
349- (; ranges_u) = semi
389+ return wrap_u (u_ode, system, semi, system_indices (system, semi))
390+ end
350391
351- range = ranges_u[system_indices (system, semi)]
392+ @inline function wrap_u (u_ode, system, semi, system_index:: Integer )
393+ return wrap_u (u_ode, system, semi. ranges_u[system_index])
394+ end
352395
353- @boundscheck @assert length (range) ==
354- u_nvariables (system) * n_integrated_particles (system)
396+ @inline function wrap_u (u_ode, system, range:: AbstractUnitRange )
397+ @boundscheck begin
398+ expected = u_nvariables (system) * n_integrated_particles (system)
399+ range_length = length (range)
400+ range_length == expected ||
401+ throw (DimensionMismatch (" u range length $range_length does not match expected $expected " ))
402+ end
355403
356404 return wrap_array (u_ode, range,
357405 (StaticInt (u_nvariables (system)), n_integrated_particles (system)))
@@ -639,19 +687,18 @@ end
639687
640688function system_interaction! (dv_ode, v_ode, u_ode, semi)
641689 # Call `interact!` for each pair of systems
642- foreach_system (semi) do system
643- foreach_system (semi) do neighbor
690+ foreach_system_indexed (semi) do system_index, system
691+ foreach_system_indexed (semi) do neighbor_index, neighbor
644692 # Construct string for the interactions timer.
645693 # Avoid allocations from string construction when no timers are used.
646694 if timeit_debug_enabled ()
647- system_index = system_indices (system, semi)
648- neighbor_index = system_indices (neighbor, semi)
649695 timer_str = " $(timer_name (system))$system_index -$(timer_name (neighbor))$neighbor_index "
650696 else
651697 timer_str = " "
652698 end
653699
654- interact! (dv_ode, v_ode, u_ode, system, neighbor, semi, timer_str= timer_str)
700+ interact! (dv_ode, v_ode, u_ode, system, neighbor, semi,
701+ system_index, neighbor_index; timer_str= timer_str)
655702 end
656703 end
657704
@@ -663,12 +710,21 @@ end
663710# dv_ode, du_ode = copy(sol.u[end]).x; v_ode, u_ode = copy(sol.u[end]).x;
664711# @btime TrixiParticles.interact!($dv_ode, $v_ode, $u_ode, $fluid_system, $fluid_system, $semi);
665712@inline function interact! (dv_ode, v_ode, u_ode, system, neighbor, semi; timer_str= " " )
666- dv = wrap_v (dv_ode, system, semi)
667- v_system = wrap_v (v_ode, system, semi)
668- u_system = wrap_u (u_ode, system, semi)
713+ system_index = system_indices (system, semi)
714+ neighbor_index = system_indices (neighbor, semi)
715+
716+ return interact! (dv_ode, v_ode, u_ode, system, neighbor, semi,
717+ system_index, neighbor_index; timer_str= timer_str)
718+ end
719+
720+ @inline function interact! (dv_ode, v_ode, u_ode, system, neighbor, semi,
721+ system_index:: Integer , neighbor_index:: Integer ; timer_str= " " )
722+ dv = wrap_v (dv_ode, system, semi, system_index)
723+ v_system = wrap_v (v_ode, system, semi, system_index)
724+ u_system = wrap_u (u_ode, system, semi, system_index)
669725
670- v_neighbor = wrap_v (v_ode, neighbor, semi)
671- u_neighbor = wrap_u (u_ode, neighbor, semi)
726+ v_neighbor = wrap_v (v_ode, neighbor, semi, neighbor_index )
727+ u_neighbor = wrap_u (u_ode, neighbor, semi, neighbor_index )
672728
673729 @trixi_timeit timer () timer_str begin
674730 interact! (dv, v_system, u_system, v_neighbor, u_neighbor, system, neighbor, semi)
0 commit comments