Skip to content

Commit b16e77b

Browse files
committed
Fix gdn attention value_dim calculation
Signed-off-by: Haowen Ning <hning@google.com>
1 parent 3a89711 commit b16e77b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/layers/vllm/ops/gdn_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def gdn_attention_core_tpu(
673673
# E.g. they are in [Q Q | K K | V V] layout. We need [Q K | Q K | Q K] layout.
674674
# Use reorder_concatenated_tensor_for_sharding to reorder into correct layout
675675
key_dim = n_kq * d_k
676-
value_dim = n_v * d_k
676+
value_dim = n_v * d_v
677677
tp_size = mesh.shape[ShardingAxisName.ATTN_HEAD]
678678
j_mixed_qkv = reorder_concatenated_tensor_for_sharding(
679679
j_mixed_qkv, [key_dim, key_dim, value_dim], tp_size, -1)

0 commit comments

Comments
 (0)