Skip to content

Commit add6be0

Browse files
committed
Rename kv_model_dim to kv_hidden_size
1 parent ddbe805 commit add6be0

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def ref_ragged_paged_attention(
8383
check_inputs_shapes(queries, k_pages, v_pages, kv_lens, page_indices,
8484
cu_q_lens, num_seqs)
8585
_, num_q_heads, head_dim = queries.shape
86-
_, _, kv_model_dim = k_pages.shape
87-
num_kv_heads = kv_model_dim // head_dim
86+
_, _, kv_hidden_size = k_pages.shape
87+
num_kv_heads = kv_hidden_size // head_dim
8888
assert num_q_heads % num_kv_heads == 0
8989
num_query_per_kv = num_q_heads // num_kv_heads
9090
outputs = []
@@ -163,13 +163,13 @@ def check_inputs_shapes(
163163
if k_pages.shape != v_pages.shape:
164164
raise ValueError(
165165
f"Expected {k_pages.shape=} to be equal to {v_pages.shape=}.")
166-
_, page_size, kv_model_dim = k_pages.shape
166+
_, page_size, kv_hidden_size = k_pages.shape
167167
kv_packing = get_dtype_packing(k_pages.dtype)
168168
if page_size % kv_packing != 0:
169169
raise ValueError(f"Expected {page_size=} is divisible by {kv_packing=}")
170-
if kv_model_dim % head_dim != 0:
171-
raise ValueError(f"Expected {kv_model_dim=} is divisible by {head_dim=}.")
172-
num_kv_heads = kv_model_dim // head_dim
170+
if kv_hidden_size % head_dim != 0:
171+
raise ValueError(f"Expected {kv_hidden_size=} is divisible by {head_dim=}.")
172+
num_kv_heads = kv_hidden_size // head_dim
173173
if num_q_heads % num_kv_heads != 0:
174174
raise ValueError(f"Expected {num_q_heads=} is divisible by {num_kv_heads=}")
175175
max_num_seqs, _ = page_indices.shape
@@ -596,8 +596,8 @@ def ragged_paged_attention(
596596
check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens,
597597
num_seqs)
598598
num_q, num_q_heads, head_dim = q.shape
599-
_, page_size, kv_model_dim = k_pages.shape
600-
num_kv_heads = kv_model_dim // head_dim
599+
_, page_size, kv_hidden_size = k_pages.shape
600+
num_kv_heads = kv_hidden_size // head_dim
601601
num_q_per_blk = num_queries_per_block
602602
num_kv_pages_per_blk = num_kv_pages_per_block
603603
num_q_heads_per_kv_head = num_q_heads // num_kv_heads

0 commit comments

Comments
 (0)