Skip to content

Commit f551922

Browse files
webdemo & server related (#8)
* Update utils.py * update This reverts commit d36c96b. * add lora load support * update with simple gui * Update infer.py * Update infer.py * only remain lora loading without name * remove web * remove v2v * add setting * Update 02-API.md
1 parent 59c72c9 commit f551922

File tree

16 files changed

+345
-44
lines changed

16 files changed

+345
-44
lines changed

.env.template

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1-
COGVIEW4_PATH=THUDM/CogView4-6B
1+
COGVIEW4_PATH=/share/official_pretrains/hf_home/CogView4-6B
2+
DTYPE=bfloat16
3+
OFFLOAD_TYPE=no_offload
4+
OPENAI_API_KEY=
5+
OPENAI_BASE_URL=
6+
LORA_DIR=

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Contribution Guide
2+
3+
We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.
4+
5+
## What We Accept
6+
7+
+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
8+
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
9+
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.
10+
11+
## Code Style Guide
12+
13+
Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
14+
15+
1. Install the required dependencies:
16+
```shell
17+
pip install ruff pre-commit
18+
```
19+
2. Then, run the following command:
20+
```shell
21+
pre-commit run --all-files
22+
```
23+
If your code complies with the standards, you should not see any errors.
24+
25+
## Naming Conventions
26+
27+
- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
28+
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,4 @@ tmp/
263263

264264
webdoc/
265265
**/wandb/
266+
test*

docs/03-Inference/02-API.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ image = generate_image(
1818
model_id_or_path="THUDM/CogView4-6B",
1919
lora_model_id_or_path=None,
2020
transformer_path=None,
21+
output_file="sunset.png", # Images will be saved here.
2122
height=1024,
2223
width=1024,
2324
)
24-
image.save("sunset.png")
25+
2526

2627
# Text/Image-to-Video generation
2728
video = generate_video(
@@ -30,10 +31,11 @@ video = generate_video(
3031
model_id_or_path="THUDM/CogVideoX1.5-5B",
3132
lora_model_id_or_path=None,
3233
transformer_path=None,
34+
output_file="cat.mp4", # Videos will be saved here.
3335
num_frames=81,
3436
fps=16,
3537
)
36-
video.save("cat_video.mp4")
38+
3739
```
3840

3941
See function signatures in for more details.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ dependencies = [
2727
torch = ["numpy", "torch", "torchvision"]
2828
api = [
2929
"fastapi[standard]~=0.115.11",
30+
"fastapi_cli~=0.0.7",
31+
"pydantic_settings~=2.8.1",
3032
"openai~=1.67",
3133
"pydantic-settings~=2.8",
3234
"python-dotenv~=1.0",

src/cogkit/api/models/images/generation_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ class ImageGenerationParams(RequestParams):
1313
size: Literal[
1414
"1024x1024", "768x1344", "864x1152", "1344x768", "1152x864", "1440x720", "720x1440"
1515
] = "1024x1024"
16+
num_inference_steps: int = 50
17+
guidance_scale: float = 3.5
18+
lora_path: str | None = None
1619
user: str | None = None
1720
# ! unsupported parameters
1821
# quality: Literal["standard", "hd"] = "standard"

src/cogkit/api/routers/images.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@ def generations(
3535
status_code=HTTPStatus.NOT_FOUND,
3636
detail=f"The model `{params.model}` does not exist. Supported models: {image_generation.supported_models}",
3737
)
38-
# TODO: add exception handling
3938
image_lst = image_generation.generate(
40-
model=params.model, prompt=params.prompt, size=params.size, num_images=params.n
39+
model=params.model,
40+
prompt=params.prompt,
41+
size=params.size,
42+
num_images=params.n,
43+
num_inference_steps=params.num_inference_steps,
44+
guidance_scale=params.guidance_scale,
45+
lora_path=params.lora_path,
4146
)
4247
image_b64_lst = [ImageInResponse(b64_json=np_to_base64(image)) for image in image_lst]
4348
return ImagesResponse(created=int(time.time()), data=image_b64_lst)

src/cogkit/api/services/image_generation.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33

44
import numpy as np
5+
import os
6+
7+
import torch
58
from diffusers import CogView4Pipeline
69

710
from cogkit.api.logging import get_logger
@@ -14,8 +17,14 @@ class ImageGenerationService(object):
1417
def __init__(self, settings: APISettings) -> None:
1518
self._models = {}
1619
if settings.cogview4_path is not None:
17-
cogview4_pl = CogView4Pipeline.from_pretrained(settings.cogview4_path)
18-
cogview4_pl.enable_model_cpu_offload()
20+
cogview4_pl = CogView4Pipeline.from_pretrained(
21+
settings.cogview4_path,
22+
torch_dtype=torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32,
23+
)
24+
if settings.offload_type == "cpu_model_offolad":
25+
cogview4_pl.enable_model_cpu_offload()
26+
else:
27+
cogview4_pl.to("cuda")
1928
cogview4_pl.vae.enable_slicing()
2029
cogview4_pl.vae.enable_tiling()
2130
self._models["cogview-4"] = cogview4_pl
@@ -36,18 +45,33 @@ def __init__(self, settings: APISettings) -> None:
3645
def supported_models(self) -> list[str]:
3746
return list(self._models.keys())
3847

39-
def generate(self, model: str, prompt: str, size: str, num_images: int) -> list[np.ndarray]:
48+
def generate(
49+
self,
50+
model: str,
51+
prompt: str,
52+
size: str,
53+
num_images: int,
54+
num_inference_steps: int = 50,
55+
guidance_scale: float = 3.5,
56+
lora_path: str | None = None,
57+
) -> list[np.ndarray]:
4058
if model not in self._models:
4159
raise ValueError(f"Model {model} not loaded")
4260
width, height = list(map(int, size.split("x")))
61+
if lora_path is not None:
62+
adapter_name = os.path.basename(lora_path)
63+
print(f"Loaded LORA weights from {adapter_name}")
64+
self._models[model].load_lora_weights(lora_path)
65+
else:
66+
print("Unloading LORA weights")
67+
self._models[model].unload_lora_weights()
4368

44-
# shape of image_np: (n, h, w, c)
4569
image_np = self._models[model](
4670
prompt=prompt,
4771
height=height,
4872
width=width,
49-
num_inference_steps=50,
50-
guidance_scale=3.5,
73+
num_inference_steps=num_inference_steps,
74+
guidance_scale=guidance_scale,
5175
num_images_per_prompt=num_images,
5276
output_type="np",
5377
).images

src/cogkit/api/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
from pydantic_settings import BaseSettings, SettingsConfigDict
5+
from typing import Literal
56

67

78
class APISettings(BaseSettings):
@@ -10,3 +11,6 @@ class APISettings(BaseSettings):
1011
)
1112
_supported_models: tuple[str, ...] = ("cogview-4",)
1213
cogview4_path: str | None = None
14+
dtype: Literal["bfloat16", "float32"] = "bfloat16"
15+
offload_type: Literal["cpu_model_offolad", "no_offload"] = "no_offload"
16+
openai_api_key: str | None = None

src/cogkit/cli/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def inference(
133133
output_file or "output.png",
134134
dtype=dtype,
135135
transformer_path=transformer_path,
136+
lora_model_id_or_path=lora_model_id_or_path,
136137
height=height,
137138
width=width,
138139
seed=seed,

0 commit comments

Comments
 (0)