@@ -321,9 +321,9 @@ function wrapper(name::Symbol, t::GAT, mod)
321321 :($ (GlobalRef ($ (TheoryInterface), :impl_type ))(x, $ s))
322322 end
323323
324+ Xdict = :(Dict (zip (nameof .($ Ts), [$ (Xs... )])))
324325 gv = :($ (GlobalRef ($ (Scopes), :getvalue )))
325326 it = :($ (GlobalRef ($ (TheoryInterface), :impl_type )))
326-
327327 esc (quote
328328 # Catch any potential docs above the macro call
329329 const $ (doctarget) = nothing
@@ -332,10 +332,16 @@ function wrapper(name::Symbol, t::GAT, mod)
332332 # Declare the wrapper struct
333333 struct $ n <: $abs
334334 val:: Any
335+ types:: Dict{Symbol, Type}
335336 function $n (x:: Any )
336- try $ ($ (GlobalRef (TheoryInterface, :implements )))(x, $ ($ name), [$ (Xs... )])
337- catch MethodError false end || error (" Invalid $($ ($ (name))) model: $x " )
338- new (x)
337+ # TODO opt into checking whether the methods are defined
338+ # right now we just implicitly check whether the types are defined
339+ types = try
340+ $ Xdict
341+ catch MethodError
342+ error (" Invalid $($ ($ (name))) model: $x " )
343+ end
344+ new (x, types)
339345 end
340346 end
341347 # Apply the caught documentation to the new struct
@@ -345,33 +351,62 @@ function wrapper(name::Symbol, t::GAT, mod)
345351 $ (Expr (:macrocall , $ (GlobalRef (StructEquality, Symbol (" @struct_hash_equal" ))), $ (mod), $ (:n )))
346352
347353 $ gv (x:: $n ) = x. val
348- $ it (x:: $n , o:: Symbol ) = $ it (x. val, $ ($ name), o)
354+ $ it (x:: $n , o:: Symbol ) = x. types[o]
355+ Base. getindex (x:: $n , k:: Symbol ) = x. methods[k]
349356
350357 # Dispatch on model value for all declarations in theory
351358 $ (map (filter (x-> x[2 ] isa $ AlgDeclaration, $ (identvalues (t)))) do (x,j)
352359 if j isa $ (AlgDeclaration)
353360 op = nameof (x)
354- :($ ($ (name)). $ op (x:: $ (($ (:n ))), args... ; kw... ) =
355- $ ($ (name)). $ op[x. val](args... ; kw... ))
361+ :(@inline function $ ($ (name)). $ op (x:: $ (($ (:n ))), args... ; kw... )
362+ $ ($ (name)). $ op (WithModel (x. val), args... ; kw... )
363+ end )
356364 end
357365 end ... )
366+
358367 nothing
359368 end )
360369 end
361370
362371 macro typed_wrapper (n, abs)
363372 doctarget = gensym ()
364373 Ts = nameof .($ (sorts)($ t))
374+ Tnames = QuoteNode .(Ts)
365375 Xs = map (Ts) do s
366376 :($ (GlobalRef ($ (TheoryInterface), :impl_type ))(x, $ ($ (name)), $ ($ (Meta. quot)(s))))
367377 end
368378 XTs = map (zip (Ts,Xs)) do (T,X)
369379 :($ X <: $T || error (" Mismatch $($ ($ (Meta. quot)(T))) : $($ X) ⊄ $($ T) " ))
370380 end
371381
382+ Xdict = :(Dict (zip ($ Ts, [$ (Xs... )])))
383+
372384 gv = :($ (GlobalRef ($ (Scopes), :getvalue )))
373385 it = :($ (GlobalRef ($ (TheoryInterface), :impl_type )))
374386
387+ # As an experiment, we put the correct types of the arguments explicitly
388+ # though this doesn't ultimately affect whether there is dynamic dispatch
389+ ms = vcat (map (collect (values ($ t. resolvers))) do res
390+ map (collect (pairs (res. bysignature))) do (sig, meth′)
391+ meth = $ t[meth′]. value
392+ is_acc = nameof (typeof (meth)) == :AlgAccessor
393+ is_typ = nameof (typeof (meth)) == :AlgTypeConstructor
394+ op = nameof (meth. declaration)
395+ args = if is_acc
396+ [:($ (gensym (:val )):: $ (nameof (meth. typecondecl)))]
397+ else
398+ vcat (is_typ ? [:($ (gensym (:val )):: Any )] : Expr[], map (meth[meth. args]) do argbind
399+ :($ (gensym (nameof (argbind))):: $ (nameof (argbind. value. body. head)))
400+ end )
401+ end
402+ argnames = [first (a. args) for a in args]
403+ :(@inline function $ ($ (name)). $ op (x:: $ (($ (:n ))){$ (Ts... )}, $ (args... ); kw... ) where {$ (Ts... )}
404+ $ ($ (name)). $ op (WithModel (x. val), $ (argnames... ); kw... )
405+ end )
406+ end
407+ end ... )
408+
409+
375410 esc (quote
376411 # Catch any potential docs above the macro call
377412 const $ (doctarget) = nothing
@@ -380,16 +415,24 @@ function wrapper(name::Symbol, t::GAT, mod)
380415 # Declare the wrapper struct
381416 struct $ n{$ (Ts... )} <: $abs
382417 val:: Any
418+ types:: Dict{Symbol, Type}
419+
383420 function $n {$(Ts...)} (x:: Any ) where {$ (Ts... )}
384- $ ($ (GlobalRef (TheoryInterface, :implements )))(x, $ ($ name), [$ (Xs... )]) || error (" Invalid $($ ($ (name))) model: $x " )
385- $ (XTs... )
386- # TODO ? CHECK THAT THE GIVEN PARAMETERS MATCH Xs?
387- new {$(Ts...)} (x)
421+ types = try
422+ $ Xdict
423+ catch MethodError
424+ error (" Invalid $($ ($ (name))) model: $x " )
425+ end
426+ all (zip ([$ (Ts... )], [$ (Tnames... )])) do (T1,k)
427+ types[k] <: T1 || error (" Bad type for $k : $(types[k]) ⊄ $T1 " )
428+ end
429+ new {$(Ts...)} (x, types)
388430 end
389431
390432 function $n (x:: Any )
391433 $ ($ (GlobalRef (TheoryInterface, :implements )))(x, $ ($ name), [$ (Xs... )]) || error (" Invalid $($ ($ (name))) model: $x " )
392- new {$(Xs...)} (x)
434+ types = $ Xdict
435+ new {$(Xs...)} (x, types)
393436 end
394437 end
395438 # Apply the caught documentation to the new struct
@@ -399,16 +442,9 @@ function wrapper(name::Symbol, t::GAT, mod)
399442 $ (Expr (:macrocall , $ (GlobalRef (StructEquality, Symbol (" @struct_hash_equal" ))), $ (mod), $ (:n )))
400443
401444 $ gv (x:: $n ) = x. val
402- $ it (x:: $n , o:: Symbol ) = $ it (x . val, $ ( $ name), o)
445+ $ it (x:: $n , o:: Symbol ) = x . types[o]
403446
404- # Dispatch on model value for all declarations in theory
405- $ (map (filter (x-> x[2 ] isa $ AlgDeclaration, $ (identvalues (t)))) do (x,j)
406- if j isa $ (AlgDeclaration)
407- op = nameof (x)
408- :($ ($ (name)). $ op (x:: $ (($ (:n ))), args... ; kw... ) =
409- $ ($ (name)). $ op[x. val](args... ; kw... ))
410- end
411- end ... )
447+ $ (ms... )
412448 nothing
413449 end )
414450 end
418454parse_wrapper_input (n:: Symbol ) = n, Any
419455parse_wrapper_input (n:: Expr ) = n. head == :< : ? n. args : error (" Bad input for wrapper" )
420456
421-
422457end # module
0 commit comments