11use crate :: {
2- Continuation , Executable , ExecutionCtx , Executor , Suspend ,
2+ Continuation , Executable , ExecutionCtx , Executor , InstrKind , Pending , Suspend ,
33 execution:: {
44 lower:: { LowerFunctionBody , LowerFunctionHead } ,
55 resolve:: { ResolveFunctionBody , ResolveFunctionHead } ,
@@ -13,6 +13,8 @@ use attributes::{Exposure, Privacy, SymbolOwnership, Tag};
1313use by_address:: ByAddress ;
1414use derivative:: Derivative ;
1515use diagnostics:: ErrorDiagnostic ;
16+ use itertools:: Itertools ;
17+ use std:: collections:: HashMap ;
1618
1719#[ derive( Clone , Derivative ) ]
1820#[ derivative( Debug , PartialEq , Eq , Hash ) ]
@@ -44,6 +46,30 @@ pub struct ResolveEvaluation<'env> {
4446 #[ derivative( Debug = "ignore" ) ]
4547 #[ derivative( PartialEq = "ignore" ) ]
4648 lowered_func_body : Suspend < ' env , ( ) > ,
49+
50+ #[ derivative( Hash = "ignore" ) ]
51+ #[ derivative( Debug = "ignore" ) ]
52+ #[ derivative( PartialEq = "ignore" ) ]
53+ transitive_func_state : TransitiveFuncState < ' env > ,
54+ }
55+
56+ /// State machine for completing transitive function dependencies
57+ #[ derive( Clone ) ]
58+ enum TransitiveFuncState < ' env > {
59+ Initialize ,
60+ Step {
61+ scan_next : Vec < & ' env FuncBody < ' env > > ,
62+ resolved_bodies : HashMap < ByAddress < & ' env FuncHead < ' env > > , & ' env FuncBody < ' env > > ,
63+ requests : HashMap < ByAddress < & ' env FuncHead < ' env > > , Pending < ' env , & ' env FuncBody < ' env > > > ,
64+ } ,
65+ LowerBodies (
66+ Vec < (
67+ & ' env FuncHead < ' env > ,
68+ & ' env FuncBody < ' env > ,
69+ Pending < ' env , ir:: FuncRef < ' env > > ,
70+ ) > ,
71+ ) ,
72+ Finished ,
4773}
4874
4975impl < ' env > ResolveEvaluation < ' env > {
@@ -56,6 +82,7 @@ impl<'env> ResolveEvaluation<'env> {
5682 resolved_func_body : None ,
5783 lowered_func_head : None ,
5884 lowered_func_body : None ,
85+ transitive_func_state : TransitiveFuncState :: Initialize ,
5986 }
6087 }
6188}
@@ -96,7 +123,7 @@ impl<'env> Executable<'env> for ResolveEvaluation<'env> {
96123 let Some ( resolved_func_head) = executor. demand ( self . resolved_func_head ) else {
97124 return suspend ! (
98125 self . resolved_func_head,
99- executor. request( ResolveFunctionHead :: new( self . comptime_view, & ast_func. head ) ) ,
126+ executor. request( ResolveFunctionHead :: new( self . comptime_view, & ast_func) ) ,
100127 ctx
101128 ) ;
102129 } ;
@@ -106,30 +133,114 @@ impl<'env> Executable<'env> for ResolveEvaluation<'env> {
106133 self . resolved_func_body,
107134 executor. request( ResolveFunctionBody :: new(
108135 self . comptime_view,
109- ast_func,
110136 resolved_func_head
111137 ) ) ,
112138 ctx
113139 ) ;
114140 } ;
115141
116- // 3) Lower the function
142+ // 3) Resolve transitive function dependencies
143+ if let TransitiveFuncState :: Initialize = & self . transitive_func_state {
144+ self . transitive_func_state = TransitiveFuncState :: Step {
145+ scan_next : vec ! [ resolved_func_body] ,
146+ resolved_bodies : HashMap :: new ( ) ,
147+ requests : HashMap :: new ( ) ,
148+ } ;
149+ }
150+
151+ match & mut self . transitive_func_state {
152+ TransitiveFuncState :: Initialize => unreachable ! ( ) ,
153+ TransitiveFuncState :: Step {
154+ scan_next,
155+ resolved_bodies,
156+ requests,
157+ } => {
158+ // Receive completed resolved function bodies
159+ {
160+ let truth = executor. truth . read ( ) . unwrap ( ) ;
161+ for ( head, pending) in std:: mem:: take ( requests) {
162+ let body = truth. demand ( pending) ;
163+ resolved_bodies. insert ( head, body) ;
164+ scan_next. push ( body) ;
165+ }
166+ }
167+
168+ // Search newly resolved function bodies
169+ for body in scan_next. drain ( ..) {
170+ for instr in body. cfg . iter_instrs_unordered ( ) {
171+ // For any function calls made
172+ if let InstrKind :: Call ( _, call_target) = & instr. kind {
173+ let call_target = call_target. as_ref ( ) . unwrap ( ) ;
174+ let key = ByAddress ( call_target. callee ) ;
175+
176+ // If we didn't already resolve the callee and aren't already planning
177+ // to request the callee to be resolved
178+ if !resolved_bodies. contains_key ( & key) && !requests. contains_key ( & key) {
179+ // Then add the callee to the set of functions to resolve in the next
180+ // suspend.
181+ requests. insert (
182+ key,
183+ executor. request ( ResolveFunctionBody :: new ( key. view , key. 0 ) ) ,
184+ ) ;
185+ }
186+ }
187+ }
188+ }
189+
190+ // Suspend and continue stepping if any requests need to be made
191+ if !requests. is_empty ( ) {
192+ ctx. suspend_on ( requests. iter ( ) . map ( |( _, pending) | pending) ) ;
193+ return Err ( Continuation :: suspend ( self ) ) ;
194+ }
195+
196+ // Otherwise, all function CFG bodies have been resolved,
197+ // and now we have to lower all of the function heads used...
198+ let mut pending_heads = Vec :: new ( ) ;
199+ for ( head, body) in resolved_bodies {
200+ let request = executor. request ( LowerFunctionHead :: new ( head) ) ;
201+ pending_heads. push ( ( * * head, * body, request) ) ;
202+ }
203+
204+ // Suspend on all function heads being lowered that we need
205+ ctx. suspend_on ( pending_heads. iter ( ) . map ( |( _, _, pending) | pending) ) ;
206+ self . transitive_func_state = TransitiveFuncState :: LowerBodies ( pending_heads) ;
207+ return Err ( Continuation :: suspend ( self ) ) ;
208+ }
209+ TransitiveFuncState :: LowerBodies ( pending_heads) => {
210+ // Extract lowered function heads that we waited on
211+ let pending = {
212+ let truth = executor. truth . read ( ) . unwrap ( ) ;
213+ pending_heads
214+ . into_iter ( )
215+ . map ( |( head, body, request) | {
216+ let ir_func_ref = truth. demand ( * request) ;
217+ LowerFunctionBody :: new ( * ir_func_ref, head, body)
218+ } )
219+ . collect_vec ( )
220+ } ;
221+
222+ // Suspend on all function bodies being lowered that we need
223+ ctx. suspend_on ( executor. request_many ( pending. into_iter ( ) ) ) ;
224+ self . transitive_func_state = TransitiveFuncState :: Finished ;
225+ return Err ( Continuation :: suspend ( self ) ) ;
226+ }
227+ TransitiveFuncState :: Finished => ( ) ,
228+ }
229+
230+ // 4) Lower the entry point function head
117231 let Some ( lowered_func_head) = executor. demand ( self . lowered_func_head ) else {
118232 return suspend ! (
119233 self . lowered_func_head,
120- executor. request( LowerFunctionHead :: new(
121- self . comptime_view,
122- resolved_func_head
123- ) ) ,
234+ executor. request( LowerFunctionHead :: new( resolved_func_head) ) ,
124235 ctx
125236 ) ;
126237 } ;
127238
239+ // 5) Lower the entry point function body
128240 let Some ( _lowered_func_body) = executor. demand ( self . lowered_func_body ) else {
129241 return suspend ! (
130242 self . lowered_func_body,
131243 executor. request( LowerFunctionBody :: new(
132- self . comptime_view,
133244 lowered_func_head,
134245 resolved_func_head,
135246 resolved_func_body,
@@ -138,10 +249,10 @@ impl<'env> Executable<'env> for ResolveEvaluation<'env> {
138249 ) ;
139250 } ;
140251
141- // 4 ) Obtain the intermediate representation for comptime so far
252+ // 6 ) Obtain the intermediate representation for comptime so far
142253 let ir = self . comptime_view . graph ( |graph| graph. ir ) ;
143254
144- // 5 ) Interpret the function and raise any interpretation errors
255+ // 7 ) Interpret the function and raise any interpretation errors
145256 let mut interpreter =
146257 Interpreter :: new ( ComptimeSystemSyscallHandler :: default ( ) , ir, Some ( 1_000_000 ) ) ;
147258
@@ -152,10 +263,10 @@ impl<'env> Executable<'env> for ResolveEvaluation<'env> {
152263 // The actual entry point result should be void
153264 entry_point_result. kind . unwrap_literal ( ) . unwrap_void ( ) ;
154265
155- // 6 ) Examine the result value that was baked by the function
266+ // 8 ) Examine the result value that was baked by the function
156267 let exit_value = interpreter. exit_value ( ) ;
157268
158- // 7 ) Expect that the exit value is transferrable
269+ // 9 ) Expect that the exit value is transferrable
159270 let Some ( exit_value) = exit_value else {
160271 return Err ( ErrorDiagnostic :: new (
161272 "Compile-time evaluation must evaluate to transferable value" ,
@@ -164,7 +275,7 @@ impl<'env> Executable<'env> for ResolveEvaluation<'env> {
164275 . into ( ) ) ;
165276 } ;
166277
167- // 8 ) Translate the constant value into a literal value
278+ // 10 ) Translate the constant value into a literal value
168279 // and/or static data that can be used as a literal.
169280 Ok ( ctx. alloc ( Evaluated :: new_unsigned ( exit_value) ) )
170281
0 commit comments