3939from jax ._src .pallas .mosaic import primitives as tpu_primitives
4040from jax ._src .pallas .mosaic import tpu_info
4141from jax ._src .state import indexing
42- import numpy as np
4342import jax .numpy as jnp
4443
4544
@@ -693,7 +692,7 @@ def cumulative_copy_in(self):
693692 @property
694693 def current_copy_in_slot (self ):
695694 """Index in multiple buffer corresponding to the current slot."""
696- return lax .rem (self .cumulative_copy_in , np .uint32 (self .buffer_count ))
695+ return lax .rem (self .cumulative_copy_in , jnp .uint32 (self .buffer_count ))
697696
698697 @property
699698 def cumulative_copy_out (self ):
@@ -704,7 +703,7 @@ def cumulative_copy_out(self):
704703 @property
705704 def current_copy_out_slot (self ):
706705 """Index in multiple buffer corresponding to the current copy slot."""
707- return lax .rem (self .cumulative_copy_out , np .uint32 (self .buffer_count ))
706+ return lax .rem (self .cumulative_copy_out , jnp .uint32 (self .buffer_count ))
708707
709708 @property
710709 def cumulative_wait_in (self ):
@@ -715,7 +714,7 @@ def cumulative_wait_in(self):
715714 @property
716715 def current_wait_in_slot (self ):
717716 """Index in multiple buffer corresponding to the current wait slot."""
718- return lax .rem (self .cumulative_wait_in , np .uint32 (self .buffer_count ))
717+ return lax .rem (self .cumulative_wait_in , jnp .uint32 (self .buffer_count ))
719718
720719 @property
721720 def cumulative_wait_out (self ):
@@ -726,7 +725,7 @@ def cumulative_wait_out(self):
726725 @property
727726 def current_wait_out_slot (self ):
728727 """Index in multiple buffer corresponding to the current wait slot."""
729- return lax .rem (self .cumulative_wait_out , np .uint32 (self .buffer_count ))
728+ return lax .rem (self .cumulative_wait_out , jnp .uint32 (self .buffer_count ))
730729
731730 @property
732731 def next_fetch_indices (self ):
@@ -781,12 +780,12 @@ def compute_slice(self, grid_indices):
781780 def initialize_slots (self ) -> BufferedRef :
782781 return dataclasses .replace (
783782 self ,
784- copy_in_slot = np .uint32 (0 ) if self .buffer_type .is_input else None ,
785- wait_in_slot = np .uint32 (0 ) if self .buffer_type .is_input else None ,
786- copy_out_slot = np .uint32 (0 ) if self .buffer_type .is_output else None ,
787- wait_out_slot = np .uint32 (0 ) if self .buffer_type .is_output else None ,
783+ copy_in_slot = jnp .uint32 (0 ) if self .buffer_type .is_input else None ,
784+ wait_in_slot = jnp .uint32 (0 ) if self .buffer_type .is_input else None ,
785+ copy_out_slot = jnp .uint32 (0 ) if self .buffer_type .is_output else None ,
786+ wait_out_slot = jnp .uint32 (0 ) if self .buffer_type .is_output else None ,
788787 next_fetch = (
789- tuple (np .int32 (0 ) for _ in range (self ._grid_rank ))
788+ tuple (jnp .int32 (0 ) for _ in range (self ._grid_rank ))
790789 if self ._grid_rank is not None
791790 else None
792791 ),
@@ -1012,20 +1011,18 @@ def fmap(bref, *f_args):
10121011
10131012
10141013def _filter_indices (
1015- indices : tuple [int | np .int32 | jax .Array , ...],
1016- grid : tuple [int | np .int32 | jax .Array , ...]
1017- ) -> tuple [int | np .int32 | jax .Array , ...]:
1014+ indices : tuple [int | jax .Array , ...], grid : tuple [int | jax .Array , ...]
1015+ ) -> tuple [int | jax .Array , ...]:
10181016 return tuple (
1019- np . int32 ( 0 ) if isinstance (g , int ) and g == 1 else i
1017+ 0 if isinstance (g , int ) and g == 1 else i
10201018 for i , g in zip (indices , grid , strict = True )
10211019 )
10221020
10231021
10241022def _next_index (
1025- indices : tuple [int | np .int32 | jax .Array , ...],
1026- grid : tuple [int | np .int32 | jax .Array , ...],
1023+ indices : tuple [int | jax .Array , ...], grid : tuple [int | jax .Array , ...],
10271024 allow_overflow : bool = False ,
1028- ) -> tuple [int | np . int32 | jax .Array , ...]:
1025+ ) -> tuple [int | jax .Array , ...]:
10291026 """Increments the grid indices by one.
10301027
10311028 Args:
@@ -1047,23 +1044,23 @@ def _next_index(
10471044 if allow_overflow and (position == len (grid ) - 1 ):
10481045 carry = False
10491046 else :
1050- carry = inc == ( np . int32 ( g ) if isinstance ( g , int ) else g )
1051- out .append (jax .lax .select (carry , np . int32 ( 0 ) , inc ))
1047+ carry = inc == g
1048+ out .append (jax .lax .select (carry , 0 , inc ))
10521049 if allow_overflow :
10531050 return tuple (reversed (out ))
10541051 else :
10551052 return _filter_indices (tuple (reversed (out )), grid )
10561053
10571054
10581055def _prev_index (
1059- indices : tuple [int | np . int32 | jax .Array , ...], grid : tuple [int | np . int32 | jax .Array , ...]
1060- ) -> tuple [int | np . int32 | jax .Array , ...]:
1056+ indices : tuple [int | jax .Array , ...], grid : tuple [int | jax .Array , ...]
1057+ ) -> tuple [int | jax .Array , ...]:
10611058 out = []
10621059 borrow : bool | jax .Array = True
10631060 for i , g in reversed (list (zip (indices , grid , strict = True ))):
10641061 dec = jax .lax .select (borrow , i - 1 , i )
10651062 borrow = dec == - 1
1066- out .append (jax .lax .select (borrow , np . int32 ( g - 1 ) if isinstance ( g , int ) else ( g - 1 ) , dec ))
1063+ out .append (jax .lax .select (borrow , g - 1 , dec ))
10671064 return _filter_indices (tuple (reversed (out )), grid )
10681065
10691066
@@ -1073,9 +1070,9 @@ class Scheduler:
10731070 def __init__ (
10741071 self ,
10751072 step : jax .Array ,
1076- indices : tuple [int | np . int32 | jax .Array , ...],
1077- grid : tuple [int | np . int32 | jax .Array , ...],
1078- grid_offsets : tuple [int | np . int32 | jax .Array , ...],
1073+ indices : tuple [int | jax .Array , ...],
1074+ grid : tuple [int | jax .Array , ...],
1075+ grid_offsets : tuple [int | jax .Array , ...],
10791076 num_stages : int ,
10801077 trace_scopes = True ,
10811078 _explicit_indices : bool = False ,
@@ -1102,12 +1099,8 @@ def __init__(
11021099 self .num_steps = math .prod (grid )
11031100
11041101 # First and last inner step conditionals.
1105- self .first_step = step == np .int32 (0 )
1106- self .last_step = step == (
1107- np .int32 (self .num_steps - 1 )
1108- if isinstance (self .num_steps , int )
1109- else (self .num_steps - 1 )
1110- )
1102+ self .first_step = step == 0
1103+ self .last_step = step == self .num_steps - 1
11111104
11121105 # Derived grid indices for present, previous, and next steps.
11131106 self .indices = tuple (
@@ -1158,9 +1151,7 @@ def out_of_fetch(self, buffered_ref):
11581151 # lookahead this will depend on whether the lookahead reached the end.
11591152 if not buffered_ref .is_buffered :
11601153 return jnp .bool (False )
1161- ub = self .num_steps - buffered_ref .buffer_count + 1
1162- ub_32 = np .int32 (ub ) if isinstance (ub , int ) else ub
1163- return self .step >= ub_32
1154+ return self .step >= (self .num_steps - buffered_ref .buffer_count + 1 )
11641155
11651156 def has_changed (self , buffered_ref ):
11661157 if not buffered_ref .is_buffered or buffered_ref .is_trivial_windowing :
@@ -1430,13 +1421,13 @@ def make_output_bref(out_spec, out_ref):
14301421
14311422
14321423def _partition_grid (
1433- grid : tuple [np . int32 | jax .Array , ...],
1424+ grid : tuple [int | jax .Array , ...],
14341425 core_axis : tuple [int | str , ...] | int | str | None ,
14351426 dimension_semantics : tuple [GridDimensionSemantics , ...] | None ,
1436- ) -> tuple [tuple [np . int32 | jax .Array , ...], tuple [np . int32 | jax .Array , ...]]:
1427+ ) -> tuple [tuple [int | jax .Array , ...], tuple [int | jax .Array , ...]]:
14371428 if core_axis is None :
14381429 # We aren't partitioning the grid
1439- return grid , (np . int32 ( 0 ) ,) * len (grid )
1430+ return grid , (0 ,) * len (grid )
14401431 if isinstance (core_axis , int ):
14411432 num_cores = num_programs (core_axis )
14421433 core_id = program_id (core_axis )
@@ -1450,7 +1441,7 @@ def _partition_grid(
14501441 )
14511442 if num_cores == 1 :
14521443 # We aren't partitioning the grid
1453- return grid , (np . int32 ( 0 ) ,) * len (grid )
1444+ return grid , (0 ,) * len (grid )
14541445
14551446 # If dimension_semantics aren't provided, we assume it is all arbitrary.
14561447 if dimension_semantics is None :
@@ -1485,7 +1476,7 @@ def _partition_grid(
14851476 grid , first_divisible_dimension , partitioned_dim_size
14861477 )
14871478 offsets = jax_util .tuple_update (
1488- (np . int32 ( 0 ) ,) * len (grid ),
1479+ (0 ,) * len (grid ),
14891480 first_divisible_dimension ,
14901481 partitioned_dim_offset ,
14911482 )
@@ -1538,7 +1529,7 @@ def _partition_grid(
15381529 core_id * base_num_iters + rem ,
15391530 )
15401531 offsets = jax_util .tuple_update (
1541- (np . int32 ( 0 ) ,) * len (grid ),
1532+ (0 ,) * len (grid ),
15421533 partition_dimension ,
15431534 grid_offset ,
15441535 )
@@ -1620,11 +1611,7 @@ def emit_pipeline(
16201611 if not (core_axis is None or core_axis_name is None ):
16211612 raise ValueError ("core_axis and core_axis_name cannot both be provided." )
16221613 core_axis_ = core_axis_name if core_axis is None else core_axis
1623- grid , grid_offsets = _partition_grid (grid , core_axis_ , dimension_semantics ) # type: ignore
1624- grid = tuple (np .int32 (g ) if isinstance (g , int ) else g for g in grid ) # type: ignore
1625- grid_offsets = tuple (
1626- np .int32 (g ) if isinstance (g , int ) else g for g in grid_offsets
1627- )
1614+ grid , grid_offsets = _partition_grid (grid , core_axis_ , dimension_semantics )
16281615
16291616 num_steps = math .prod (grid )
16301617 in_specs = _normalize_specs (in_specs )
@@ -1717,15 +1704,13 @@ def loop_body(step, carry):
17171704
17181705 if no_pipelining :
17191706 # Debugging mode where all copies are synchronous.
1720- lower_bnd = np .int32 (0 )
1721- upper_bnd = np .int32 (num_steps ) if isinstance (num_steps , int ) else num_steps
1722- initial_indices = (np .int32 (0 ),) * len (grid )
1707+ initial_indices = (0 ,) * len (grid )
17231708 brefs = map_brefs (lambda bref : bref .initialize_slots (), allocations )
17241709
17251710 @functools .partial (
17261711 jax .lax .fori_loop ,
1727- lower_bnd ,
1728- upper_bnd ,
1712+ 0 ,
1713+ num_steps ,
17291714 init_val = (brefs , initial_indices ),
17301715 )
17311716 def _loop_body (step , carry ):
@@ -1756,9 +1741,7 @@ def _loop_body(step, carry):
17561741 @when (num_steps > 0 )
17571742 def _ ():
17581743 # pipeline prologue
1759- lower_bnd = np .int32 (0 )
1760- upper_bnd = np .int32 (num_steps ) if isinstance (num_steps , int ) else num_steps
1761- initial_indices = (np .int32 (0 ),) * len (grid )
1744+ initial_indices = (0 ,) * len (grid )
17621745 scheduler = make_scheduler (0 , initial_indices )
17631746 brefs = map_brefs (lambda bref : bref .initialize_slots (), allocations )
17641747 def _sync_copy_in (bref , ref ):
@@ -1777,8 +1760,7 @@ def _sync_copy_in(bref, ref):
17771760
17781761 # pipeline loop
17791762 brefs , next_indices = lax .fori_loop (
1780- lower_bnd , upper_bnd ,
1781- loop_body , (brefs , initial_indices )
1763+ 0 , num_steps , loop_body , (brefs , initial_indices )
17821764 )
17831765
17841766 # pipeline epilogue
0 commit comments