Skip to content

Commit 846e9dc

Browse files
fix with main?
2 parents e267385 + 2829e72 commit 846e9dc

File tree

9 files changed

+216
-12
lines changed

9 files changed

+216
-12
lines changed

.env.template

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,12 @@ OFFLOAD_TYPE=cpu_model_offload
66
COGVIEW4_PATH=THUDM/CogView4-6B
77
# Optional, only needed when you don't want to use the default transformer in COGVIEW4_PATH
88
# COGVIEW4_TRANSFORMER_PATH=
9+
10+
### OPENAI API
11+
12+
OPENAI_API_KEY=
13+
OPENAI_BASE_URL=
14+
15+
### LORA PATH FOR GUI(test)
16+
17+
LORA_DIR=

src/cogkit/api/python/generation/util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
CogVideoXDPMScheduler,
77
CogVideoXImageToVideoPipeline,
88
CogVideoXPipeline,
9+
CogView4ControlPipeline,
910
CogView4Pipeline,
1011
)
1112

1213
TVideoPipeline = CogVideoXPipeline | CogVideoXImageToVideoPipeline
1314
TPipeline = CogView4Pipeline | TVideoPipeline
15+
CogviewPipline = CogView4Pipeline | CogView4ControlPipeline
1416

1517

1618
def _is_cogvideox1_0(pipeline: TVideoPipeline) -> bool:
@@ -103,7 +105,7 @@ def guess_resolution(
103105
height: int | None = None,
104106
width: int | None = None,
105107
) -> tuple[int, int]:
106-
if isinstance(pipeline, CogView4Pipeline):
108+
if isinstance(pipeline, CogviewPipline):
107109
return _guess_cogview_resolution(pipeline, height=height, width=width)
108110
if isinstance(pipeline, TVideoPipeline):
109111
return _guess_cogvideox_resolution(pipeline, height=height, width=width)

src/cogkit/api/services/image_generation.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import os
6-
76
import torch
87

98
from cogkit.api.logging import get_logger
@@ -27,14 +26,12 @@ def __init__(self, settings: APISettings) -> None:
2726
before_generation(cogview4_pl, settings.offload_type)
2827
self._models["cogview-4"] = cogview4_pl
2928

30-
### Check if loaded models are supported
3129
for model in self._models.keys():
3230
if model not in settings._supported_models:
3331
raise ValueError(
3432
f"Registered model {model} not in supported list: {settings._supported_models}"
3533
)
3634

37-
### Check if all supported models are loaded
3835
for model in settings._supported_models:
3936
if model not in self._models:
4037
_logger.warning(f"Model {model} not loaded")
@@ -57,7 +54,6 @@ def generate(
5754
raise ValueError(f"Model {model} not loaded")
5855
width, height = list(map(int, size.split("x")))
5956

60-
# TODO: Refactor this to switch by LoRA endpoint API
6157
if lora_path is not None:
6258
adapter_name = os.path.basename(lora_path)
6359
_logger.info(f"Loaded LORA weights from {adapter_name}")
@@ -68,13 +64,12 @@ def generate(
6864

6965
output = generate_image(
7066
prompt=prompt,
71-
pipeline=self._models[model],
72-
num_images_per_prompt=num_images,
73-
output_type="np",
7467
height=height,
7568
width=width,
7669
num_inference_steps=num_inference_steps,
7770
guidance_scale=guidance_scale,
71+
num_images_per_prompt=num_images,
72+
output_type="np",
7873
)
7974

8075
image_lst = self.postprocess(output)
@@ -84,6 +79,7 @@ def is_valid_model(self, model: str) -> bool:
8479
return model in self._models
8580

8681
def postprocess(self, image_np: np.ndarray) -> list[np.ndarray]:
82+
image_np = (image_np * 255).round().astype("uint8")
8783
image_lst = np.split(image_np, image_np.shape[0], axis=0)
8884
image_lst = [img.squeeze(0) for img in image_lst]
8985
return image_lst

src/cogkit/api/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ class APISettings(BaseSettings):
1919
# cogview-4 related settings
2020
cogview4_path: str | None = None
2121
cogview4_transformer_path: str | None = None
22+
openai_api_key: str | None = None
23+
lora_dir: str | None = None

src/cogkit/cli/inference.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,18 @@
3131
type=click.Path(dir_okay=False, writable=True),
3232
help="the path to save the generated image/video. If not provided, the generated image/video will be saved to 'output.png/mp4'.",
3333
)
34+
@click.option(
35+
"--task",
36+
type=click.Choice(
37+
choices=[mode.value for mode in GenerationMode],
38+
case_sensitive=False,
39+
),
40+
help="the generation task",
41+
)
3442
@click.option(
3543
"--image_file",
3644
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
37-
help="the image to guide the image/video generation (for i2i/i2v generation task)",
45+
help="the image to guide the video generation (for i2v or ct2i generation task)",
3846
)
3947
@click.option(
4048
"--dtype",
@@ -150,7 +158,10 @@ def inference(
150158
_logger.info("Saving the generated video to path '%s'.", os.fspath(output_file))
151159
export_to_video(output[0], output_file, fps=fps)
152160

153-
elif task in (GenerationMode.TextToImage,):
161+
elif task in (
162+
GenerationMode.TextToImage,
163+
GenerationMode.CtrlTextToImage,
164+
):
154165
batched_images = generate_image(
155166
prompt=prompt,
156167
pipeline=pipeline,

src/cogkit/types/generation_mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ class GenerationMode(enum.Enum):
88
TextToVideo = "t2v"
99
ImageToVideo = "i2v"
1010
TextToImage = "t2i"
11+
CtrlTextToImage = "ct2i"

src/cogkit/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from cogkit.utils.lora import load_lora_checkpoint, unload_lora_checkpoint
77
from cogkit.utils.misc import guess_generation_mode
88
from cogkit.utils.path import mkdir, resolve_path
9+
from cogkit.utils.prompt import convert_prompt
910
from cogkit.utils.random import rand_generator
1011
from cogkit.utils.load import load_pipeline
1112

@@ -19,4 +20,5 @@
1920
"resolve_path",
2021
"rand_generator",
2122
"load_pipeline",
23+
"convert_prompt",
2224
]

src/cogkit/utils/misc.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from diffusers import DiffusionPipeline
55
from PIL import Image
6-
6+
from pathlib import Path
77
from cogkit.logging import get_logger
88
from cogkit.types import GenerationMode
99

@@ -14,6 +14,7 @@
1414
"CogView4Pipeline",
1515
"CogVideoXPipeline",
1616
"CogVideoXImageToVideoPipeline",
17+
"CogView4ControlPipeline",
1718
)
1819

1920

@@ -36,6 +37,23 @@ def _check_text_to_image_params(
3637
)
3738

3839

40+
def _check_control_text_to_image_params(
41+
pl_cls_name: str,
42+
generation_mode: GenerationMode | None,
43+
image: str | Path | None,
44+
) -> None:
45+
if generation_mode is not None and generation_mode != GenerationMode.CtrlTextToImage:
46+
_logger.warning(
47+
"The pipeline `%s` does not support `%s` task. Will try the `%s` task.",
48+
pl_cls_name,
49+
generation_mode.value,
50+
GenerationMode.CtrlTextToImage,
51+
)
52+
if image is not None:
53+
err_msg = f"Image input is required in the image2video pipeline. Please provide a regular image file (image_file = {image})."
54+
raise ValueError(err_msg)
55+
56+
3957
def _check_image_to_video_params(
4058
pl_cls_name: str,
4159
generation_mode: GenerationMode | None,
@@ -66,7 +84,7 @@ def guess_generation_mode(
6684
if generation_mode is not None:
6785
generation_mode = GenerationMode(generation_mode)
6886

69-
if pl_cls_name.startswith("CogView"):
87+
if pl_cls_name == "CogView4Pipeline":
7088
# TextToImage
7189
_check_text_to_image_params(pl_cls_name, generation_mode, image)
7290
return GenerationMode.TextToImage
@@ -75,6 +93,11 @@ def guess_generation_mode(
7593
_check_image_to_video_params(pl_cls_name, generation_mode, image)
7694
return GenerationMode.ImageToVideo
7795

96+
if pl_cls_name == "CogView4ControlPipeline":
97+
# Control TextToImage
98+
_check_control_text_to_image_params(pl_cls_name, generation_mode, image)
99+
return GenerationMode.CtrlTextToImage
100+
78101
if image is not None:
79102
_logger.warning(
80103
"Pipeline `%s` does not support image input. Will ignore the image file.",

src/cogkit/utils/prompt.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import os
2+
import re
3+
4+
from cogkit.logging import get_logger
5+
from cogkit.types import GenerationMode
6+
from dotenv import find_dotenv, load_dotenv
7+
from openai import OpenAI
8+
9+
load_dotenv(find_dotenv(".env.template"))
10+
11+
_logger = get_logger(__name__)
12+
13+
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
14+
15+
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
16+
There are a few rules to follow:
17+
18+
You will only ever output a single video description per user request.
19+
20+
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
21+
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
22+
23+
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
24+
"""
25+
26+
27+
def _enhance_video_messages(prompt: str) -> list:
28+
fixed_messages = [
29+
{"role": "system", "content": sys_prompt},
30+
{
31+
"role": "user",
32+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
33+
},
34+
{
35+
"role": "assistant",
36+
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
37+
},
38+
{
39+
"role": "user",
40+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
41+
},
42+
{
43+
"role": "assistant",
44+
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
45+
},
46+
{
47+
"role": "user",
48+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
49+
},
50+
{
51+
"role": "assistant",
52+
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
53+
},
54+
{
55+
"role": "user",
56+
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
57+
},
58+
]
59+
60+
return fixed_messages
61+
62+
63+
def _enhance_control_image_messages(prompt: str) -> list:
64+
fixed_messages = [
65+
{
66+
"role": "system",
67+
"content": 'You are a bilingual image description assistant that works with an image generation bot. You work with an assistant bot that will draw anything you say . \n For example ,For example, outputting "a beautiful morning in the woods with the sun peaking through the trees" or "阳光透过树林的美丽清晨" will trigger your partner bot to output an image of a forest morning, as described . \n You will be prompted by people looking to create detailed , amazing images . The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive . \n There are a few rules to follow : \n - Input can be in Chinese or English. If input is in English, prompt should be written in English. If input is in Chinese, prompt should be written in Chinese.\n - You will only ever output a single image description per user request .\n - Image descriptions must be detailed and specific, including keyword categories such as subject, medium, style, additional details, color, and lighting. \n - When generating descriptions, focus on portraying the visual elements rather than delving into abstract psychological and emotional aspects. Provide clear and concise details that vividly depict the scene and its composition, capturing the tangible elements that make up the setting.\n - Do not provide the process and explanation, just return the modified description . \n ',
68+
},
69+
{
70+
"role": "user",
71+
"content": 'Create an imaginative image descriptive caption for the user input : "An anime girl stands amidst a dense flower bush."',
72+
},
73+
{
74+
"role": "assistant",
75+
"content": "This image is a beautifully crafted digital illustration in an anime style. It features Suzuran, a character from the video game Arknights. She stands gracefully amidst a picturesque meadow with lush green grass and scattered wildflowers that gently sway in the breeze. Her attire includes a detailed outfit with layered ruffles and intricate fastenings, reflecting both elegance and functionality typical of the game's aesthetic. Suzuran holds a small bouquet of flowers delicately in her hands, adding to the serene atmosphere.\n\nProminent fox ears extend from her head, signifying her animal spirit form or persona within the game. The background showcases rolling hills covered in dense foliage under a brilliant blue sky dotted with fluffy white clouds. Golden leaves float whimsically through the air, enhancing the magical quality of the scene. A distant building partially obscured by vegetation hints at human presence without detracting from the natural beauty surrounding Suzuran. Overall, this artwork captures a moment of tranquility and charm, embodying the essence of the character and the enchanting world she inhabits.",
76+
},
77+
{
78+
"role": "user",
79+
"content": 'Create an imaginative image descriptive caption for the user input : "Draw a bright convertible car with a sense of artistic design."',
80+
},
81+
{
82+
"role": "assistant",
83+
"content": 'The image showcases a meticulously crafted Delahaye 135 M.S. Special Roadster from 1937 in a highly polished and realistic rendering style that highlights its luxurious design and impeccable details. The car\'s body is an elegant deep brown with a glossy finish, exuding sophistication and timeless beauty. Its aerodynamic, streamlined shape features smooth curves accentuated by chrome detailing on the fenders and running boards. The front grille has three prominent circular headlights, adding to its classic charm while the dual exhaust pipes are tastefully integrated into the rear fenders.\n\nThe open cockpit area reveals sumptuous tan leather seats, emphasizing both comfort and elegance for the driver and passenger. The vehicle’s number plate reads "MB 175". Photographed against a dark gradient background, the focus remains solely on this automotive masterpiece, allowing viewers to appreciate every intricate detail without distraction. This image captures not just a car but a piece of art, reflecting the luxury and craftsmanship emblematic of pre-war European automobiles.',
84+
},
85+
{
86+
"role": "user",
87+
"content": 'Create an imaginative image descriptive caption for the user input : "画一个白发、神情忧郁的动漫女孩,手里拿着一支香烟"',
88+
},
89+
{
90+
"role": "assistant",
91+
"content": "这幅图像是一幅动漫风格的插画,画中描绘了一位长发飘逸的白发女孩。她神情忧郁,双眼低垂,脸上带着微微的愁容。女孩穿着浅色外套,里面搭配深色衬衫和领带,增添了她沉静却时尚的外表。她的手靠近嘴边,似乎在拿着一支香烟,画作的说明也表示她在吸烟。背景描绘了一个冬季的城市场景,地面和建筑物上覆盖着积雪。几株稀疏的植物从雪地中探出,街边停着一辆车,似乎是空荡荡的街道,增强了场景的荒凉氛围。这幅作品整体风格细致精美,典型的现代动漫插画风格,既捕捉了角色的情感表达,也呈现了周围环境的静谧氛围",
92+
},
93+
{
94+
"role": "user",
95+
"content": 'Create an imaginative image descriptive caption for the user input : "一张红色的海报,中间写有“开门大吉”"',
96+
},
97+
{
98+
"role": "assistant",
99+
"content": "这张图片采用了海报风格,色彩鲜艳,主要以红色和金色为主,寓意吉祥如意。在画面的正中间是一块红色的大匾,上面用白色大字写着“开门大吉”,四角点缀着金色的装饰图案,显得格外喜庆。匾额上方悬挂着一对红色的灯笼,增添了节日的气氛。背景左右两侧是传统的绿色中式建筑,屋顶呈现出典型的飞檐翘角设计。底部有祥云朵朵和可爱的卡通福袋,象征着好运福气满满。整张海报传达出浓厚的节日氛围。",
100+
},
101+
{
102+
"role": "user",
103+
"content": f"Create an imaginative image descriptive caption for the user input : {prompt}",
104+
},
105+
]
106+
return fixed_messages
107+
108+
109+
def clean_string(s):
110+
s = s.replace("\n", " ")
111+
s = s.strip()
112+
s = re.sub(r"\s{2,}", " ", s)
113+
return s
114+
115+
116+
def convert_prompt(
117+
prompt: str,
118+
task: GenerationMode,
119+
retry_times: int = 5,
120+
) -> str:
121+
if not os.environ.get("OPENAI_API_KEY"):
122+
_logger.warning("OPENAI_API_KEY not found, the original prompt will be used directly")
123+
return prompt
124+
client = OpenAI(
125+
api_key=os.environ.get("OPENAI_API_KEY"),
126+
base_url=os.environ.get("OPENAI_BASE_URL"),
127+
)
128+
messages = []
129+
prompt = clean_string(prompt)
130+
131+
if task in (
132+
GenerationMode.TextToVideo,
133+
GenerationMode.ImageToVideo,
134+
):
135+
messages = _enhance_video_messages(prompt)
136+
if task == GenerationMode.CtrlTextToImage:
137+
messages = _enhance_control_image_messages(prompt)
138+
139+
for i in range(retry_times):
140+
try:
141+
_logger.info("Strengthening prompt ...")
142+
response = client.chat.completions.create(
143+
messages=messages,
144+
model="glm-4-plus",
145+
temperature=0.01,
146+
top_p=0.7,
147+
stream=False,
148+
max_tokens=300,
149+
)
150+
prompt = response.choices[0].message.content
151+
if prompt:
152+
prompt = clean_string(prompt)
153+
break
154+
_logger.info("Use enhanced prompt")
155+
except Exception as e:
156+
_logger.warning("openai api ", e)
157+
158+
return prompt

0 commit comments

Comments
 (0)