@@ -963,31 +963,36 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
963963 if (m_lm_head_request) {
964964 LOG_DEBUG (" Calling inference for LM head model asynchronously" );
965965 m_lm_head_request->start_async ();
966- }
966+ m_llm_profile[" N/generate:3.update_kvcache" ].record ([&]() {
967+ if (kvcache_desc.num_stored_tokens < kvcache_desc.total_size ) {
968+ update_kvcache_for (m_kvcache_request,
969+ m_kvcache_in_ports,
970+ m_kvcache_out_ports,
971+ input_tokens_len,
972+ kvcache_desc.v_tensors_transposed_gen );
973+ }
974+ });
975+ m_lm_head_request->wait ();
976+ LOG_DEBUG (" Calling inference for LM head model -- done." );
967977
968- m_llm_profile[" N/generate:3.update_kvcache" ].record ([&]() {
969- if (kvcache_desc.num_stored_tokens < kvcache_desc.total_size ) {
970- update_kvcache_for (m_kvcache_request,
971- m_kvcache_in_ports,
972- m_kvcache_out_ports,
973- input_tokens_len,
974- kvcache_desc.v_tensors_transposed_gen );
975- }
976- });
978+ m_logits = m_lm_head_request->get_tensor (m_lm_head_logits_port);
979+ } else {
980+ m_llm_profile[" N/generate:3.update_kvcache" ].record ([&]() {
981+ if (kvcache_desc.num_stored_tokens < kvcache_desc.total_size ) {
982+ update_kvcache_for (m_kvcache_request,
983+ m_kvcache_in_ports,
984+ m_kvcache_out_ports,
985+ input_tokens_len,
986+ kvcache_desc.v_tensors_transposed_gen );
987+ }
988+ });
977989
978- m_llm_profile[" N/generate:4.lm_head" ].record ([&]() {
979- if (m_lm_head_request) {
980- m_lm_head_request->wait ();
981- LOG_DEBUG (" Calling inference for LM head model -- done." );
982- m_logits = m_lm_head_request->get_tensor (m_lm_head_logits_port);
983- } else {
984- m_logits = m_kvcache_request->get_tensor (m_kvcache_out_ports.at (layer_names::logits));
985- }
990+ m_logits = m_kvcache_request->get_tensor (m_kvcache_out_ports.at (layer_names::logits));
991+ }
986992
987- if (m_eagle3_ext.is_eagle3_model ()) {
988- m_eagle3_ext.update_last_hidden_state (m_kvcache_request, m_kvcache_out_ports);
989- }
990- });
993+ if (m_eagle3_ext.is_eagle3_model ()) {
994+ m_eagle3_ext.update_last_hidden_state (m_kvcache_request, m_kvcache_out_ports);
995+ }
991996
992997 LOG_DEBUG (" Done" );
993998}
0 commit comments