Skip to content

Commit 8651599

Browse files
committed
If node fix: fixes ReadValue doesn't have sibling output error
1 parent 091730a commit 8651599

2 files changed

Lines changed: 34 additions & 22 deletions

File tree

samples/cpp/text_generation/lora_greedy_causal_lm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ int main(int argc, char* argv[]) try {
3232
std::cout << "\nGenerate without LoRA adapter:" << std::endl;
3333
// // LLMPipeline pipe1(models_path, device); // register all required adapters here
3434

35-
// std::cout << pipe.generate(prompt, max_new_tokens(100)) << std::endl;
35+
std::cout << pipe.generate(prompt, max_new_tokens(100)) << std::endl;
3636

3737
} catch (const std::exception& error) {
3838
std::cerr << error.what() << '\n';

src/cpp/src/lora/adapter.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -474,12 +474,13 @@ struct LoRAStateGetterForConst : public BaseStateGetter {
474474
std::map<std::string, ov::op::util::VariableInfo>& variable_ids) :
475475
getter(getter),
476476
BaseStateGetter(model),
477-
variable_ids(variable_ids) {}
477+
variable_ids(variable_ids) {
478+
std::cout << "LoRAStateGetterForConst constructor" << std::endl;
479+
}
478480

479481
std::optional<LoRAConstantNode> operator() (NodePtr node) const {
480482
// std::cout << "LoRAStateGetterForConst operator()" << std::endl;
481483
std::string name = node->get_friendly_name();
482-
std::cout << "LoRAStateGetterForConst operator() " << name << std::endl;
483484
if (auto params = getter(name)) {
484485
// FIXME: Potential name conflict if LoRA is applied multiple times by using this infrastructure independently each time (not a recommended approach).
485486
// TODO: Check for name collisions searching for existing variables with the same names.
@@ -503,6 +504,7 @@ struct LoRAStateGetterForConst : public BaseStateGetter {
503504
}
504505

505506
NodePtr create_if_input() {
507+
std::cout << "create if input" << std::endl;
506508
auto variable_info = ov::op::util::VariableInfo{
507509
ov::Shape{1},
508510
ov::element::Type_t::boolean,
@@ -611,46 +613,52 @@ class LoRAReplaceConstantTransformDynamic : public LoRAReplaceConstantTransform
611613
public:
612614
LoRAReplaceConstantTransformDynamic(const LoRAConstantByNodeGetter& getter,
613615
const NodePtr if_input) :
614-
LoRAReplaceConstantTransform(getter), if_input(if_input) {}
616+
LoRAReplaceConstantTransform(getter), if_input(if_input) {
617+
std::cout << "LoRAReplaceConstantTransformDynamic constructor" << std::endl;
618+
}
615619

616620
protected:
617621
NodePtr if_input;
618622

619623
bool apply(NodePtr node, const LoRAConstantNode& lora_weight) override {
620624
auto consumers = node->get_output_target_inputs(0);
621625
const auto node_type = node->get_element_type();
626+
627+
std::cout << "LoRAReplaceConstantTransformDynamic apply" << std::endl;
628+
std::cout << "LoRAReplaceConstantTransformDynamic consumers size " << consumers.size() << std::endl;
622629

623-
// Приводим tensor к типу node
624-
// TODO: check that it's safe
630+
// cast to node type
625631
auto lora_output = lora_weight.tensor;
626632
if (lora_weight.tensor->get_element_type() != node_type) {
627633
lora_output = std::make_shared<ov::op::v0::Convert>(lora_weight.tensor, node_type);
628634
}
629635

630-
// std::shared_ptr<ov::Model> then_body = std::make_shared<ov::Model>(lora_weight.tensor, ov::ParameterVector{}),
631-
// else_body = std::make_shared<ov::Model>(node, ov::ParameterVector{});
632-
// если есть константный вес
633-
std::shared_ptr<ov::op::v8::If> if_node = std::make_shared<ov::op::v8::If>(if_input);
634-
// то вставь его
635-
std::shared_ptr<ov::Model> then_body = std::make_shared<ov::Model>(lora_output, ov::ParameterVector{});
636-
// иначе оригинальный вес модели
637-
std::shared_ptr<ov::Model> else_body = std::make_shared<ov::Model>(node, ov::ParameterVector{});
636+
auto if_node = std::make_shared<ov::op::v8::If>(if_input);
638637

638+
// IF branch: there is the constant weight, replace with ReadValue
639+
auto then_param = std::make_shared<ov::op::v0::Parameter>(lora_output->get_element_type(), lora_output->get_output_partial_shape(0));
640+
auto then_result = std::make_shared<ov::op::v0::Result>(then_param);
641+
auto then_body = std::make_shared<ov::Model>(ov::ResultVector{then_result}, ov::ParameterVector{then_param});
639642
if_node->set_then_body(then_body);
643+
644+
// ELSE branch: use original weight
645+
auto else_param = std::make_shared<ov::op::v0::Parameter>(node->get_element_type(), node->get_output_partial_shape(0));
646+
auto else_result = std::make_shared<ov::op::v0::Result>(else_param);
647+
auto else_body = std::make_shared<ov::Model>(ov::ResultVector{else_result}, ov::ParameterVector{else_param});
640648
if_node->set_else_body(else_body);
641-
if_node->set_output(then_body->get_results()[0], else_body->get_results()[0]);
642-
643-
// std::cout << if_node->outputs().size() << std::endl;
644649

645-
// std::cout << if_node->output(0).get_shape() << std::endl;
650+
// set if_node inputs
651+
if_node->set_input(lora_output, then_param, nullptr); // put LoRA tensor to then
652+
if_node->set_input(node, nullptr, else_param); // put original Constant to else
646653

654+
// set if_node output
655+
if_node->set_output(then_result, else_result);
647656

648657
for (auto& consumer : consumers) {
649658
consumer.replace_source_output(if_node->output(0));
650659
}
651-
return true;
652660

653-
661+
return true;
654662
}
655663
};
656664

@@ -685,7 +693,7 @@ NodePtr tensors_multiplication(NodePtr input,
685693
const auto target_shape = target.get_partial_shape();
686694
const auto target_rank = target_shape.rank().get_length();
687695

688-
std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!tensors_multiplication" << std::endl;
696+
// std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!tensors_multiplication" << std::endl;
689697

690698
for (size_t i = 0; i < multipliers.size(); ++i) {
691699
NodePtr normalized = multipliers[i];
@@ -1244,6 +1252,7 @@ struct AdapterControllerImpl {
12441252
pm.register_pass<LoRASeparateTransform>(LoRAWeightStateGetter(params_getter, model, variable_ids));
12451253
if (!const_getter.empty()) {
12461254
LoRAStateGetterForConst getter = LoRAStateGetterForConst(const_getter.front(), model, constant_variable_ids);
1255+
std::cout << "register pass for dynamic lora" << std::endl;
12471256
pm.register_pass<LoRAReplaceConstantTransformDynamic>(getter, getter.create_if_input());
12481257
}
12491258
} else if(mode == AdapterConfig::MODE_STATIC) {
@@ -1341,6 +1350,7 @@ struct AdapterControllerImpl {
13411350

13421351
void set_new_adapter_tensors(ov::InferRequest& infer_request, bool alpha_only = false) {
13431352
std::cout << "!!!!!!!!!!!!!!set_new_adapter_tensors"<< std::endl;
1353+
13441354
if(current_config.get_mode() != AdapterConfig::MODE_AUTO && current_config.get_mode() != AdapterConfig::MODE_DYNAMIC && current_config.get_mode() != AdapterConfig::MODE_STATIC_RANK ) {
13451355
return;
13461356
}
@@ -1358,6 +1368,8 @@ struct AdapterControllerImpl {
13581368
weight_getters.emplace_back(LoRAWeightGetterDefault<LoRAWeight, LoRANode>(&adapter_impl->get_tensors(), current_config.get_tensor_name_prefix().value_or("")));
13591369
}
13601370

1371+
std::cout << "const_getter size " << const_getter.size() << std::endl;
1372+
13611373
auto state = infer_request.query_state();
13621374
// TODO: Forced to use variable_id instead of index to address the state tensors, require the same order for state as for variables from plugins
13631375

@@ -1378,7 +1390,7 @@ struct AdapterControllerImpl {
13781390
set_lora_tensors(state, lora_var_ids.first, lora_var_ids.second, lora_indices, weight_getters, alpha_only);
13791391
}
13801392

1381-
std::cout <<" constant names" << std::endl;
1393+
// std::cout <<" constant names" << std::endl;
13821394
for (const auto& constant_variable_id : constant_variable_ids) {
13831395
std::cout << constant_variable_id.first << std::endl;
13841396
}

0 commit comments

Comments
 (0)