Skip to content

Commit 6465433

Browse files
gneculaGoogle-ML-Automation
authored andcommitted
Rollback of 64-bit more fix for pipelines.
That change was #38154. The problem is that now constants that were 0 or 1 were turned into np.int32 and there are several `is_instance(v, int)` checks throughout Pallas that behave differently than before. I need to rethink how to make 64-bit more work for pipelines. Reverts 9b82b11 PiperOrigin-RevId: 926690252
1 parent 4b9bf3f commit 6465433

3 files changed

Lines changed: 38 additions & 116 deletions

File tree

jax/_src/pallas/mosaic/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2911,7 +2911,7 @@ def _convert_element_type_lowering_rule(
29112911
old_dtype = in_aval.dtype
29122912
out_type = ctx.aval_to_ir_type(out_aval)
29132913

2914-
if old_dtype == new_dtype or x.type == out_type:
2914+
if old_dtype == new_dtype:
29152915
return x
29162916

29172917
if new_dtype.itemsize == 8:

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from jax._src.pallas.mosaic import primitives as tpu_primitives
4040
from jax._src.pallas.mosaic import tpu_info
4141
from jax._src.state import indexing
42-
import numpy as np
4342
import jax.numpy as jnp
4443

4544

@@ -693,7 +692,7 @@ def cumulative_copy_in(self):
693692
@property
694693
def current_copy_in_slot(self):
695694
"""Index in multiple buffer corresponding to the current slot."""
696-
return lax.rem(self.cumulative_copy_in, np.uint32(self.buffer_count))
695+
return lax.rem(self.cumulative_copy_in, jnp.uint32(self.buffer_count))
697696

698697
@property
699698
def cumulative_copy_out(self):
@@ -704,7 +703,7 @@ def cumulative_copy_out(self):
704703
@property
705704
def current_copy_out_slot(self):
706705
"""Index in multiple buffer corresponding to the current copy slot."""
707-
return lax.rem(self.cumulative_copy_out, np.uint32(self.buffer_count))
706+
return lax.rem(self.cumulative_copy_out, jnp.uint32(self.buffer_count))
708707

709708
@property
710709
def cumulative_wait_in(self):
@@ -715,7 +714,7 @@ def cumulative_wait_in(self):
715714
@property
716715
def current_wait_in_slot(self):
717716
"""Index in multiple buffer corresponding to the current wait slot."""
718-
return lax.rem(self.cumulative_wait_in, np.uint32(self.buffer_count))
717+
return lax.rem(self.cumulative_wait_in, jnp.uint32(self.buffer_count))
719718

720719
@property
721720
def cumulative_wait_out(self):
@@ -726,7 +725,7 @@ def cumulative_wait_out(self):
726725
@property
727726
def current_wait_out_slot(self):
728727
"""Index in multiple buffer corresponding to the current wait slot."""
729-
return lax.rem(self.cumulative_wait_out, np.uint32(self.buffer_count))
728+
return lax.rem(self.cumulative_wait_out, jnp.uint32(self.buffer_count))
730729

731730
@property
732731
def next_fetch_indices(self):
@@ -781,12 +780,12 @@ def compute_slice(self, grid_indices):
781780
def initialize_slots(self) -> BufferedRef:
782781
return dataclasses.replace(
783782
self,
784-
copy_in_slot=np.uint32(0) if self.buffer_type.is_input else None,
785-
wait_in_slot=np.uint32(0) if self.buffer_type.is_input else None,
786-
copy_out_slot=np.uint32(0) if self.buffer_type.is_output else None,
787-
wait_out_slot=np.uint32(0) if self.buffer_type.is_output else None,
783+
copy_in_slot=jnp.uint32(0) if self.buffer_type.is_input else None,
784+
wait_in_slot=jnp.uint32(0) if self.buffer_type.is_input else None,
785+
copy_out_slot=jnp.uint32(0) if self.buffer_type.is_output else None,
786+
wait_out_slot=jnp.uint32(0) if self.buffer_type.is_output else None,
788787
next_fetch=(
789-
tuple(np.int32(0) for _ in range(self._grid_rank))
788+
tuple(jnp.int32(0) for _ in range(self._grid_rank))
790789
if self._grid_rank is not None
791790
else None
792791
),
@@ -1012,20 +1011,18 @@ def fmap(bref, *f_args):
10121011

10131012

10141013
def _filter_indices(
1015-
indices: tuple[int | np.int32 | jax.Array, ...],
1016-
grid: tuple[int | np.int32 | jax.Array, ...]
1017-
) -> tuple[int | np.int32 | jax.Array, ...]:
1014+
indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...]
1015+
) -> tuple[int | jax.Array, ...]:
10181016
return tuple(
1019-
np.int32(0) if isinstance(g, int) and g == 1 else i
1017+
0 if isinstance(g, int) and g == 1 else i
10201018
for i, g in zip(indices, grid, strict=True)
10211019
)
10221020

10231021

10241022
def _next_index(
1025-
indices: tuple[int | np.int32 | jax.Array, ...],
1026-
grid: tuple[int | np.int32 | jax.Array, ...],
1023+
indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...],
10271024
allow_overflow: bool = False,
1028-
) -> tuple[int | np.int32 | jax.Array, ...]:
1025+
) -> tuple[int | jax.Array, ...]:
10291026
"""Increments the grid indices by one.
10301027
10311028
Args:
@@ -1047,23 +1044,23 @@ def _next_index(
10471044
if allow_overflow and (position == len(grid) - 1):
10481045
carry = False
10491046
else:
1050-
carry = inc == (np.int32(g) if isinstance(g, int) else g)
1051-
out.append(jax.lax.select(carry, np.int32(0), inc))
1047+
carry = inc == g
1048+
out.append(jax.lax.select(carry, 0, inc))
10521049
if allow_overflow:
10531050
return tuple(reversed(out))
10541051
else:
10551052
return _filter_indices(tuple(reversed(out)), grid)
10561053

10571054

10581055
def _prev_index(
1059-
indices: tuple[int | np.int32 | jax.Array, ...], grid: tuple[int | np.int32 | jax.Array, ...]
1060-
) -> tuple[int | np.int32 | jax.Array, ...]:
1056+
indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...]
1057+
) -> tuple[int | jax.Array, ...]:
10611058
out = []
10621059
borrow: bool | jax.Array = True
10631060
for i, g in reversed(list(zip(indices, grid, strict=True))):
10641061
dec = jax.lax.select(borrow, i - 1, i)
10651062
borrow = dec == -1
1066-
out.append(jax.lax.select(borrow, np.int32(g - 1) if isinstance(g, int) else (g - 1), dec))
1063+
out.append(jax.lax.select(borrow, g - 1, dec))
10671064
return _filter_indices(tuple(reversed(out)), grid)
10681065

10691066

@@ -1073,9 +1070,9 @@ class Scheduler:
10731070
def __init__(
10741071
self,
10751072
step: jax.Array,
1076-
indices: tuple[int | np.int32 | jax.Array, ...],
1077-
grid: tuple[int | np.int32 | jax.Array, ...],
1078-
grid_offsets: tuple[int | np.int32 | jax.Array, ...],
1073+
indices: tuple[int | jax.Array, ...],
1074+
grid: tuple[int | jax.Array, ...],
1075+
grid_offsets: tuple[int | jax.Array, ...],
10791076
num_stages: int,
10801077
trace_scopes=True,
10811078
_explicit_indices: bool = False,
@@ -1102,12 +1099,8 @@ def __init__(
11021099
self.num_steps = math.prod(grid)
11031100

11041101
# First and last inner step conditionals.
1105-
self.first_step = step == np.int32(0)
1106-
self.last_step = step == (
1107-
np.int32(self.num_steps - 1)
1108-
if isinstance(self.num_steps, int)
1109-
else (self.num_steps - 1)
1110-
)
1102+
self.first_step = step == 0
1103+
self.last_step = step == self.num_steps - 1
11111104

11121105
# Derived grid indices for present, previous, and next steps.
11131106
self.indices = tuple(
@@ -1158,9 +1151,7 @@ def out_of_fetch(self, buffered_ref):
11581151
# lookahead this will depend on whether the lookahead reached the end.
11591152
if not buffered_ref.is_buffered:
11601153
return jnp.bool(False)
1161-
ub = self.num_steps - buffered_ref.buffer_count + 1
1162-
ub_32 = np.int32(ub) if isinstance(ub, int) else ub
1163-
return self.step >= ub_32
1154+
return self.step >= (self.num_steps - buffered_ref.buffer_count + 1)
11641155

11651156
def has_changed(self, buffered_ref):
11661157
if not buffered_ref.is_buffered or buffered_ref.is_trivial_windowing:
@@ -1430,13 +1421,13 @@ def make_output_bref(out_spec, out_ref):
14301421

14311422

14321423
def _partition_grid(
1433-
grid: tuple[np.int32 | jax.Array, ...],
1424+
grid: tuple[int | jax.Array, ...],
14341425
core_axis: tuple[int | str, ...] | int | str | None,
14351426
dimension_semantics: tuple[GridDimensionSemantics, ...] | None,
1436-
) -> tuple[tuple[np.int32 | jax.Array, ...], tuple[np.int32 | jax.Array, ...]]:
1427+
) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]:
14371428
if core_axis is None:
14381429
# We aren't partitioning the grid
1439-
return grid, (np.int32(0),) * len(grid)
1430+
return grid, (0,) * len(grid)
14401431
if isinstance(core_axis, int):
14411432
num_cores = num_programs(core_axis)
14421433
core_id = program_id(core_axis)
@@ -1450,7 +1441,7 @@ def _partition_grid(
14501441
)
14511442
if num_cores == 1:
14521443
# We aren't partitioning the grid
1453-
return grid, (np.int32(0),) * len(grid)
1444+
return grid, (0,) * len(grid)
14541445

14551446
# If dimension_semantics aren't provided, we assume it is all arbitrary.
14561447
if dimension_semantics is None:
@@ -1485,7 +1476,7 @@ def _partition_grid(
14851476
grid, first_divisible_dimension, partitioned_dim_size
14861477
)
14871478
offsets = jax_util.tuple_update(
1488-
(np.int32(0),) * len(grid),
1479+
(0,) * len(grid),
14891480
first_divisible_dimension,
14901481
partitioned_dim_offset,
14911482
)
@@ -1538,7 +1529,7 @@ def _partition_grid(
15381529
core_id * base_num_iters + rem,
15391530
)
15401531
offsets = jax_util.tuple_update(
1541-
(np.int32(0),) * len(grid),
1532+
(0,) * len(grid),
15421533
partition_dimension,
15431534
grid_offset,
15441535
)
@@ -1620,11 +1611,7 @@ def emit_pipeline(
16201611
if not (core_axis is None or core_axis_name is None):
16211612
raise ValueError("core_axis and core_axis_name cannot both be provided.")
16221613
core_axis_ = core_axis_name if core_axis is None else core_axis
1623-
grid, grid_offsets = _partition_grid(grid, core_axis_, dimension_semantics) # type: ignore
1624-
grid = tuple(np.int32(g) if isinstance(g, int) else g for g in grid) # type: ignore
1625-
grid_offsets = tuple(
1626-
np.int32(g) if isinstance(g, int) else g for g in grid_offsets
1627-
)
1614+
grid, grid_offsets = _partition_grid(grid, core_axis_, dimension_semantics)
16281615

16291616
num_steps = math.prod(grid)
16301617
in_specs = _normalize_specs(in_specs)
@@ -1717,15 +1704,13 @@ def loop_body(step, carry):
17171704

17181705
if no_pipelining:
17191706
# Debugging mode where all copies are synchronous.
1720-
lower_bnd = np.int32(0)
1721-
upper_bnd = np.int32(num_steps) if isinstance(num_steps, int) else num_steps
1722-
initial_indices = (np.int32(0),) * len(grid)
1707+
initial_indices = (0,) * len(grid)
17231708
brefs = map_brefs(lambda bref: bref.initialize_slots(), allocations)
17241709

17251710
@functools.partial(
17261711
jax.lax.fori_loop,
1727-
lower_bnd,
1728-
upper_bnd,
1712+
0,
1713+
num_steps,
17291714
init_val=(brefs, initial_indices),
17301715
)
17311716
def _loop_body(step, carry):
@@ -1756,9 +1741,7 @@ def _loop_body(step, carry):
17561741
@when(num_steps > 0)
17571742
def _():
17581743
# pipeline prologue
1759-
lower_bnd = np.int32(0)
1760-
upper_bnd = np.int32(num_steps) if isinstance(num_steps, int) else num_steps
1761-
initial_indices = (np.int32(0),) * len(grid)
1744+
initial_indices = (0,) * len(grid)
17621745
scheduler = make_scheduler(0, initial_indices)
17631746
brefs = map_brefs(lambda bref: bref.initialize_slots(), allocations)
17641747
def _sync_copy_in(bref, ref):
@@ -1777,8 +1760,7 @@ def _sync_copy_in(bref, ref):
17771760

17781761
# pipeline loop
17791762
brefs, next_indices = lax.fori_loop(
1780-
lower_bnd, upper_bnd,
1781-
loop_body, (brefs, initial_indices)
1763+
0, num_steps, loop_body, (brefs, initial_indices)
17821764
)
17831765

17841766
# pipeline epilogue

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,66 +2511,6 @@ def pipeline(x_ref, o_ref):
25112511

25122512
np.testing.assert_array_equal(kernel(x), x + 1)
25132513

2514-
def test_basic_x64(self):
2515-
self.skip_if_tc_tiling()
2516-
self.enter_context(jax.enable_x64(True))
2517-
num_steps = 16
2518-
x = jnp.arange(num_steps * self.num_lanes, dtype=jnp.int32).reshape(-1, 8)
2519-
2520-
@self.vector_subcore_kernel(
2521-
out_shape=x,
2522-
in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),),
2523-
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
2524-
)
2525-
def kernel(x_hbm_ref, o_hbm_ref):
2526-
2527-
@functools.partial(
2528-
pltpu.emit_pipeline,
2529-
grid=(num_steps,),
2530-
in_specs=pl.BlockSpec(
2531-
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
2532-
),
2533-
out_specs=pl.BlockSpec(
2534-
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
2535-
),
2536-
)
2537-
def pipeline(x_ref, o_ref):
2538-
o_ref[...] = x_ref[...] + 1
2539-
2540-
pipeline(x_hbm_ref, o_hbm_ref)
2541-
2542-
np.testing.assert_array_equal(kernel(x), x + 1)
2543-
2544-
def test_pipeline_disable_jit(self):
2545-
self.skip_if_tc_tiling()
2546-
num_steps = 16
2547-
x = jnp.arange(num_steps * self.num_lanes, dtype=jnp.int32).reshape(-1, 8)
2548-
2549-
@self.vector_subcore_kernel(
2550-
out_shape=x,
2551-
in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),),
2552-
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
2553-
)
2554-
def kernel(x_hbm_ref, o_hbm_ref):
2555-
2556-
@functools.partial(
2557-
pltpu.emit_pipeline,
2558-
grid=(num_steps,),
2559-
in_specs=pl.BlockSpec(
2560-
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
2561-
),
2562-
out_specs=pl.BlockSpec(
2563-
(pl.Squeezed(), self.num_lanes), lambda i: (i, 0)
2564-
),
2565-
)
2566-
def pipeline(x_ref, o_ref):
2567-
o_ref[...] = x_ref[...] + 1
2568-
2569-
pipeline(x_hbm_ref, o_hbm_ref)
2570-
2571-
with jax.disable_jit():
2572-
kernel(x)
2573-
25742514
def test_gather_with_emit(self):
25752515
self.skip_if_tc_tiling()
25762516
sc_mesh = sc_core.VectorSubcoreMesh(

0 commit comments

Comments
 (0)