1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
14+ # pytype: disable=module-attr
1515"""Attentions Ops Layers."""
16-
16+ import dataclasses
1717import functools
1818from typing import Any , Callable , Optional , Tuple
1919from functools import partial
2222import numpy as np
2323from packaging import version
2424
25+ import jax
2526from jax import lax
2627from jax .ad_checkpoint import checkpoint_name
2728from jax .experimental .pallas .ops .gpu import attention as gpu_pallas_attention
2829from jax .experimental .pallas .ops .gpu import decode_attention as gpu_pallas_decode_attention
29- from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
30- from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
30+ from jax .experimental import pallas as pl
3131from jax .sharding import Mesh
32- import jax
3332import jax .numpy as jnp
3433
34+ if jax .__version__ < "0.8.0" :
35+ from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
36+ from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
37+ else :
38+ from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel
39+ from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask
40+
41+
3542from flax import linen as nn
3643from flax import nnx
3744from flax .linen import partitioning
3845
46+
3947from MaxText import max_utils
4048from MaxText .common_types import (
4149 DEFAULT_MASK_VALUE ,
@@ -1080,22 +1088,58 @@ def tpu_flash_attention(
10801088 f" got { query .shape [0 ]= } /{ devices_in_data_fsdp = } "
10811089 )
10821090
1083- # create_splash_attention kernel
1084- block_sizes = splash_attention_kernel .BlockSizes (
1085- block_q = min (global_block_q , query .shape [2 ]),
1086- block_kv = min (global_block_kv , key .shape [2 ]),
1087- block_kv_compute = min (global_block_kv_compute , key .shape [2 ]),
1088- block_q_dkv = min (global_block_q_dkv , query .shape [2 ]),
1089- block_kv_dkv = min (global_block_kv_dkv , key .shape [2 ]),
1090- block_kv_dkv_compute = min (global_block_kv_dkv_compute , query .shape [2 ]),
1091- block_q_dq = None if global_use_fused_bwd_kernel else min (global_block_q_dq , query .shape [2 ]),
1092- block_kv_dq = None if global_use_fused_bwd_kernel else min (global_block_kv_dq , query .shape [2 ]),
1093- use_fused_bwd_kernel = global_use_fused_bwd_kernel ,
1094- q_layout = splash_attention_kernel .QKVLayout [global_q_layout ],
1095- k_layout = splash_attention_kernel .QKVLayout [global_k_layout ],
1096- v_layout = splash_attention_kernel .QKVLayout [global_v_layout ],
1097- )
1091+ # create_splash_attention config
1092+ def create_sa_config (config , query , key , attn_logits_soft_cap ):
1093+ if jax .__version__ >= "0.8.0" :
1094+ sa_config = splash_attention_kernel .SplashConfig (
1095+ block_q = min (global_block_q , query .shape [2 ]),
1096+ block_kv = min (global_block_kv , key .shape [2 ]),
1097+ block_kv_compute = min (global_block_kv_compute , key .shape [2 ]),
1098+ block_q_dkv = min (global_block_q_dkv , query .shape [2 ]),
1099+ block_kv_dkv = min (global_block_kv_dkv , key .shape [2 ]),
1100+ block_kv_dkv_compute = min (global_block_kv_dkv_compute , query .shape [2 ]),
1101+ block_q_dq = None if global_use_fused_bwd_kernel else min (global_block_q_dq , query .shape [2 ]),
1102+ block_kv_dq = None if global_use_fused_bwd_kernel else min (global_block_kv_dq , query .shape [2 ]),
1103+ use_fused_bwd_kernel = True , # tokamax only supports fused bwd kernel
1104+ q_layout = splash_attention_kernel .QKVLayout [global_q_layout ],
1105+ k_layout = splash_attention_kernel .QKVLayout [global_k_layout ],
1106+ v_layout = splash_attention_kernel .QKVLayout [global_v_layout ],
1107+ attn_logits_soft_cap = attn_logits_soft_cap ,
1108+ residual_checkpoint_name = "context" ,
1109+ fwd_cost_estimate = pl .CostEstimate (
1110+ flops = config .cost_estimate_flops_fwd ,
1111+ transcendentals = 0 ,
1112+ bytes_accessed = 0 ,
1113+ )
1114+ if config .cost_estimate_flops_fwd >= 0
1115+ else None ,
1116+ bwd_cost_estimate = pl .CostEstimate (
1117+ flops = config .cost_estimate_flops_bwd ,
1118+ transcendentals = 0 ,
1119+ bytes_accessed = 0 ,
1120+ )
1121+ if config .cost_estimate_flops_bwd >= 0
1122+ else None ,
1123+ dq_reduction_steps = config .dq_reduction_steps if config .dq_reduction_steps > 0 else None ,
1124+ )
1125+ else :
1126+ sa_config = splash_attention_kernel .BlockSizes (
1127+ block_q = min (global_block_q , query .shape [2 ]),
1128+ block_kv = min (global_block_kv , key .shape [2 ]),
1129+ block_kv_compute = min (global_block_kv_compute , key .shape [2 ]),
1130+ block_q_dkv = min (global_block_q_dkv , query .shape [2 ]),
1131+ block_kv_dkv = min (global_block_kv_dkv , key .shape [2 ]),
1132+ block_kv_dkv_compute = min (global_block_kv_dkv_compute , query .shape [2 ]),
1133+ block_q_dq = None if global_use_fused_bwd_kernel else min (global_block_q_dq , query .shape [2 ]),
1134+ block_kv_dq = None if global_use_fused_bwd_kernel else min (global_block_kv_dq , query .shape [2 ]),
1135+ use_fused_bwd_kernel = global_use_fused_bwd_kernel ,
1136+ q_layout = splash_attention_kernel .QKVLayout [global_q_layout ],
1137+ k_layout = splash_attention_kernel .QKVLayout [global_k_layout ],
1138+ v_layout = splash_attention_kernel .QKVLayout [global_v_layout ],
1139+ )
1140+ return sa_config
10981141
1142+ sa_config = create_sa_config (self .config , query , key , attn_logits_soft_cap )
10991143 mask_shape = (query .shape [2 ], key .shape [2 ]) # (q_seq_len, kv_seq_len)
11001144 if self .attention_type == AttentionType .FULL :
11011145 mask = splash_attention_mask .FullMask (mask_shape )
@@ -1122,35 +1166,68 @@ def tpu_flash_attention(
11221166
11231167 mask &= ChunkedCausalMask (shape = (query .shape [2 ], key .shape [2 ]), chunk_size = self .chunk_attn_window_size )
11241168
1125- # Create multi-head mask
1126- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
1169+ max_logit_value = None
1170+ if jax .__version__ >= "0.8.0" :
1171+ # Create mask
1172+ single_head_mask = mask # tokamax now just uses a single mask and assumes broadcast to all heads
1173+ if self .config .use_max_logit_estimate > 0 :
1174+ sa_config = dataclasses .replace (sa_config , max_logit_const = self .config .use_max_logit_estimate )
1175+
1176+ # Create the splash attention kernel object separately, jit it for performance
1177+ @partial (
1178+ jax .jit ,
1179+ static_argnames = [
1180+ "single_head_mask" ,
1181+ "shard_head_size" ,
1182+ ],
1183+ )
1184+ def wrap_splash_kernel (single_head_mask , shard_head_size = 1 ):
1185+ splash_kernel = splash_attention_kernel .make_splash_mha (
1186+ mask = single_head_mask ,
1187+ config = sa_config ,
1188+ q_seq_shards = cp_size , # axis for sequence sharding,
1189+ )
1190+ return splash_kernel
11271191
1128- # Create the splash attention kernel object separately, jit it for performance
1129- @partial (
1130- jax .jit ,
1131- static_argnames = [
1132- "multi_head_mask" ,
1133- "shard_head_size" ,
1134- ],
1135- )
1136- def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
1137- splash_kernel = splash_attention_kernel .make_splash_mha (
1138- mask = multi_head_mask ,
1139- head_shards = shard_head_size , # the size of the axis if sharding over heads
1140- q_seq_shards = cp_size , # axis for sequence sharding
1141- block_sizes = block_sizes ,
1142- attn_logits_soft_cap = attn_logits_soft_cap ,
1143- residual_checkpoint_name = "context" ,
1192+ logical_axis_rules_head = np .array (
1193+ [self .mesh .shape [physical_axes ] for physical_axes in dict (self .config .logical_axis_rules )[HEAD ]]
1194+ )
1195+ shard_head_size = np .prod (logical_axis_rules_head )
1196+ splash_kernel = wrap_splash_kernel (single_head_mask , int (shard_head_size ))
1197+ if self .config .expert_shard_attention_option == EP_AS_CONTEXT :
1198+ segment_axis_names_splash_kernel = nn .logical_to_mesh_axes ((Q_LENGTH ,))
1199+ else :
1200+ segment_axis_names_splash_kernel = nn .logical_to_mesh_axes ((Q_LENGTH_NO_EXP ,))
1201+ else :
1202+ # Create multi-head mask
1203+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
1204+
1205+ # Create the splash attention kernel object separately, jit it for performance
1206+ @partial (
1207+ jax .jit ,
1208+ static_argnames = [
1209+ "multi_head_mask" ,
1210+ "shard_head_size" ,
1211+ ],
11441212 )
1145- return splash_kernel
1213+ def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
1214+ splash_kernel = splash_attention_kernel .make_splash_mha (
1215+ mask = multi_head_mask ,
1216+ head_shards = shard_head_size , # the size of the axis if sharding over heads
1217+ q_seq_shards = cp_size , # axis for sequence sharding
1218+ block_sizes = sa_config ,
1219+ attn_logits_soft_cap = attn_logits_soft_cap ,
1220+ residual_checkpoint_name = "context" ,
1221+ )
1222+ return splash_kernel
11461223
1147- logical_axis_rules_head = np .array (
1148- [self .mesh .shape [physical_axes ] for physical_axes in dict (self .config .logical_axis_rules )[HEAD ]]
1149- )
1150- shard_head_size = np .prod (logical_axis_rules_head )
1151- splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
1152- named_sharding = jax .sharding .NamedSharding (self .mesh , axis_names_splash_kernel )
1153- segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
1224+ logical_axis_rules_head = np .array (
1225+ [self .mesh .shape [physical_axes ] for physical_axes in dict (self .config .logical_axis_rules )[HEAD ]]
1226+ )
1227+ shard_head_size = np .prod (logical_axis_rules_head )
1228+ splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
1229+ named_sharding = jax .sharding .NamedSharding (self .mesh , axis_names_splash_kernel )
1230+ segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
11541231
11551232 # Now call the function wrap_flash_attention which does the actual computation.
11561233 # The splash kernel is passed as a parameter to the function. Since we have the shard map
@@ -1214,9 +1291,17 @@ def wrap_flash_attention(
12141291 if version .parse (jax .__version__ ) < version .parse ("0.7.2.dev20250824" ):
12151292 attention_output = jax .vmap (splash_kernel )(query , key , value , decoder_segment_ids_tuple )
12161293 else :
1217- attention_output = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , 0 , None ))(
1218- query , key , value , decoder_segment_ids_tuple , sinks
1219- )
1294+ if jax .__version__ >= "0.8.0" :
1295+ if max_logit_value is not None :
1296+ attention_output = jax .vmap (partial (splash_kernel , max_logit_value = max_logit_value ))(
1297+ query , key , value , decoder_segment_ids_tuple
1298+ )
1299+ else :
1300+ attention_output = jax .vmap (splash_kernel )(query , key , value , decoder_segment_ids_tuple )
1301+ else :
1302+ attention_output = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , 0 , None ))(
1303+ query , key , value , decoder_segment_ids_tuple , sinks
1304+ )
12201305 return attention_output
12211306
12221307 x = wrap_flash_attention (
0 commit comments