Skip to content

Commit ff52302

Browse files
author
YorickHe
committed
refine the inference on gradio
1 parent 41bfa1b commit ff52302

2 files changed

Lines changed: 50 additions & 18 deletions

File tree

app.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ class UploadTarget(enum.Enum):
3030
PERSONAL_PROFILE = 'Personal Profile'
3131
LORA_LIaBRARY = 'LoRA Library'
3232

33+
def update_cloth(style_index):
34+
prompts = []
35+
if style_index == 0:
36+
example_prompt = generate_pos_prompt(styles[style_index]['name'],
37+
cloth_prompt[0]['prompt'])
38+
for prompt in cloth_prompt:
39+
prompts.append(prompt['name'])
40+
else:
41+
example_prompt = generate_pos_prompt(styles[style_index]['name'],
42+
styles[style_index]['add_prompt_style'])
43+
prompts.append(styles[style_index]['cloth_name'])
44+
return gr.Radio.update(choices=prompts, value=prompts[0]), gr.Textbox.update(value=example_prompt)
45+
46+
47+
def update_prompt(style_index, cloth_index):
48+
if style_index == 0:
49+
pos_prompt = generate_pos_prompt(styles[style_index]['name'],
50+
cloth_prompt[cloth_index]['prompt'])
51+
else:
52+
pos_prompt = generate_pos_prompt(styles[style_index]['name'],
53+
styles[style_index]['add_prompt_style'])
54+
return gr.Textbox.update(value=pos_prompt)
3355

3456
def concatenate_images(images):
3557
heights = [img.shape[0] for img in images]
@@ -75,6 +97,7 @@ def launch_pipeline(uuid,
7597
base_model = 'ly261666/cv_portrait_model'
7698
before_queue_size = inference_threadpool._work_queue.qsize()
7799
before_done_count = inference_done_count
100+
style_model = styles[style_model]['name']
78101

79102
if style_model == styles[0]['name']:
80103
style_model_path = None
@@ -274,21 +297,26 @@ def inference_input():
274297
with gr.Column():
275298
user_models = gr.Radio(label="模型选择(Model list)", choices=HOT_MODELS, type="value",
276299
value=HOT_MODELS[0])
277-
pos_prompt = gr.Textbox(label="Prompt", lines=3,
278-
value=generate_pos_prompt(None, cloth_prompt[0]['prompt']))
279-
style_model = gr.Textbox(label="风格模型(Style model)", value=styles[0]['name'])
280-
281-
prompts = []
282-
for prompt in cloth_prompt[0:1]:
283-
prompts.append([styles[0]['name'], generate_pos_prompt(styles[0]['name'], prompt['prompt'])])
284-
for style in styles[1:]:
285-
prompts.append([style['name'], generate_pos_prompt(style['name'], style['add_prompt_style'])])
286-
gr.Examples(prompts,
287-
inputs=[style_model, pos_prompt], label='提示词和风格示例(Prompt and styles examples)')
288-
multiplier_style = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.05, label='multiplier_style')
300+
style_model_list = []
301+
for style in styles:
302+
style_model_list.append(style['name'])
303+
style_model = gr.Dropdown(choices=style_model_list, value=styles[0]['name'],
304+
type="index", label="风格模型(Style model)")
305+
306+
prompts=[]
307+
for prompt in cloth_prompt:
308+
prompts.append(prompt['name'])
309+
cloth_style = gr.Radio(choices=prompts, value=cloth_prompt[0]['name'],
310+
type="index", label="服装风格(Cloth style)")
311+
312+
with gr.Accordion("高级选项(Expert)", open=False):
313+
pos_prompt = gr.Textbox(label="提示语(Prompt)", lines=3,
314+
value=generate_pos_prompt(None, cloth_prompt[0]['prompt']), interactive=True)
315+
multiplier_style = gr.Slider(minimum=0, maximum=1, value=0.25,
316+
step=0.05, label='风格权重(Multiplier style)')
289317
with gr.Box():
290318
num_images = gr.Number(
291-
label='生成图片数量(Number of photos)', value=6, precision=1)
319+
label='生成图片数量(Number of photos)', value=6, precision=1, minimum=1, maximum=6)
292320
gr.Markdown('''
293321
注意:最多支持生成6张图片!(You may generate a maximum of 6 photos at one time!)
294322
''')
@@ -301,6 +329,9 @@ def inference_input():
301329
gr.Markdown('生成结果(Result)')
302330
output_images = gr.Gallery(label='Output', show_label=False).style(columns=3, rows=2, height=600,
303331
object_fit="contain")
332+
333+
style_model.change(update_cloth, style_model, [cloth_style, pos_prompt])
334+
cloth_style.change(update_prompt, [style_model, cloth_style], [pos_prompt])
304335
display_button.click(fn=launch_pipeline,
305336
inputs=[uuid, pos_prompt, user_models, num_images, style_model, multiplier_style],
306337
outputs=[infer_progress, output_images])

facechain/constants.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66

77
cloth_prompt = [
8-
{'name': 'working suit', 'prompt': 'high-class business/working suit'}, # male and female
9-
{'name': 'armor', 'prompt': 'silver armor'}, # male
10-
{'name': 'T-shirt', 'prompt': 'T-shirt'}, # male and female
11-
{'name': 'hanfu', 'prompt': 'beautiful traditional hanfu, upper_body'}, # female
12-
{'name': 'gown', 'prompt': 'an elegant evening gown'}, # female
8+
{'name': '工作服(working suit)', 'prompt': 'high-class business/working suit'}, # male and female
9+
{'name': '盔甲风(armor)', 'prompt': 'silver armor'}, # male
10+
{'name': 'T恤衫(T-shirt)', 'prompt': 'T-shirt'}, # male and female
11+
{'name': '汉服风(hanfu)', 'prompt': 'beautiful traditional hanfu, upper_body'}, # female
12+
{'name': '女士晚礼服(gown)', 'prompt': 'an elegant evening gown'}, # female
1313
]
1414

1515
styles = [
@@ -19,6 +19,7 @@
1919
'revision': 'v1.0.0',
2020
'bin_file': 'xiapei.safetensors',
2121
'multiplier_style': 0.35,
22+
'cloth_name': '汉服风(hanfu)',
2223
'add_prompt_style': 'red, hanfu, tiara, crown, '},
2324
]
2425

0 commit comments

Comments
 (0)