Skip to content

Commit e267385

Browse files
OleehyOzRzRzRzRzRzRzRzhangch9
authored
[refactor] Restructure generation services and etc. (#9)
* Update utils.py * update This reverts commit d36c96b. * add lora load support * update with simple gui * Update infer.py * Update infer.py * only remain lora loading without name * remove web * remove v2v * [refactor] Restructure generation services and etc. - Move generation directory to `api/python` for better code organization - Refactor image generation service to reuse Python API - Clean up .env.template by removing unused properties and fixing typos - Refactor generation_mode to fetch inference tasks directly from pipeline - Add unload_lora_checkpoint logic - Add offload mode selection in API and `before_generation` - Add docs for python API function signature - Fix lora merge logic * [fix] Fix LoRA weights merging process * [refactor] Remove lora parameter from `generate_image` and `generate_video`, add lora operations to Python API - Modified `src/cogkit/__init__.py` to updated Python API - Updated `src/cogkit/api/python/generation/(image,video).py` to remove lora param and integrate lora operations - Adjusted `src/cogkit/cli/inference.py` to align with API changes * nit: renames some variables --------- Co-authored-by: Yuxuan Zhang <[email protected]> Co-authored-by: Chenhui Zhang <[email protected]>
1 parent f551922 commit e267385

File tree

20 files changed

+412
-337
lines changed

20 files changed

+412
-337
lines changed

.env.template

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
COGVIEW4_PATH=/share/official_pretrains/hf_home/CogView4-6B
1+
### basic configs
22
DTYPE=bfloat16
3-
OFFLOAD_TYPE=no_offload
4-
OPENAI_API_KEY=
5-
OPENAI_BASE_URL=
6-
LORA_DIR=
3+
OFFLOAD_TYPE=cpu_model_offload
4+
5+
### cogview4 related configs
6+
COGVIEW4_PATH=THUDM/CogView4-6B
7+
# Optional, only needed when you don't want to use the default transformer in COGVIEW4_PATH
8+
# COGVIEW4_TRANSFORMER_PATH=

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ torch = ["numpy", "torch", "torchvision"]
2828
api = [
2929
"fastapi[standard]~=0.115.11",
3030
"fastapi_cli~=0.0.7",
31-
"pydantic_settings~=2.8.1",
3231
"openai~=1.67",
33-
"pydantic-settings~=2.8",
32+
"pydantic_settings~=2.8.1",
3433
"python-dotenv~=1.0",
3534
]
3635

src/cogkit/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,13 @@
11
# -*- coding: utf-8 -*-
2+
3+
4+
from cogkit.api.python import generate_image, generate_video
5+
from cogkit.utils import load_lora_checkpoint, load_pipeline, unload_lora_checkpoint
6+
7+
__all__ = [
8+
"generate_image",
9+
"generate_video",
10+
"load_pipeline",
11+
"load_lora_checkpoint",
12+
"unload_lora_checkpoint",
13+
]

src/cogkit/api/python/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
from .generation.image import generate_image
5+
from .generation.video import generate_video
6+
from .generation.util import before_generation
7+
8+
__all__ = ["generate_image", "generate_video", "before_generation"]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
from typing import Literal
5+
6+
import numpy as np
7+
import torch
8+
from PIL import Image
9+
10+
from cogkit.logging import get_logger
11+
from cogkit.utils import (
12+
rand_generator,
13+
)
14+
from diffusers import DiffusionPipeline
15+
16+
from .util import before_generation, guess_resolution
17+
18+
_logger = get_logger(__name__)
19+
20+
21+
def generate_image(
22+
prompt: str,
23+
pipeline: DiffusionPipeline,
24+
num_images_per_prompt: int = 1,
25+
output_type: Literal["pil", "pt", "np"] = "pil",
26+
load_type: Literal["cuda", "cpu_model_offload", "sequential_cpu_offload"] = "cpu_model_offload",
27+
height: int | None = None,
28+
width: int | None = None,
29+
num_inference_steps: int = 50,
30+
guidance_scale: float = 3.5,
31+
seed: int | None = None,
32+
) -> list[Image.Image] | torch.Tensor | np.ndarray:
33+
"""Generates images from a text prompt using a diffusion model pipeline.
34+
35+
This function leverages a diffusion pipeline to create images based on a given text prompt. It supports
36+
customization of image dimensions, inference steps, and guidance scale, as well as optional LoRA (Low-Rank
37+
Adaptation) fine-tuning. The output can be returned in different formats: PIL images, PyTorch tensors, or
38+
NumPy arrays.
39+
40+
Args:
41+
- prompt: The text description used to guide the image generation process.
42+
- pipeline: Preloaded DiffusionPipeline instance.
43+
- num_images_per_prompt: Number of images to generate per prompt. Defaults to 1.
44+
- output_type: Format of the output images. Options are "pil" (PIL.Image), "pt" (PyTorch tensor), or
45+
"np" (NumPy array). Defaults to "pil".
46+
- load_type: Type of offloading to use for the model, use "cuda" if you have enough GPU memory. Defaults to "cpu_model_offload".
47+
- height: Desired height of the output images in pixels. If None, inferred from the pipeline.
48+
- width: Desired width of the output images in pixels. If None, inferred from the pipeline.
49+
- num_inference_steps: Number of denoising steps during generation. Defaults to 50.
50+
- guidance_scale: Strength of the prompt guidance (classifier-free guidance scale). Defaults to 3.5.
51+
- seed: Optional random seed for reproducible results. Defaults to None.
52+
53+
Returns:
54+
A list of generated images in the specified format:
55+
- If output_type is "pil": List of PIL.Image.Image objects.
56+
- If output_type is "pt": PyTorch tensor of shape (num_images, 3, height, width) with dtype torch.uint8.
57+
- If output_type is "np": NumPy array of shape (num_images, height, width, 3) with dtype uint8.
58+
"""
59+
60+
height, width = guess_resolution(pipeline, height, width)
61+
62+
_logger.info(f"Generation config: height {height}, width {width}.")
63+
64+
before_generation(pipeline, load_type)
65+
66+
output = pipeline(
67+
prompt=prompt,
68+
height=height,
69+
width=width,
70+
num_inference_steps=num_inference_steps,
71+
guidance_scale=guidance_scale,
72+
num_images_per_prompt=num_images_per_prompt,
73+
generator=rand_generator(seed),
74+
output_type=output_type,
75+
).images
76+
77+
if output_type != "pil":
78+
output = (output * 255).round()
79+
if output_type == "pt":
80+
assert output.ndim == 4, f"Expected 4D numpy array, got {output.ndim}D array"
81+
assert output.shape[1] == 3, f"Expected 3 channels, got {output.shape[3]} channels"
82+
output = output.to(torch.uint8)
83+
elif output_type == "np":
84+
assert output.ndim == 4, f"Expected 4D torch tensor, got {output.ndim}D torch tensor"
85+
# Dim of image_np: (num_images, height, width, 3)
86+
assert output.shape[3] == 3, f"Expected 3 channels, got {output.shape[3]} channels"
87+
output = output.astype("uint8")
88+
89+
return output
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3+
from typing import Literal
34

45
from diffusers import (
56
CogVideoXDPMScheduler,
@@ -132,18 +133,24 @@ def guess_frames(pipeline: TVideoPipeline, frames: int | None = None) -> tuple[i
132133
return frames, fps
133134

134135

135-
def before_generation(pipeline: TPipeline) -> None:
136+
def before_generation(
137+
pipeline: TPipeline,
138+
load_type: Literal["cuda", "cpu_model_offload", "sequential_cpu_offload"] = "cpu_model_offload",
139+
) -> None:
136140
if isinstance(pipeline, TVideoPipeline):
137141
pipeline.scheduler = CogVideoXDPMScheduler.from_config(
138142
pipeline.scheduler.config, timestep_spacing="trailing"
139143
)
140144

141-
# * enables CPU offload for the model.
142-
# turns off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
143-
# and enable to("cuda")
144-
pipeline.to("cuda")
145-
# pipeline.enable_model_cpu_offload()
146-
# pipe.enable_sequential_cpu_offload()
145+
# turns off offload if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
146+
if load_type == "cuda":
147+
pipeline.to("cuda")
148+
elif load_type == "cpu_model_offload":
149+
pipeline.enable_model_cpu_offload()
150+
elif load_type == "sequential_cpu_offload":
151+
pipeline.enable_sequential_cpu_offload()
152+
else:
153+
raise ValueError(f"Unsupported offload type: {load_type}")
147154
if hasattr(pipeline, "vae"):
148155
pipeline.vae.enable_slicing()
149156
pipeline.vae.enable_tiling()
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
from functools import partial
5+
from typing import Any, List, Literal
6+
7+
import numpy as np
8+
import torch
9+
from diffusers import DiffusionPipeline
10+
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
11+
from PIL import Image
12+
13+
from cogkit.logging import get_logger
14+
from cogkit.types import GenerationMode
15+
from cogkit.utils import (
16+
guess_generation_mode,
17+
rand_generator,
18+
)
19+
20+
from .util import before_generation, guess_frames, guess_resolution
21+
22+
_logger = get_logger(__name__)
23+
24+
25+
def _cast_to_pipeline_output(output: Any) -> CogVideoXPipelineOutput:
26+
if isinstance(output, CogVideoXPipelineOutput):
27+
return output
28+
if isinstance(output, tuple):
29+
return CogVideoXPipelineOutput(frames=output[0])
30+
31+
err_msg = f"Cannot cast a `{output.__class__.__name__}` to a `CogVideoXPipelineOutput`."
32+
raise ValueError(err_msg)
33+
34+
35+
def generate_video(
36+
prompt: str,
37+
pipeline: DiffusionPipeline,
38+
num_videos_per_prompt: int = 1,
39+
output_type: Literal["pil", "pt", "np"] = "pil",
40+
input_image: Image.Image | None = None,
41+
# * params for model loading
42+
load_type: Literal["cuda", "cpu_model_offload", "sequential_cpu_offload"] = "cpu_model_offload",
43+
height: int | None = None,
44+
width: int | None = None,
45+
num_frames: int | None = None,
46+
num_inference_steps: int = 50,
47+
guidance_scale: float = 6.0,
48+
seed: int | None = 42,
49+
) -> tuple[List[Image.Image] | torch.Tensor | np.ndarray, int]:
50+
"""Main function for video generation, supporting both text-to-video and image-to-video generation modes.
51+
52+
Args:
53+
- prompt (str): Text prompt describing the desired video content.
54+
- pipeline (DiffusionPipeline): Pre-loaded diffusion model pipeline.
55+
- num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1.
56+
- output_type (Literal, optional): Output type, one of "pil", "pt", or "np". Defaults to "pil".
57+
- input_image (Image.Image | None, optional): Input image for image-to-video generation. Defaults to None.
58+
- load_type (Literal, optional): Model loading type, one of "cuda", "cpu_model_offload", or
59+
"sequential_cpu_offload". Defaults to "cpu_model_offload".
60+
- height (int | None, optional): Height of output video. If None, will be inferred. Defaults to None.
61+
- width (int | None, optional): Width of output video. If None, will be inferred. Defaults to None.
62+
- num_frames (int | None, optional): Number of frames in generated video. If None, will be inferred.
63+
Defaults to None.
64+
- num_inference_steps (int, optional): Number of inference steps. Defaults to 50.
65+
- guidance_scale (float, optional): Classifier guidance scale. Defaults to 6.0.
66+
- seed (int | None, optional): Random seed for generation. Defaults to 42.
67+
68+
Returns:
69+
tuple[torch.Tensor, int]: Returns a tuple containing:
70+
- Generated video tensor with shape (num_videos, num_frames, height, width, 3)
71+
- Video frame rate (fps)
72+
73+
Raises:
74+
ValueError: When provided generation mode is unknown or output cannot be cast to CogVideoXPipelineOutput.
75+
AssertionError: When both pipeline and model_id_or_path are None or both are provided.
76+
77+
Note:
78+
- Either pipeline or model_id_or_path must be provided, but not both.
79+
- If lora_model_id_or_path is provided, LoRA weights will be loaded and applied.
80+
- Height, width, number of frames, and fps will be automatically inferred if not specified.
81+
"""
82+
83+
task = guess_generation_mode(
84+
pipeline=pipeline,
85+
generation_mode=None,
86+
image=input_image,
87+
)
88+
89+
height, width = guess_resolution(pipeline, height, width)
90+
num_frames, fps = guess_frames(pipeline, num_frames)
91+
92+
_logger.info(
93+
f"Generation config: height {height}, width {width}, num_frames {num_frames}, fps {fps}."
94+
)
95+
96+
before_generation(pipeline, load_type)
97+
98+
pipeline_fn = partial(
99+
pipeline,
100+
height=height,
101+
width=width,
102+
prompt=prompt,
103+
num_videos_per_prompt=num_videos_per_prompt,
104+
num_inference_steps=num_inference_steps,
105+
num_frames=num_frames,
106+
use_dynamic_cfg=True,
107+
guidance_scale=guidance_scale,
108+
output_type=output_type,
109+
generator=rand_generator(seed),
110+
)
111+
if task == GenerationMode.TextToVideo:
112+
pipeline_out = pipeline_fn()
113+
elif task == GenerationMode.ImageToVideo:
114+
pipeline_out = pipeline_fn(image=input_image)
115+
else:
116+
err_msg = f"Unknown generation mode: {task.value}"
117+
raise ValueError(err_msg)
118+
119+
batch_video = _cast_to_pipeline_output(pipeline_out).frames
120+
121+
if output_type in ("pt", "np"):
122+
# Dim of a video: (num_videos, num_frames, 3, height, width)
123+
assert batch_video.ndim == 5, f"Expected 5D array, got {batch_video[0].ndim}D array"
124+
assert batch_video.shape[2] == 3, (
125+
f"Expected 3 channels, got {batch_video[0].shape[2]} channels"
126+
)
127+
return batch_video, fps

src/cogkit/api/services/image_generation.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import os
66

77
import torch
8-
from diffusers import CogView4Pipeline
98

109
from cogkit.api.logging import get_logger
1110
from cogkit.api.settings import APISettings
11+
from cogkit.api.python import before_generation, generate_image
12+
from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint, load_pipeline
1213

1314
_logger = get_logger(__name__)
1415

@@ -17,16 +18,13 @@ class ImageGenerationService(object):
1718
def __init__(self, settings: APISettings) -> None:
1819
self._models = {}
1920
if settings.cogview4_path is not None:
20-
cogview4_pl = CogView4Pipeline.from_pretrained(
21+
torch_dtype = torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32
22+
cogview4_pl = load_pipeline(
2123
settings.cogview4_path,
22-
torch_dtype=torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32,
24+
transformer_path=settings.cogview4_transformer_path,
25+
dtype=torch_dtype,
2326
)
24-
if settings.offload_type == "cpu_model_offolad":
25-
cogview4_pl.enable_model_cpu_offload()
26-
else:
27-
cogview4_pl.to("cuda")
28-
cogview4_pl.vae.enable_slicing()
29-
cogview4_pl.vae.enable_tiling()
27+
before_generation(cogview4_pl, settings.offload_type)
3028
self._models["cogview-4"] = cogview4_pl
3129

3230
### Check if loaded models are supported
@@ -58,33 +56,34 @@ def generate(
5856
if model not in self._models:
5957
raise ValueError(f"Model {model} not loaded")
6058
width, height = list(map(int, size.split("x")))
59+
60+
# TODO: Refactor this to switch by LoRA endpoint API
6161
if lora_path is not None:
6262
adapter_name = os.path.basename(lora_path)
63-
print(f"Loaded LORA weights from {adapter_name}")
64-
self._models[model].load_lora_weights(lora_path)
63+
_logger.info(f"Loaded LORA weights from {adapter_name}")
64+
load_lora_checkpoint(self._models[model], lora_path)
6565
else:
66-
print("Unloading LORA weights")
67-
self._models[model].unload_lora_weights()
66+
_logger.info("Unloading LORA weights")
67+
unload_lora_checkpoint(self._models[model])
6868

69-
image_np = self._models[model](
69+
output = generate_image(
7070
prompt=prompt,
71+
pipeline=self._models[model],
72+
num_images_per_prompt=num_images,
73+
output_type="np",
7174
height=height,
7275
width=width,
7376
num_inference_steps=num_inference_steps,
7477
guidance_scale=guidance_scale,
75-
num_images_per_prompt=num_images,
76-
output_type="np",
77-
).images
78-
assert image_np.ndim == 4, f"Expected 4D array, got {image_np.ndim}D array"
78+
)
7979

80-
image_lst = self.postprocess(image_np)
80+
image_lst = self.postprocess(output)
8181
return image_lst
8282

8383
def is_valid_model(self, model: str) -> bool:
8484
return model in self._models
8585

8686
def postprocess(self, image_np: np.ndarray) -> list[np.ndarray]:
87-
image_np = (image_np * 255).round().astype("uint8")
8887
image_lst = np.split(image_np, image_np.shape[0], axis=0)
8988
image_lst = [img.squeeze(0) for img in image_lst]
9089
return image_lst

0 commit comments

Comments
 (0)