Skip to content

Commit 9191781

Browse files
committed
fix type conflict
1 parent 8efdc0c commit 9191781

2 files changed

Lines changed: 36 additions & 9 deletions

File tree

samples/cpp/text_generation/lora_greedy_causal_lm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ int main(int argc, char* argv[]) try {
1515
using namespace ov::genai;
1616

1717
Adapter adapter(adapter_path);
18-
LLMPipeline pipe(models_path, device, adapters(adapter, 0.75, AdapterConfig::MODE_FUSE)); // register all required adapters here
18+
LLMPipeline pipe(models_path, device, adapters(adapter, 0.75, AdapterConfig::MODE_DYNAMIC)); // register all required adapters here
1919

2020
// Resetting config to set greedy behaviour ignoring generation config from model directory.
2121
// It helps to compare two generations with and without LoRA adapter.
@@ -24,12 +24,12 @@ int main(int argc, char* argv[]) try {
2424
pipe.set_generation_config(config);
2525

2626
std::cout << "Generate with LoRA adapter and alpha set to 0.75:" << std::endl;
27-
// std::cout << pipe.generate(prompt, max_new_tokens(100), adapters(adapter, 0.75)) << std::endl;
28-
std::cout << pipe.generate(prompt, max_new_tokens(100)) << std::endl;
27+
std::cout << pipe.generate(prompt, max_new_tokens(100), adapters(adapter, 0.75)) << std::endl;
28+
// std::cout << pipe.generate(prompt, max_new_tokens(100)) << std::endl;
2929

3030
std::cout << "\n-----------------------------";
3131
std::cout << "\nGenerate without LoRA adapter:" << std::endl;
32-
// LLMPipeline pipe1(models_path, device); // register all required adapters here
32+
// // LLMPipeline pipe1(models_path, device); // register all required adapters here
3333
std::cout << pipe.generate(prompt, max_new_tokens(100), adapters()) << std::endl;
3434

3535
} catch (const std::exception& error) {

src/cpp/src/lora/adapter.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ using LoRAConstantTensors = std::map<std::string, LoRAConstantNode>;
176176

177177
// Group constant tensors loaded from LoRA adapter file into constants
178178
LoRAConstantTensors group_lora_constant_tensors(const ConstantMap& tensors, const std::vector<RegexParser>& const_parsers) {
179+
std::cout << "group_lora_constant_tensors" << std::endl;
179180
LoRAConstantTensors result;
180181
for(const auto& named_tensor: tensors) {
181182
for (const auto& const_parser : const_parsers) {
@@ -236,7 +237,7 @@ struct LoRAParameters {
236237

237238
using LoRAParametersGetter = std::function<std::optional<LoRAParameters>(NodePtr node)>;
238239

239-
// Maps a given layer name to corresponding LoRA tensors based on the default name mapping schema.
240+
// Maps a given layer name to corresponding LoRA tensors based on the default name mapping schema.0
240241
// Layer name should start with a given prefix that is eliminated from the name before search for matching LoRA tensor.
241242
// It works for a single LoRA adapter.
242243
// Returns std::nullopt, if there is no LoRA adapter for a given layer name.
@@ -256,7 +257,7 @@ struct LoRAWeightGetterDefault {
256257
std::replace(name_with_underscores.begin(), name_with_underscores.end(), '.', '_');
257258
std::vector<std::string> variants{name, name_with_underscores};
258259
// auto it = std::find_if(lora_tensors->begin(), lora_tensors->end(), [this, variants](const LoRATensors::value_type& pair){
259-
auto it = std::find_if(lora_tensors->begin(), lora_tensors->end(), [this, variants](const std::pair<std::string, TENSOR_TYPE>& pair){
260+
auto it = std::find_if(lora_tensors->begin(), lora_tensors->end(), [this, variants](const std::pair<std::string, TENSOR_TYPE>& pair) {
260261
std::string lora_name = pair.first;
261262
// TODO: Make this filtering for prefix once in ctor as a more efficient solution
262263
if(lora_name.find(prefix) == 0) {
@@ -476,6 +477,7 @@ struct LoRAStateGetterForConst : public BaseStateGetter {
476477
variable_ids(variable_ids) {}
477478

478479
std::optional<LoRAConstantNode> operator() (NodePtr node) const {
480+
std::cout << "LoRAStateGetterForConst operator()" << std::endl;
479481
std::string name = node->get_friendly_name();
480482
if (auto params = getter(name)) {
481483
// FIXME: Potential name conflict if LoRA is applied multiple times by using this infrastructure independently each time (not a recommended approach).
@@ -484,6 +486,8 @@ struct LoRAStateGetterForConst : public BaseStateGetter {
484486
LoRAConstantNode result;
485487
ov::op::util::VariableInfo variable_info;
486488

489+
std::cout << "111 " << params->tensor->get_output_element_type(0) << std::endl;
490+
487491
// FIXME: No guarantees on ordering of state in InferRequest makes impossible using indices of variables later, forced to use variable_id instead
488492
variable_info = ov::op::util::VariableInfo{
489493
params->tensor->get_output_shape(0),
@@ -615,12 +619,29 @@ class LoRAReplaceConstantTransformDynamic : public LoRAReplaceConstantTransform
615619
bool apply(NodePtr node, const LoRAConstantNode& lora_weight) override {
616620
auto consumers = node->get_output_target_inputs(0);
617621

618-
// std::cout << lora_weight.tensor->get_output_shape(0) << std::endl;
622+
std::cout << lora_weight.tensor->get_element_type() << std::endl;
623+
std::cout << node->get_element_type() << std::endl;
624+
625+
const auto node_type = node->get_element_type();
626+
627+
// Приводим tensor к типу node
628+
// TODO: check that it's safe
629+
auto lora_output = lora_weight.tensor;
630+
if (lora_weight.tensor->get_element_type() != node_type) {
631+
lora_output = std::make_shared<ov::op::v0::Convert>(lora_weight.tensor, node_type);
632+
}
633+
619634
// std::cout << node->get_output_shape(0) << std::endl;
620635

621-
std::shared_ptr<ov::Model> then_body = std::make_shared<ov::Model>(lora_weight.tensor, ov::ParameterVector{}),
622-
else_body = std::make_shared<ov::Model>(node, ov::ParameterVector{});
636+
// std::shared_ptr<ov::Model> then_body = std::make_shared<ov::Model>(lora_weight.tensor, ov::ParameterVector{}),
637+
// else_body = std::make_shared<ov::Model>(node, ov::ParameterVector{});
638+
// если есть константный вес
623639
std::shared_ptr<ov::op::v8::If> if_node = std::make_shared<ov::op::v8::If>(if_input);
640+
// то вставь его
641+
std::shared_ptr<ov::Model> then_body = std::make_shared<ov::Model>(lora_output, ov::ParameterVector{});
642+
// иначе оригинальный вес модели
643+
std::shared_ptr<ov::Model> else_body = std::make_shared<ov::Model>(node, ov::ParameterVector{});
644+
624645
if_node->set_then_body(then_body);
625646
if_node->set_else_body(else_body);
626647
if_node->set_output(then_body->get_results()[0], else_body->get_results()[0]);
@@ -632,6 +653,7 @@ class LoRAReplaceConstantTransformDynamic : public LoRAReplaceConstantTransform
632653

633654
for (auto& consumer : consumers) {
634655
consumer.replace_source_output(if_node->output(0));
656+
std::cout << consumer.get_node()->get_element_type() << std::endl;
635657
}
636658
return true;
637659
}
@@ -668,10 +690,14 @@ NodePtr tensors_multiplication(NodePtr input,
668690
const auto target_shape = target.get_partial_shape();
669691
const auto target_rank = target_shape.rank().get_length();
670692

693+
std::cout << 'tensors_multiplication' << std::endl;
694+
671695
for (size_t i = 0; i < multipliers.size(); ++i) {
672696
NodePtr normalized = multipliers[i];
697+
std::cout <<"aaa "<< normalized->get_output_element_type(0).get_type_name() << std::endl;
673698
if (normalized->get_output_element_type(0) != target_type) {
674699
normalized = std::make_shared<v0::Convert>(normalized, target_type);
700+
std::cout <<"bbb "<< normalized->get_output_element_type(0).get_type_name() << std::endl;
675701
if (std::dynamic_pointer_cast<v0::Constant>(normalized)) {
676702
input->get_rt_info()["decompression"];
677703
}
@@ -1320,6 +1346,7 @@ struct AdapterControllerImpl {
13201346
}
13211347

13221348
void set_new_adapter_tensors (ov::InferRequest& infer_request, bool alpha_only = false) {
1349+
std::cout << "!!!!!!!!!!!!!!set_new_adapter_tensors"<< std::endl;
13231350
if(current_config.get_mode() != AdapterConfig::MODE_AUTO && current_config.get_mode() != AdapterConfig::MODE_DYNAMIC && current_config.get_mode() != AdapterConfig::MODE_STATIC_RANK ) {
13241351
return;
13251352
}

0 commit comments

Comments
 (0)