Skip to content

Commit 0257d60

Browse files
feat(macros): add macros.jl file
1 parent 2c8d381 commit 0257d60

1 file changed

Lines changed: 254 additions & 0 deletions

File tree

src/macros.jl

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
export @parameterize, @measurify
2+
3+
"""
4+
@parameterize(container_expr, union_expr)
5+
6+
Takes a parameterized container expression with a `_` placeholder
7+
(e.g., `Array{_, 3}` or `Vector{_}`) and a Union type (e.g., `REALTYPES`),
8+
and expands it into a Union of concrete container types at compile time.
9+
10+
# Example
11+
```julia
12+
`@parameterize Vector{_} REALTYPES`
13+
...expands to: `Union{Vector{Float64}, Vector{Measurement{Float64}}}`
14+
15+
`@parameterize Array{_, 3} COMPLEXTYPES`
16+
...expands to: `Union{Array{Complex{Float64}, 3}, Array{Complex{Measurement{Float64}}, 3}}`
17+
```
18+
"""
19+
macro parameterize(container_expr, union_expr)
20+
# Evaluate the Union type from the provided expression
21+
local union_type
22+
try
23+
# Core.eval gets the *value* of the symbol passed in (e.g., the actual Union type)
24+
union_type = Core.eval(__module__, union_expr)
25+
catch e
26+
error("Expression `$union_expr` could not be evaluated. Make sure it's a defined const or type.")
27+
end
28+
29+
# Sanity check
30+
if !(union_type isa Union)
31+
error("Second argument must be a Union type. Got a `$(typeof(union_type))` instead.")
32+
end
33+
34+
# Base.uniontypes gets the component types, e.g., (Float64, Measurement{Float64})
35+
component_types = Base.uniontypes(union_type)
36+
37+
# Define a recursive function to substitute the placeholder `_`
38+
function substitute_placeholder(expr, replacement_type)
39+
# If the current part of the expression is the placeholder symbol,
40+
# we replace it with the target type (e.g., Float64).
41+
if expr == :_
42+
return replacement_type
43+
# If the current part is another expression (like `Array{_,3}`),
44+
# we need to recurse into its arguments to find the placeholder.
45+
elseif expr isa Expr
46+
# Rebuild the expression with the substituted arguments.
47+
new_args = [substitute_placeholder(arg, replacement_type) for arg in expr.args]
48+
return Expr(expr.head, new_args...)
49+
# Otherwise, it's a literal or symbol we don't need to change (e.g., `:Array` or `3`).
50+
else
51+
return expr
52+
end
53+
end
54+
55+
# Build the list of new, concrete types
56+
# For each type in the original Union, create a new container expression
57+
# by calling our substitution function.
58+
parameterized_types = [substitute_placeholder(container_expr, t) for t in component_types]
59+
60+
# Wrap the new types in a single `Union{...}` expression and escape
61+
final_expr = Expr(:curly, :Union, parameterized_types...)
62+
return esc(final_expr)
63+
end
64+
65+
"""
66+
@measurify(function_definition)
67+
68+
Wraps a function definition. If any argument tied to a parametric type `T` is a
69+
`Measurement`, this macro automatically promotes any other arguments of the same
70+
parametric type `T` to `Measurement` with zero uncertainty. Other arguments
71+
(e.g., `i::Int`) are ignored.
72+
"""
73+
macro measurify(def)
74+
# Normalize to long form
75+
if def.head == :(=)
76+
call = def.args[1]
77+
body = def.args[2]
78+
def = Expr(:function, call, body)
79+
elseif def.head == :function
80+
# ok
81+
else
82+
error("@measurify must wrap a function definition")
83+
end
84+
85+
sig = def.args[1]
86+
body = def.args[2]
87+
88+
# Extract call expr and where-clauses
89+
call_expr = sig
90+
where_items = Any[]
91+
if sig isa Expr && sig.head == :where
92+
call_expr = sig.args[1]
93+
where_items = sig.args[2:end]
94+
end
95+
fname = call_expr.args[1]
96+
97+
# Bounds for each where typevar: Dict{Symbol,Any}
98+
bounds = Dict{Symbol,Any}()
99+
for w in where_items
100+
if w isa Symbol
101+
bounds[w] = :Any
102+
elseif w isa Expr && w.head == :(<:)
103+
tv = w.args[1]::Symbol
104+
ub = w.args[2]
105+
bounds[tv] = ub
106+
else
107+
error("@measurify: unsupported where item: $w")
108+
end
109+
end
110+
typevars = collect(keys(bounds))
111+
112+
# Split positional vs keyword args in the signature
113+
posargs = Any[]
114+
kwexpr = nothing
115+
for a in call_expr.args[2:end]
116+
if a isa Expr && a.head == :parameters
117+
kwexpr = a
118+
else
119+
push!(posargs, a)
120+
end
121+
end
122+
123+
# Collect names and find which are promotable (annotated exactly as one of the where typevars)
124+
names = Symbol[]
125+
promotable = Symbol[]
126+
wrapper_posargs = Any[]
127+
128+
for a in posargs
129+
if a isa Symbol
130+
push!(names, a)
131+
push!(wrapper_posargs, a) # untyped positional
132+
elseif a isa Expr && a.head == :(::)
133+
nm = a.args[1]::Symbol
134+
ty = a.args[2]
135+
push!(names, nm)
136+
if ty isa Symbol && haskey(bounds, ty)
137+
# replace ::T with ::Bound(T)
138+
push!(promotable, nm)
139+
push!(wrapper_posargs, Expr(:(::), nm, bounds[ty]))
140+
else
141+
push!(wrapper_posargs, a)
142+
end
143+
else
144+
error("@measurify: unsupported arg form: $a")
145+
end
146+
end
147+
148+
# Build the tight/original method exactly as written
149+
tight = def
150+
151+
# Build the loose wrapper signature: same name, same args,
152+
# but with ::T replaced by ::Bound(T), and NO where-clauses.
153+
wrapper_call = Expr(:call, fname, wrapper_posargs...)
154+
if kwexpr !== nothing
155+
push!(wrapper_call.args, kwexpr) # keep kw defaults/types as-is
156+
end
157+
158+
# Build the call to the tight method (same arg names; keywords forwarded as k=k)
159+
forward_call = Expr(:call, fname, (:($n) for n in names)...)
160+
if kwexpr !== nothing
161+
# transform each kw def into k=k for forwarding
162+
pairs = Any[]
163+
for e in kwexpr.args
164+
kn = e isa Expr && e.head == :(=) ? e.args[1] :
165+
e isa Expr && e.head == :(::) ? e.args[1] :
166+
e isa Symbol ? e : error("@measurify: bad kw: $e")
167+
push!(pairs, Expr(:(=), kn, kn))
168+
end
169+
push!(forward_call.args, Expr(:parameters, pairs...))
170+
end
171+
172+
# If nothing is promotable, wrapper just forwards (harmless)
173+
promote_tuple = Expr(:tuple, (:($n) for n in promotable)...)
174+
175+
# make a fresh name for the promoted tuple
176+
pp = gensym(:promoted)
177+
178+
# rebinding statements: NO `local`
179+
rebinding = Any[]
180+
for (i, nm) in enumerate(promotable)
181+
push!(rebinding, :($(nm) = $(pp)[$i]))
182+
end
183+
184+
wrapper_body = quote
185+
$(length(promotable) == 0 ? :(nothing) : quote
186+
$(pp) = promote($(promote_tuple.args...))
187+
$(rebinding...)
188+
end)
189+
$(forward_call)
190+
end
191+
192+
loose = Expr(:function, wrapper_call, wrapper_body)
193+
194+
# @info "measurify input" def
195+
# ... build `tight`, `loose` ...
196+
out = Expr(:block, :(Base.@__doc__ $tight), loose)
197+
# @info "measurify output" out
198+
return esc(out)
199+
200+
end
201+
202+
"""
203+
$(TYPEDSIGNATURES)
204+
205+
Automatically exports public functions, types, and modules from a module. This is meant for temporary development chores and should never be used in production code.
206+
207+
# Arguments
208+
209+
- None.
210+
211+
# Returns
212+
213+
- An `export` expression containing all public symbols that should be exported.
214+
215+
# Notes
216+
217+
This macro scans the current module for all defined symbols and automatically generates an `export` statement for public functions, types, and submodules, excluding built-in and private names. Private names are considered those starting with an underscore ('_'), as per standard Julia conventions.
218+
219+
# Examples
220+
221+
```julia
222+
@autoexport
223+
```
224+
"""
225+
macro autoexport()
226+
mod = __module__
227+
228+
# Get all names defined in the module, including unexported ones
229+
all_names = names(mod; all=true)
230+
231+
# List of names to explicitly exclude
232+
excluded_names = Set([:eval, :include, :using, :import, :export, :require])
233+
234+
# Filter out private names (starting with '_'), module name, built-in functions, and auto-generated method symbols
235+
public_names = Symbol[]
236+
for name in all_names
237+
str_name = string(name)
238+
239+
startswith(str_name, "@_") && continue # Skip private macros
240+
startswith(str_name, "_") && continue # Skip private names
241+
name === nameof(mod) && continue # Skip the module's own name
242+
name in excluded_names && continue # Skip built-in functions
243+
startswith(str_name, "#") && continue # Skip generated method symbols (e.g., #eval, #include)
244+
245+
if isdefined(mod, name)
246+
val = getfield(mod, name)
247+
if val isa Function || val isa Type || val isa Module
248+
push!(public_names, name)
249+
end
250+
end
251+
end
252+
253+
return esc(Expr(:export, public_names...))
254+
end

0 commit comments

Comments
 (0)