Skip to content

Commit aee7b9f

Browse files
Copilotlukem12345
andauthored
Replace stale GeneralizedGenerated dependency with RuntimeGeneratedFunctions (#1002)
Co-authored-by: lukem12345 <70283489+lukem12345@users.noreply.github.com>
1 parent 75e2dc0 commit aee7b9f

5 files changed

Lines changed: 57 additions & 15 deletions

File tree

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ CompTime = "0fb5dd42-039a-4ca4-a1d7-89a96eae6d39"
1212
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
1313
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1414
GATlab = "f0ffcf3b-d13a-433e-917c-cc44ccf5ead2"
15-
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
15+
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1616
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1717
LightXML = "9c8b4983-aa76-5018-a973-4c85ecc9e179"
1818
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -57,7 +57,7 @@ Convex = "0.16"
5757
DataFrames = "1"
5858
DataStructures = "0.17, 0.18, 0.19"
5959
GATlab = "0.2.2"
60-
GeneralizedGenerated = "0.2, 0.3"
60+
RuntimeGeneratedFunctions = "0.5"
6161
Graphs = "1"
6262
Graphviz_jll = "2"
6363
JSON3 = "1"

src/programs/GenerateJuliaPrograms.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ module GenerateJuliaPrograms
44
export Block, CompileState, compile, compile_expr, compile_block,
55
evaluate, evaluate_hom
66

7-
using GeneralizedGenerated: mk_function
7+
using RuntimeGeneratedFunctions
8+
RuntimeGeneratedFunctions.init(@__MODULE__)
89

910
using ...Catlab
1011
using GATlab
@@ -33,11 +34,20 @@ abstract type CompileState end
3334
end
3435

3536
""" Compile a morphism expression into a Julia function.
37+
38+
The optional `mod` parameter specifies a context module used to resolve
39+
symbols in the generated function. Pass the module where the referenced
40+
functions are defined so that unqualified names (e.g. `:my_func`) are looked
41+
up there. The module must have been initialised with
42+
`RuntimeGeneratedFunctions.init(mod)` beforehand.
43+
44+
To bind specific generators to arbitrary function values (including closures),
45+
use the `generators` keyword argument instead.
3646
"""
3747
function compile(mod::Module, f::HomExpr; kw...)
38-
mk_function(mod, compile_expr(f; kw...))
48+
@RuntimeGeneratedFunction(mod, compile_expr(f; kw...))
3949
end
40-
compile(f::HomExpr; kw...) = compile(Main, f; kw...)
50+
compile(f::HomExpr; kw...) = compile(GenerateJuliaPrograms, f; kw...)
4151

4252
""" Compile a morphism expression into a Julia function expression.
4353
"""

src/programs/ParseJuliaPrograms.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
module ParseJuliaPrograms
44
export @program, parse_wiring_diagram
55

6-
using GeneralizedGenerated: mk_function
7-
using MLStyle: @match
6+
using RuntimeGeneratedFunctions
7+
RuntimeGeneratedFunctions.init(@__MODULE__)
8+
using MLStyle: @match, GuardBy
89

910
using GATlab
1011
import GATlab.Util.MetaUtils: Expr0
@@ -77,10 +78,10 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri
7778

7879
# Compile...
7980
args = Symbol[ first(arg) for arg in parsed_args ]
80-
kwargs = make_lookup_table(pres, syntax_module, unique_symbols(body))
81+
lookup_dict = make_lookup_table(pres, syntax_module, unique_symbols(body))
8182
func_expr = compile_recording_expr(body, args,
82-
kwargs = sort!(collect(keys(kwargs))))
83-
func = mk_function(parentmodule(syntax_module), func_expr)
83+
kwargs = sort!(collect(keys(lookup_dict))))
84+
func = @RuntimeGeneratedFunction(func_expr)
8485

8586
# ...and then evaluate function that records the function calls.
8687
arg_obs = syntax_module.Ob[ last(arg) for arg in parsed_args ]
@@ -91,7 +92,7 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri
9192
arg_ports = [ Tuple(Port(v_in, OutputPort, i) for i in (stop-len+1):stop)
9293
for (len, stop) in zip(arg_blocks, cumsum(arg_blocks)) ]
9394
recorder = f -> (args...) -> record_call!(diagram, f, args...)
94-
value = func(recorder, arg_ports...; kwargs...)
95+
value = func(recorder, lookup_dict, arg_ports...)
9596

9697
# Add outgoing wires for return values.
9798
out_ports = normalize_arguments((value,))
@@ -111,13 +112,16 @@ end
111112
function make_lookup_table(pres::Presentation, syntax_module::Module, names)
112113
theory = syntax_module.Meta.theory
113114
terms = Set(nameof.(keys(theory.resolvers)))
115+
context_mod = parentmodule(syntax_module)
114116

115117
table = Dict{Symbol,Any}()
116118
for name in names
117119
if has_generator(pres, name)
118120
table[name] = generator(pres, name)
119121
elseif name in terms
120122
table[name] = (args...) -> invoke_term(syntax_module, name, args)
123+
elseif isdefined(context_mod, name)
124+
table[name] = getfield(context_mod, name)
121125
end
122126
end
123127
table
@@ -148,9 +152,13 @@ Rewrites the function body so that:
148152
"""
149153
function compile_recording_expr(body::Expr, args::Vector{Symbol};
150154
kwargs::Vector{Symbol}=Symbol[],
151-
recorder::Symbol=Symbol("##recorder"))::Expr
155+
recorder::Symbol=Symbol("##recorder"),
156+
lookup::Symbol=Symbol("##lookup"))::Expr
157+
lookup_keys_set = Set(kwargs)
152158
function rewrite(expr)
153159
@match expr begin
160+
f::Symbol && GuardBy(in(lookup_keys_set)) =>
161+
:($(lookup)[$(QuoteNode(f))])
154162
Expr(:call, f, args...) =>
155163
Expr(:call, Expr(:call, recorder, rewrite(f)), map(rewrite, args)...)
156164
Expr(:curly, f, args...) =>
@@ -160,9 +168,7 @@ function compile_recording_expr(body::Expr, args::Vector{Symbol};
160168
end
161169
end
162170
Expr(:function,
163-
Expr(:tuple,
164-
Expr(:parameters, (Expr(:kw, kw, nothing) for kw in kwargs)...),
165-
recorder, args...),
171+
Expr(:tuple, recorder, lookup, args...),
166172
rewrite(body))
167173
end
168174

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
1616
OpenSSL_jll = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
1717
PEG = "12d937ae-5f68-53be-93c9-3a6f997a20a8"
1818
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2122
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

test/programs/GenerateJuliaPrograms.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@ using Test
33

44
using Catlab.Theories
55
using Catlab.Programs.GenerateJuliaPrograms
6+
using RuntimeGeneratedFunctions
7+
8+
# Context module for testing compile with non-Base functions referenced by name.
9+
module TestContext
10+
using RuntimeGeneratedFunctions
11+
RuntimeGeneratedFunctions.init(@__MODULE__)
12+
my_square(x) = x^2
13+
my_cube(x) = x^3
14+
end
615

716
= Ob(FreeCartesianCategory, :ℝ)
817
plus_hom = Hom(:+, ℝℝ, ℝ)
@@ -30,6 +39,22 @@ x = collect(range(-2,stop=2,length=50))
3039
local_f(x) = x + 1
3140
@test compile(f_hom, generators=Dict(:f => local_f)).(x) == [xi+1 for xi in x]
3241

42+
# Functions not defined in Base (module-level functions in this test module).
43+
square(x) = x^2
44+
cube(x) = x^3
45+
@test compile(f_hom, generators=Dict(:f => square)).(x) == x.^2
46+
@test compile(compose(f_hom, g_hom),
47+
generators=Dict(:f => square, :g => cube)).(x) == (x.^2).^3
48+
@test compile(otimes(f_hom, g_hom),
49+
generators=Dict(:f => square, :g => cube))(2.0, 3.0) == (4.0, 27.0)
50+
51+
# Context module: functions referenced by name are resolved from the given module.
52+
my_square_hom = Hom(:my_square, ℝ, ℝ)
53+
my_cube_hom = Hom(:my_cube, ℝ, ℝ)
54+
@test compile(TestContext, my_square_hom).(x) == x.^2
55+
@test compile(TestContext, compose(my_square_hom, my_cube_hom)).(x) == (x.^2).^3
56+
@test compile(TestContext, otimes(my_square_hom, my_cube_hom))(2.0, 3.0) == (4.0, 27.0)
57+
3358
# Evaluation
3459
############
3560

0 commit comments

Comments
 (0)