@@ -130,79 +130,6 @@ def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh):
130130 MultiHeadLatentAttentionWrapper .register_oot (
131131 VllmTPUMultiHeadLatentAttentionWrapper )
132132
133- def _patch_sdpa (self ):
134- from torchax .ops .jtorch import register_function
135-
136- from tpu_inference .layers .common .attention_interface import \
137- sharded_flash_attention
138-
139- @register_function (
140- torch .nn .functional .scaled_dot_product_attention ,
141- is_jax_function = True ,
142- needs_env = False ,
143- )
144- def patched_sdpa (
145- query ,
146- key ,
147- value ,
148- attn_mask = None ,
149- dropout_p = 0.0 ,
150- is_causal = False ,
151- scale = None ,
152- enable_gqa = False ,
153- ):
154- if dropout_p != 0.0 :
155- raise NotImplementedError (
156- "patched_sdpa does not support dropout_p" )
157- if enable_gqa is not False :
158- raise NotImplementedError (
159- "patched_sdpa does not support enable_gqa" )
160-
161- # Q, K, V shapes: (batch, num_heads, seq_len, head_dim)
162- batch = query .shape [0 ]
163- num_heads = query .shape [1 ]
164- q_seq_len = query .shape [2 ]
165- kv_seq_len = key .shape [2 ]
166-
167- # padding due to the requirement of sharded_flash_attention
168- q_pad = (128 - (q_seq_len % 128 )) % 128
169- kv_pad = (128 - (kv_seq_len % 128 )) % 128
170-
171- if q_pad > 0 :
172- query = jnp .pad (query , ((0 , 0 ), (0 , 0 ), (0 , q_pad ), (0 , 0 )))
173- if kv_pad > 0 :
174- key = jnp .pad (key , ((0 , 0 ), (0 , 0 ), (0 , kv_pad ), (0 , 0 )))
175- value = jnp .pad (value , ((0 , 0 ), (0 , 0 ), (0 , kv_pad ), (0 , 0 )))
176-
177- # Prevent nan while using -inf
178- mask_value = - 0.7 * float (jnp .finfo (jnp .dtype ("float32" )).max )
179- ab = jnp .zeros ((batch , num_heads , q_seq_len , kv_seq_len ),
180- dtype = jnp .float32 )
181- if attn_mask is not None :
182- # attn_mask shape: (batch, num_heads, q_len, kv_len)
183- if attn_mask .dtype == jnp .bool_ :
184- ab = jnp .where (attn_mask , ab , mask_value )
185- else :
186- ab += attn_mask
187-
188- if q_pad > 0 or kv_pad > 0 :
189- ab = jnp .pad (
190- ab ,
191- ((0 , 0 ), (0 , 0 ), (0 , q_pad ), (0 , kv_pad )),
192- mode = "constant" ,
193- constant_values = mask_value ,
194- )
195-
196- attn_fn = sharded_flash_attention (self .mesh ,
197- causal = is_causal ,
198- sm_scale = scale ,
199- use_attention_bias = True )
200- out = attn_fn (query , key , value , ab , None )
201-
202- if q_pad > 0 :
203- out = out [:, :, :q_seq_len , :]
204-
205- return out
206133
207134 def _patch_vllm_ops (self ):
208135 # Caution: there is no public api for restore the ops.
0 commit comments