Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 101 additions & 6 deletions jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,45 @@
PipelineRefs = Union[Sequence[REF], Any]


@dataclasses.dataclass(frozen=True)
class PrefetchedInput:
"""Bundles an ref with its prefetched refs and count of buffers it has prefetched.

This type is recognized by emit_pipeline to automatically bind prefetched
windows to the corresponding BufferedRef.

Attributes:
ref: The original Ref for the input.
prefetched_ref: The prefetched Ref.
prefetched_count: The number of buffers we have prefetched ahead. Note that
prefetched_count is NOT the buffer count of the pipeline (i.e., of the
BufferedRef). It should be strictly less than the pipeline's buffer count.
This is for leaving at least one slot empty for fetching during the first
iteration.
"""

ref: ArrayRef
prefetched_ref: ArrayRef
prefetched_count: int

def bind(self, bref: BufferedRefBase) -> BufferedRefBase:
"""Binds prefetched ref and count to a BufferedRef."""
prefetched_count = self.prefetched_count
window_ref = self.prefetched_ref
if isinstance(bref, BufferedRef):
if (
not bref.is_trivial_windowing
and prefetched_count >= bref.buffer_count
):
prefetched_count = bref.buffer_count - 1
return dataclasses.replace(
bref,
window_ref=window_ref,
is_prefetched=True,
prefetched_count=prefetched_count,
)


def _create_blocked_slice(
block_index: jax.Array | int,
block_size: int,
Expand Down Expand Up @@ -300,6 +339,10 @@ def is_trivial_windowing(self) -> bool:
"""
return False

@property
def is_prefetched(self) -> bool:
return False

def initialize_slots(self):
"""Initializes slots to 0."""
raise NotImplementedError()
Expand Down Expand Up @@ -477,6 +520,12 @@ class BufferedRef(BufferedRefBase):
has_allocated_buffer: bool = dataclasses.field(
default=False, metadata=dict(static=True)
)
is_prefetched: bool = dataclasses.field(
default=False, metadata=dict(static=True)
)
prefetched_count: int = dataclasses.field(
default=0, metadata=dict(static=True)
)

def __post_init__(self):
if self.is_buffered and self.buffer_count < 1:
Expand Down Expand Up @@ -526,6 +575,7 @@ def create(
source_memory_space: tpu_core.MemorySpace | Literal[ANY] = ANY, # pyrefly: ignore[not-a-type]
tiling: Tiling | None = None,
is_trivial_windowing: bool = False,
is_prefetched: bool = False,
) -> BufferedRef:
"""Create a BufferedRef.

Expand Down Expand Up @@ -594,12 +644,16 @@ def create(
buffer_ty = ty.update(shape=(buffer_count * block_shape[0],))
else:
buffer_ty = ty.update(shape=(buffer_count, *block_shape))
if is_prefetched:
window_ref = None # to be bound to existing ref by the pipeline routine
else:
window_ref = buffer_memory_space.from_type(buffer_ty)
return cls(
_spec=spec,
_buffer_type=buffer_type,
_buffer_count=buffer_count,
_grid_rank=grid_rank if use_lookahead else None,
window_ref=buffer_memory_space.from_type(buffer_ty),
window_ref=window_ref,
copy_in_slot=None,
wait_in_slot=None,
copy_out_slot=None,
Expand All @@ -618,6 +672,7 @@ def create(
tiling=tiling,
is_trivial_windowing=is_trivial_windowing,
has_allocated_buffer=True,
is_prefetched=is_prefetched,
)

@classmethod
Expand Down Expand Up @@ -912,6 +967,9 @@ def wait_out(self, dst_ref, grid_indices):
self.sem_sends.at[wait_slot],
).wait()

def advance_next_fetch(self, grid):
return self.with_next_fetch(_next_index(self.next_fetch, grid))


def fetch_with_lookahead(buffered_ref, src_ref,
grid,
Expand Down Expand Up @@ -1208,6 +1266,12 @@ def initialize_step(self, buffered_ref, src_ref, step=0):
if (step + 1) >= buffered_ref.buffer_count:
return buffered_ref

if buffered_ref.is_prefetched:
if step < buffered_ref.prefetched_count:
if buffered_ref.use_lookahead and step > 0:
buffered_ref = buffered_ref.advance_next_fetch(self.grid)
return buffered_ref.advance_copy_in_slot()

if buffered_ref.use_lookahead:
if step == 0:
# We always fetch the first block.
Expand Down Expand Up @@ -1246,6 +1310,12 @@ def wait_in(self, buffered_ref, src_ref) -> BufferedRef:
if buffered_ref.is_trivial_windowing:
return buffered_ref
pred = self.has_changed(buffered_ref) | self.first_step
pred = pred & (
~(
buffered_ref.is_prefetched
& (self.step < buffered_ref.prefetched_count)
)
)

@when(pred)
@self._named_scope("ep_wait_in")
Expand Down Expand Up @@ -1374,6 +1444,9 @@ def _make_pipeline_allocations(
in_refs = refs[:num_in_specs]
out_refs = refs[num_in_specs:]
def make_input_bref(in_spec, in_ref):
is_prefetched = isinstance(in_ref, PrefetchedInput)
if is_prefetched:
in_ref = in_ref.ref
in_aval = _ref_to_value_aval(in_ref)
buffer_count = 2
use_lookahead = False
Expand All @@ -1395,6 +1468,7 @@ def make_input_bref(in_spec, in_ref):
source_memory_space=in_ref.memory_space,
tiling=tiling,
is_trivial_windowing=is_trivial,
is_prefetched=is_prefetched,
)
in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs)
def make_output_bref(out_spec, out_ref):
Expand Down Expand Up @@ -1659,6 +1733,25 @@ def pipeline(
if isinstance(allocations, list):
allocations = tuple(allocations)

# Bind prefetched refs from PrefetchedInput instances.
def _bind_window_ref(bref, in_ref):
assert isinstance(bref, BufferedRef), bref
if isinstance(in_ref, PrefetchedInput):
return in_ref.bind(bref)
return bref

allocations = jax.tree.map(
_bind_window_ref,
allocations,
refs,
is_leaf=lambda x: isinstance(x, (BufferedRefBase, PrefetchedInput)),
)

# Unwrap PrefetchedInput leaves so the loop body gets raw Pallas refs.
refs = jax.tree.map(
lambda r: r.ref if isinstance(r, PrefetchedInput) else r, refs
)

def make_scheduler(step, indices):
return Scheduler(
step,
Expand Down Expand Up @@ -1745,7 +1838,11 @@ def _():
scheduler = make_scheduler(0, initial_indices)
brefs = map_brefs(lambda bref: bref.initialize_slots(), allocations)
def _sync_copy_in(bref, ref):
if bref.is_trivial_windowing and bref.window_ref is not None:
if (
bref.is_trivial_windowing
and bref.window_ref is not None
and not bref.is_prefetched
):
sync_copy(ref, bref, initial_indices)

map_inputs(_sync_copy_in, brefs, refs)
Expand Down Expand Up @@ -1805,8 +1902,6 @@ def emit_pipeline_with_allocations(
out_specs=out_specs,
grid=grid)
pipeline = emit_pipeline(
body,
grid=grid,
in_specs=in_specs,
out_specs=out_specs)
body, grid=grid, in_specs=in_specs, out_specs=out_specs
)
return pipeline, make_allocations
1 change: 1 addition & 0 deletions jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.pallas.mosaic.pipeline import BufferType as BufferType
from jax._src.pallas.mosaic.pipeline import emit_pipeline as emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations as emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import PrefetchedInput as PrefetchedInput
from jax._src.pallas.mosaic.primitives import async_copy as async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast as bitcast
Expand Down
122 changes: 120 additions & 2 deletions tests/pallas/tpu_pallas_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,126 @@ def kernel(x_hbm, out_hbm):

np.testing.assert_allclose(out, x[0])

def test_prefetched_input(self):
def pipeline_body(x_ref, o_ref):
o_ref[...] = x_ref[...]

def kernel(x_hbm_ref, o_hbm_ref):
@functools.partial(
pl.run_scoped,
x_prefetched_vmem=pltpu.VMEM((2, 128), jnp.float32),
)
def _(x_prefetched_vmem):
pltpu.sync_copy(x_hbm_ref.at[pl.ds(0, 128)], x_prefetched_vmem.at[0])

prefetched_input = pltpu.PrefetchedInput(
ref=x_hbm_ref,
prefetched_ref=x_prefetched_vmem,
prefetched_count=1,
)

pltpu.emit_pipeline(
pipeline_body,
grid=(4,),
in_specs=pl.BlockSpec(
(128,),
lambda i: (i,),
pipeline_mode=pl.Buffered(buffer_count=2),
),
out_specs=pl.BlockSpec((128,), lambda i: (i,)),
)(prefetched_input, o_hbm_ref)

x = jnp.arange(512, dtype=jnp.float32)
out = pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((512,), jnp.float32),
)(x)
np.testing.assert_allclose(out, x)

def test_prefetched_input_lookahead(self):
def pipeline_body(x_ref, o_ref):
o_ref[...] = x_ref[...]

def kernel(x_hbm_ref, o_hbm_ref):
@functools.partial(
pl.run_scoped,
x_prefetched_vmem=pltpu.VMEM((2, 128), jnp.float32),
)
def _(x_prefetched_vmem):
pltpu.sync_copy(x_hbm_ref.at[pl.ds(0, 128)], x_prefetched_vmem.at[0])

prefetched_input = pltpu.PrefetchedInput(
ref=x_hbm_ref,
prefetched_ref=x_prefetched_vmem,
prefetched_count=1,
)

pltpu.emit_pipeline(
pipeline_body,
grid=(4,),
in_specs=pl.BlockSpec(
(128,),
lambda i: (i,),
pipeline_mode=pl.Buffered(buffer_count=2, use_lookahead=True),
),
out_specs=pl.BlockSpec((128,), lambda i: (i,)),
)(prefetched_input, o_hbm_ref)

x = jnp.arange(512, dtype=jnp.float32)
out = pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((512,), jnp.float32),
)(x)
np.testing.assert_allclose(out, x)

def test_prefetched_input_trivial_windowing(self):
def pipeline_body(x_ref, o_ref):
o_ref[...] = x_ref[...]

def kernel(x_hbm_ref, o_hbm_ref):
@functools.partial(
pl.run_scoped,
x_prefetched_vmem=pltpu.VMEM((512,), jnp.float32),
)
def _(x_prefetched_vmem):
pltpu.sync_copy(x_hbm_ref, x_prefetched_vmem)

prefetched_input = pltpu.PrefetchedInput(
ref=x_hbm_ref,
prefetched_ref=x_prefetched_vmem,
prefetched_count=1,
)

pltpu.emit_pipeline(
pipeline_body,
grid=(1,),
in_specs=pl.BlockSpec(
(512,),
lambda i: (0,),
pipeline_mode=pl.Buffered(buffer_count=1),
),
out_specs=pl.BlockSpec((512,), lambda i: (0,)),
)(prefetched_input, o_hbm_ref)

x = jnp.arange(512, dtype=jnp.float32)
out = pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((512,), jnp.float32),
)(x)
np.testing.assert_allclose(out, x)


@jtu.with_config(jax_pallas_poison_buffers=True)
class PallasCallPipelinePoisonTest(jtu.JaxTestCase):
Expand Down Expand Up @@ -542,8 +662,6 @@ def kernel(x_hbm_ref, o_hbm_ref):
)




class PallasCallMultipleBufferedPipelineTest(jtu.JaxTestCase):

def setUp(self):
Expand Down
Loading