Skip to content

Commit 0dcc7bc

Browse files
committed
Handle the post-processing with a separate model
1 parent e48a215 commit 0dcc7bc

File tree

2 files changed

+99
-43
lines changed

2 files changed

+99
-43
lines changed

src/cpp/src/rag/text_embedding_pipeline.cpp

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ bool has_token_type_ids_input(const T& inputs) {
4545
return false;
4646
}
4747

48+
void set_node_name(std::shared_ptr<ov::Node> node, const std::string& name) {
49+
node->set_friendly_name(name);
50+
node->get_output_tensor(0).set_names({name});
51+
}
52+
4853
/**
4954
* CLS pooling slices first element from seq_length dimension
5055
* [batch_size, seq_length, hidden_size] -> [batch_size, seq_length[0], hidden_size]
@@ -62,12 +67,10 @@ std::shared_ptr<op::Op> get_cls_pooling_op(const ov::Output<ov::Node>& last_hidd
6267
return std::make_shared<op::v15::Squeeze>(slice, squeeze_axis);
6368
}
6469

65-
std::shared_ptr<op::Op> get_mean_pooling_op(std::shared_ptr<Model> model,
66-
const ov::Output<ov::Node>& last_hidden_state_node) {
70+
std::shared_ptr<op::Op> get_mean_pooling_op(const ov::Output<ov::Node>& last_hidden_state_node,
71+
const ov::Output<ov::Node>& attention_mask) {
6772
auto shape_of = std::make_shared<op::v3::ShapeOf>(last_hidden_state_node);
6873

69-
auto attention_mask = model->input("attention_mask").get_node()->outputs()[0];
70-
7174
auto unsqueze_axis = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
7275

7376
auto unsqueze = std::make_shared<op::v0::Unsqueeze>(attention_mask, unsqueze_axis);
@@ -95,8 +98,8 @@ std::shared_ptr<op::Op> get_mean_pooling_op(std::shared_ptr<Model> model,
9598
return std::make_shared<op::v1::Divide>(sum_hidden_state, max_expanded_mask);
9699
}
97100

98-
std::shared_ptr<op::Op> get_last_token_pooling_op(std::shared_ptr<Model> model,
99-
const ov::Output<ov::Node>& last_hidden_state_node,
101+
std::shared_ptr<op::Op> get_last_token_pooling_op(const ov::Output<ov::Node>& last_hidden_state_node,
102+
const ov::Output<ov::Node>& attention_mask,
100103
const TextEmbeddingPipeline::Config& config) {
101104
const auto left_padding = config.padding_side.has_value() && config.padding_side.value() == "left";
102105

@@ -115,8 +118,6 @@ std::shared_ptr<op::Op> get_last_token_pooling_op(std::shared_ptr<Model> model,
115118
return std::make_shared<op::v15::Squeeze>(slice, squeeze_axis);
116119
}
117120

118-
auto attention_mask = model->input("attention_mask").get_node()->outputs()[0];
119-
120121
auto axis_1 = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
121122
auto reduce_sum = std::make_shared<op::v1::ReduceSum>(attention_mask, axis_1);
122123
auto subtract_1 = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
@@ -125,31 +126,71 @@ std::shared_ptr<op::Op> get_last_token_pooling_op(std::shared_ptr<Model> model,
125126
return std::make_shared<op::v8::Gather>(last_hidden_state_node, subtract, axis_1, 1);
126127
}
127128

129+
std::shared_ptr<op::Op> create_post_ops(const ov::Output<ov::Node>& input,
130+
const ov::Output<ov::Node>& attention_mask,
131+
const TextEmbeddingPipeline::Config& config) {
132+
if (config.pooling_type == TextEmbeddingPipeline::PoolingType::CLS) {
133+
return get_cls_pooling_op(input);
134+
} else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::MEAN) {
135+
return get_mean_pooling_op(input, attention_mask);
136+
} else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::LAST_TOKEN) {
137+
return get_last_token_pooling_op(input, attention_mask, config);
138+
}
139+
140+
OPENVINO_THROW("Pooling type is not supported");
141+
}
142+
143+
std::shared_ptr<op::Op> create_normalize_ops(const ov::Output<ov::Node>& input,
144+
const TextEmbeddingPipeline::Config& config) {
145+
if (config.normalize) {
146+
auto axis_const = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector{1});
147+
return std::make_shared<op::v0::NormalizeL2>(input, axis_const, 1e-12, op::EpsMode::MAX);
148+
}
149+
return std::dynamic_pointer_cast<op::Op>(input.get_node_shared_ptr());
150+
}
151+
128152
std::shared_ptr<Model> apply_postprocessing(std::shared_ptr<Model> model, const TextEmbeddingPipeline::Config& config) {
129153
ov::preprocess::PrePostProcessor processor(model);
130154

131155
processor.output().postprocess().custom([model, &config](const ov::Output<ov::Node>& node) {
132-
if (config.pooling_type == TextEmbeddingPipeline::PoolingType::CLS) {
133-
return get_cls_pooling_op(node);
134-
} else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::MEAN) {
135-
return get_mean_pooling_op(model, node);
136-
} else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::LAST_TOKEN) {
137-
return get_last_token_pooling_op(model, node, config);
138-
}
139-
140-
OPENVINO_THROW("Pooling type is not supported");
156+
auto attention_mask = model->input("attention_mask").get_node()->outputs()[0];
157+
return create_post_ops(node, attention_mask, config);
141158
});
142159

143160
if (config.normalize) {
144-
processor.output().postprocess().custom([](const ov::Output<ov::Node>& node) {
145-
auto axis_const = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector{1});
146-
return std::make_shared<op::v0::NormalizeL2>(node, axis_const, 1e-12, op::EpsMode::MAX);
161+
processor.output().postprocess().custom([&config](const ov::Output<ov::Node>& node) {
162+
return create_normalize_ops(node, config);
147163
});
148164
}
149165

150166
return processor.build();
151167
}
152168

169+
std::shared_ptr<ov::Model> create_post_model(std::shared_ptr<ov::Model> model,
170+
const TextEmbeddingPipeline::Config& config,
171+
ov::Dimension::value_type max_prompt_size) {
172+
auto output_node = model->outputs()[0];
173+
auto output_shape = output_node.get_partial_shape();
174+
auto input_param =
175+
std::make_shared<ov::op::v0::Parameter>(output_node.get_element_type(), ov::PartialShape{1, max_prompt_size, output_shape[2]});
176+
set_node_name(input_param, "input_ids");
177+
178+
auto attention_mask = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{1, max_prompt_size});
179+
set_node_name(attention_mask, "attention_mask");
180+
181+
auto post_output = create_post_ops(input_param, attention_mask, config);
182+
auto post_normalize_output = create_normalize_ops(post_output, config);
183+
OPENVINO_ASSERT(post_normalize_output != nullptr);
184+
185+
auto result_node = std::make_shared<ov::op::v0::Result>(post_normalize_output);
186+
set_node_name(result_node, "last_hidden_state");
187+
auto post_model =
188+
std::make_shared<ov::Model>(ov::OutputVector{result_node}, ov::ParameterVector{input_param, attention_mask});
189+
post_model->set_friendly_name(model->get_friendly_name() + "_post_process");
190+
post_model->validate_nodes_and_infer_types();
191+
return post_model;
192+
}
193+
153194
std::optional<size_t> read_max_position_embeddings(const std::filesystem::path& models_path) {
154195
// config.json not found. Skip parameters initialization from file, use defaults.
155196
const std::filesystem::path& json_path = models_path / "config.json";
@@ -238,16 +279,22 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
238279

239280
ov::CompiledModel compiled_model;
240281
if (device == "NPU" && model->is_dynamic()) {
282+
OPENVINO_ASSERT(m_config.max_length.has_value(), "The parameter max_length is not set");
283+
241284
bool is_padding_on_left = m_config.padding_side.has_value() && m_config.padding_side.value() == "left";
242285
if (is_padding_on_left && is_seq_len_fixed &&
243286
config.pooling_type != TextEmbeddingPipeline::PoolingType::MEAN) {
244-
OPENVINO_THROW("Padding on left is only supported for the mean post-processing type");
287+
OPENVINO_THROW("Padding on left is only supported for the MEAN pooling type");
245288
}
246289

247290
auto kv_pos = ov::genai::utils::get_kv_axes_pos(model);
248291
utils::KVDesc kv_desc;
249292
std::tie(compiled_model, kv_desc) =
250293
utils::compile_decoder_for_npu_text_embedding(model, properties, kv_pos, m_config);
294+
295+
auto post_model = create_post_model(model, m_config, m_config.max_length.value());
296+
auto post_compiled_model = core.compile_model(post_model, "CPU");
297+
m_post_request = post_compiled_model.create_infer_request();
251298
} else {
252299
model = apply_postprocessing(model, m_config);
253300
compiled_model = core.compile_model(model, device, properties);
@@ -296,9 +343,11 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
296343
private:
297344
Tokenizer m_tokenizer;
298345
InferRequest m_request;
346+
InferRequest m_post_request;
299347
Config m_config;
300348
AnyMap m_tokenization_params;
301349
std::optional<size_t> m_max_position_embeddings;
350+
ov::Tensor m_attention_mask;
302351

303352
void reshape_model(std::shared_ptr<Model>& model) {
304353
ov::PartialShape target_shape{ov::Dimension::dynamic(), ov::Dimension::dynamic()};
@@ -336,6 +385,28 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
336385
model->reshape(input_name_to_shape);
337386
}
338387

388+
ov::Tensor post_model_infer(ov::Tensor input) {
389+
if (m_post_request) {
390+
m_post_request.set_tensor("input_ids", input);
391+
392+
auto attention_mask_tensor = m_post_request.get_tensor("attention_mask");
393+
394+
std::copy_n(m_attention_mask.data<int64_t>(),
395+
m_attention_mask.get_size(),
396+
attention_mask_tensor.data<int64_t>());
397+
if (m_attention_mask.get_size() < attention_mask_tensor.get_size()) {
398+
std::fill_n(attention_mask_tensor.data<int64_t>() + m_attention_mask.get_size(),
399+
attention_mask_tensor.get_size() - m_attention_mask.get_size(),
400+
0);
401+
}
402+
403+
m_post_request.infer();
404+
return m_post_request.get_tensor("last_hidden_state");
405+
}
406+
407+
return input;
408+
}
409+
339410
void start_embed_async(std::vector<std::string>& texts) {
340411
if (m_config.batch_size.has_value()) {
341412
// if batch_size is set, model shape is fixed
@@ -347,10 +418,11 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
347418
}
348419

349420
const auto encoded = m_tokenizer.encode(texts, m_tokenization_params);
350-
351421
m_request.set_tensor("input_ids", encoded.input_ids);
352422
m_request.set_tensor("attention_mask", encoded.attention_mask);
353423

424+
m_attention_mask = encoded.attention_mask;
425+
354426
// fill token_type_ids
355427
// todo: pass token_type_ids from tokenizer
356428
if (has_token_type_ids_input(m_request.get_compiled_model().inputs())) {
@@ -366,9 +438,8 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
366438
m_request.wait();
367439

368440
// [batch_size, hidden_size]
369-
const Tensor last_hidden_state = m_request.get_tensor("last_hidden_state");
370-
371-
return to_embedding_result(last_hidden_state);
441+
const auto last_hidden_state = m_request.get_tensor("last_hidden_state");
442+
return to_embedding_result(post_model_infer(last_hidden_state));
372443
};
373444

374445
std::vector<std::string> format_texts(const std::vector<std::string>& texts) {
@@ -398,6 +469,7 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
398469

399470
std::vector<std::vector<float>> result;
400471
const auto shape = last_hidden_state.get_shape();
472+
401473
const size_t batch_size = shape[0];
402474
const size_t hidden_size = shape[1];
403475

src/cpp/src/utils.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ void update_npu_config_whisper(ov::AnyMap& config,
118118

119119
void update_npu_config_text_embedding(ov::AnyMap& config,
120120
const ov::genai::utils::KVAxesPosition& kv_pos,
121-
const ov::genai::utils::KVDesc& kv_desc,
122-
const std::string& post_type,
123-
const bool is_to_normalize) {
121+
const ov::genai::utils::KVDesc& kv_desc) {
124122
update_config(config, {"NPU_USE_NPUW", "YES"});
125123
update_config(config, {"NPUW_LLM", "YES"});
126124
update_config(config, {"NPUW_LLM_BATCH_DIM", kv_pos.batch});
@@ -131,8 +129,6 @@ void update_npu_config_text_embedding(ov::AnyMap& config,
131129
update_config(config, {"NPUW_LLM_SHARED_HEAD", "NO"});
132130

133131
update_config(config, {"NPUW_TEXT_EMBED", "YES"});
134-
update_config(config, {"NPUW_TEXT_EMBED_POST_TYPE", post_type});
135-
update_config(config, {"NPUW_TEXT_EMBED_NORMALIZE", is_to_normalize});
136132
}
137133

138134
inline bool is_paged_attention_available() {
@@ -639,18 +635,6 @@ void get_npu_model_config(ov::AnyMap& properties,
639635
}
640636
}
641637

642-
std::string get_post_type_string(const TextEmbeddingPipeline::Config& config) {
643-
std::string post_type;
644-
if (config.pooling_type == TextEmbeddingPipeline::PoolingType::CLS) {
645-
post_type = "cls";
646-
} else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::MEAN) {
647-
post_type = "mean";
648-
} else {
649-
post_type = "last_token";
650-
}
651-
return post_type;
652-
}
653-
654638
void get_npu_text_embedding_config(ov::AnyMap& properties,
655639
const KVAxesPosition& kv_pos,
656640
KVDesc& kv_desc,
@@ -661,7 +645,7 @@ void get_npu_text_embedding_config(ov::AnyMap& properties,
661645
kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u);
662646
}
663647
kv_desc.min_response_len = kv_desc.max_prompt_len;
664-
update_npu_config_text_embedding(properties, kv_pos, kv_desc, get_post_type_string(text_embed_config), text_embed_config.normalize);
648+
update_npu_config_text_embedding(properties, kv_pos, kv_desc);
665649
}
666650

667651
std::pair<ov::CompiledModel, KVDesc> compile_decoder_for_npu_impl(const std::shared_ptr<ov::Model>& model,

0 commit comments

Comments
 (0)