Skip to content

Commit 8d7e811

Browse files
kris-brownKris Brown
andauthored
change wrapper data structure (#199)
remove Performance file docs Allow some basic dependent sliced types cache method types in wrapper remove dep types cosmetic cosmetic fix Co-authored-by: Kris Brown <kris@topos.institute>
1 parent 92b59cd commit 8d7e811

2 files changed

Lines changed: 60 additions & 26 deletions

File tree

src/models/ModelInterface.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ macro instance(head, model, body)
206206
# A dictionary to look up the Julia type of a type constructor from its name (an ident)
207207
jltype_by_sort = Dict{AlgSort,Expr0}([
208208
zip(primitive_sorts(theory), instance_types)...,
209-
[s => nameof(headof(s)) for s in struct_sorts(theory)]...,
209+
[s => nameof(headof(s)) for s in struct_sorts(theory)]...,
210210
collect(theory.fixed_types)...
211211
])
212212

@@ -219,6 +219,7 @@ macro instance(head, model, body)
219219
generate_instance(theory, theory_module, jltype_by_sort, model_type, whereparams, body)
220220
end
221221

222+
222223
function generate_instance(
223224
theory::GAT,
224225
theory_module::Union{Expr0, Module},
@@ -228,8 +229,6 @@ function generate_instance(
228229
body::Expr;
229230
escape=true
230231
)
231-
# The old (Catlab) style of instance, where there is no explicit model
232-
oldinstance = isnothing(model_type)
233232

234233
# Parse the body into functions defined here and functions defined elsewhere
235234
typechecked_functions = parse_instance_body(body, theory)
@@ -514,7 +513,7 @@ function impl_type_declaration(model_type, whereparams, sort, jltype)
514513
quote
515514
if !hasmethod($(GlobalRef(ModelInterface, :impl_type)),
516515
($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(methodof(sort)))}}, Type{Val{$(getlid(methodof(sort)))}}))
517-
$(GlobalRef(ModelInterface, :impl_type))(
516+
@inline $(GlobalRef(ModelInterface, :impl_type))(
518517
::$(model_type), ::Type{Val{$(gettag(methodof(sort)))}}, ::Type{Val{$(getlid(methodof(sort)))}}
519518
) where {$(whereparams...)} = $(jltype)
520519
end

src/syntax/TheoryInterface.jl

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,9 @@ function wrapper(name::Symbol, t::GAT, mod)
321321
:($(GlobalRef($(TheoryInterface), :impl_type))(x, $s))
322322
end
323323

324+
Xdict = :(Dict(zip(nameof.($Ts), [$(Xs...)])))
324325
gv = :($(GlobalRef($(Scopes), :getvalue)))
325326
it = :($(GlobalRef($(TheoryInterface), :impl_type)))
326-
327327
esc(quote
328328
# Catch any potential docs above the macro call
329329
const $(doctarget) = nothing
@@ -332,10 +332,16 @@ function wrapper(name::Symbol, t::GAT, mod)
332332
# Declare the wrapper struct
333333
struct $n <: $abs
334334
val::Any
335+
types::Dict{Symbol, Type}
335336
function $n(x::Any)
336-
try $($(GlobalRef(TheoryInterface, :implements)))(x, $($name), [$(Xs...)])
337-
catch MethodError false end || error("Invalid $($($(name))) model: $x")
338-
new(x)
337+
# TODO opt into checking whether the methods are defined
338+
# right now we just implicitly check whether the types are defined
339+
types = try
340+
$Xdict
341+
catch MethodError
342+
error("Invalid $($($(name))) model: $x")
343+
end
344+
new(x, types)
339345
end
340346
end
341347
# Apply the caught documentation to the new struct
@@ -345,33 +351,62 @@ function wrapper(name::Symbol, t::GAT, mod)
345351
$(Expr(:macrocall, $(GlobalRef(StructEquality, Symbol("@struct_hash_equal"))), $(mod), $(:n)))
346352

347353
$gv(x::$n) = x.val
348-
$it(x::$n, o::Symbol) = $it(x.val, $($name), o)
354+
$it(x::$n, o::Symbol) = x.types[o]
355+
Base.getindex(x::$n, k::Symbol) = x.methods[k]
349356

350357
# Dispatch on model value for all declarations in theory
351358
$(map(filter(x->x[2] isa $AlgDeclaration, $(identvalues(t)))) do (x,j)
352359
if j isa $(AlgDeclaration)
353360
op = nameof(x)
354-
:($($(name)).$op(x::$(($(:n))), args...; kw...) =
355-
$($(name)).$op[x.val](args...; kw...))
361+
:(@inline function $($(name)).$op(x::$(($(:n))), args...; kw...)
362+
$($(name)).$op(WithModel(x.val), args...; kw...)
363+
end)
356364
end
357365
end...)
366+
358367
nothing
359368
end)
360369
end
361370

362371
macro typed_wrapper(n, abs)
363372
doctarget = gensym()
364373
Ts = nameof.($(sorts)($t))
374+
Tnames = QuoteNode.(Ts)
365375
Xs = map(Ts) do s
366376
:($(GlobalRef($(TheoryInterface), :impl_type))(x, $($(name)), $($(Meta.quot)(s))))
367377
end
368378
XTs = map(zip(Ts,Xs)) do (T,X)
369379
:($X <: $T || error("Mismatch $($($(Meta.quot)(T))): $($X)$($T)"))
370380
end
371381

382+
Xdict = :(Dict(zip($Ts, [$(Xs...)])))
383+
372384
gv = :($(GlobalRef($(Scopes), :getvalue)))
373385
it = :($(GlobalRef($(TheoryInterface), :impl_type)))
374386

387+
# As an experiment, we put the correct types of the arguments explicitly
388+
# though this doesn't ultimately affect whether there is dynamic dispatch
389+
ms = vcat(map(collect(values($t.resolvers))) do res
390+
map(collect(pairs(res.bysignature))) do (sig, meth′)
391+
meth = $t[meth′].value
392+
is_acc = nameof(typeof(meth)) == :AlgAccessor
393+
is_typ = nameof(typeof(meth)) == :AlgTypeConstructor
394+
op = nameof(meth.declaration)
395+
args = if is_acc
396+
[:($(gensym(:val))::$(nameof(meth.typecondecl)))]
397+
else
398+
vcat(is_typ ? [:($(gensym(:val))::Any)] : Expr[], map(meth[meth.args]) do argbind
399+
:($(gensym(nameof(argbind)))::$(nameof(argbind.value.body.head)))
400+
end)
401+
end
402+
argnames = [first(a.args) for a in args]
403+
:(@inline function $($(name)).$op(x::$(($(:n))){$(Ts...)}, $(args...); kw...) where {$(Ts...)}
404+
$($(name)).$op(WithModel(x.val), $(argnames...); kw...)
405+
end)
406+
end
407+
end...)
408+
409+
375410
esc(quote
376411
# Catch any potential docs above the macro call
377412
const $(doctarget) = nothing
@@ -380,16 +415,24 @@ function wrapper(name::Symbol, t::GAT, mod)
380415
# Declare the wrapper struct
381416
struct $n{$(Ts...)} <: $abs
382417
val::Any
418+
types::Dict{Symbol, Type}
419+
383420
function $n{$(Ts...)}(x::Any) where {$(Ts...)}
384-
$($(GlobalRef(TheoryInterface, :implements)))(x, $($name), [$(Xs...)]) || error("Invalid $($($(name))) model: $x")
385-
$(XTs...)
386-
# TODO? CHECK THAT THE GIVEN PARAMETERS MATCH Xs?
387-
new{$(Ts...)}(x)
421+
types = try
422+
$Xdict
423+
catch MethodError
424+
error("Invalid $($($(name))) model: $x")
425+
end
426+
all(zip([$(Ts...)], [$(Tnames...)])) do (T1,k)
427+
types[k] <: T1 || error("Bad type for $k: $(types[k])$T1 ")
428+
end
429+
new{$(Ts...)}(x, types)
388430
end
389431

390432
function $n(x::Any)
391433
$($(GlobalRef(TheoryInterface, :implements)))(x, $($name), [$(Xs...)]) || error("Invalid $($($(name))) model: $x")
392-
new{$(Xs...)}(x)
434+
types = $Xdict
435+
new{$(Xs...)}(x, types)
393436
end
394437
end
395438
# Apply the caught documentation to the new struct
@@ -399,16 +442,9 @@ function wrapper(name::Symbol, t::GAT, mod)
399442
$(Expr(:macrocall, $(GlobalRef(StructEquality, Symbol("@struct_hash_equal"))), $(mod), $(:n)))
400443

401444
$gv(x::$n) = x.val
402-
$it(x::$n, o::Symbol) = $it(x.val, $($name), o)
445+
$it(x::$n, o::Symbol) = x.types[o]
403446

404-
# Dispatch on model value for all declarations in theory
405-
$(map(filter(x->x[2] isa $AlgDeclaration, $(identvalues(t)))) do (x,j)
406-
if j isa $(AlgDeclaration)
407-
op = nameof(x)
408-
:($($(name)).$op(x::$(($(:n))), args...; kw...) =
409-
$($(name)).$op[x.val](args...; kw...))
410-
end
411-
end...)
447+
$(ms...)
412448
nothing
413449
end)
414450
end
@@ -418,5 +454,4 @@ end
418454
parse_wrapper_input(n::Symbol) = n, Any
419455
parse_wrapper_input(n::Expr) = n.head == :<: ? n.args : error("Bad input for wrapper")
420456

421-
422457
end # module

0 commit comments

Comments
 (0)