Skip to content

Commit a8dfb18

Browse files
authored
[GPU] Reuse kv cache mem if it is not changed from previous infer (#28361)
### Details: - When kv cache variable is reset, it is allocating a new memory. - However, if the variable mem is not changed from previous iteration, we can reuse previsouly allocated memory ### Tickets: - *ticket-id*
1 parent f616896 commit a8dfb18

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ struct network {
193193
const ov::intel_gpu::VariableStateInfo& get_variable_info(const std::string &variable_id) const;
194194
const ov::intel_gpu::VariablesMap& get_variables() const;
195195
const ov::intel_gpu::VariablesInfoMap& get_variables_info() const;
196+
void set_reuse_variable_mem(bool reuse = false);
197+
bool is_reuse_variable_mem() { return _reuse_variable_mem; }
196198

197199
const ExecutionConfig& get_config() const { return _config; }
198200

@@ -216,6 +218,7 @@ struct network {
216218
bool _is_dynamic = false;
217219
bool _enable_profiling = false;
218220
bool _reset_arguments;
221+
bool _reuse_variable_mem = false;
219222

220223
std::unordered_map<primitive_id, std::shared_ptr<primitive_inst>> _primitives;
221224
std::vector<shared_mem_type> _in_out_shared_mem_types;

src/plugins/intel_gpu/src/graph/network.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,5 +1028,9 @@ void network::set_variables_state_info(const std::string& variable_id,
10281028
_variables_state_info.at(variable_id).m_primitives.insert(p);
10291029
}
10301030

1031+
void network::set_reuse_variable_mem(bool reuse) {
1032+
_reuse_variable_mem = reuse;
1033+
}
1034+
10311035

10321036
} // namespace cldnn

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -624,16 +624,24 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
624624
_max_output_layout_count[j] = 0;
625625
}
626626
} else {
627-
_outputs[0] = variable.get_memory();
627+
GPU_DEBUG_TRACE_DETAIL
628+
<< id() << " : realloc_if_needed: can_be_optimized = false and memories are not being shared"
629+
<< std::endl;
630+
if (!get_network().is_reuse_variable_mem()) {
631+
GPU_DEBUG_TRACE_DETAIL << "Update output mem with new variable mem" << std::endl;
632+
_outputs[0] = variable.get_memory();
633+
_max_output_layout_count[0] = variable.get_actual_mem_size() / dt_sizes_in_B[0];
628634

629-
if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
630-
_outputs[2] = compressed_cache_variable->get_compression_scale_state()->get_memory();
635+
if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
636+
_outputs[2] = compressed_cache_variable->get_compression_scale_state()->get_memory();
631637

632-
if (compressed_cache_variable->has_zp_state()) {
633-
_outputs[3] = compressed_cache_variable->get_compression_zp_state()->get_memory();
638+
if (compressed_cache_variable->has_zp_state()) {
639+
_outputs[3] = compressed_cache_variable->get_compression_zp_state()->get_memory();
640+
}
634641
}
642+
} else {
643+
GPU_DEBUG_TRACE_DETAIL << "Can reuse variable mem of prev request" << std::endl;
635644
}
636-
GPU_DEBUG_TRACE_DETAIL << id() << " : realloc_if_needed: can_be_optimized = false and memories are not being shared" << std::endl;
637645
}
638646
} else {
639647
variable.set_layout(_impl_params->output_layouts[0]);

src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,21 @@ void SyncInferRequest::enqueue() {
295295
std::move(events.begin(), events.end(), std::back_inserter(dependencies));
296296
}
297297

298+
auto network = m_graph->get_network();
298299
for (const auto& it : m_variables) {
299300
const auto& name = it.first;
300301
const auto& variable = it.second;
302+
if (network->has_variable(name)) {
303+
const auto& prev_var = network->get_variable(name);
304+
if (prev_var.get_memory() == variable->get_memory()) {
305+
network->set_reuse_variable_mem(true);
306+
continue;
307+
}
308+
}
309+
network->set_reuse_variable_mem(false);
301310
prepare_state(name, variable);
302311
}
303312

304-
auto network = m_graph->get_network();
305313
network->set_shape_predictor(m_shape_predictor);
306314

307315
m_internal_outputs.clear();

0 commit comments

Comments
 (0)