Skip to content

Commit be6178e

Browse files
Fix error in merge conflct
1 parent 6dd4bfb commit be6178e

File tree

1 file changed

+0
-73
lines changed

1 file changed

+0
-73
lines changed

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -130,79 +130,6 @@ def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh):
130130
MultiHeadLatentAttentionWrapper.register_oot(
131131
VllmTPUMultiHeadLatentAttentionWrapper)
132132

133-
def _patch_sdpa(self):
134-
from torchax.ops.jtorch import register_function
135-
136-
from tpu_inference.layers.common.attention_interface import \
137-
sharded_flash_attention
138-
139-
@register_function(
140-
torch.nn.functional.scaled_dot_product_attention,
141-
is_jax_function=True,
142-
needs_env=False,
143-
)
144-
def patched_sdpa(
145-
query,
146-
key,
147-
value,
148-
attn_mask=None,
149-
dropout_p=0.0,
150-
is_causal=False,
151-
scale=None,
152-
enable_gqa=False,
153-
):
154-
if dropout_p != 0.0:
155-
raise NotImplementedError(
156-
"patched_sdpa does not support dropout_p")
157-
if enable_gqa is not False:
158-
raise NotImplementedError(
159-
"patched_sdpa does not support enable_gqa")
160-
161-
# Q, K, V shapes: (batch, num_heads, seq_len, head_dim)
162-
batch = query.shape[0]
163-
num_heads = query.shape[1]
164-
q_seq_len = query.shape[2]
165-
kv_seq_len = key.shape[2]
166-
167-
# padding due to the requirement of sharded_flash_attention
168-
q_pad = (128 - (q_seq_len % 128)) % 128
169-
kv_pad = (128 - (kv_seq_len % 128)) % 128
170-
171-
if q_pad > 0:
172-
query = jnp.pad(query, ((0, 0), (0, 0), (0, q_pad), (0, 0)))
173-
if kv_pad > 0:
174-
key = jnp.pad(key, ((0, 0), (0, 0), (0, kv_pad), (0, 0)))
175-
value = jnp.pad(value, ((0, 0), (0, 0), (0, kv_pad), (0, 0)))
176-
177-
# Prevent nan while using -inf
178-
mask_value = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
179-
ab = jnp.zeros((batch, num_heads, q_seq_len, kv_seq_len),
180-
dtype=jnp.float32)
181-
if attn_mask is not None:
182-
# attn_mask shape: (batch, num_heads, q_len, kv_len)
183-
if attn_mask.dtype == jnp.bool_:
184-
ab = jnp.where(attn_mask, ab, mask_value)
185-
else:
186-
ab += attn_mask
187-
188-
if q_pad > 0 or kv_pad > 0:
189-
ab = jnp.pad(
190-
ab,
191-
((0, 0), (0, 0), (0, q_pad), (0, kv_pad)),
192-
mode="constant",
193-
constant_values=mask_value,
194-
)
195-
196-
attn_fn = sharded_flash_attention(self.mesh,
197-
causal=is_causal,
198-
sm_scale=scale,
199-
use_attention_bias=True)
200-
out = attn_fn(query, key, value, ab, None)
201-
202-
if q_pad > 0:
203-
out = out[:, :, :q_seq_len, :]
204-
205-
return out
206133

207134
def _patch_vllm_ops(self):
208135
# Caution: there is no public api for restore the ops.

0 commit comments

Comments
 (0)