Skip to content

Commit 7396de7

Browse files
authored
Merge pull request taco-project#31 from nvidia-china-sae/linhu/quickfix-mla-tp
quick fix for mla + tp
2 parents 567fa59 + 6d1e6bf commit 7396de7

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

csrc/tp_transfer_thread_group.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ void TPTransferThreadGroup::tp_group_transfer(
7070
dst_layer_ptrs = static_cast<void**>(gpu_blocks_ + i * num_layers + layer_id);
7171
src_layer_ptrs = static_cast<void**>(cpu_blocks_ + layer_id);
7272
dst_startoff_inside_chunks = 0;
73-
src_startoff_inside_chunks = i * dst_chunk_size_in_bytes;
73+
src_startoff_inside_chunks = is_mla ? 0 : i * dst_chunk_size_in_bytes;
7474
copy_size_in_bytes = dst_chunk_size_in_bytes;
7575
} else {
7676
dst_layer_ptrs = static_cast<void**>(cpu_blocks_ + layer_id);
7777
src_layer_ptrs = static_cast<void**>(gpu_blocks_ + i * num_layers + layer_id);
78-
dst_startoff_inside_chunks = i * src_chunk_size_in_bytes;
78+
dst_startoff_inside_chunks = is_mla ? 0 : i * src_chunk_size_in_bytes;
7979
src_startoff_inside_chunks = 0;
8080
copy_size_in_bytes = src_chunk_size_in_bytes;
8181
}

tests/test_kvmanager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def generate_gpu_blocks(model_config, cache_config, test_config):
123123
num_layer=num_layers,
124124
num_block=num_gpu_blocks,
125125
tokens_per_block=tokens_per_block,
126-
num_head=num_kv_heads//tp_size,
126+
num_head=num_kv_heads//tp_size if not use_mla else num_kv_heads,
127127
head_size=head_size,
128128
is_mla=model_config.use_mla
129129
)

0 commit comments

Comments
 (0)