33module ParseJuliaPrograms
44export @program , parse_wiring_diagram
55
6- using GeneralizedGenerated: mk_function
7- using MLStyle: @match
6+ using RuntimeGeneratedFunctions
7+ RuntimeGeneratedFunctions. init (@__MODULE__ )
8+ using MLStyle: @match , GuardBy
89
910using GATlab
1011import GATlab. Util. MetaUtils: Expr0
@@ -77,10 +78,10 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri
7778
7879 # Compile...
7980 args = Symbol[ first (arg) for arg in parsed_args ]
80- kwargs = make_lookup_table (pres, syntax_module, unique_symbols (body))
81+ lookup_dict = make_lookup_table (pres, syntax_module, unique_symbols (body))
8182 func_expr = compile_recording_expr (body, args,
82- kwargs = sort! (collect (keys (kwargs ))))
83- func = mk_function ( parentmodule (syntax_module), func_expr)
83+ kwargs = sort! (collect (keys (lookup_dict ))))
84+ func = @RuntimeGeneratedFunction ( func_expr)
8485
8586 # ...and then evaluate function that records the function calls.
8687 arg_obs = syntax_module. Ob[ last (arg) for arg in parsed_args ]
@@ -91,7 +92,7 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri
9192 arg_ports = [ Tuple (Port (v_in, OutputPort, i) for i in (stop- len+ 1 ): stop)
9293 for (len, stop) in zip (arg_blocks, cumsum (arg_blocks)) ]
9394 recorder = f -> (args... ) -> record_call! (diagram, f, args... )
94- value = func (recorder, arg_ports ... ; kwargs ... )
95+ value = func (recorder, lookup_dict, arg_ports ... )
9596
9697 # Add outgoing wires for return values.
9798 out_ports = normalize_arguments ((value,))
@@ -111,13 +112,16 @@ end
111112function make_lookup_table (pres:: Presentation , syntax_module:: Module , names)
112113 theory = syntax_module. Meta. theory
113114 terms = Set (nameof .(keys (theory. resolvers)))
115+ context_mod = parentmodule (syntax_module)
114116
115117 table = Dict {Symbol,Any} ()
116118 for name in names
117119 if has_generator (pres, name)
118120 table[name] = generator (pres, name)
119121 elseif name in terms
120122 table[name] = (args... ) -> invoke_term (syntax_module, name, args)
123+ elseif isdefined (context_mod, name)
124+ table[name] = getfield (context_mod, name)
121125 end
122126 end
123127 table
@@ -148,9 +152,13 @@ Rewrites the function body so that:
148152"""
149153function compile_recording_expr (body:: Expr , args:: Vector{Symbol} ;
150154 kwargs:: Vector{Symbol} = Symbol[],
151- recorder:: Symbol = Symbol (" ##recorder" )):: Expr
155+ recorder:: Symbol = Symbol (" ##recorder" ),
156+ lookup:: Symbol = Symbol (" ##lookup" )):: Expr
157+ lookup_keys_set = Set (kwargs)
152158 function rewrite (expr)
153159 @match expr begin
160+ f:: Symbol && GuardBy (in (lookup_keys_set)) =>
161+ :($ (lookup)[$ (QuoteNode (f))])
154162 Expr (:call , f, args... ) =>
155163 Expr (:call , Expr (:call , recorder, rewrite (f)), map (rewrite, args)... )
156164 Expr (:curly , f, args... ) =>
@@ -160,9 +168,7 @@ function compile_recording_expr(body::Expr, args::Vector{Symbol};
160168 end
161169 end
162170 Expr (:function ,
163- Expr (:tuple ,
164- Expr (:parameters , (Expr (:kw , kw, nothing ) for kw in kwargs). .. ),
165- recorder, args... ),
171+ Expr (:tuple , recorder, lookup, args... ),
166172 rewrite (body))
167173end
168174
0 commit comments