Skip to content

Commit 4d0f662

Browse files
authored
[webui] upgrade to gradio 5 (hiyouga#6688)
1 parent 7bf09ab commit 4d0f662

File tree

8 files changed

+36
-44
lines changed

8 files changed

+36
-44
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ accelerate>=0.34.0,<=1.0.1
44
peft>=0.11.1,<=0.12.0
55
trl>=0.8.6,<=0.9.6
66
tokenizers>=0.19.0,<0.20.4
7-
gradio>=4.0.0,<5.0.0
7+
gradio>=4.0.0,<6.0.0
88
pandas>=2.0.0
99
scipy
1010
einops

src/llamafactory/webui/chatter.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import json
1616
import os
17-
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
17+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
1818

1919
from ..chat import ChatModel
2020
from ..data import Role
@@ -120,26 +120,26 @@ def unload_model(self, data) -> Generator[str, None, None]:
120120

121121
def append(
122122
self,
123-
chatbot: List[List[Optional[str]]],
124-
messages: Sequence[Dict[str, str]],
123+
chatbot: List[Dict[str, str]],
124+
messages: List[Dict[str, str]],
125125
role: str,
126126
query: str,
127-
) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
128-
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
127+
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
128+
return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], ""
129129

130130
def stream(
131131
self,
132-
chatbot: List[List[Optional[str]]],
133-
messages: Sequence[Dict[str, str]],
132+
chatbot: List[Dict[str, str]],
133+
messages: List[Dict[str, str]],
134134
system: str,
135135
tools: str,
136136
image: Optional[Any],
137137
video: Optional[Any],
138138
max_new_tokens: int,
139139
top_p: float,
140140
temperature: float,
141-
) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
142-
chatbot[-1][1] = ""
141+
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
142+
chatbot.append({"role": "assistant", "content": ""})
143143
response = ""
144144
for new_text in self.stream_chat(
145145
messages,
@@ -166,5 +166,5 @@ def stream(
166166
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
167167
bot_text = result
168168

169-
chatbot[-1][1] = bot_text
169+
chatbot[-1] = {"role": "assistant", "content": bot_text}
170170
yield chatbot, output_messages

src/llamafactory/webui/components/chatbot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def create_chat_box(
3333
engine: "Engine", visible: bool = False
3434
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
3535
with gr.Column(visible=visible) as chat_box:
36-
chatbot = gr.Chatbot(show_copy_button=True)
36+
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
3737
messages = gr.State([])
3838
with gr.Row():
3939
with gr.Column(scale=4):

src/llamafactory/webui/components/top.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030

3131

3232
def create_top() -> Dict[str, "Component"]:
33-
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
34-
3533
with gr.Row():
36-
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
37-
model_name = gr.Dropdown(choices=available_models, scale=3)
34+
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], value=None, scale=1)
35+
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
36+
model_name = gr.Dropdown(choices=available_models, value=None, scale=3)
3837
model_path = gr.Textbox(scale=3)
3938

4039
with gr.Row():

src/llamafactory/webui/components/train.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
3939
elem_dict = dict()
4040

4141
with gr.Row():
42-
training_stage = gr.Dropdown(
43-
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
44-
)
42+
stages = list(TRAINING_STAGES.keys())
43+
training_stage = gr.Dropdown(choices=stages, value=stages[0], scale=1)
4544
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
4645
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
4746
preview_elems = create_preview_box(dataset_dir, dataset)
@@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
107106
use_llama_pro = gr.Checkbox()
108107

109108
with gr.Column():
110-
shift_attn = gr.Checkbox()
111-
report_to = gr.Checkbox()
109+
report_to = gr.Dropdown(
110+
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"],
111+
value=["none"],
112+
allow_custom_value=True,
113+
multiselect=True,
114+
)
112115

113116
input_elems.update(
114117
{
@@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
123126
mask_history,
124127
resize_vocab,
125128
use_llama_pro,
126-
shift_attn,
127129
report_to,
128130
}
129131
)
@@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
141143
mask_history=mask_history,
142144
resize_vocab=resize_vocab,
143145
use_llama_pro=use_llama_pro,
144-
shift_attn=shift_attn,
145146
report_to=report_to,
146147
)
147148
)

src/llamafactory/webui/locales.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -713,24 +713,6 @@
713713
"info": "확장된 블록의 매개변수를 학습 가능하게 만듭니다.",
714714
},
715715
},
716-
"shift_attn": {
717-
"en": {
718-
"label": "Enable S^2 Attention",
719-
"info": "Use shift short attention proposed by LongLoRA.",
720-
},
721-
"ru": {
722-
"label": "Включить S^2 внимание",
723-
"info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
724-
},
725-
"zh": {
726-
"label": "使用 S^2 Attention",
727-
"info": "使用 LongLoRA 提出的 shift short attention。",
728-
},
729-
"ko": {
730-
"label": "S^2 Attention 사용",
731-
"info": "LongLoRA에서 제안한 shift short attention을 사용합니다.",
732-
},
733-
},
734716
"report_to": {
735717
"en": {
736718
"label": "Enable external logger",

src/llamafactory/webui/runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
144144
mask_history=get("train.mask_history"),
145145
resize_vocab=get("train.resize_vocab"),
146146
use_llama_pro=get("train.use_llama_pro"),
147-
shift_attn=get("train.shift_attn"),
148-
report_to="all" if get("train.report_to") else "none",
147+
report_to=get("train.report_to"),
149148
use_galore=get("train.use_galore"),
150149
use_apollo=get("train.use_apollo"),
151150
use_badam=get("train.use_badam"),
@@ -239,6 +238,12 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
239238
args["badam_switch_interval"] = get("train.badam_switch_interval")
240239
args["badam_update_ratio"] = get("train.badam_update_ratio")
241240

241+
# report_to
242+
if "none" in args["report_to"]:
243+
args["report_to"] = "none"
244+
elif "all" in args["report_to"]:
245+
args["report_to"] = "all"
246+
242247
# swanlab config
243248
if get("train.use_swanlab"):
244249
args["swanlab_project"] = get("train.swanlab_project")

src/llamafactory/webui/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,12 @@ def gen_cmd(args: Dict[str, Any]) -> str:
111111
"""
112112
cmd_lines = ["llamafactory-cli train "]
113113
for k, v in clean_cmd(args).items():
114-
cmd_lines.append(f" --{k} {str(v)} ")
114+
if isinstance(v, dict):
115+
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
116+
elif isinstance(v, list):
117+
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
118+
else:
119+
cmd_lines.append(f" --{k} {str(v)} ")
115120

116121
if os.name == "nt":
117122
cmd_text = "`\n".join(cmd_lines)

0 commit comments

Comments
 (0)