Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2911,7 +2911,7 @@ def _convert_element_type_lowering_rule(
old_dtype = in_aval.dtype
out_type = ctx.aval_to_ir_type(out_aval)

if old_dtype == new_dtype:
if old_dtype == new_dtype or x.type == out_type:
return x

if new_dtype.itemsize == 8:
Expand Down
90 changes: 54 additions & 36 deletions jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax._src.pallas.mosaic import tpu_info
from jax._src.state import indexing
import numpy as np
import jax.numpy as jnp


Expand Down Expand Up @@ -692,7 +693,7 @@ def cumulative_copy_in(self):
@property
def current_copy_in_slot(self):
"""Index in multiple buffer corresponding to the current slot."""
return lax.rem(self.cumulative_copy_in, jnp.uint32(self.buffer_count))
return lax.rem(self.cumulative_copy_in, np.uint32(self.buffer_count))

@property
def cumulative_copy_out(self):
Expand All @@ -703,7 +704,7 @@ def cumulative_copy_out(self):
@property
def current_copy_out_slot(self):
"""Index in multiple buffer corresponding to the current copy slot."""
return lax.rem(self.cumulative_copy_out, jnp.uint32(self.buffer_count))
return lax.rem(self.cumulative_copy_out, np.uint32(self.buffer_count))

@property
def cumulative_wait_in(self):
Expand All @@ -714,7 +715,7 @@ def cumulative_wait_in(self):
@property
def current_wait_in_slot(self):
"""Index in multiple buffer corresponding to the current wait slot."""
return lax.rem(self.cumulative_wait_in, jnp.uint32(self.buffer_count))
return lax.rem(self.cumulative_wait_in, np.uint32(self.buffer_count))

@property
def cumulative_wait_out(self):
Expand All @@ -725,7 +726,7 @@ def cumulative_wait_out(self):
@property
def current_wait_out_slot(self):
"""Index in multiple buffer corresponding to the current wait slot."""
return lax.rem(self.cumulative_wait_out, jnp.uint32(self.buffer_count))
return lax.rem(self.cumulative_wait_out, np.uint32(self.buffer_count))

@property
def next_fetch_indices(self):
Expand Down Expand Up @@ -780,12 +781,12 @@ def compute_slice(self, grid_indices):
def initialize_slots(self) -> BufferedRef:
return dataclasses.replace(
self,
copy_in_slot=jnp.uint32(0) if self.buffer_type.is_input else None,
wait_in_slot=jnp.uint32(0) if self.buffer_type.is_input else None,
copy_out_slot=jnp.uint32(0) if self.buffer_type.is_output else None,
wait_out_slot=jnp.uint32(0) if self.buffer_type.is_output else None,
copy_in_slot=np.uint32(0) if self.buffer_type.is_input else None,
wait_in_slot=np.uint32(0) if self.buffer_type.is_input else None,
copy_out_slot=np.uint32(0) if self.buffer_type.is_output else None,
wait_out_slot=np.uint32(0) if self.buffer_type.is_output else None,
next_fetch=(
tuple(jnp.int32(0) for _ in range(self._grid_rank))
tuple(np.int32(0) for _ in range(self._grid_rank))
if self._grid_rank is not None
else None
),
Expand Down Expand Up @@ -1011,18 +1012,20 @@ def fmap(bref, *f_args):


def _filter_indices(
indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...]
) -> tuple[int | jax.Array, ...]:
indices: tuple[int | np.int32 | jax.Array, ...],
grid: tuple[int | np.int32 | jax.Array, ...]
) -> tuple[int | np.int32 | jax.Array, ...]:
return tuple(
0 if isinstance(g, int) and g == 1 else i
np.int32(0) if isinstance(g, int) and g == 1 else i
for i, g in zip(indices, grid, strict=True)
)


def _next_index(
indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...],
indices: tuple[int | np.int32 | jax.Array, ...],
grid: tuple[int | np.int32 | jax.Array, ...],
allow_overflow: bool = False,
) -> tuple[int | jax.Array, ...]:
) -> tuple[int | np.int32 | jax.Array, ...]:
"""Increments the grid indices by one.

Args:
Expand All @@ -1044,23 +1047,23 @@ def _next_index(
if allow_overflow and (position == len(grid) - 1):
carry = False
else:
carry = inc == g
out.append(jax.lax.select(carry, 0, inc))
carry = inc == (np.int32(g) if isinstance(g, int) else g)
out.append(jax.lax.select(carry, np.int32(0), inc))
if allow_overflow:
return tuple(reversed(out))
else:
return _filter_indices(tuple(reversed(out)), grid)


def _prev_index(
indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...]
) -> tuple[int | jax.Array, ...]:
indices: tuple[np.int32 | jax.Array, ...], grid: tuple[np.int32 | jax.Array, ...]
) -> tuple[np.int32 | jax.Array, ...]:
out = []
borrow: bool | jax.Array = True
for i, g in reversed(list(zip(indices, grid, strict=True))):
dec = jax.lax.select(borrow, i - 1, i)
borrow = dec == -1
out.append(jax.lax.select(borrow, g - 1, dec))
out.append(jax.lax.select(borrow, np.int32(g - 1) if isinstance(g, int) else (g - 1), dec))
return _filter_indices(tuple(reversed(out)), grid)


Expand All @@ -1070,9 +1073,9 @@ class Scheduler:
def __init__(
self,
step: jax.Array,
indices: tuple[int | jax.Array, ...],
grid: tuple[int | jax.Array, ...],
grid_offsets: tuple[int | jax.Array, ...],
indices: tuple[np.int32 | jax.Array, ...],
grid: tuple[np.int32 | jax.Array, ...],
grid_offsets: tuple[np.int32 | jax.Array, ...],
num_stages: int,
trace_scopes=True,
_explicit_indices: bool = False,
Expand All @@ -1099,8 +1102,12 @@ def __init__(
self.num_steps = math.prod(grid)

# First and last inner step conditionals.
self.first_step = step == 0
self.last_step = step == self.num_steps - 1
self.first_step = step == np.int32(0)
self.last_step = step == (
np.int32(self.num_steps - 1)
if isinstance(self.num_steps, int)
else (self.num_steps - 1)
)

# Derived grid indices for present, previous, and next steps.
self.indices = tuple(
Expand Down Expand Up @@ -1151,7 +1158,9 @@ def out_of_fetch(self, buffered_ref):
# lookahead this will depend on whether the lookahead reached the end.
if not buffered_ref.is_buffered:
return jnp.bool(False)
return self.step >= (self.num_steps - buffered_ref.buffer_count + 1)
ub = self.num_steps - buffered_ref.buffer_count + 1
ub_32 = np.int32(ub) if isinstance(ub, int) else ub
return self.step >= ub_32

def has_changed(self, buffered_ref):
if not buffered_ref.is_buffered or buffered_ref.is_trivial_windowing:
Expand Down Expand Up @@ -1421,13 +1430,13 @@ def make_output_bref(out_spec, out_ref):


def _partition_grid(
grid: tuple[int | jax.Array, ...],
grid: tuple[np.int32 | jax.Array, ...],
core_axis: tuple[int | str, ...] | int | str | None,
dimension_semantics: tuple[GridDimensionSemantics, ...] | None,
) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]:
) -> tuple[tuple[np.int32 | jax.Array, ...], tuple[np.int32 | jax.Array, ...]]:
if core_axis is None:
# We aren't partitioning the grid
return grid, (0,) * len(grid)
return grid, (np.int32(0),) * len(grid)
if isinstance(core_axis, int):
num_cores = num_programs(core_axis)
core_id = program_id(core_axis)
Expand All @@ -1441,7 +1450,7 @@ def _partition_grid(
)
if num_cores == 1:
# We aren't partitioning the grid
return grid, (0,) * len(grid)
return grid, (np.int32(0),) * len(grid)

# If dimension_semantics aren't provided, we assume it is all arbitrary.
if dimension_semantics is None:
Expand Down Expand Up @@ -1476,7 +1485,7 @@ def _partition_grid(
grid, first_divisible_dimension, partitioned_dim_size
)
offsets = jax_util.tuple_update(
(0,) * len(grid),
(np.int32(0),) * len(grid),
first_divisible_dimension,
partitioned_dim_offset,
)
Expand Down Expand Up @@ -1529,7 +1538,7 @@ def _partition_grid(
core_id * base_num_iters + rem,
)
offsets = jax_util.tuple_update(
(0,) * len(grid),
(np.int32(0),) * len(grid),
partition_dimension,
grid_offset,
)
Expand Down Expand Up @@ -1612,6 +1621,10 @@ def emit_pipeline(
raise ValueError("core_axis and core_axis_name cannot both be provided.")
core_axis_ = core_axis_name if core_axis is None else core_axis
grid, grid_offsets = _partition_grid(grid, core_axis_, dimension_semantics)
grid = tuple(np.int32(g) if isinstance(g, int) else g for g in grid)
grid_offsets = tuple(
np.int32(g) if isinstance(g, int) else g for g in grid_offsets
)

num_steps = math.prod(grid)
in_specs = _normalize_specs(in_specs)
Expand Down Expand Up @@ -1704,13 +1717,15 @@ def loop_body(step, carry):

if no_pipelining:
# Debugging mode where all copies are synchronous.
initial_indices = (0,) * len(grid)
lower_bnd = np.int32(0)
upper_bnd = np.int32(num_steps) if isinstance(num_steps, int) else num_steps
initial_indices = (np.int32(0),) * len(grid)
brefs = map_brefs(lambda bref: bref.initialize_slots(), allocations)

@functools.partial(
jax.lax.fori_loop,
0,
num_steps,
lower_bnd,
upper_bnd,
init_val=(brefs, initial_indices),
)
def _loop_body(step, carry):
Expand Down Expand Up @@ -1741,7 +1756,9 @@ def _loop_body(step, carry):
@when(num_steps > 0)
def _():
# pipeline prologue
initial_indices = (0,) * len(grid)
lower_bnd = np.int32(0)
upper_bnd = np.int32(num_steps) if isinstance(num_steps, int) else num_steps
initial_indices = (np.int32(0),) * len(grid)
scheduler = make_scheduler(0, initial_indices)
brefs = map_brefs(lambda bref: bref.initialize_slots(), allocations)
def _sync_copy_in(bref, ref):
Expand All @@ -1760,7 +1777,8 @@ def _sync_copy_in(bref, ref):

# pipeline loop
brefs, next_indices = lax.fori_loop(
0, num_steps, loop_body, (brefs, initial_indices)
lower_bnd, upper_bnd,
loop_body, (brefs, initial_indices)
)

# pipeline epilogue
Expand Down
60 changes: 60 additions & 0 deletions tests/pallas/tpu_sparsecore_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,6 +2511,66 @@ def pipeline(x_ref, o_ref):

np.testing.assert_array_equal(kernel(x), x + 1)

def test_basic_x64(self):
self.skip_if_tc_tiling()
self.enter_context(jax.enable_x64(True))
num_steps = 16
x = jnp.arange(num_steps * self.num_lanes, dtype=jnp.int32).reshape(-1, 8)

@self.vector_subcore_kernel(
out_shape=x,
in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),),
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
)
def kernel(x_hbm_ref, o_hbm_ref):

@functools.partial(
pltpu.emit_pipeline,
grid=(num_steps,),
in_specs=pl.BlockSpec(
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
),
out_specs=pl.BlockSpec(
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
),
)
def pipeline(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1

pipeline(x_hbm_ref, o_hbm_ref)

np.testing.assert_array_equal(kernel(x), x + 1)

def test_pipeline_disable_jit(self):
self.skip_if_tc_tiling()
num_steps = 16
x = jnp.arange(num_steps * self.num_lanes, dtype=jnp.int32).reshape(-1, 8)

@self.vector_subcore_kernel(
out_shape=x,
in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),),
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
)
def kernel(x_hbm_ref, o_hbm_ref):

@functools.partial(
pltpu.emit_pipeline,
grid=(num_steps,),
in_specs=pl.BlockSpec(
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
),
out_specs=pl.BlockSpec(
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
),
)
def pipeline(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1

pipeline(x_hbm_ref, o_hbm_ref)

with jax.disable_jit():
kernel(x)

def test_gather_with_emit(self):
self.skip_if_tc_tiling()
sc_mesh = sc_core.VectorSubcoreMesh(
Expand Down
Loading