Skip to content

Commit aa73aa0

Browse files
epiquerasGoogle-ML-Automation
authored andcommitted
Pallas pipeline API tweaks for more advanced pipelining patterns.
PiperOrigin-RevId: 678426679
1 parent 85a466d commit aa73aa0

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

jax/_src/pallas/mosaic/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ py_library(
107107
":primitives",
108108
"//jax",
109109
"//jax:api_util",
110+
"//jax:pallas",
110111
"//jax:util",
111112
"//jax/_src/pallas",
112113
] + py_deps("numpy"),

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class BufferType(enum.Enum):
178178
ACCUMULATOR = 3
179179
INPUT_OUTPUT = 4
180180

181+
MANUAL = 5
182+
181183

182184
@tree_util.register_pytree_node_class
183185
@dataclasses.dataclass(frozen=True)
@@ -234,6 +236,10 @@ def tree_flatten(self):
234236
def tree_unflatten(cls, meta, data):
235237
return cls(*meta, *data)
236238

239+
@staticmethod
240+
def buffer_types() -> type[BufferType]:
241+
return BufferType
242+
237243
@classmethod
238244
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
239245
"""Create a BufferedRef.
@@ -1034,6 +1040,7 @@ def pipeline(
10341040
prefetch=None,
10351041
postyeet=None,
10361042
schedule=None,
1043+
body_prologue=None,
10371044
):
10381045
"""
10391046
Run the pipeline.
@@ -1056,6 +1063,8 @@ def pipeline(
10561063
Called during the outputs phase in the first inner step.
10571064
schedule: manually specified pipeline schedules for brefs, None indicates
10581065
default schedule.
1066+
body_prologue: For running code within the grid environment before the
1067+
body is run. Useful for updating manual refs.
10591068
"""
10601069
if scratches is None:
10611070
scratches = ()
@@ -1119,6 +1128,9 @@ def loop_body(step, _):
11191128
lambda: None)
11201129

11211130
# run the kernel!
1131+
if body_prologue is not None:
1132+
with scheduler.grid_env():
1133+
body_prologue()
11221134
current_refs = map_brefs(lambda x: x.current_ref, brefs)
11231135
with scheduler._named_scope("ep_run_kernel"):
11241136
with scheduler.grid_env():

0 commit comments

Comments
 (0)