@@ -178,6 +178,8 @@ class BufferType(enum.Enum):
178178 ACCUMULATOR = 3
179179 INPUT_OUTPUT = 4
180180
181+ MANUAL = 5
182+
181183
182184@tree_util .register_pytree_node_class
183185@dataclasses .dataclass (frozen = True )
@@ -234,6 +236,10 @@ def tree_flatten(self):
234236 def tree_unflatten (cls , meta , data ):
235237 return cls (* meta , * data )
236238
239+ @staticmethod
240+ def buffer_types () -> type [BufferType ]:
241+ return BufferType
242+
237243 @classmethod
238244 def create (cls , spec , dtype , buffer_type ) -> BufferedRef :
239245 """Create a BufferedRef.
@@ -1034,6 +1040,7 @@ def pipeline(
10341040 prefetch = None ,
10351041 postyeet = None ,
10361042 schedule = None ,
1043+ body_prologue = None ,
10371044 ):
10381045 """
10391046 Run the pipeline.
@@ -1056,6 +1063,8 @@ def pipeline(
10561063 Called during the outputs phase in the first inner step.
10571064 schedule: manually specified pipeline schedules for brefs, None indicates
10581065 default schedule.
1066+ body_prologue: For running code within the grid environment before the
1067+ body is run. Useful for updating manual refs.
10591068 """
10601069 if scratches is None :
10611070 scratches = ()
@@ -1119,6 +1128,9 @@ def loop_body(step, _):
11191128 lambda : None )
11201129
11211130 # run the kernel!
1131+ if body_prologue is not None :
1132+ with scheduler .grid_env ():
1133+ body_prologue ()
11221134 current_refs = map_brefs (lambda x : x .current_ref , brefs )
11231135 with scheduler ._named_scope ("ep_run_kernel" ):
11241136 with scheduler .grid_env ():
0 commit comments