@@ -83,8 +83,8 @@ def ref_ragged_paged_attention(
83
83
check_inputs_shapes (queries , k_pages , v_pages , kv_lens , page_indices ,
84
84
cu_q_lens , num_seqs )
85
85
_ , num_q_heads , head_dim = queries .shape
86
- _ , _ , kv_model_dim = k_pages .shape
87
- num_kv_heads = kv_model_dim // head_dim
86
+ _ , _ , kv_hidden_size = k_pages .shape
87
+ num_kv_heads = kv_hidden_size // head_dim
88
88
assert num_q_heads % num_kv_heads == 0
89
89
num_query_per_kv = num_q_heads // num_kv_heads
90
90
outputs = []
@@ -163,13 +163,13 @@ def check_inputs_shapes(
163
163
if k_pages .shape != v_pages .shape :
164
164
raise ValueError (
165
165
f"Expected { k_pages .shape = } to be equal to { v_pages .shape = } ." )
166
- _ , page_size , kv_model_dim = k_pages .shape
166
+ _ , page_size , kv_hidden_size = k_pages .shape
167
167
kv_packing = get_dtype_packing (k_pages .dtype )
168
168
if page_size % kv_packing != 0 :
169
169
raise ValueError (f"Expected { page_size = } is divisible by { kv_packing = } " )
170
- if kv_model_dim % head_dim != 0 :
171
- raise ValueError (f"Expected { kv_model_dim = } is divisible by { head_dim = } ." )
172
- num_kv_heads = kv_model_dim // head_dim
170
+ if kv_hidden_size % head_dim != 0 :
171
+ raise ValueError (f"Expected { kv_hidden_size = } is divisible by { head_dim = } ." )
172
+ num_kv_heads = kv_hidden_size // head_dim
173
173
if num_q_heads % num_kv_heads != 0 :
174
174
raise ValueError (f"Expected { num_q_heads = } is divisible by { num_kv_heads = } " )
175
175
max_num_seqs , _ = page_indices .shape
@@ -596,8 +596,8 @@ def ragged_paged_attention(
596
596
check_inputs_shapes (q , k_pages , v_pages , kv_lens , page_indices , cu_q_lens ,
597
597
num_seqs )
598
598
num_q , num_q_heads , head_dim = q .shape
599
- _ , page_size , kv_model_dim = k_pages .shape
600
- num_kv_heads = kv_model_dim // head_dim
599
+ _ , page_size , kv_hidden_size = k_pages .shape
600
+ num_kv_heads = kv_hidden_size // head_dim
601
601
num_q_per_blk = num_queries_per_block
602
602
num_kv_pages_per_blk = num_kv_pages_per_block
603
603
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
0 commit comments