@@ -474,12 +474,13 @@ struct LoRAStateGetterForConst : public BaseStateGetter {
474474 std::map<std::string, ov::op::util::VariableInfo>& variable_ids) :
475475 getter (getter),
476476 BaseStateGetter (model),
477- variable_ids (variable_ids) {}
477+ variable_ids (variable_ids) {
478+ std::cout << " LoRAStateGetterForConst constructor" << std::endl;
479+ }
478480
479481 std::optional<LoRAConstantNode> operator () (NodePtr node) const {
480482 // std::cout << "LoRAStateGetterForConst operator()" << std::endl;
481483 std::string name = node->get_friendly_name ();
482- std::cout << " LoRAStateGetterForConst operator() " << name << std::endl;
483484 if (auto params = getter (name)) {
484485 // FIXME: Potential name conflict if LoRA is applied multiple times by using this infrastructure independently each time (not a recommended approach).
485486 // TODO: Check for name collisions searching for existing variables with the same names.
@@ -503,6 +504,7 @@ struct LoRAStateGetterForConst : public BaseStateGetter {
503504 }
504505
505506 NodePtr create_if_input () {
507+ std::cout << " create if input" << std::endl;
506508 auto variable_info = ov::op::util::VariableInfo{
507509 ov::Shape{1 },
508510 ov::element::Type_t::boolean,
@@ -611,46 +613,52 @@ class LoRAReplaceConstantTransformDynamic : public LoRAReplaceConstantTransform
611613public:
612614 LoRAReplaceConstantTransformDynamic (const LoRAConstantByNodeGetter& getter,
613615 const NodePtr if_input) :
614- LoRAReplaceConstantTransform (getter), if_input(if_input) {}
616+ LoRAReplaceConstantTransform (getter), if_input(if_input) {
617+ std::cout << " LoRAReplaceConstantTransformDynamic constructor" << std::endl;
618+ }
615619
616620protected:
617621 NodePtr if_input;
618622
619623 bool apply (NodePtr node, const LoRAConstantNode& lora_weight) override {
620624 auto consumers = node->get_output_target_inputs (0 );
621625 const auto node_type = node->get_element_type ();
626+
627+ std::cout << " LoRAReplaceConstantTransformDynamic apply" << std::endl;
628+ std::cout << " LoRAReplaceConstantTransformDynamic consumers size " << consumers.size () << std::endl;
622629
623- // Приводим tensor к типу node
624- // TODO: check that it's safe
630+ // cast to node type
625631 auto lora_output = lora_weight.tensor ;
626632 if (lora_weight.tensor ->get_element_type () != node_type) {
627633 lora_output = std::make_shared<ov::op::v0::Convert>(lora_weight.tensor , node_type);
628634 }
629635
630- // std::shared_ptr<ov::Model> then_body = std::make_shared<ov::Model>(lora_weight.tensor, ov::ParameterVector{}),
631- // else_body = std::make_shared<ov::Model>(node, ov::ParameterVector{});
632- // если есть константный вес
633- std::shared_ptr<ov::op::v8::If> if_node = std::make_shared<ov::op::v8::If>(if_input);
634- // то вставь его
635- std::shared_ptr<ov::Model> then_body = std::make_shared<ov::Model>(lora_output, ov::ParameterVector{});
636- // иначе оригинальный вес модели
637- std::shared_ptr<ov::Model> else_body = std::make_shared<ov::Model>(node, ov::ParameterVector{});
636+ auto if_node = std::make_shared<ov::op::v8::If>(if_input);
638637
638+ // IF branch: there is the constant weight, replace with ReadValue
639+ auto then_param = std::make_shared<ov::op::v0::Parameter>(lora_output->get_element_type (), lora_output->get_output_partial_shape (0 ));
640+ auto then_result = std::make_shared<ov::op::v0::Result>(then_param);
641+ auto then_body = std::make_shared<ov::Model>(ov::ResultVector{then_result}, ov::ParameterVector{then_param});
639642 if_node->set_then_body (then_body);
643+
644+ // ELSE branch: use original weight
645+ auto else_param = std::make_shared<ov::op::v0::Parameter>(node->get_element_type (), node->get_output_partial_shape (0 ));
646+ auto else_result = std::make_shared<ov::op::v0::Result>(else_param);
647+ auto else_body = std::make_shared<ov::Model>(ov::ResultVector{else_result}, ov::ParameterVector{else_param});
640648 if_node->set_else_body (else_body);
641- if_node->set_output (then_body->get_results ()[0 ], else_body->get_results ()[0 ]);
642-
643- // std::cout << if_node->outputs().size() << std::endl;
644649
645- // std::cout << if_node->output(0).get_shape() << std::endl;
650+ // set if_node inputs
651+ if_node->set_input (lora_output, then_param, nullptr ); // put LoRA tensor to then
652+ if_node->set_input (node, nullptr , else_param); // put original Constant to else
646653
654+ // set if_node output
655+ if_node->set_output (then_result, else_result);
647656
648657 for (auto & consumer : consumers) {
649658 consumer.replace_source_output (if_node->output (0 ));
650659 }
651- return true ;
652660
653-
661+ return true ;
654662 }
655663};
656664
@@ -685,7 +693,7 @@ NodePtr tensors_multiplication(NodePtr input,
685693 const auto target_shape = target.get_partial_shape ();
686694 const auto target_rank = target_shape.rank ().get_length ();
687695
688- std::cout << " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!tensors_multiplication" << std::endl;
696+ // std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!tensors_multiplication" << std::endl;
689697
690698 for (size_t i = 0 ; i < multipliers.size (); ++i) {
691699 NodePtr normalized = multipliers[i];
@@ -1244,6 +1252,7 @@ struct AdapterControllerImpl {
12441252 pm.register_pass <LoRASeparateTransform>(LoRAWeightStateGetter (params_getter, model, variable_ids));
12451253 if (!const_getter.empty ()) {
12461254 LoRAStateGetterForConst getter = LoRAStateGetterForConst (const_getter.front (), model, constant_variable_ids);
1255+ std::cout << " register pass for dynamic lora" << std::endl;
12471256 pm.register_pass <LoRAReplaceConstantTransformDynamic>(getter, getter.create_if_input ());
12481257 }
12491258 } else if (mode == AdapterConfig::MODE_STATIC) {
@@ -1341,6 +1350,7 @@ struct AdapterControllerImpl {
13411350
13421351 void set_new_adapter_tensors (ov::InferRequest& infer_request, bool alpha_only = false ) {
13431352 std::cout << " !!!!!!!!!!!!!!set_new_adapter_tensors" << std::endl;
1353+
13441354 if (current_config.get_mode () != AdapterConfig::MODE_AUTO && current_config.get_mode () != AdapterConfig::MODE_DYNAMIC && current_config.get_mode () != AdapterConfig::MODE_STATIC_RANK ) {
13451355 return ;
13461356 }
@@ -1358,6 +1368,8 @@ struct AdapterControllerImpl {
13581368 weight_getters.emplace_back (LoRAWeightGetterDefault<LoRAWeight, LoRANode>(&adapter_impl->get_tensors (), current_config.get_tensor_name_prefix ().value_or (" " )));
13591369 }
13601370
1371+ std::cout << " const_getter size " << const_getter.size () << std::endl;
1372+
13611373 auto state = infer_request.query_state ();
13621374 // TODO: Forced to use variable_id instead of index to address the state tensors, require the same order for state as for variables from plugins
13631375
@@ -1378,7 +1390,7 @@ struct AdapterControllerImpl {
13781390 set_lora_tensors (state, lora_var_ids.first , lora_var_ids.second , lora_indices, weight_getters, alpha_only);
13791391 }
13801392
1381- std::cout <<" constant names" << std::endl;
1393+ // std::cout <<" constant names" << std::endl;
13821394 for (const auto & constant_variable_id : constant_variable_ids) {
13831395 std::cout << constant_variable_id.first << std::endl;
13841396 }
0 commit comments