Skip to content

Commit baf2e4e

Browse files
committed
a monkey patch for lora_target
Former-commit-id: 622f44a
1 parent eac7f97 commit baf2e4e

2 files changed

Lines changed: 11 additions & 0 deletions

File tree

src/llmtuner/extras/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,12 @@
2929
"InternLM-7B-Base": "internlm/internlm-7b",
3030
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
3131
}
32+
33+
DEFAULT_MODULE = { # will be deprecated
34+
"LLaMA": "q_proj,v_proj",
35+
"BLOOM": "query_key_value",
36+
"BLOOMZ": "query_key_value",
37+
"Falcon": "query_key_value",
38+
"Baichuan": "W_pack",
39+
"InternLM": "q_proj,v_proj"
40+
}

src/llmtuner/webui/runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional, Tuple
77

88
from llmtuner.extras.callbacks import LogCallback
9+
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
910
from llmtuner.extras.logging import LoggerHandler
1011
from llmtuner.extras.misc import torch_gc
1112
from llmtuner.tuner import get_train_args, run_sft
@@ -79,6 +80,7 @@ def run_train(
7980
model_name_or_path=model_name_or_path,
8081
do_train=True,
8182
finetuning_type=finetuning_type,
83+
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
8284
prompt_template=template,
8385
dataset=",".join(dataset),
8486
dataset_dir=dataset_dir,

0 commit comments

Comments
 (0)