Skip to content

Commit 558a98f

Browse files
authored
Make assignment replacement with trixi_include recursive (#54)
* Make assignment replacement with `trixi_include` recursive * Fix * Fix error message * Update docs * Fix validation * Add tests * Remove `@test_nowarn_mod` * Fix tests * Add comment * Implement Copilot suggestions * Fix * Make `enable_assignment_validation` public API * Make recursive assignment replacement optional * Fix typo * Reformat * Fix tests * Fix warnings in tests
1 parent 58789a2 commit 558a98f

2 files changed

Lines changed: 266 additions & 12 deletions

File tree

src/trixi_include.jl

Lines changed: 112 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
# of `TrixiBase`. However, users will want to evaluate in the global scope of `Main` or something
44
# similar to manage dependencies on their own.
55
"""
6-
trixi_include([mapexpr::Function=identity,] [mod::Module=Main,] elixir::AbstractString; kwargs...)
6+
trixi_include([mapexpr::Function=identity,] [mod::Module=Main,] elixir::AbstractString;
7+
enable_assignment_validation::Bool = true,
8+
replace_assignments_recursive::Bool = false, kwargs...)
79
810
`include` the file `elixir` and evaluate its content in the global scope of module `mod`.
911
You can override specific assignments in `elixir` by supplying keyword arguments.
@@ -20,6 +22,16 @@ The optional first argument `mapexpr` can be used to transform the included code
2022
it is evaluated: for each parsed expression `expr` in `elixir`, the `include` function
2123
actually evaluates `mapexpr(expr)`. If it is omitted, `mapexpr` defaults to `identity`.
2224
25+
With `replace_assignments_recursive=true`, the keyword arguments are also passed
26+
to nested calls of `trixi_include`. This allows to override assignments in nested files as well.
27+
28+
The keyword argument `enable_assignment_validation`, which is enabled by default,
29+
can be used to enable or disable validation that all passed keyword arguments exist
30+
as assignments in `elixir`. If `enable_assignment_validation` is `true` and
31+
an assignment for a passed keyword argument is not found in `elixir`, an error is thrown.
32+
If `replace_assignments_recursive` is `true` and `elixir` contains calls to `trixi_include`
33+
itself, a warning is issued instead of an error.
34+
2335
# Examples
2436
2537
```@example
@@ -34,23 +46,40 @@ julia> redirect_stdout(devnull) do
3446
0.1
3547
```
3648
"""
37-
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString; kwargs...)
49+
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString;
50+
enable_assignment_validation::Bool = true,
51+
replace_assignments_recursive::Bool = false, kwargs...)
3852
# Check that all kwargs exist as assignments
3953
code = read(elixir, String)
4054
expr = Meta.parse("begin \n$code \nend")
4155
expr = insert_maxiters(expr)
4256

43-
for (key, val) in kwargs
44-
# This will throw an error when `key` is not found
45-
find_assignment(expr, key)
57+
# Validate that all kwargs exist as assignments (with warning for recursive cases).
58+
# Skip for nested calls because all kwargs are passed to all nested calls,
59+
# some of which may not use all kwargs.
60+
if enable_assignment_validation
61+
validate_assignments(expr, kwargs, elixir, replace_assignments_recursive)
4662
end
4763

4864
# Print information on potential wait time only in non-parallel case
4965
if !mpi_isparallel()
5066
@info "You just called `trixi_include`. Julia may now compile the code, please be patient."
5167
end
52-
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex); kwargs...)),
53-
mod, elixir)
68+
69+
if replace_assignments_recursive
70+
# Add kwarg `enable_assignment_validation` to disable validation in nested
71+
# `trixi_include` calls.
72+
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex),
73+
replace_assignments_recursive;
74+
enable_assignment_validation = false,
75+
replace_assignments_recursive = true,
76+
kwargs...)),
77+
mod, elixir)
78+
else
79+
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex);
80+
kwargs...)),
81+
mod, elixir)
82+
end
5483
end
5584

5685
function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
@@ -159,24 +188,98 @@ walkexpr(f, expr::Expr) = f(Expr(expr.head, (walkexpr(f, arg) for arg in expr.ar
159188
walkexpr(f, x) = f(x)
160189

161190
# Replace assignments to `key` in `expr` by `key = val` for all `(key,val)` in `kwargs`.
162-
function replace_assignments(expr; kwargs...)
163-
# replace explicit and keyword assignments
191+
function replace_assignments(expr, recursive = false; kwargs...)
164192
expr = walkexpr(expr) do x
165193
if x isa Expr
194+
# Replace explicit and keyword assignments
166195
for (key, val) in kwargs
167196
if (x.head === Symbol("=") || x.head === :kw) &&
168197
x.args[1] === Symbol(key)
169198
x.args[2] = :($val)
170199
# dump(x)
171200
end
172201
end
202+
203+
# If `recursive` is true:
204+
# Handle `trixi_include` calls - add kwargs to them as well.
205+
is_trixi_include = (x.head === :call && length(x.args) >= 2 &&
206+
(x.args[1] === :trixi_include ||
207+
x.args[1] === :trixi_include_changeprecision))
208+
if !isempty(kwargs) && is_trixi_include && recursive
209+
210+
# Check for existing kwargs (both direct :kw and bare symbols in :parameters)
211+
existing_kwargs = Set{Symbol}()
212+
for arg in x.args[2:end] # Skip function name
213+
if arg isa Expr && arg.head === :kw
214+
# Direct keyword argument like `x=5` in `f(x=5)`
215+
push!(existing_kwargs, arg.args[1])
216+
elseif arg isa Expr && arg.head === :parameters
217+
# Keyword arguments grouped in `parameters`
218+
# like `f(; x=5)` or `f(; x)`.
219+
for nested_arg in arg.args
220+
if nested_arg isa Symbol
221+
# Bare symbol like `x` in `f(; x)`
222+
push!(existing_kwargs, nested_arg)
223+
elseif nested_arg isa Expr && nested_arg.head === :kw
224+
# Keyword argument like `x=5` in `f(; x=5)`
225+
push!(existing_kwargs, nested_arg.args[1])
226+
end
227+
end
228+
end
229+
end
230+
231+
# Add kwargs that don't already exist.
232+
# Note that existing keywords as assignment (`x=5`) don't need to be added
233+
# again because they are replaced in the loop
234+
# "Replace explicit and keyword assignments" above.
235+
# Bare symbol like `x` in `f(; x)` must have been defined in the file
236+
# before they are passed to `trixi_include`, so there must be an assignment
237+
# `x = ...` in the file, which will also be replaced in the loop above.
238+
for (key, val) in kwargs
239+
if !(Symbol(key) in existing_kwargs)
240+
push!(x.args, Expr(:kw, Symbol(key), val))
241+
end
242+
end
243+
end
173244
end
174245
return x
175246
end
176247

177248
return expr
178249
end
179250

251+
# Validate that keyword arguments passed to `trixi_include` exist as assignments
252+
# in the expression. Throw an error if they are not found or a warning for recursive calls.
253+
function validate_assignments(expr, assignments, filename, replace_assignments_recursive)
254+
isempty(assignments) && return
255+
256+
found_assignments = Set{Symbol}()
257+
has_nested_calls = false
258+
259+
walkexpr(expr) do x
260+
if x isa Expr
261+
if (x.head === Symbol("=") || x.head === :kw) && x.args[1] isa Symbol
262+
push!(found_assignments, x.args[1])
263+
elseif (x.head === :call && length(x.args) >= 2 &&
264+
(x.args[1] === :trixi_include ||
265+
x.args[1] === :trixi_include_changeprecision))
266+
has_nested_calls = true
267+
end
268+
end
269+
return x
270+
end
271+
272+
missing_assignments = setdiff(Symbol.(keys(assignments)), found_assignments)
273+
if !isempty(missing_assignments)
274+
if replace_assignments_recursive && has_nested_calls
275+
@warn "assignments $missing_assignments not found in $filename, " *
276+
"but nested trixi_include calls detected. They may be used in nested files."
277+
else
278+
throw(ArgumentError("assignments $missing_assignments not found in $filename"))
279+
end
280+
end
281+
end
282+
180283
# Find a (keyword or common) assignment to `destination` in `expr`
181284
# and return the assigned value.
182285
function find_assignment(expr, destination)

test/trixi_include.jl

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
@trixi_test_nowarn trixi_include(path, x = 11)
2323
@test Main.x == 11
2424

25-
@test_throws "assignment `y` not found in expression" trixi_include(@__MODULE__,
26-
path,
27-
y = 3)
25+
@test_throws "assignments [:y] not found" trixi_include(@__MODULE__,
26+
path, y = 3)
2827
end
2928
end
3029

@@ -115,6 +114,158 @@
115114
end
116115
end
117116
end
117+
118+
@trixi_testset "Recursive assignment overwriting" begin
119+
# Test basic recursive kwargs passing
120+
example1 = """
121+
x = 1
122+
y = 2
123+
"""
124+
125+
example2 = """
126+
z = 3
127+
trixi_include(@__MODULE__, nested_path)
128+
"""
129+
130+
mktemp() do path1, io1
131+
write(io1, example1)
132+
close(io1)
133+
134+
mktemp() do path2, io2
135+
# Use raw string to allow backslashes in Windows paths
136+
nested_code = replace(example2, "nested_path" => "raw\"$path1\"")
137+
write(io2, nested_code)
138+
close(io2)
139+
140+
# Test that kwargs are passed recursively
141+
# Should warn about x,y not being in top file but allow due to nested calls
142+
@test_warn "assignments" trixi_include(@__MODULE__, path2;
143+
x = 10, y = 20, z = 30,
144+
replace_assignments_recursive = true)
145+
@test @isdefined x
146+
@test @isdefined y
147+
@test @isdefined z
148+
@test x == 10 # Overridden from nested file
149+
@test y == 20 # Overridden from nested file
150+
@test z == 30 # Overridden from top file
151+
152+
# Test that kwargs are NOT passed recursively
153+
@trixi_test_nowarn trixi_include(@__MODULE__, path2;
154+
x = 10, y = 20, z = 30,
155+
replace_assignments_recursive = false,
156+
enable_assignment_validation = false)
157+
158+
@test x == 1 # Not overridden from nested file
159+
@test y == 2 # Not overridden from nested file
160+
@test z == 30 # Overridden from top file
161+
162+
# Without disabling validation, this should result in an error:
163+
@test_throws "assignments [:x, :y] not found" trixi_include(@__MODULE__,
164+
path2; x = 10,
165+
y = 20, z = 30)
166+
end
167+
end
168+
169+
# Test with existing kwargs in nested calls
170+
example3 = """
171+
a = 100
172+
trixi_include(@__MODULE__, nested_path; a = 200)
173+
"""
174+
175+
example4 = """
176+
a = 1
177+
b = 2
178+
"""
179+
180+
mktemp() do path3, io3
181+
write(io3, example4)
182+
close(io3)
183+
184+
mktemp() do path4, io4
185+
nested_code = replace(example3, "nested_path" => "raw\"$path3\"")
186+
write(io4, nested_code)
187+
close(io4)
188+
189+
# Test that top-level kwargs override existing nested kwargs
190+
trixi_include(@__MODULE__, path4; a = 500, b = 600,
191+
replace_assignments_recursive = true)
192+
@test @isdefined a
193+
@test @isdefined b
194+
@test a == 500 # Top-level override wins over nested explicit kwarg
195+
@test b == 600 # Passed through to nested file
196+
end
197+
end
198+
199+
# Test bare symbol syntax with recursion
200+
example5 = """
201+
x = 42
202+
trixi_include(@__MODULE__, nested_path; x)
203+
"""
204+
205+
example6 = """
206+
x = 1
207+
"""
208+
209+
mktemp() do path5, io5
210+
write(io5, example6)
211+
close(io5)
212+
213+
mktemp() do path6, io6
214+
nested_code = replace(example5, "nested_path" => "raw\"$path5\"")
215+
write(io6, nested_code)
216+
close(io6)
217+
218+
# Test bare symbol with recursive override
219+
@trixi_test_nowarn trixi_include(@__MODULE__, path6; x = 999,
220+
replace_assignments_recursive = true)
221+
@test @isdefined x
222+
@test x == 999 # Top-level override
223+
end
224+
end
225+
226+
# Test deep nesting (3 levels)
227+
example7 = """
228+
level1 = 1
229+
"""
230+
231+
example8 = """
232+
level2 = 2
233+
trixi_include(@__MODULE__, level1_path)
234+
"""
235+
236+
example9 = """
237+
level3 = 3
238+
trixi_include(@__MODULE__, level2_path; level2 = 22)
239+
"""
240+
241+
mktemp() do path7, io7
242+
write(io7, example7)
243+
close(io7)
244+
245+
mktemp() do path8, io8
246+
level2_code = replace(example8, "level1_path" => "raw\"$path7\"")
247+
write(io8, level2_code)
248+
close(io8)
249+
250+
mktemp() do path9, io9
251+
level3_code = replace(example9, "level2_path" => "raw\"$path8\"")
252+
write(io9, level3_code)
253+
close(io9)
254+
255+
# Test 3-level deep recursive override
256+
trixi_include(@__MODULE__, path9; level1 = 111,
257+
level2 = 222, level3 = 333,
258+
replace_assignments_recursive = true)
259+
@test @isdefined level1
260+
@test @isdefined level2
261+
@test @isdefined level3
262+
@test level1 == 111 # Passed through 3 levels
263+
@test level2 == 222 # Top-level override wins over level3 explicit kwarg
264+
@test level3 == 333 # Direct override
265+
end
266+
end
267+
end
268+
end
118269
end
119270

120271
@trixi_testset "`trixi_include_changeprecision`" begin

0 commit comments

Comments
 (0)